Skip to content

Commit 1dd023b

Browse files
dakshdaksh
authored andcommitted
Update n3fit from BayesianPDF repo
1 parent 7ec5496 commit 1dd023b

File tree

21 files changed

+1521
-589
lines changed

21 files changed

+1521
-589
lines changed

n3fit/src/n3fit/backends/keras_backend/MetaModel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
def _default_loss(y_true, y_pred): # pylint: disable=unused-argument
5151
"""Default loss to be used when the model is compiled with loss = Null
5252
(for instance if the prediction of the model is already the loss"""
53-
return ops.nansum(y_pred)
53+
return ops.sum(y_pred)
5454

5555

5656
class MetaModel(Model):
@@ -219,7 +219,7 @@ def losses_fun():
219219
# If we only have one dataset the output changes
220220
if len(out_names) == 2:
221221
predictions = [predictions]
222-
total_loss = ops.nansum(predictions, axis=0)
222+
total_loss = ops.sum(predictions, axis=0)
223223
ret = [total_loss] + predictions
224224
return dict(zip(out_names, ret))
225225

n3fit/src/n3fit/backends/keras_backend/base_layers.py

Lines changed: 137 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,36 @@
11
"""
2-
This module defines custom base layers to be used by the n3fit
3-
Neural Network.
4-
These layers can use the keras standard set of activation function
5-
or implement their own.
2+
This module defines custom base layers to be used by the n3fit
3+
Neural Network.
4+
These layers can use the keras standard set of activation function
5+
or implement their own.
66
7-
For a layer to be used by n3fit it should be contained in the `layers` dictionary defined below.
8-
This dictionary has the following structure:
7+
For a layer to be used by n3fit it should be contained in the `layers` dictionary defined below.
8+
This dictionary has the following structure:
99
10-
'name of the layer' : ( Layer_class, {dictionary of arguments: defaults} )
10+
'name of the layer' : ( Layer_class, {dictionary of arguments: defaults} )
1111
12-
In order to add custom activation functions, they must be added to
13-
the `custom_activations` dictionary with the following structure:
12+
In order to add custom activation functions, they must be added to
13+
the `custom_activations` dictionary with the following structure:
1414
15-
'name of the activation' : function
15+
'name of the activation' : function
1616
17-
The names of the layer and the activation function are the ones to be used in the n3fit runcard.
17+
The names of the layer and the activation function are the ones to be used in the n3fit runcard.
1818
"""
19+
import numpy as np
20+
import keras.backend as K
21+
import tensorflow as tf
22+
import math
23+
from scipy.stats import norm
1924

2025
from keras.layers import Dense as KerasDense
21-
from keras.layers import Dropout, Lambda
26+
from keras.layers import Dropout, Lambda, Layer
2227
from keras.layers import Input # pylint: disable=unused-import
2328
from keras.layers import LSTM, Concatenate
2429
from keras.regularizers import l1_l2
2530

2631
from . import operations as ops
2732
from .MetaLayer import MetaLayer
28-
33+
from contextlib import contextmanager
2934

3035
# Custom activation functions
3136
def square_activation(x):
@@ -74,14 +79,116 @@ def ReshapedLSTM(input_tensor):
7479

7580
return ReshapedLSTM
7681

