@@ -78,9 +78,9 @@ def flatten_data(items_data):
7878 parser = argparse .ArgumentParser ()
7979 parser .add_argument ('--input_file' , default = "openmanus_rl/agentgym/agentenv-webshop/webshop/data/items_human_ins.json" ,
8080 help = "Path to items_human_ins.json" )
81- parser .add_argument ('--output_dir' , required = True , help = "Output directory for processed parquet" )
81+ parser .add_argument ('--output_dir' , required = False , default = "data/webshop" , help = "Output directory for processed parquet" )
8282 parser .add_argument ('--split' , type = str , default = "train" )
83- parser .add_argument ('--train_ratio' , type = float , default = 0.95 ,
83+ parser .add_argument ('--train_ratio' , type = float , default = 0.90 ,
8484 help = "Ratio of data to use for training (rest for val/test)" )
8585 parser .add_argument ('--val_ratio' , type = float , default = 0.1 ,
8686 help = "Ratio of data to use for validation" )
@@ -99,6 +99,11 @@ def flatten_data(items_data):
9999 if args .split == "all" :
100100 # Process all data with the same split label
101101 dataset = dataset .map (function = make_map_fn (args .split ), with_indices = True )
102+ # Save the entire dataset
103+ output_path = os .path .join (args .output_dir , f"{ args .split } .parquet" )
104+ dataset .to_parquet (output_path )
105+ print (f"Processed { len (dataset )} examples" )
106+ print (f"Data saved to { output_path } " )
102107 else :
103108 # Split the dataset
104109 splits = dataset .train_test_split (
@@ -108,30 +113,43 @@ def flatten_data(items_data):
108113
109114 # Further split the test set into validation and test
110115 if args .val_ratio > 0 :
116+ # Calculate the ratio for the validation set from the remaining data
117+ remaining_ratio = 1.0 - args .train_ratio
118+ val_test_ratio = max (0.5 , args .val_ratio / remaining_ratio ) # Ensure ratio is valid
111119 test_val_split = splits ["test" ].train_test_split (
112- test_size = ( 1.0 - ( args . val_ratio / ( 1.0 - args . train_ratio ))),
120+ test_size = 0.5 , # Split remaining data equally between val and test
113121 seed = 42
114122 )
115123 splits = {
116124 "train" : splits ["train" ],
117125 "validation" : test_val_split ["train" ],
118126 "test" : test_val_split ["test" ]
119127 }
120-
121- # Process only the requested split
122- if args .split in splits :
123- dataset = splits [args .split ]
124- dataset = dataset .map (function = make_map_fn (args .split ), with_indices = True )
128+
129+ # Process and save all splits
130+ for split_name , split_dataset in splits .items ():
131+ processed_dataset = split_dataset .map (function = make_map_fn (split_name ), with_indices = True )
132+ output_path = os .path .join (args .output_dir , f"{ split_name } .parquet" )
133+ processed_dataset .to_parquet (output_path )
134+ print (f"Processed { len (processed_dataset )} examples for { split_name } " )
135+ print (f"Data saved to { output_path } " )
136+
137+ # Print sample for the requested split
138+ if split_name == args .split :
139+ dataset = processed_dataset
125140 else :
126- raise ValueError (f"Invalid split: { args .split } . Must be 'train', 'validation', 'test', or 'all'" )
141+ # If no validation split is requested, just process train and test
142+ for split_name , split_dataset in splits .items ():
143+ processed_dataset = split_dataset .map (function = make_map_fn (split_name ), with_indices = True )
144+ output_path = os .path .join (args .output_dir , f"{ split_name } .parquet" )
145+ processed_dataset .to_parquet (output_path )
146+ print (f"Processed { len (processed_dataset )} examples for { split_name } " )
147+ print (f"Data saved to { output_path } " )
148+
149+ # Set the dataset for the requested split
150+ if split_name == args .split :
151+ dataset = processed_dataset
127152
128- # Create output directory and save dataset
129- os .makedirs (args .output_dir , exist_ok = True )
130- output_path = os .path .join (args .output_dir , f"{ args .split } .parquet" )
131- dataset .to_parquet (output_path )
132-
133- print (f"Processed { len (dataset )} examples" )
134- print (f"Data saved to { output_path } " )
135-
136- # Print sample
153+ # Print sample from the requested split
154+ print (f"\n Sample from { args .split } split:" )
137155 pprint (dataset [0 ])
0 commit comments