Skip to content

Commit 95886d9

Browse files
authored
Merge pull request #47 from OpenManus/test-pipeline
Test pipeline add first version
2 parents 85cd7a5 + af34b3d commit 95886d9

File tree

11 files changed

+1576
-411
lines changed

11 files changed

+1576
-411
lines changed

data/generate_train_webshop.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ def flatten_data(items_data):
7878
parser = argparse.ArgumentParser()
7979
parser.add_argument('--input_file', default="openmanus_rl/agentgym/agentenv-webshop/webshop/data/items_human_ins.json",
8080
help="Path to items_human_ins.json")
81-
parser.add_argument('--output_dir', required=True, help="Output directory for processed parquet")
81+
parser.add_argument('--output_dir', required=False, default="data/webshop", help="Output directory for processed parquet")
8282
parser.add_argument('--split', type=str, default="train")
83-
parser.add_argument('--train_ratio', type=float, default=0.95,
83+
parser.add_argument('--train_ratio', type=float, default=0.90,
8484
help="Ratio of data to use for training (rest for val/test)")
8585
parser.add_argument('--val_ratio', type=float, default=0.1,
8686
help="Ratio of data to use for validation")
@@ -99,6 +99,11 @@ def flatten_data(items_data):
9999
if args.split == "all":
100100
# Process all data with the same split label
101101
dataset = dataset.map(function=make_map_fn(args.split), with_indices=True)
102+
# Save the entire dataset
103+
output_path = os.path.join(args.output_dir, f"{args.split}.parquet")
104+
dataset.to_parquet(output_path)
105+
print(f"Processed {len(dataset)} examples")
106+
print(f"Data saved to {output_path}")
102107
else:
103108
# Split the dataset
104109
splits = dataset.train_test_split(
@@ -108,30 +113,43 @@ def flatten_data(items_data):
108113

109114
# Further split the test set into validation and test
110115
if args.val_ratio > 0:
116+
# Calculate the ratio for the validation set from the remaining data
117+
remaining_ratio = 1.0 - args.train_ratio
118+
val_test_ratio = max(0.5, args.val_ratio / remaining_ratio) # Ensure ratio is valid
111119
test_val_split = splits["test"].train_test_split(
112-
test_size=(1.0 - (args.val_ratio / (1.0 - args.train_ratio))),
120+
test_size=0.5, # Split remaining data equally between val and test
113121
seed=42
114122
)
115123
splits = {
116124
"train": splits["train"],
117125
"validation": test_val_split["train"],
118126
"test": test_val_split["test"]
119127
}
120-
121-
# Process only the requested split
122-
if args.split in splits:
123-
dataset = splits[args.split]
124-
dataset = dataset.map(function=make_map_fn(args.split), with_indices=True)
128+
129+
# Process and save all splits
130+
for split_name, split_dataset in splits.items():
131+
processed_dataset = split_dataset.map(function=make_map_fn(split_name), with_indices=True)
132+
output_path = os.path.join(args.output_dir, f"{split_name}.parquet")
133+
processed_dataset.to_parquet(output_path)
134+
print(f"Processed {len(processed_dataset)} examples for {split_name}")
135+
print(f"Data saved to {output_path}")
136+
137+
# Print sample for the requested split
138+
if split_name == args.split:
139+
dataset = processed_dataset
125140
else:
126-
raise ValueError(f"Invalid split: {args.split}. Must be 'train', 'validation', 'test', or 'all'")
141+
# If no validation split is requested, just process train and test
142+
for split_name, split_dataset in splits.items():
143+
processed_dataset = split_dataset.map(function=make_map_fn(split_name), with_indices=True)
144+
output_path = os.path.join(args.output_dir, f"{split_name}.parquet")
145+
processed_dataset.to_parquet(output_path)
146+
print(f"Processed {len(processed_dataset)} examples for {split_name}")
147+
print(f"Data saved to {output_path}")
148+
149+
# Set the dataset for the requested split
150+
if split_name == args.split:
151+
dataset = processed_dataset
127152

128-
# Create output directory and save dataset
129-
os.makedirs(args.output_dir, exist_ok=True)
130-
output_path = os.path.join(args.output_dir, f"{args.split}.parquet")
131-
dataset.to_parquet(output_path)
132-
133-
print(f"Processed {len(dataset)} examples")
134-
print(f"Data saved to {output_path}")
135-
136-
# Print sample
153+
# Print sample from the requested split
154+
print(f"\nSample from {args.split} split:")
137155
pprint(dataset[0])

data/webshop/test.parquet

14.8 KB
Binary file not shown.

data/webshop/train.parquet

2.5 MB
Binary file not shown.

data/webshop/validation.parquet

299 KB
Binary file not shown.

0 commit comments

Comments
 (0)