11"""
2- This module defines custom base layers to be used by the n3fit
3- Neural Network.
4- These layers can use the keras standard set of activation function
5- or implement their own.
2+ This module defines custom base layers to be used by the n3fit
3+ Neural Network.
4+ These layers can use the keras standard set of activation function
5+ or implement their own.
66
7- For a layer to be used by n3fit it should be contained in the `layers` dictionary defined below.
8- This dictionary has the following structure:
7+ For a layer to be used by n3fit it should be contained in the `layers` dictionary defined below.
8+ This dictionary has the following structure:
99
10- 'name of the layer' : ( Layer_class, {dictionary of arguments: defaults} )
10+ 'name of the layer' : ( Layer_class, {dictionary of arguments: defaults} )
1111
12- In order to add custom activation functions, they must be added to
13- the `custom_activations` dictionary with the following structure:
12+ In order to add custom activation functions, they must be added to
13+ the `custom_activations` dictionary with the following structure:
1414
15- 'name of the activation' : function
15+ 'name of the activation' : function
1616
17- The names of the layer and the activation function are the ones to be used in the n3fit runcard.
17+ The names of the layer and the activation function are the ones to be used in the n3fit runcard.
1818"""
19+
1920import numpy as np
2021import keras .backend as K
2122import tensorflow as tf
3233from .MetaLayer import MetaLayer
3334from contextlib import contextmanager
3435
36+
3537# Custom activation functions
3638def square_activation (x ):
3739 """Squares the input"""
@@ -79,17 +81,18 @@ def ReshapedLSTM(input_tensor):
7981
8082 return ReshapedLSTM
8183
84+
8285class VBDense (Layer ):
8386 def __init__ (
84- self ,
85- out_features : int ,
86- in_features : int ,
87- prior_prec : float = 0.001 ,
88- map : bool = False ,
89- std_init : float = - 9 ,
90- lbound = - 30 ,
91- ubound = 11 ,
92- training = True
87+ self ,
88+ out_features : int ,
89+ in_features : int ,
90+ prior_prec : float = 0.001 ,
91+ map : bool = False ,
92+ std_init : float = - 9 ,
93+ lbound = - 30 ,
94+ ubound = 11 ,
95+ training = True ,
9396 ):
9497 super ().__init__ ()
9598 self .output_dim = out_features
@@ -105,79 +108,99 @@ def __init__(
105108
106109 def build (self , input_shape ):
107110 self .bias = self .add_weight (
108- name = 'bias' ,
109- shape = (self .output_dim ,),
110- initializer = 'glorot_normal' ,
111- trainable = True ,
112- dtype = tf .float64
113- )
114-
111+ name = 'bias' ,
112+ shape = (self .output_dim ,),
113+ initializer = 'glorot_normal' ,
114+ trainable = True ,
115+ dtype = tf .float64 ,
116+ )
117+
115118 self .mu_w = self .add_weight (
116- name = 'mu_w' ,
117- shape = (self .output_dim , self .input_dim ),
118- initializer = 'glorot_normal' ,
119- trainable = True ,
120- dtype = tf .float64
121- )
122-
119+ name = 'mu_w' ,
120+ shape = (self .output_dim , self .input_dim ),
121+ initializer = 'glorot_normal' ,
122+ trainable = True ,
123+ dtype = tf .float64 ,
124+ )
125+
123126 self .logsig2_w = self .add_weight (
124- name = 'logsig2_w' ,
125- shape = (self .output_dim , self .input_dim ),
126- initializer = 'glorot_normal' ,
127+ name = 'logsig2_w' ,
128+ shape = (self .output_dim , self .input_dim ),
129+ initializer = 'glorot_normal' ,
127130 trainable = True ,
128131 dtype = tf .float64 ,
129- )
130-
132+ )
133+
131134 self .reset_parameters ()
132135
133136 def reset_parameters (self ):
134137 stdv = 1.0 / tf .math .sqrt (tf .cast (self .input_dim , dtype = tf .float64 ))
135138 self .bias .assign (tf .zeros_like (self .bias ))
136- self .mu_w .assign (tf .random .normal (tf .shape (self .mu_w ), mean = 0 , stddev = stdv , dtype = tf .float64 ))
137- self .logsig2_w .assign (tf .random .normal (tf .shape (self .logsig2_w ), mean = self .std_init , stddev = 0.001 , dtype = tf .float64 ))
138-
139- #initial_logsig2 = tf.constant(self.std_init, dtype=tf.float64)
140- #self.logsig2_w.assign(tf.fill(tf.shape(self.logsig2_w), initial_logsig2))
139+ self .mu_w .assign (
140+ tf .random .normal (tf .shape (self .mu_w ), mean = 0 , stddev = stdv , dtype = tf .float64 )
141+ )
142+ self .logsig2_w .assign (
143+ tf .random .normal (
144+ tf .shape (self .logsig2_w ), mean = self .std_init , stddev = 0.001 , dtype = tf .float64
145+ )
146+ )
141147
142148 def reset_random (self ):
143149 self .random = None
144150 self .map = False
145151
146152 def kl_loss (self ) -> tf .Tensor :
147153 logsig2_w = tf .clip_by_value (self .logsig2_w , self .lbound , self .ubound )
148- kl = 0.5 * tf .reduce_sum ((self .prior_prec * (tf .math .pow (self .mu_w ,2 )+ tf .math .exp (logsig2_w ))
149- - logsig2_w - tf .constant (1.0 , dtype = tf .float64 ) - tf .math .log (self .prior_prec )))
154+ kl = 0.5 * tf .reduce_sum (
155+ (
156+ self .prior_prec * (tf .math .pow (self .mu_w , 2 ) + tf .math .exp (logsig2_w ))
157+ - logsig2_w
158+ - tf .constant (1.0 , dtype = tf .float64 )
159+ - tf .math .log (self .prior_prec )
160+ )
161+ )
150162 return kl
151-
163+
164+ def train (self ):
165+ self .training = True
166+
167+ def eval (self ):
168+ self .training = False
169+
152170 def call (self , input : tf .Tensor ) -> tf .Tensor :
153171 # Ensure input is tf.float64
154172 input = tf .cast (input , tf .float64 )
155-
173+
156174 if self .training :
157- mu_out = tf .matmul (input , tf .cast (self .mu_w , input .dtype ), transpose_b = True ) + tf .cast (self .bias , input .dtype )
175+ mu_out = tf .matmul (input , tf .cast (self .mu_w , input .dtype ), transpose_b = True ) + tf .cast (
176+ self .bias , input .dtype
177+ )
158178 logsig2_w = tf .clip_by_value (self .logsig2_w , self .lbound , self .ubound )
159179 s2_w = tf .math .exp (logsig2_w )
160180 input2 = tf .math .pow (input , 2 )
161181 var_out = tf .matmul (input2 , s2_w , transpose_b = True ) + tf .cast (self .eps , input .dtype )
162-
163- return mu_out + tf .math .sqrt (var_out ) * tf .random .normal (shape = tf .shape (mu_out ), dtype = input .dtype )
164-
182+
183+ return mu_out + tf .math .sqrt (var_out ) * tf .random .normal (
184+ shape = tf .shape (mu_out ), dtype = input .dtype
185+ )
186+
165187 else :
166188 # During inference, use MAP estimation (posterior mean) for deterministic output
167189 if self .map :
168190 mu_out = tf .matmul (input , self .mu_w , transpose_b = True ) + self .bias
169191 return mu_out
170-
192+
171193 logsig2_w = tf .clip_by_value (self .logsig2_w , self .lbound , 11 )
172194 if self .random is None :
173- self .random = tf .Variable (tf .random .normal (shape = tf .shape (self .mu_w ), dtype = tf .float64 ))
195+ self .random = tf .Variable (
196+ tf .random .normal (shape = tf .shape (self .mu_w ), dtype = tf .float64 )
197+ )
174198 s2_w = tf .math .exp (logsig2_w )
175199 # draw fresh samples instead of caching
176200 epsilon = tf .random .normal (shape = tf .shape (self .mu_w ), dtype = tf .float64 )
177- weight = self .mu_w + tf .math .sqrt (s2_w ) * epsilon #self.random #
178-
179- return tf .matmul (input , weight , transpose_b = True ) + self .bias
201+ weight = self .mu_w + tf .math .sqrt (s2_w ) * epsilon # self.random #
180202
203+ return tf .matmul (input , weight , transpose_b = True ) + self .bias
181204
182205
183206class Dense (KerasDense , MetaLayer ):
@@ -237,6 +260,7 @@ def apply_dense(xinput):
237260
238261 return apply_dense
239262
263+
240264layers = {
241265 "dense" : (
242266 Dense ,
@@ -262,13 +286,7 @@ def apply_dense(xinput):
262286 LSTM_modified ,
263287 {"kernel_initializer" : "glorot_normal" , "units" : 5 , "activation" : "sigmoid" },
264288 ),
265- "VBDense" : (
266- VBDense ,
267- {
268- "in_features" : 10 , #hardcoded for now
269- "out_features" : 8 ,
270- },
271- ),
289+ "VBDense" : (VBDense , {"in_features" : 10 , "out_features" : 8 }), # hardcoded for now
272290 "dropout" : (Dropout , {"rate" : 0.0 }),
273291 "concatenate" : (Concatenate , {}),
274292}
@@ -344,4 +362,4 @@ def regularizer_selector(reg_name, **kwargs):
344362 if key in reg_args .keys ():
345363 reg_args [key ] = value
346364
347- return reg_class (** reg_args )
365+ return reg_class (** reg_args )
0 commit comments