82+
class VBDense(Layer):
83+
def __init__(
84+
self,
85+
out_features: int,
86+
in_features: int,
87+
prior_prec: float = 0.001,
88+
map: bool = False,
89+
std_init: float = -9,
90+
lbound=-30,
91+
ubound=11,
92+
training = True
93+
):
94+
super().__init__()
95+
self.output_dim = out_features
96+
self.input_dim = in_features
97+
self.map = map
98+
self.prior_prec = tf.cast(prior_prec, tf.float64)
99+
self.random = None
100+
self.eps = 1e-12 if K.floatx() == 'float64' else 1e-8
101+
self.std_init = tf.cast(std_init, tf.float64)
102+
self.lbound = lbound
103+
self.ubound = ubound
104+
self.training = training
105+
106+
def build(self, input_shape):
107+
self.bias = self.add_weight(
108+
name='bias',
109+
shape=(self.output_dim,),
110+
initializer='glorot_normal',
111+
trainable=True,
112+
dtype=tf.float64
113+
)
114+
115+
self.mu_w = self.add_weight(
116+
name='mu_w',
117+
shape=(self.output_dim, self.input_dim),
118+
initializer='glorot_normal',
119+
trainable=True,
120+
dtype=tf.float64
121+
)
122+
123+
self.logsig2_w = self.add_weight(
124+
name='logsig2_w',
125+
shape=(self.output_dim, self.input_dim),
126+
initializer='glorot_normal',
127+
trainable=True,
128+
dtype=tf.float64,
129+
)
130+
131+
self.reset_parameters()
132+
133+
def reset_parameters(self):
134+
stdv = 1.0 / tf.math.sqrt(tf.cast(self.input_dim, dtype=tf.float64))
135+
self.bias.assign(tf.zeros_like(self.bias))
136+
self.mu_w.assign(tf.random.normal(tf.shape(self.mu_w), mean=0, stddev=stdv, dtype=tf.float64))
137+
self.logsig2_w.assign(tf.random.normal(tf.shape(self.logsig2_w), mean=self.std_init, stddev=0.001, dtype=tf.float64))
138+
139+
def reset_random(self):
140+
self.random = None
141+
self.map = False
142+
143+
def train(self):
144+
self.training = True
145+
146+
def eval(self):
147+
self.training = False
148+
149+
def kl_loss(self) -> tf.Tensor:
150+
logsig2_w = tf.clip_by_value(self.logsig2_w, self.lbound, self.ubound)
151+
kl = 0.5 * tf.reduce_sum((self.prior_prec*(tf.math.pow(self.mu_w,2)+tf.math.exp(logsig2_w))
152+
- logsig2_w - tf.constant(1.0, dtype=tf.float64) - tf.math.log(self.prior_prec)))
153+
return kl
154+
155+
def call(self, input: tf.Tensor) -> tf.Tensor:
156+
# Ensure input is tf.float64
157+
input = tf.cast(input, tf.float64)
158+
159+
if self.training:
160+
mu_out = tf.matmul(input, tf.cast(self.mu_w, input.dtype), transpose_b=True) + tf.cast(self.bias, input.dtype)
161+
logsig2_w = tf.clip_by_value(self.logsig2_w, self.lbound, self.ubound)
162+
s2_w = tf.math.exp(logsig2_w)
163+
input2 = tf.math.pow(input, 2)
164+
var_out = tf.matmul(input2, s2_w, transpose_b=True) + tf.cast(self.eps, input.dtype)
165+
166+
return mu_out + tf.math.sqrt(var_out) * tf.random.normal(shape=tf.shape(mu_out), dtype=input.dtype)
167+
168+
else:
169+
# During inference, use MAP estimation (posterior mean) for deterministic output
170+
if self.map:
171+
mu_out = tf.matmul(input, self.mu_w, transpose_b=True) + self.bias
172+
return mu_out
173+
174+
logsig2_w = tf.clip_by_value(self.logsig2_w, self.lbound, 11)
175+
if self.random is None:
176+
self.random = tf.Variable(tf.random.normal(shape=tf.shape(self.mu_w), dtype=tf.float64))
177+
s2_w = tf.math.exp(logsig2_w)
178+
# draw fresh samples instead of caching
179+
epsilon = tf.random.normal(shape=tf.shape(self.mu_w), dtype=tf.float64)
180+
weight = self.mu_w + tf.math.sqrt(s2_w) * epsilon #self.random
181+
182+
return tf.matmul(input, weight, transpose_b=True) + self.bias
183+
77184

78-
class Dense(KerasDense, MetaLayer):
79185

80-
def __init__(self, *args, **kwargs):
81-
# In Keras == 3.13, np.int() is not accepted by Dense
82-
if "units" in kwargs:
83-
kwargs["units"] = int(kwargs["units"])
84-
super().__init__(*args, **kwargs)
186+
class Dense(KerasDense, MetaLayer):
187+
def __init__(self, **kwargs):
188+
# Set default dtype to tf.float64 if not provided
189+
if 'dtype' not in kwargs:
190+
kwargs['dtype'] = tf.float64
191+
super().__init__(**kwargs)
85192

86193

87194
def dense_per_flavour(basis_size=8, kernel_initializer="glorot_normal", **dense_kwargs):
@@ -133,7 +240,6 @@ def apply_dense(xinput):
133240

134241
return apply_dense
135242

