Skip to content

Commit af7f8aa

Browse files
Fix alfworld dataset loading to use correct train/test split (#378)
1 parent 81b3554 commit af7f8aa

File tree

2 files changed

+45
-16
lines changed

2 files changed

+45
-16
lines changed

examples/grpo_alfworld/get_alfworld_data.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,31 +10,55 @@
1010
random.seed(42)
1111

1212

13-
def create_dataset_files(output_dir, train_size=1024, test_size=100):
13+
def create_dataset_files(output_dir, train_size=None, test_size=None):
1414
# The ALFWORLD_DATA is the dataset path in the environment variable ALFWORLD_DATA, you need to set it when install alfworld dataset
1515
from alfworld.info import ALFWORLD_DATA
1616

17-
# get all matched game files
18-
game_files = glob.glob(os.path.expanduser(f"{ALFWORLD_DATA}/json_2.1.1/train/*/*/game.tw-pddl"))
17+
# get all matched game files from train and valid_seen directories
18+
train_game_files = glob.glob(
19+
os.path.expanduser(f"{ALFWORLD_DATA}/json_2.1.1/train/*/*/game.tw-pddl")
20+
)
21+
test_game_files = glob.glob(
22+
os.path.expanduser(f"{ALFWORLD_DATA}/json_2.1.1/valid_seen/*/*/game.tw-pddl")
23+
)
1924

2025
# get absolute path
21-
game_files = [os.path.abspath(file) for file in game_files]
22-
game_files = sorted(game_files)
26+
train_game_files = [os.path.abspath(file) for file in train_game_files]
27+
test_game_files = [os.path.abspath(file) for file in test_game_files]
28+
train_game_files = sorted(train_game_files)
29+
test_game_files = sorted(test_game_files)
2330

24-
# randomly sellect the game files
25-
sellected_game_files = random.sample(game_files, train_size + test_size)
31+
print(f"Total train game files found: {len(train_game_files)}")
32+
print(f"Total test game files found: {len(test_game_files)}")
33+
34+
# if size is None, use all files
35+
if train_size is None:
36+
train_size = len(train_game_files)
37+
if test_size is None:
38+
test_size = len(test_game_files)
39+
40+
# check sizes
41+
assert train_size <= len(
42+
train_game_files
43+
), f"train_size {train_size} > available {len(train_game_files)}"
44+
assert test_size <= len(
45+
test_game_files
46+
), f"test_size {test_size} > available {len(test_game_files)}"
47+
48+
# randomly select the game files
49+
selected_train_files = random.sample(train_game_files, train_size)
50+
selected_test_files = random.sample(test_game_files, test_size)
2651

2752
# make the output directory
2853
os.makedirs(output_dir, exist_ok=True)
2954

30-
# for webshop dataset, we just need the session id as the task id
31-
all_data = []
32-
for game_file_path in sellected_game_files:
33-
all_data.append({"game_file": game_file_path, "target": ""})
34-
35-
# split the train and test data
36-
train_data = all_data[:train_size]
37-
test_data = all_data[train_size : train_size + test_size]
55+
# create train and test data
56+
train_data = [
57+
{"game_file": game_file_path, "target": ""} for game_file_path in selected_train_files
58+
]
59+
test_data = [
60+
{"game_file": game_file_path, "target": ""} for game_file_path in selected_test_files
61+
]
3862

3963
# create dataset_dict
4064
dataset_dict = {"train": train_data, "test": test_data}
@@ -58,8 +82,12 @@ def create_dataset_files(output_dir, train_size=1024, test_size=100):
5882
with open(os.path.join(output_dir, "dataset_dict.json"), "w") as f:
5983
json.dump(dataset_info, f, indent=2)
6084

85+
print(f"Created dataset with {len(train_data)} train and {len(test_data)} test examples.")
86+
6187

6288
if __name__ == "__main__":
6389
current_file_dir = os.path.dirname(os.path.abspath(__file__))
6490
output_dir = f"{current_file_dir}/alfworld_data"
65-
create_dataset_files(output_dir, train_size=1024, test_size=100)
91+
# use all data by default, or specify train_size and test_size if needed
92+
create_dataset_files(output_dir)
93+
# create_dataset_files(output_dir, train_size=1024, test_size=100) # use subset of data for testing

trinity/buffer/reader/file_reader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def select_batch(self, indices: List[int]) -> List:
8080
for i in indices:
8181
assert 0 <= i < self.dataset_size
8282
batch.append(self.dataset[int(i)])
83+
self.progress_bar.update(len(batch)) # update progress bar
8384
return batch
8485

8586

0 commit comments

Comments
 (0)