diff --git a/examples/grpo_alfworld/get_alfworld_data.py b/examples/grpo_alfworld/get_alfworld_data.py index 93e6c8a1b3..9989e8bffa 100644 --- a/examples/grpo_alfworld/get_alfworld_data.py +++ b/examples/grpo_alfworld/get_alfworld_data.py @@ -10,31 +10,55 @@ random.seed(42) -def create_dataset_files(output_dir, train_size=1024, test_size=100): +def create_dataset_files(output_dir, train_size=None, test_size=None): # The ALFWORLD_DATA is the dataset path in the environment variable ALFWORLD_DATA, you need to set it when install alfworld dataset from alfworld.info import ALFWORLD_DATA - # get all matched game files - game_files = glob.glob(os.path.expanduser(f"{ALFWORLD_DATA}/json_2.1.1/train/*/*/game.tw-pddl")) + # get all matched game files from train and valid_seen directories + train_game_files = glob.glob( + os.path.expanduser(f"{ALFWORLD_DATA}/json_2.1.1/train/*/*/game.tw-pddl") + ) + test_game_files = glob.glob( + os.path.expanduser(f"{ALFWORLD_DATA}/json_2.1.1/valid_seen/*/*/game.tw-pddl") + ) # get absolute path - game_files = [os.path.abspath(file) for file in game_files] - game_files = sorted(game_files) + train_game_files = [os.path.abspath(file) for file in train_game_files] + test_game_files = [os.path.abspath(file) for file in test_game_files] + train_game_files = sorted(train_game_files) + test_game_files = sorted(test_game_files) - # randomly sellect the game files - sellected_game_files = random.sample(game_files, train_size + test_size) + print(f"Total train game files found: {len(train_game_files)}") + print(f"Total test game files found: {len(test_game_files)}") + + # if size is None, use all files + if train_size is None: + train_size = len(train_game_files) + if test_size is None: + test_size = len(test_game_files) + + # check sizes + assert train_size <= len( + train_game_files + ), f"train_size {train_size} > available {len(train_game_files)}" + assert test_size <= len( + test_game_files + ), f"test_size {test_size} > available {len(test_game_files)}" + + # randomly select the game files + selected_train_files = random.sample(train_game_files, train_size) + selected_test_files = random.sample(test_game_files, test_size) # make the output directory os.makedirs(output_dir, exist_ok=True) - # for webshop dataset, we just need the session id as the task id - all_data = [] - for game_file_path in sellected_game_files: - all_data.append({"game_file": game_file_path, "target": ""}) - - # split the train and test data - train_data = all_data[:train_size] - test_data = all_data[train_size : train_size + test_size] + # create train and test data + train_data = [ + {"game_file": game_file_path, "target": ""} for game_file_path in selected_train_files + ] + test_data = [ + {"game_file": game_file_path, "target": ""} for game_file_path in selected_test_files + ] # create dataset_dict dataset_dict = {"train": train_data, "test": test_data} @@ -58,8 +82,12 @@ def create_dataset_files(output_dir, train_size=1024, test_size=100): with open(os.path.join(output_dir, "dataset_dict.json"), "w") as f: json.dump(dataset_info, f, indent=2) + print(f"Created dataset with {len(train_data)} train and {len(test_data)} test examples.") + if __name__ == "__main__": current_file_dir = os.path.dirname(os.path.abspath(__file__)) output_dir = f"{current_file_dir}/alfworld_data" - create_dataset_files(output_dir, train_size=1024, test_size=100) + # use all data by default, or specify train_size and test_size if needed + create_dataset_files(output_dir) + # create_dataset_files(output_dir, train_size=1024, test_size=100) # use subset of data for testing diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index f5a218dac0..b6f39979f4 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -80,6 +80,7 @@ def select_batch(self, indices: List[int]) -> List: for i in indices: assert 0 <= i < self.dataset_size batch.append(self.dataset[int(i)]) + self.progress_bar.update(len(batch)) # update progress bar return batch