File tree Expand file tree Collapse file tree 8 files changed +31
-5
lines changed
Expand file tree Collapse file tree 8 files changed +31
-5
lines changed Original file line number Diff line number Diff line change @@ -154,6 +154,10 @@ impl Env for GridWorld {
154154 }
155155 }
156156
157+ fn success ( & self ) -> bool {
158+ self . at_goal ( )
159+ }
160+
157161 fn observe ( & self ) -> Vec < usize > {
158162 self . get_state ( ) . iter ( ) . enumerate ( ) . map ( |( i, v) | i * self . height * self . width + v) . collect ( )
159163 }
Original file line number Diff line number Diff line change 3535 },
3636 {
3737 "cell_type" : " code" ,
38- "execution_count" : 3 ,
38+ "execution_count" : null ,
3939 "metadata" : {},
4040 "outputs" : [],
4141 "source" : [
4242 " import numpy as np\n " ,
4343 " \n " ,
44+ " \n " ,
4445 " def display(state):\n " ,
4546 " padding = len(str(max(state)))\n " ,
4647 " txt = \"\"\n " ,
4748 " for row in np.array(state).reshape([config[\" env\" ][\" width\" ], config[\" env\" ][\" height\" ]]):\n " ,
48- " txt += \" |\" + \" |\" .join((\" {:\" + str(padding)+ \" d}\" ).format(num) if num > 0 else \" \" * padding for num in row) + \" |\\ n\"\n " ,
49+ " txt += \" |\" + \" |\" .join((\" {:\" + str(padding) + \" d}\" ).format(num) if num > 0 else \" \" * padding for num in row) + \" |\\ n\"\n " ,
4950 " print(txt)"
5051 ]
5152 },
121122 },
122123 {
123124 "cell_type" : " code" ,
124- "execution_count" : 6 ,
125+ "execution_count" : null ,
125126 "metadata" : {},
126127 "outputs" : [
127128 {
Original file line number Diff line number Diff line change @@ -156,6 +156,7 @@ mod tests {
156156 fn is_final ( & self ) -> bool { self . step >= 1 }
157157 fn reward ( & self ) -> f32 { 1.0 }
158158 fn observe ( & self ) -> Vec < usize > { vec ! [ 0 ] }
159+ fn success ( & self ) -> bool { true }
159160 }
160161
161162 fn dummy_policy ( ) -> Policy {
Original file line number Diff line number Diff line change @@ -152,6 +152,7 @@ mod tests {
152152 fn is_final ( & self ) -> bool { self . step >= 1 }
153153 fn reward ( & self ) -> f32 { 1.0 }
154154 fn observe ( & self ) -> Vec < usize > { vec ! [ 0 ] }
155+ fn success ( & self ) -> bool { true }
155156 }
156157
157158 fn dummy_policy ( ) -> Policy {
Original file line number Diff line number Diff line change @@ -176,6 +176,10 @@ impl Env for Puzzle {
176176 }
177177 }
178178
179+ fn success ( & self ) -> bool {
180+ self . solved ( )
181+ }
182+
179183fn observe ( & self , ) -> Vec < usize > {
180184 self . state . iter ( ) . enumerate ( ) . map ( |( i, v) | i * self . height * self . width + v) . collect ( )
181185 }
Original file line number Diff line number Diff line change @@ -138,6 +138,16 @@ impl Env for PyEnvImpl {
138138 } )
139139 }
140140
141+ fn success ( & self ) -> bool {
142+ Python :: with_gil ( |py| {
143+ let py_env = self . py_env . borrow ( ) ;
144+ py_env
145+ . call_method0 ( py, "success" )
146+ . and_then ( |val| val. extract :: < bool > ( py) )
147+ . expect ( "Python `success` method must return a bool." )
148+ } )
149+ }
150+
141151 fn observe ( & self ) -> Vec < usize > {
142152 Python :: with_gil ( |py| {
143153 let py_env = self . py_env . borrow ( ) ;
Original file line number Diff line number Diff line change @@ -46,6 +46,9 @@ pub trait Env : DynClone + Send + Sync {
4646 // Returns True if the given state is a terminal state
4747 fn is_final ( & self ) -> bool ;
4848
49+ // Returns True if the goal is achieved
50+ fn success ( & self ) -> bool ;
51+
4952 // Returns the value of current state
5053 fn reward ( & self ) -> f32 ;
5154
@@ -56,4 +59,4 @@ pub trait Env : DynClone + Send + Sync {
5659 fn twists ( & self ) -> ( Vec < Vec < usize > > , Vec < Vec < usize > > ) { ( vec ! [ ] , vec ! [ ] ) }
5760}
5861
59- dyn_clone:: clone_trait_object!( Env ) ;
62+ dyn_clone:: clone_trait_object!( Env ) ;
Original file line number Diff line number Diff line change @@ -59,7 +59,9 @@ pub fn single_solve(
5959 let val = env. reward ( ) ;
6060 total_val += val;
6161
62- ( ( ( val == 1.0 ) as usize as f32 , total_val) , solution)
62+ let success = if env. success ( ) { 1.0 } else { 0.0 } ;
63+
64+ ( ( success, total_val) , solution)
6365}
6466
6567pub fn solve (
You can’t perform that action at this time.
0 commit comments