Skip to content

Commit 515dff4

Browse files
authored
Merge pull request #48442 from SegmentLinking/reproducible_training
LST notebooks for neural network training: make training reproducible
2 parents 2ba4e54 + 0e04cdd commit 515dff4

File tree

4 files changed

+44
-2
lines changed

4 files changed

+44
-2
lines changed

RecoTracker/LSTCore/standalone/analysis/DNN/embed_train.ipynb

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@
1919
"import awkward as ak # Using awkward array for easier handling of jagged data\n",
2020
"import time # For timing steps\n",
2121
"\n",
22+
"# Set seeds for reproducibility\n",
23+
"seed_value = 42\n",
24+
"random.seed(seed_value)\n",
25+
"np.random.seed(seed_value)\n",
26+
"torch.manual_seed(seed_value)\n",
27+
"os.environ['PYTHONHASHSEED'] = str(seed_value)\n",
28+
"\n",
2229
"def load_root_file(file_path, branches=None, print_branches=False):\n",
2330
" all_branches = {}\n",
2431
" with uproot.open(file_path) as file:\n",
@@ -381,6 +388,7 @@
381388
" dis_pairs = idxs_triu[dis_mask]\n",
382389
"\n",
383390
" # down-sample\n",
391+
" random.seed(evt_idx)\n",
384392
" if len(sim_pairs) > max_sim:\n",
385393
" sim_pairs = sim_pairs[random.sample(range(len(sim_pairs)), max_sim)]\n",
386394
" if len(dis_pairs) > max_dis:\n",
@@ -417,7 +425,7 @@
417425
"\n",
418426
" with ProcessPoolExecutor(max_workers=n_workers) as pool:\n",
419427
" futures = [pool.submit(_pairs_single_event, *args) for args in work_args]\n",
420-
" for fut in as_completed(futures):\n",
428+
" for fut in futures:\n",
421429
" evt_idx, sim_pairs_evt, dis_pairs_evt = fut.result()\n",
422430
" F = features_per_event[evt_idx]\n",
423431
" D = displaced_per_event[evt_idx]\n",
@@ -546,6 +554,7 @@
546554
" dis_pairs = np.column_stack((idx_p[dis_mask], idx_t[dis_mask]))\n",
547555
"\n",
548556
" # down-sample\n",
557+
" random.seed(evt_idx)\n",
549558
" if len(sim_pairs) > max_sim:\n",
550559
" sim_pairs = sim_pairs[random.sample(range(len(sim_pairs)), max_sim)]\n",
551560
" if len(dis_pairs) > max_dis:\n",
@@ -584,7 +593,7 @@
584593
" )\n",
585594
" for ev in range(len(features_per_event))\n",
586595
" ]\n",
587-
" for fut in as_completed(futures):\n",
596+
" for fut in futures:\n",
588597
" _, packed = fut.result()\n",
589598
" # accumulate\n",
590599
" sim_evt = sum(1 for _,_,lbl,_ in packed if lbl == 0)\n",

RecoTracker/LSTCore/standalone/analysis/DNN/train_T3_DNN.ipynb

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
{
22
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"# set seed for reproducibility\n",
10+
"import torch\n",
11+
"torch.manual_seed(42)"
12+
]
13+
},
314
{
415
"cell_type": "code",
516
"execution_count": 1,

RecoTracker/LSTCore/standalone/analysis/DNN/train_T5_DNN.ipynb

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
{
22
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"# set seed for reproducibility\n",
10+
"import torch\n",
11+
"torch.manual_seed(42)"
12+
]
13+
},
314
{
415
"cell_type": "code",
516
"execution_count": 1,

RecoTracker/LSTCore/standalone/analysis/DNN/train_pT3_DNN.ipynb

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
{
22
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"# set seed for reproducibility\n",
10+
"import torch\n",
11+
"torch.manual_seed(42)"
12+
]
13+
},
314
{
415
"cell_type": "code",
516
"execution_count": 1,

0 commit comments

Comments
 (0)