|
19 | 19 | "import awkward as ak # Using awkward array for easier handling of jagged data\n", |
20 | 20 | "import time # For timing steps\n", |
21 | 21 | "\n", |
| 22 | + "# Set seeds for reproducibility\n", |
| 23 | + "seed_value = 42\n", |
| 24 | + "random.seed(seed_value)\n", |
| 25 | + "np.random.seed(seed_value)\n", |
| 26 | + "torch.manual_seed(seed_value)\n", |
| 27 | + "if torch.cuda.is_available():\n", |
| 28 | + " torch.cuda.manual_seed(seed_value)\n", |
| 29 | + "os.environ['PYTHONHASHSEED'] = str(seed_value)\n", |
| 30 | + "\n", |
22 | 31 | "def load_root_file(file_path, branches=None, print_branches=False):\n", |
23 | 32 | " all_branches = {}\n", |
24 | 33 | " with uproot.open(file_path) as file:\n", |
|
381 | 390 | " dis_pairs = idxs_triu[dis_mask]\n", |
382 | 391 | "\n", |
383 | 392 | " # down-sample\n", |
| 393 | + " random.seed(evt_idx)\n", |
384 | 394 | " if len(sim_pairs) > max_sim:\n", |
385 | 395 | " sim_pairs = sim_pairs[random.sample(range(len(sim_pairs)), max_sim)]\n", |
386 | 396 | " if len(dis_pairs) > max_dis:\n", |
|
417 | 427 | "\n", |
418 | 428 | " with ProcessPoolExecutor(max_workers=n_workers) as pool:\n", |
419 | 429 | " futures = [pool.submit(_pairs_single_event, *args) for args in work_args]\n", |
420 | | - " for fut in as_completed(futures):\n", |
| 430 | + " for fut in futures:\n", |
421 | 431 | " evt_idx, sim_pairs_evt, dis_pairs_evt = fut.result()\n", |
422 | 432 | " F = features_per_event[evt_idx]\n", |
423 | 433 | " D = displaced_per_event[evt_idx]\n", |
|
546 | 556 | " dis_pairs = np.column_stack((idx_p[dis_mask], idx_t[dis_mask]))\n", |
547 | 557 | "\n", |
548 | 558 | " # down-sample\n", |
| 559 | + " random.seed(evt_idx)\n", |
549 | 560 | " if len(sim_pairs) > max_sim:\n", |
550 | 561 | " sim_pairs = sim_pairs[random.sample(range(len(sim_pairs)), max_sim)]\n", |
551 | 562 | " if len(dis_pairs) > max_dis:\n", |
|
584 | 595 | " )\n", |
585 | 596 | " for ev in range(len(features_per_event))\n", |
586 | 597 | " ]\n", |
587 | | - " for fut in as_completed(futures):\n", |
| 598 | + " for fut in futures:\n", |
588 | 599 | " _, packed = fut.result()\n", |
589 | 600 | " # accumulate\n", |
590 | 601 | " sim_evt = sum(1 for _,_,lbl,_ in packed if lbl == 0)\n", |
|
0 commit comments