136-
137243
layers = {
138244
"dense": (
139245
Dense,
@@ -142,6 +248,7 @@ def apply_dense(xinput):
142248
"units": 5,
143249
"activation": "sigmoid",
144250
"kernel_regularizer": None,
251+
"dtype": tf.float64,
145252
},
146253
),
147254
"dense_per_flavour": (
@@ -151,12 +258,20 @@ def apply_dense(xinput):
151258
"units": 5,
152259
"activation": "sigmoid",
153260
"basis_size": 8,
261+
"dtype": tf.float64,
154262
},
155263
),
156264
"LSTM": (
157265
LSTM_modified,
158266
{"kernel_initializer": "glorot_normal", "units": 5, "activation": "sigmoid"},
159267
),
268+
"VBDense": (
269+
VBDense,
270+
{
271+
"in_features" : None,
272+
"out_features" : None,
273+
},
274+
),
160275
"dropout": (Dropout, {"rate": 0.0}),
161276
"concatenate": (Concatenate, {}),
162277
}
@@ -173,7 +288,7 @@ def base_layer_selector(layer_name, **kwargs):
173288
174289
Parameters
175290
----------
176-
`layer_name
291+
`layer_name`
177292
str with the name of the layer
178293
`**kwargs`
179294
extra optional arguments to pass to the layer (beyond their defaults)
@@ -232,4 +347,4 @@ def regularizer_selector(reg_name, **kwargs):
232347
if key in reg_args.keys():
233348
reg_args[key] = value
234349

235-
return reg_class(**reg_args)
350+
return reg_class(**reg_args)

n3fit/src/n3fit/backends/keras_backend/operations.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
expand_dims,
4747
leaky_relu,
4848
reshape,
49-
nan_to_num,
5049
repeat,
5150
split,
5251
sum,
@@ -291,8 +290,3 @@ def tensor_splitter(ishape, split_sizes, axis=2, name="splitter"):
291290
lambda x: Kops.split(x, indices, axis=axis), output_shape=oshapes, name=name
292291
)
293292
return sp_layer
294-
295-
296-
def nansum(x, *args, **kwargs):
297-
"""Like np.nansum, returns the sum treating NaN as 0.0 (and inf as a very large number)."""
298-
return sum(nan_to_num(x), *args, **kwargs)

n3fit/src/n3fit/bnn_wrapper.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""
2+
Wrapper for BNN (Bayesian Neural Network) inference
3+
4+
This module provides utilities for:
5+
1. Detecting if a model is a BNN (has VBDense layers)
6+
2. Generating pseudo-replicas from BNN weight samples for plotting/analysis
7+
"""
8+
9+
import numpy as np
10+
import tensorflow as tf
11+
from n3fit.backends.keras_backend.base_layers import VBDense
12+
13+
14+
15+
def is_bayesian_model(pdf_model):
16+
"""
17+
Check if the given pdf_model is a BNN (contains VBDense layers)
18+
19+
Parameters
20+
----------
21+
pdf_model : MetaModel
22+
The PDF model to check
23+
24+
Returns
25+
-------
26+
bool
27+
True if the model is a BNN (has VBDense layers), False otherwise
28+
"""
29+
vb_layers = get_vb_layers(pdf_model)
30+
return len(vb_layers) > 0
31+
32+
def _get_all_layers_recursively(container):
33+
"""
34+
Recursively get all layers at any depth
35+
Ques.: At what layer depth of network would this hit RecursionError?
36+
"""
37+
layers = [container]
38+
if hasattr(container, 'layers'):
39+
for sub_layer in container.layers:
40+
layers.extend(_get_all_layers_recursively(sub_layer))
41+
return layers
42+
43+
def get_vb_layers(pdf_model):
44+
"""
45+
Extract all VBDense layers from a PDF model.
46+
Uses recursion to find VBDense at any depth in the model hierarchy.
47+
"""
48+
vb_layers = []
49+
50+
# Recursively get all layers at any depth
51+
all_layers = _get_all_layers_recursively(pdf_model)
52+
53+
# Check each layer using isinstance
54+
for layer in all_layers:
55+
if isinstance(layer, VBDense):
56+
vb_layers.append(layer)
57+
58+
return vb_layers
59+
60+
class BNNPredictor:
61+
"""
62+
Predictor class for BNNs
63+
64+
This class handles sampling from the posterior distribution
65+
to generate predictions with uncertainty estimates.
66+
"""
67+
68+
def __init__(self, pdf_model, n_samples=3):
69+
"""
70+
Initialize the BNN predictor
71+
72+
Parameters
73+
----------
74+
pdf_model : MetaModel
75+
The trained PDF model with VBDense layers
76+
n_samples : int
77+
Number of samples to generate
78+
"""
79+
self.pdf_model = pdf_model
80+
self.n_samples = n_samples
81+
self.vb_layers = get_vb_layers(pdf_model)
82+
83+
def reset_random(self):
84+
"""Reset the random state for each VBDense layer."""
85+
for vb_layer in self.vb_layers:
86+
vb_layer.reset_random()
87+
88+
def eval(self):
89+
"""Evaluate the model in inference mode (training=False)."""
90+
for vb_layer in self.vb_layers:
91+
vb_layer.eval()
92+
93+
def train(self):
94+
"""Set the model to training mode (training=True)."""
95+
for vb_layer in self.vb_layers:
96+
vb_layer.train()
97+
98+
def generate_bnn_replica(self):
99+
replica_models =[]
100+
for i in range(self.n_samples):
101+
self.reset_random()
102+
self.eval()
103+
104+
replica = self.pdf_model.single_replica_generator(0)
105+
#replica.set_replica_weights(self.pdf_model.get_replica_weights(0), i_replica=0)
106+
replica_models.append(replica)
107+
108+
return replica_models
109+
110+
def generate_predictions(self, xinput):
111+
"""
112+
Generate predictions via sampling
113+
114+
Parameters
115+
----------
116+
xinput : InputInfo.input.tensor_content
117+
an array containing the input values
118+
119+
Returns
120+
-------
121+
predictions : np.ndarray
122+
"""
123+
predictions = []
124+
125+
for i in range(self.n_samples):
126+
# Reset random weights for each VBDense layer
127+
self.reset_random()
128+
self.eval()
129+
130+
pdf_output = self.pdf_model.predict({"pdf_input": xinput})
131+
predictions.append(pdf_output)
132+
133+
return predictions

0 commit comments

Comments
 (0)