Skip to content

Commit aa270fc

Browse files
authored
Add level infos on reset (#15)
2 parents 93bdd8c + dbeb9a2 commit aa270fc

File tree

7 files changed

+53
-17
lines changed

7 files changed

+53
-17
lines changed

envpool/sokoban/astar_log.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void RunAStar(const std::string& level_file_name,
4141
if (line.empty()) {
4242
continue;
4343
}
44-
SokobanLevel level = *level_loader.GetLevel(gen);
44+
SokobanLevel level = level_loader.GetLevel(gen).data;
4545
level_idx++;
4646
}
4747
}
@@ -50,7 +50,7 @@ void RunAStar(const std::string& level_file_name,
5050
while (level_idx < total_levels_to_run) {
5151
std::AStarSearch<SokobanNode> astarsearch(fsa_limit);
5252
std::cout << "Running level " << level_idx << std::endl;
53-
SokobanLevel level = *level_loader.GetLevel(gen);
53+
SokobanLevel level = level_loader.GetLevel(gen).data;
5454

5555
SokobanNode node_start(dim_room, level, false);
5656
SokobanNode node_end(dim_room, level, true);

envpool/sokoban/astar_log_level.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void RunAStar(const std::string& level_file_name,
4141
}
4242
std::AStarSearch<SokobanNode> astarsearch(fsa_limit);
4343
std::cout << "Running level " << level_idx << std::endl;
44-
SokobanLevel level = *level_loader.GetLevel(gen);
44+
SokobanLevel level = level_loader.GetLevel(gen).data;
4545

4646
SokobanNode node_start(dim_room, level, false);
4747
SokobanNode node_end(dim_room, level, true);

envpool/sokoban/level_loader.cc

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,24 +118,26 @@ void LevelLoader::LoadFile(std::mt19937& gen) {
118118
if (cur_file_ == level_file_paths_.end()) {
119119
throw std::runtime_error("No more files to load.");
120120
}
121+
cur_level_file_++;
121122
file_path = *cur_file_;
122123
cur_file_++;
123124
} else {
124-
const size_t load_file_idx = SafeUniformInt(
125-
static_cast<size_t>(0), level_file_paths_.size() - 1, gen);
126-
file_path = level_file_paths_.at(load_file_idx);
125+
cur_level_file_ = SafeUniformInt(static_cast<size_t>(0),
126+
level_file_paths_.size() - 1, gen);
127+
file_path = level_file_paths_.at(cur_level_file_);
127128
}
128129
std::ifstream file(file_path);
129130

130131
levels_.clear();
132+
int cur_level_idx = 0;
131133
std::string line;
132134
while (std::getline(file, line)) {
133135
if (line.empty()) {
134136
continue;
135137
}
136138

137139
if (line.at(0) == '#') {
138-
SokobanLevel& cur_level = levels_.emplace_back(0);
140+
SokobanLevel cur_level(0);
139141
cur_level.reserve(10 * 10); // In practice most levels are this size
140142

141143
// Count contiguous '#' characters and use this as the box dimension
@@ -163,6 +165,8 @@ void LevelLoader::LoadFile(std::mt19937& gen) {
163165
<< "x" << dim_room << std::endl;
164166
throw std::runtime_error(msg.str());
165167
}
168+
levels_.emplace_back(
169+
std::make_pair(cur_level_idx++, std::move(cur_level)));
166170
}
167171
}
168172
if (!load_sequentially_) {
@@ -178,20 +182,21 @@ void LevelLoader::LoadFile(std::mt19937& gen) {
178182
std::cout << "***Loaded " << levels_.size() << " levels from " << file_path
179183
<< std::endl;
180184
if (verbose >= 2) {
181-
PrintLevel(std::cout, levels_.at(0));
185+
PrintLevel(std::cout, levels_.at(0).second);
182186
std::cout << std::endl;
183-
PrintLevel(std::cout, levels_.at(1));
187+
PrintLevel(std::cout, levels_.at(1).second);
184188
std::cout << std::endl;
185189
}
186190
}
187191
}
188192

189-
std::vector<SokobanLevel>::iterator LevelLoader::GetLevel(std::mt19937& gen) {
193+
TaggedSokobanLevel LevelLoader::GetLevel(std::mt19937& gen) {
190194
if (n_levels_to_load_ > 0 && levels_loaded_ >= n_levels_to_load_) {
191195
// std::cerr << "Warning: All levels loaded. Looping around now." <<
192196
// std::endl;
193197
levels_loaded_ = 0;
194198
cur_file_ = level_file_paths_.begin();
199+
cur_level_file_ = -1;
195200
LoadFile(gen);
196201
// re-start from the `env_id`th level, like we do in the constructor.
197202
cur_level_ = env_id_;
@@ -206,7 +211,9 @@ std::vector<SokobanLevel>::iterator LevelLoader::GetLevel(std::mt19937& gen) {
206211
auto out = levels_.begin() + cur_level_;
207212
cur_level_ += num_envs_;
208213
levels_loaded_++;
209-
return out;
214+
215+
TaggedSokobanLevel tagged_level{cur_level_file_, out->first, out->second};
216+
return tagged_level;
210217
}
211218

212219
} // namespace sokoban

envpool/sokoban/level_loader.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include <filesystem>
2121
#include <random>
22+
#include <utility>
2223
#include <vector>
2324

2425
namespace sokoban {
@@ -34,23 +35,28 @@ constexpr uint8_t kPlayer = 5;
3435
constexpr uint8_t kPlayerOnTarget = 6;
3536
constexpr uint8_t kMaxLevelObject = kPlayerOnTarget;
3637

38+
struct TaggedSokobanLevel {
39+
int file_idx, level_idx;
40+
SokobanLevel data;
41+
};
42+
3743
class LevelLoader {
3844
protected:
3945
bool load_sequentially_;
4046
int n_levels_to_load_;
4147
int levels_loaded_{0};
4248
int env_id_{0};
4349
int num_envs_{1};
44-
std::vector<SokobanLevel> levels_{0};
45-
int cur_level_;
50+
std::vector<std::pair<int, SokobanLevel>> levels_{0};
51+
int cur_level_{-1}, cur_level_file_{-1};
4652
std::vector<std::filesystem::path> level_file_paths_{0};
4753
std::vector<std::filesystem::path>::iterator cur_file_;
4854
void LoadFile(std::mt19937& gen);
4955

5056
public:
5157
int verbose;
5258

53-
std::vector<SokobanLevel>::iterator GetLevel(std::mt19937& gen);
59+
TaggedSokobanLevel GetLevel(std::mt19937& gen);
5460
explicit LevelLoader(const std::filesystem::path& base_path,
5561
bool load_sequentially, int n_levels_to_load,
5662
int env_id = 0, int num_envs = 1, int verbose = 0);

envpool/sokoban/sokoban_envpool.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
#include "envpool/sokoban/sokoban_envpool.h"
1616

1717
#include <array>
18+
#include <iostream>
1819
#include <limits>
1920
#include <sstream>
2021
#include <stdexcept>
22+
#include <utility>
2123
#include <vector>
2224

2325
#include "envpool/core/py_envpool.h"
@@ -31,7 +33,11 @@ void SokobanEnv::ResetWithoutWrite() {
3133
current_max_episode_steps_ =
3234
SafeUniformInt(min_episode_steps, max_episode_steps, gen_);
3335

34-
world_ = *(level_loader_.GetLevel(gen_));
36+
TaggedSokobanLevel level = level_loader_.GetLevel(gen_);
37+
world_ = level.data;
38+
level_idx_ = level.level_idx;
39+
level_file_idx_ = level.file_idx;
40+
3541
if (world_.size() != dim_room_ * dim_room_) {
3642
std::stringstream msg;
3743
msg << "Loaded level is not dim_room x dim_room. world_.size()="
@@ -204,6 +210,9 @@ void SokobanEnv::WriteState(float reward) {
204210
}
205211
}
206212
obs.Assign(out.data(), out.size());
213+
214+
state["info:level_file_idx"_] = level_file_idx_;
215+
state["info:level_idx"_] = level_idx_;
207216
}
208217

209218
} // namespace sokoban

envpool/sokoban/sokoban_envpool.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ class SokobanEnvFns {
4848
template <typename Config>
4949
static decltype(auto) StateSpec(const Config& conf) {
5050
int dim_room = conf["dim_room"_];
51-
return MakeDict("obs"_.Bind(Spec<uint8_t>({3, dim_room, dim_room})));
51+
return MakeDict("obs"_.Bind(Spec<uint8_t>({3, dim_room, dim_room})),
52+
"info:level_file_idx"_.Bind(Spec<int>({-1})),
53+
"info:level_idx"_.Bind(Spec<int>({-1})));
5254
}
5355
template <typename Config>
5456
static decltype(auto) ActionSpec(const Config& conf) {
@@ -106,6 +108,7 @@ class SokobanEnv : public Env<SokobanEnvSpec> {
106108
std::filesystem::path levels_dir_;
107109

108110
LevelLoader level_loader_;
111+
int level_file_idx_{-1}, level_idx_{-1};
109112
SokobanLevel world_;
110113
int verbose_;
111114

envpool/sokoban/sokoban_py_envpool_test.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,12 @@ def test_load_sequentially_with_multiple_envs() -> None:
285285
levels_dir = "/app/envpool/sokoban/sample_levels"
286286
files = glob.glob(f"{levels_dir}/*.txt")
287287
levels_by_files = []
288+
levels_per_file = []
288289
total_levels, num_envs = 8, 2
289290
for file in sorted(files):
290291
levels = read_levels_file(file)
291292
levels_by_files.extend(levels)
293+
levels_per_file.append(len(levels))
292294
assert len(levels_by_files) == total_levels, "8 levels stored in files."
293295

294296
env = envpool.make(
@@ -307,8 +309,17 @@ def test_load_sequentially_with_multiple_envs() -> None:
307309
printed_obs = []
308310

309311
for _ in range(2): # check loader loops around and loads levels again
312+
gt_file_idx, gt_level_idx = 0, 0
310313
for _ in range(total_levels // num_envs):
311-
obs, _ = env.reset()
314+
obs, info = env.reset()
315+
level_file_idxs, level_idxs = info["level_file_idx"], info["level_idx"]
316+
for lfi, li in zip(level_file_idxs, level_idxs):
317+
assert lfi == gt_file_idx, f"lfi: {lfi}, gt_file_idx: {gt_file_idx}"
318+
assert li == gt_level_idx, f"li: {li}, gt_level_idx: {gt_level_idx}"
319+
gt_level_idx += 1
320+
if gt_level_idx == levels_per_file[gt_file_idx]:
321+
gt_file_idx += 1
322+
gt_level_idx = 0
312323
assert obs.shape == (
313324
num_envs,
314325
3,

0 commit comments

Comments
 (0)