Skip to content

Commit 43fa04f

Browse files
Add success function to env trait (#28)
* Add permutations as twists in training * Fix lint * Add success to env trait. Check it in training * Revert notebooks reformat * Revert notebooks reformat * Add succes in test structures
1 parent 3637aa9 commit 43fa04f

File tree

8 files changed

+31
-5
lines changed

8 files changed

+31
-5
lines changed

examples/grid_world/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ impl Env for GridWorld {
154154
}
155155
}
156156

157+
fn success(&self) -> bool {
158+
self.at_goal()
159+
}
160+
157161
fn observe(&self) -> Vec<usize> {
158162
self.get_state().iter().enumerate().map(|(i, v)| i * self.height * self.width + v).collect()
159163
}

examples/puzzle.ipynb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,18 @@
3535
},
3636
{
3737
"cell_type": "code",
38-
"execution_count": 3,
38+
"execution_count": null,
3939
"metadata": {},
4040
"outputs": [],
4141
"source": [
4242
"import numpy as np\n",
4343
"\n",
44+
"\n",
4445
"def display(state):\n",
4546
" padding = len(str(max(state)))\n",
4647
" txt = \"\"\n",
4748
" for row in np.array(state).reshape([config[\"env\"][\"width\"], config[\"env\"][\"height\"]]):\n",
48-
" txt += \"|\" + \"|\".join((\"{:\"+str(padding)+\"d}\").format(num) if num > 0 else \" \"*padding for num in row) + \"|\\n\"\n",
49+
" txt += \"|\" + \"|\".join((\"{:\" + str(padding) + \"d}\").format(num) if num > 0 else \" \" * padding for num in row) + \"|\\n\"\n",
4950
" print(txt)"
5051
]
5152
},
@@ -121,7 +122,7 @@
121122
},
122123
{
123124
"cell_type": "code",
124-
"execution_count": 6,
125+
"execution_count": null,
125126
"metadata": {},
126127
"outputs": [
127128
{

rust/src/collector/az.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ mod tests {
156156
fn is_final(&self) -> bool { self.step >= 1 }
157157
fn reward(&self) -> f32 { 1.0 }
158158
fn observe(&self) -> Vec<usize> { vec![0] }
159+
fn success(&self) -> bool { true }
159160
}
160161

161162
fn dummy_policy() -> Policy {

rust/src/collector/ppo.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ mod tests {
152152
fn is_final(&self) -> bool { self.step >= 1 }
153153
fn reward(&self) -> f32 { 1.0 }
154154
fn observe(&self) -> Vec<usize> { vec![0] }
155+
fn success(&self) -> bool { true }
155156
}
156157

157158
fn dummy_policy() -> Policy {

rust/src/envs/puzzle.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ impl Env for Puzzle {
176176
}
177177
}
178178

179+
fn success(&self) -> bool {
180+
self.solved()
181+
}
182+
179183
fn observe(&self,) -> Vec<usize> {
180184
self.state.iter().enumerate().map(|(i, v)| i * self.height * self.width + v).collect()
181185
}

rust/src/python_interface/pyenv.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,16 @@ impl Env for PyEnvImpl {
138138
})
139139
}
140140

141+
fn success(&self) -> bool {
142+
Python::with_gil(|py| {
143+
let py_env = self.py_env.borrow();
144+
py_env
145+
.call_method0(py, "success")
146+
.and_then(|val| val.extract::<bool>(py))
147+
.expect("Python `success` method must return a bool.")
148+
})
149+
}
150+
141151
fn observe(&self) -> Vec<usize> {
142152
Python::with_gil(|py| {
143153
let py_env = self.py_env.borrow();

rust/src/rl/env.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ pub trait Env : DynClone + Send + Sync {
4646
// Returns True if the given state is a terminal state
4747
fn is_final(&self) -> bool;
4848

49+
// Returns True if the goal is achieved
50+
fn success(&self) -> bool;
51+
4952
// Returns the value of current state
5053
fn reward(&self) -> f32;
5154

@@ -56,4 +59,4 @@ pub trait Env : DynClone + Send + Sync {
5659
fn twists(&self) -> (Vec<Vec<usize>>, Vec<Vec<usize>>) {(vec![], vec![])}
5760
}
5861

59-
dyn_clone::clone_trait_object!(Env);
62+
dyn_clone::clone_trait_object!(Env);

rust/src/rl/solve.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ pub fn single_solve(
5959
let val = env.reward();
6060
total_val += val;
6161

62-
(((val == 1.0) as usize as f32, total_val), solution)
62+
let success = if env.success() { 1.0 } else { 0.0 };
63+
64+
((success, total_val), solution)
6365
}
6466

6567
pub fn solve(

0 commit comments

Comments
 (0)