11"""
2- This module defines custom base layers to be used by the n3fit
3- Neural Network.
4- These layers can use the keras standard set of activation function
5- or implement their own.
2+ This module defines custom base layers to be used by the n3fit
3+ Neural Network.
4+ These layers can use the keras standard set of activation function
5+ or implement their own.
66
7- For a layer to be used by n3fit it should be contained in the `layers` dictionary defined below.
8- This dictionary has the following structure:
7+ For a layer to be used by n3fit it should be contained in the `layers` dictionary defined below.
8+ This dictionary has the following structure:
99
10- 'name of the layer' : ( Layer_class, {dictionary of arguments: defaults} )
10+ 'name of the layer' : ( Layer_class, {dictionary of arguments: defaults} )
1111
12- In order to add custom activation functions, they must be added to
13- the `custom_activations` dictionary with the following structure:
12+ In order to add custom activation functions, they must be added to
13+ the `custom_activations` dictionary with the following structure:
1414
15- 'name of the activation' : function
15+ 'name of the activation' : function
1616
17- The names of the layer and the activation function are the ones to be used in the n3fit runcard.
17+ The names of the layer and the activation function are the ones to be used in the n3fit runcard.
1818"""
19+ import numpy as np
20+ import keras .backend as K
21+ import tensorflow as tf
22+ import math
23+ from scipy .stats import norm
1924
2025from keras .layers import Dense as KerasDense
21- from keras .layers import Dropout , Lambda
26+ from keras .layers import Dropout , Lambda , Layer
2227from keras .layers import Input # pylint: disable=unused-import
2328from keras .layers import LSTM , Concatenate
2429from keras .regularizers import l1_l2
2530
2631from . import operations as ops
2732from .MetaLayer import MetaLayer
28-
33+ from contextlib import contextmanager
2934
3035# Custom activation functions
3136def square_activation (x ):
@@ -74,14 +79,116 @@ def ReshapedLSTM(input_tensor):
7479
7580 return ReshapedLSTM
7681
82+ class VBDense (Layer ):
83+ def __init__ (
84+ self ,
85+ out_features : int ,
86+ in_features : int ,
87+ prior_prec : float = 0.001 ,
88+ map : bool = False ,
89+ std_init : float = - 9 ,
90+ lbound = - 30 ,
91+ ubound = 11 ,
92+ training = True
93+ ):
94+ super ().__init__ ()
95+ self .output_dim = out_features
96+ self .input_dim = in_features
97+ self .map = map
98+ self .prior_prec = tf .cast (prior_prec , tf .float64 )
99+ self .random = None
100+ self .eps = 1e-12 if K .floatx () == 'float64' else 1e-8
101+ self .std_init = tf .cast (std_init , tf .float64 )
102+ self .lbound = lbound
103+ self .ubound = ubound
104+ self .training = training
105+
106+ def build (self , input_shape ):
107+ self .bias = self .add_weight (
108+ name = 'bias' ,
109+ shape = (self .output_dim ,),
110+ initializer = 'glorot_normal' ,
111+ trainable = True ,
112+ dtype = tf .float64
113+ )
114+
115+ self .mu_w = self .add_weight (
116+ name = 'mu_w' ,
117+ shape = (self .output_dim , self .input_dim ),
118+ initializer = 'glorot_normal' ,
119+ trainable = True ,
120+ dtype = tf .float64
121+ )
122+
123+ self .logsig2_w = self .add_weight (
124+ name = 'logsig2_w' ,
125+ shape = (self .output_dim , self .input_dim ),
126+ initializer = 'glorot_normal' ,
127+ trainable = True ,
128+ dtype = tf .float64 ,
129+ )
130+
131+ self .reset_parameters ()
132+
133+ def reset_parameters (self ):
134+ stdv = 1.0 / tf .math .sqrt (tf .cast (self .input_dim , dtype = tf .float64 ))
135+ self .bias .assign (tf .zeros_like (self .bias ))
136+ self .mu_w .assign (tf .random .normal (tf .shape (self .mu_w ), mean = 0 , stddev = stdv , dtype = tf .float64 ))
137+ self .logsig2_w .assign (tf .random .normal (tf .shape (self .logsig2_w ), mean = self .std_init , stddev = 0.001 , dtype = tf .float64 ))
138+
139+ def reset_random (self ):
140+ self .random = None
141+ self .map = False
142+
143+ def train (self ):
144+ self .training = True
145+
146+ def eval (self ):
147+ self .training = False
148+
149+ def kl_loss (self ) -> tf .Tensor :
150+ logsig2_w = tf .clip_by_value (self .logsig2_w , self .lbound , self .ubound )
151+ kl = 0.5 * tf .reduce_sum ((self .prior_prec * (tf .math .pow (self .mu_w ,2 )+ tf .math .exp (logsig2_w ))
152+ - logsig2_w - tf .constant (1.0 , dtype = tf .float64 ) - tf .math .log (self .prior_prec )))
153+ return kl
154+
155+ def call (self , input : tf .Tensor ) -> tf .Tensor :
156+ # Ensure input is tf.float64
157+ input = tf .cast (input , tf .float64 )
158+
159+ if self .training :
160+ mu_out = tf .matmul (input , tf .cast (self .mu_w , input .dtype ), transpose_b = True ) + tf .cast (self .bias , input .dtype )
161+ logsig2_w = tf .clip_by_value (self .logsig2_w , self .lbound , self .ubound )
162+ s2_w = tf .math .exp (logsig2_w )
163+ input2 = tf .math .pow (input , 2 )
164+ var_out = tf .matmul (input2 , s2_w , transpose_b = True ) + tf .cast (self .eps , input .dtype )
165+
166+ return mu_out + tf .math .sqrt (var_out ) * tf .random .normal (shape = tf .shape (mu_out ), dtype = input .dtype )
167+
168+ else :
169+ # During inference, use MAP estimation (posterior mean) for deterministic output
170+ if self .map :
171+ mu_out = tf .matmul (input , self .mu_w , transpose_b = True ) + self .bias
172+ return mu_out
173+
174+ logsig2_w = tf .clip_by_value (self .logsig2_w , self .lbound , 11 )
175+ if self .random is None :
176+ self .random = tf .Variable (tf .random .normal (shape = tf .shape (self .mu_w ), dtype = tf .float64 ))
177+ s2_w = tf .math .exp (logsig2_w )
178+ # draw fresh samples instead of caching
179+ epsilon = tf .random .normal (shape = tf .shape (self .mu_w ), dtype = tf .float64 )
180+ weight = self .mu_w + tf .math .sqrt (s2_w ) * epsilon #self.random
181+
182+ return tf .matmul (input , weight , transpose_b = True ) + self .bias
183+
77184
78- class Dense (KerasDense , MetaLayer ):
79185
80- def __init__ (self , * args , ** kwargs ):
81- # In Keras == 3.13, np.int() is not accepted by Dense
82- if "units" in kwargs :
83- kwargs ["units" ] = int (kwargs ["units" ])
84- super ().__init__ (* args , ** kwargs )
186+ class Dense (KerasDense , MetaLayer ):
187+ def __init__ (self , ** kwargs ):
188+ # Set default dtype to tf.float64 if not provided
189+ if 'dtype' not in kwargs :
190+ kwargs ['dtype' ] = tf .float64
191+ super ().__init__ (** kwargs )
85192
86193
87194def dense_per_flavour (basis_size = 8 , kernel_initializer = "glorot_normal" , ** dense_kwargs ):
@@ -133,7 +240,6 @@ def apply_dense(xinput):
133240
134241 return apply_dense
135242
136-
137243layers = {
138244 "dense" : (
139245 Dense ,
@@ -142,6 +248,7 @@ def apply_dense(xinput):
142248 "units" : 5 ,
143249 "activation" : "sigmoid" ,
144250 "kernel_regularizer" : None ,
251+ "dtype" : tf .float64 ,
145252 },
146253 ),
147254 "dense_per_flavour" : (
@@ -151,12 +258,20 @@ def apply_dense(xinput):
151258 "units" : 5 ,
152259 "activation" : "sigmoid" ,
153260 "basis_size" : 8 ,
261+ "dtype" : tf .float64 ,
154262 },
155263 ),
156264 "LSTM" : (
157265 LSTM_modified ,
158266 {"kernel_initializer" : "glorot_normal" , "units" : 5 , "activation" : "sigmoid" },
159267 ),
268+ "VBDense" : (
269+ VBDense ,
270+ {
271+ "in_features" : None ,
272+ "out_features" : None ,
273+ },
274+ ),
160275 "dropout" : (Dropout , {"rate" : 0.0 }),
161276 "concatenate" : (Concatenate , {}),
162277}
@@ -173,7 +288,7 @@ def base_layer_selector(layer_name, **kwargs):
173288
174289 Parameters
175290 ----------
176- `layer_name
291+ `layer_name`
177292 str with the name of the layer
178293 `**kwargs`
179294 extra optional arguments to pass to the layer (beyond their defaults)
@@ -232,4 +347,4 @@ def regularizer_selector(reg_name, **kwargs):
232347 if key in reg_args .keys ():
233348 reg_args [key ] = value
234349
235- return reg_class (** reg_args )
350+ return reg_class (** reg_args )
0 commit comments