Skip to content

Commit 7364f0e

Browse files
committed
Add seeds and deterministic thread-getting to make this reproducible
1 parent aa687e8 commit 7364f0e

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

RecoTracker/LSTCore/standalone/analysis/DNN/embed_train.ipynb

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@
1919
"import awkward as ak # Using awkward array for easier handling of jagged data\n",
2020
"import time # For timing steps\n",
2121
"\n",
22+
"# Set seeds for reproducibility\n",
23+
"seed_value = 42\n",
24+
"random.seed(seed_value)\n",
25+
"np.random.seed(seed_value)\n",
26+
"torch.manual_seed(seed_value)\n",
27+
"if torch.cuda.is_available():\n",
28+
" torch.cuda.manual_seed(seed_value)\n",
29+
"os.environ['PYTHONHASHSEED'] = str(seed_value)\n",
30+
"\n",
2231
"def load_root_file(file_path, branches=None, print_branches=False):\n",
2332
" all_branches = {}\n",
2433
" with uproot.open(file_path) as file:\n",
@@ -381,6 +390,7 @@
381390
" dis_pairs = idxs_triu[dis_mask]\n",
382391
"\n",
383392
" # down-sample\n",
393+
" random.seed(evt_idx)\n",
384394
" if len(sim_pairs) > max_sim:\n",
385395
" sim_pairs = sim_pairs[random.sample(range(len(sim_pairs)), max_sim)]\n",
386396
" if len(dis_pairs) > max_dis:\n",
@@ -417,7 +427,7 @@
417427
"\n",
418428
" with ProcessPoolExecutor(max_workers=n_workers) as pool:\n",
419429
" futures = [pool.submit(_pairs_single_event, *args) for args in work_args]\n",
420-
" for fut in as_completed(futures):\n",
430+
" for fut in futures:\n",
421431
" evt_idx, sim_pairs_evt, dis_pairs_evt = fut.result()\n",
422432
" F = features_per_event[evt_idx]\n",
423433
" D = displaced_per_event[evt_idx]\n",
@@ -546,6 +556,7 @@
546556
" dis_pairs = np.column_stack((idx_p[dis_mask], idx_t[dis_mask]))\n",
547557
"\n",
548558
" # down-sample\n",
559+
" random.seed(evt_idx)\n",
549560
" if len(sim_pairs) > max_sim:\n",
550561
" sim_pairs = sim_pairs[random.sample(range(len(sim_pairs)), max_sim)]\n",
551562
" if len(dis_pairs) > max_dis:\n",
@@ -584,7 +595,7 @@
584595
" )\n",
585596
" for ev in range(len(features_per_event))\n",
586597
" ]\n",
587-
" for fut in as_completed(futures):\n",
598+
" for fut in futures:\n",
588599
" _, packed = fut.result()\n",
589600
" # accumulate\n",
590601
" sim_evt = sum(1 for _,_,lbl,_ in packed if lbl == 0)\n",

0 commit comments

Comments
 (0)