Skip to content

Commit 92d89f3

Browse files
MC dropout implementation, with scripts and slurms for reproduction
1 parent c98b892 commit 92d89f3

File tree

3 files changed

+15
-19
lines changed

3 files changed

+15
-19
lines changed

n3fit/src/n3fit/backends/keras_backend/MetaModel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ def mc_dropout_predict(self, x=None, n_samples=100):
195195
(``training=True``), implementing MC Dropout inference.
196196
197197
Each forward pass draws a fresh random binary mask on every Dropout
198-
layer, so the spread across samples reflects the epistemic uncertainty
199-
captured by the dropout regulariser.
198+
layer, so the spread across samples reflects the uncertainty
199+
captured by the dropout.
200200
201201
Parameters
202202
----------

n3fit/src/n3fit/mc_dropout_assemble.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
--mode mean (default)
1212
For every trained replica N write ONE exportgrid using the mean over all
1313
MC Dropout samples. Gives 1 PDF member per trained replica.
14-
Uncertainty comes from inter-replica diversity (same as standard NNPDF).
14+
Uncertainty comes from replica diversity (same as standard NNPDF).
1515
1616
--mode samples
1717
For every trained replica N write ONE exportgrid per MC Dropout sample.
1818
Gives n_samples PDF members per trained replica (sequential numbering).
1919
Uncertainty comes from the MC Dropout spread within a single trained
20-
replica. This is the mode you want to study MC Dropout uncertainty.
20+
replica.
2121
Example: 1 trained replica x 100 samples -> 100 PDF members.
2222
2323
Usage

n3fit/src/n3fit/mc_dropout_inference.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
"""
2-
MC Dropout inference for NNPDF dropout fits - (fit on central value)
2+
MC Dropout inference for NNPDF dropout fits
33
======================================================================
44
55
Loads the trained weights of one replica, runs N stochastic forward passes
66
with dropout kept active (training=True), and computes the mean PDF over
7-
those passes. The result is written as a numpy archive (.npz) containing:
7+
those passes. The result contains:
88
99
- ``x`` : x-grid, shape (n_x,)
1010
- ``mean`` : mean PDF, shape (n_x, 14)
1111
- ``std`` : std PDF, shape (n_x, 14)
1212
- ``samples`` : all N samples, shape (N, n_x, 14)
1313
- ``flavours`` : LHAPDF PID list, shape (14,)
1414
15-
Usage (from the project root, inside environment_nnpdf):
15+
Usage :
1616
python -m n3fit.mc_dropout_inference \\
1717
--fit-dir nnpdf40-like-dropout-cluster \\
1818
--replica 1 \\
@@ -25,8 +25,7 @@
2525
nnfit/replica_<N>/weights.weights.h5 - saved Keras weights
2626
filter.yml or n3fit runcard - architecture parameters
2727
28-
The architecture is hardcoded from the runcard below and must match the
29-
saved weights. If you change the runcard parameters, update the defaults.
28+
Warning: The architecture parameters are read from the n3fit runcard, not from the saved model.
3029
"""
3130

3231
import argparse
@@ -63,10 +62,7 @@ def _get_xgrid():
6362
from n3fit.io.writer import XGRID # noqa - re-raise if truly missing
6463
return XGRID
6564

66-
67-
# ---------------------------------------------------------------------------
68-
# ARCHITECTURE DEFAULTS : must match what was used during training
69-
# ---------------------------------------------------------------------------
65+
# Loading the architecture parameters from the fit runcard
7066
DEFAULT_RUNCARD = (
7167
Path(__file__).resolve().parent.parent.parent # n3fit/src/../.. : n3fit
7268
/ "runcards" / "examples" / "nnpdf40-like-dropout-cluster.yml"
@@ -83,13 +79,13 @@ def _load_architecture(runcard_path):
8379
fitbasis = rc["fitting"]["fitbasis"]
8480

8581
return dict(
86-
nodes = params["nodes_per_layer"], # [25, 20, 8]
87-
activations = params["activation_per_layer"], # ['tanh','tanh','linear']
88-
initializer = params["initializer"], # 'glorot_normal'
89-
architecture = params["layer_type"], # 'dense'
90-
dropout_rate = params.get("dropout", 0.0), # 0.1
82+
nodes = params["nodes_per_layer"],
83+
activations = params["activation_per_layer"],
84+
initializer = params["initializer"],
85+
architecture = params["layer_type"],
86+
dropout_rate = params.get("dropout", 0.0),
9187
flav_info = basis,
92-
fitbasis = fitbasis, # 'EVOL'
88+
fitbasis = fitbasis,
9389
)
9490

9591

0 commit comments

Comments
 (0)