Skip to content

Commit 93bdd8c

Browse files
authored
Sneaky noop: action=-1 (#14)
It is useful to have a NOOP action in the environment, for "thinking time" or other ablations. But we don't want to make the NNs that learn from it use NOOPs. Here, we intentionally send -1 (or any other invalid action <0). We catch that case and make the environment not reset.
2 parents 899f346 + 9de7244 commit 93bdd8c

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

envpool/sokoban/sokoban_envpool.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "envpool/sokoban/sokoban_envpool.h"
1616

1717
#include <array>
18+
#include <limits>
1819
#include <sstream>
1920
#include <stdexcept>
2021
#include <vector>
@@ -76,10 +77,17 @@ constexpr std::array<std::array<int, 2>, 4> kChangeCoordinates = {
7677
{{0, -1}, {0, 1}, {-1, 0}, {1, 0}}};
7778

7879
void SokobanEnv::Step(const Action& action_dict) {
79-
current_step_++;
80-
8180
const int action = action_dict["action"_];
81+
// Sneaky Noop action
82+
if (action < 0) {
83+
WriteState(std::numeric_limits<float>::signaling_NaN());
84+
// Avoid advancing the current_step_. `envpool/core/env.h` advances
85+
// `current_step_` at every non-Reset step, and sets it to 0 when it is a
86+
// Reset.
87+
return;
88+
}
8289

90+
current_step_++;
8391
const int change_coordinates_idx = action;
8492
const int delta_x = kChangeCoordinates.at(change_coordinates_idx).at(0);
8593
const int delta_y = kChangeCoordinates.at(change_coordinates_idx).at(1);

envpool/sokoban/sokoban_py_envpool_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,44 @@ def test_astar_log(tmp_path) -> None:
342342
assert f"0,{SOLVE_LEVEL_ZERO},21,1380" == log.split("\n")[1]
343343

344344

345+
def test_sneaky_noop():
346+
"""
347+
Even though an action < 0 is not part of the environment, we overload it to
348+
mean NOOP.
349+
350+
This lets us easily do thinking-time experiments
351+
"""
352+
MIN_EP_STEPS = 1
353+
MAX_EP_STEPS = 3
354+
NUM_ENVS = 5
355+
356+
env = envpool.make(
357+
"Sokoban-v0",
358+
env_type="gymnasium",
359+
num_envs=NUM_ENVS,
360+
batch_size=NUM_ENVS,
361+
min_episode_steps=MIN_EP_STEPS,
362+
max_episode_steps=MAX_EP_STEPS,
363+
levels_dir="/app/envpool/sokoban/sample_levels",
364+
)
365+
init_obs, _ = env.reset()
366+
assert env.action_space.n == 4
367+
for _ in range(MAX_EP_STEPS * 5):
368+
obs, reward, terminated, truncated, info = env.step(
369+
-np.ones([NUM_ENVS], dtype=np.int64)
370+
)
371+
assert np.array_equal(init_obs, obs)
372+
assert not np.any(terminated | truncated)
373+
assert np.all(np.isnan(reward))
374+
375+
truncs = []
376+
for _ in range(MAX_EP_STEPS):
377+
_, _, _, truncated, _ = env.step(np.zeros([NUM_ENVS], dtype=np.int64))
378+
truncs.append(truncated)
379+
380+
assert np.all(np.any(truncated, axis=0), axis=0)
381+
382+
345383
if __name__ == "__main__":
346384
retcode = pytest.main(["-v", __file__])
347385
sys.exit(retcode)

0 commit comments

Comments
 (0)