diff --git a/src/toolsgen/core/parallel.py b/src/toolsgen/core/parallel.py index 269537b..3ad6017 100644 --- a/src/toolsgen/core/parallel.py +++ b/src/toolsgen/core/parallel.py @@ -133,6 +133,7 @@ def generate_records_parallel( results_by_index: Dict[int, Record] = {} failed = 0 + next_id_to_write = 0 ctx = mp.get_context("spawn") with ProcessPoolExecutor( @@ -156,6 +157,13 @@ def generate_records_parallel( if sample_result.record: record = Record.model_validate(sample_result.record) results_by_index[sample_result.sample_index] = record + + while next_id_to_write in results_by_index: + rec = results_by_index[next_id_to_write] + rec.id = f"record_{next_id_to_write:06d}" + append_record_jsonl(rec, jsonl_path) + del results_by_index[next_id_to_write] + next_id_to_write += 1 else: tqdm.write( "Warning: Failed to generate sample " @@ -170,8 +178,4 @@ def generate_records_parallel( pbar.update(1) all_records = [results_by_index[i] for i in sorted(results_by_index.keys())] - for idx, record in enumerate(all_records): - record.id = f"record_{idx:06d}" - append_record_jsonl(record, jsonl_path) - return all_records, failed