Skip to content
Merged
60 changes: 44 additions & 16 deletions examples/grpo_alfworld/get_alfworld_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,55 @@
random.seed(42)


def create_dataset_files(output_dir, train_size=1024, test_size=100):
def create_dataset_files(output_dir, train_size=None, test_size=None):
# The ALFWORLD_DATA is the dataset path in the environment variable ALFWORLD_DATA, you need to set it when install alfworld dataset
from alfworld.info import ALFWORLD_DATA

# get all matched game files
game_files = glob.glob(os.path.expanduser(f"{ALFWORLD_DATA}/json_2.1.1/train/*/*/game.tw-pddl"))
# get all matched game files from train and valid_seen directories
train_game_files = glob.glob(
os.path.expanduser(f"{ALFWORLD_DATA}/json_2.1.1/train/*/*/game.tw-pddl")
)
test_game_files = glob.glob(
os.path.expanduser(f"{ALFWORLD_DATA}/json_2.1.1/valid_seen/*/*/game.tw-pddl")
)

# get absolute path
game_files = [os.path.abspath(file) for file in game_files]
game_files = sorted(game_files)
train_game_files = [os.path.abspath(file) for file in train_game_files]
test_game_files = [os.path.abspath(file) for file in test_game_files]
train_game_files = sorted(train_game_files)
test_game_files = sorted(test_game_files)

# randomly sellect the game files
sellected_game_files = random.sample(game_files, train_size + test_size)
print(f"Total train game files found: {len(train_game_files)}")
print(f"Total test game files found: {len(test_game_files)}")

# if size is None, use all files
if train_size is None:
train_size = len(train_game_files)
if test_size is None:
test_size = len(test_game_files)

# check sizes
assert train_size <= len(
train_game_files
), f"train_size {train_size} > available {len(train_game_files)}"
assert test_size <= len(
test_game_files
), f"test_size {test_size} > available {len(test_game_files)}"

# randomly select the game files
selected_train_files = random.sample(train_game_files, train_size)
selected_test_files = random.sample(test_game_files, test_size)

# make the output directory
os.makedirs(output_dir, exist_ok=True)

# for webshop dataset, we just need the session id as the task id
all_data = []
for game_file_path in sellected_game_files:
all_data.append({"game_file": game_file_path, "target": ""})

# split the train and test data
train_data = all_data[:train_size]
test_data = all_data[train_size : train_size + test_size]
# create train and test data
train_data = [
{"game_file": game_file_path, "target": ""} for game_file_path in selected_train_files
]
test_data = [
{"game_file": game_file_path, "target": ""} for game_file_path in selected_test_files
]

# create dataset_dict
dataset_dict = {"train": train_data, "test": test_data}
Expand All @@ -58,8 +82,12 @@ def create_dataset_files(output_dir, train_size=1024, test_size=100):
with open(os.path.join(output_dir, "dataset_dict.json"), "w") as f:
json.dump(dataset_info, f, indent=2)

print(f"Created dataset with {len(train_data)} train and {len(test_data)} test examples.")


if __name__ == "__main__":
current_file_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = f"{current_file_dir}/alfworld_data"
create_dataset_files(output_dir, train_size=1024, test_size=100)
# use all data by default, or specify train_size and test_size if needed
create_dataset_files(output_dir)
# create_dataset_files(output_dir, train_size=1024, test_size=100) # use subset of data for testing
1 change: 1 addition & 0 deletions trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def select_batch(self, indices: List[int]) -> List:
for i in indices:
assert 0 <= i < self.dataset_size
batch.append(self.dataset[int(i)])
self.progress_bar.update(len(batch)) # update progress bar
return batch


Expand Down