|
1 | 1 | #!/usr/bin/env python |
2 | 2 | """ |
3 | | -Create a small dataset by sampling random graphs from a larger dataset. |
| 3 | +Create a small dataset by selecting graphs per width from a larger dataset. |
| 4 | +
|
| 5 | +For each unique width value found in the input dataset, selects a specified |
| 6 | +number of graphs (first ones encountered) and saves them to a new file. |
4 | 7 |
|
5 | 8 | Usage: |
6 | 9 | python create_small_dataset.py |
7 | 10 | """ |
8 | 11 |
|
9 | | -import random |
10 | 12 | import gzip |
| 13 | +import argparse |
11 | 14 | from pathlib import Path |
12 | 15 | import sys |
13 | 16 |
|
|
16 | 19 | from benchmark_utils import DatasetLoader |
17 | 20 |
|
18 | 21 |
|
| 22 | +def derive_output_path(input_path: str, graphs_per_width: int) -> str: |
| 23 | + """Derive output dataset path from input path and sampling parameter.""" |
| 24 | + input_file = Path(input_path) |
| 25 | + suffixes = input_file.suffixes |
| 26 | + |
| 27 | + # Preserve historical format for .grp.gz files. |
| 28 | + if suffixes[-2:] == ['.grp', '.gz']: |
| 29 | + base_name = input_file.name[:-len('.grp.gz')] |
| 30 | + output_name = f"{base_name}_{graphs_per_width}_perwidth.grp.gz" |
| 31 | + elif suffixes: |
| 32 | + # Keep existing extension for other file types. |
| 33 | + ext = ''.join(suffixes) |
| 34 | + base_name = input_file.name[:-len(ext)] |
| 35 | + output_name = f"{base_name}_{graphs_per_width}_perwidth{ext}" |
| 36 | + else: |
| 37 | + output_name = f"{input_file.name}_{graphs_per_width}_perwidth" |
| 38 | + |
| 39 | + return str(input_file.with_name(output_name)) |
| 40 | + |
| 41 | + |
19 | 42 | def write_graph_to_lines(graph): |
20 | 43 | """ |
21 | 44 | Convert a graph back to the .grp format lines. |
@@ -50,38 +73,44 @@ def write_graph_to_lines(graph): |
50 | 73 | def create_small_dataset( |
51 | 74 | input_path: str, |
52 | 75 | output_path: str, |
53 | | - num_graphs: int = 50, |
54 | | - seed: int = 42 |
| 76 | + graphs_per_width: int = 5 |
55 | 77 | ): |
56 | 78 | """ |
57 | | - Sample random graphs from a dataset and save to a new file. |
| 79 | + Select graphs from a dataset by taking a fixed number per width value. |
58 | 80 | |
59 | 81 | Parameters |
60 | 82 | ---------- |
61 | 83 | input_path : str |
62 | 84 | Path to input dataset |
63 | 85 | output_path : str |
64 | 86 | Path to output dataset |
65 | | - num_graphs : int |
66 | | - Number of graphs to sample |
67 | | - seed : int |
68 | | - Random seed for reproducibility |
| 87 | + graphs_per_width : int |
| 88 | + Number of graphs to select for each unique width value (default: 5) |
69 | 89 | """ |
70 | 90 | print(f"Loading graphs from {input_path}...") |
71 | 91 | loader = DatasetLoader(input_path) |
72 | 92 | graphs = loader.load_graphs() |
73 | 93 | print(f"Loaded {len(graphs)} graphs") |
74 | 94 |
|
75 | | - # Set random seed for reproducibility |
76 | | - random.seed(seed) |
| 95 | + # Group graphs by width |
| 96 | + width_groups = {} |
| 97 | + for graph in graphs: |
| 98 | + width = graph.graph.get('w', 0) |
| 99 | + if width not in width_groups: |
| 100 | + width_groups[width] = [] |
| 101 | + width_groups[width].append(graph) |
77 | 102 |
|
78 | | - # Sample graphs |
79 | | - if num_graphs >= len(graphs): |
80 | | - print(f"Warning: Requested {num_graphs} graphs but dataset only has {len(graphs)}") |
81 | | - sampled_graphs = graphs |
82 | | - else: |
83 | | - sampled_graphs = random.sample(graphs, num_graphs) |
84 | | - print(f"Sampled {len(sampled_graphs)} random graphs") |
| 103 | + print(f"Found {len(width_groups)} unique width values") |
| 104 | + |
| 105 | + # Select first graphs_per_width graphs for each width |
| 106 | + sampled_graphs = [] |
| 107 | + for width in sorted(width_groups.keys()): |
| 108 | + graphs_with_width = width_groups[width] |
| 109 | + selected = graphs_with_width[:graphs_per_width] |
| 110 | + sampled_graphs.extend(selected) |
| 111 | + print(f" Width {width}: selected {len(selected)} of {len(graphs_with_width)} graphs") |
| 112 | + |
| 113 | + print(f"\nTotal selected: {len(sampled_graphs)} graphs") |
85 | 114 |
|
86 | 115 | # Create output directory if needed |
87 | 116 | output_file = Path(output_path) |
@@ -119,18 +148,46 @@ def create_small_dataset( |
119 | 148 |
|
120 | 149 | def main(): |
121 | 150 | """Main function.""" |
122 | | - # Configuration |
123 | | - input_path = "datasets/esa2025/Mouse.PacBio_reads.grp.gz" |
124 | | - output_path = "datasets/small/Mouse.PacBio_reads_500.grp.gz" |
125 | | - num_graphs = 500 |
126 | | - seed = 42 |
| 151 | + parser = argparse.ArgumentParser( |
| 152 | + description="Create a smaller dataset by sampling a fixed number of graphs per width" |
| 153 | + ) |
| 154 | + parser.add_argument( |
| 155 | + "--input", |
| 156 | + default="datasets/esa2025/Mouse.PacBio_reads.grp.gz", |
| 157 | + help="Path to input dataset (default: %(default)s)", |
| 158 | + ) |
| 159 | + parser.add_argument( |
| 160 | + "--graphs-per-width", |
| 161 | + type=int, |
| 162 | + default=5, |
| 163 | + help="Number of graphs to select per width value (default: %(default)s)", |
| 164 | + ) |
| 165 | + parser.add_argument( |
| 166 | + "--output", |
| 167 | + default=None, |
| 168 | + help=( |
| 169 | + "Path to output dataset. If omitted, it is derived from --input as " |
| 170 | + "<input>_<graphs-per-width>_perwidth with the same extension(s)." |
| 171 | + ), |
| 172 | + ) |
| 173 | + args = parser.parse_args() |
| 174 | + |
| 175 | + if args.graphs_per_width <= 0: |
| 176 | + parser.error("--graphs-per-width must be a positive integer") |
| 177 | + |
| 178 | + input_path = args.input |
| 179 | + graphs_per_width = args.graphs_per_width |
| 180 | + output_path = args.output or derive_output_path(input_path, graphs_per_width) |
127 | 181 |
|
128 | 182 | print("="*70) |
129 | 183 | print("Creating small dataset") |
130 | 184 | print("="*70) |
| 185 | + print(f"Input: {input_path}") |
| 186 | + print(f"Output: {output_path}") |
| 187 | + print(f"Graphs per width: {graphs_per_width}") |
131 | 188 |
|
132 | 189 | try: |
133 | | - create_small_dataset(input_path, output_path, num_graphs, seed) |
| 190 | + create_small_dataset(input_path, output_path, graphs_per_width) |
134 | 191 |
|
135 | 192 | # Verify the output |
136 | 193 | print("\nVerifying output...") |
|
0 commit comments