Skip to content

Commit 3e378a3

Browse files
committed
Fix inference sampling and switch to training-mode
1 parent 3cbc8d4 commit 3e378a3

File tree

3 files changed

+169
-140
lines changed

3 files changed

+169
-140
lines changed

n3fit/src/n3fit/backends/keras_backend/base_layers.py

Lines changed: 83 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
"""
2-
This module defines custom base layers to be used by the n3fit
3-
Neural Network.
4-
These layers can use the keras standard set of activation function
5-
or implement their own.
2+
This module defines custom base layers to be used by the n3fit
3+
Neural Network.
4+
These layers can use the keras standard set of activation function
5+
or implement their own.
66
7-
For a layer to be used by n3fit it should be contained in the `layers` dictionary defined below.
8-
This dictionary has the following structure:
7+
For a layer to be used by n3fit it should be contained in the `layers` dictionary defined below.
8+
This dictionary has the following structure:
99
10-
'name of the layer' : ( Layer_class, {dictionary of arguments: defaults} )
10+
'name of the layer' : ( Layer_class, {dictionary of arguments: defaults} )
1111
12-
In order to add custom activation functions, they must be added to
13-
the `custom_activations` dictionary with the following structure:
12+
In order to add custom activation functions, they must be added to
13+
the `custom_activations` dictionary with the following structure:
1414
15-
'name of the activation' : function
15+
'name of the activation' : function
1616
17-
The names of the layer and the activation function are the ones to be used in the n3fit runcard.
17+
The names of the layer and the activation function are the ones to be used in the n3fit runcard.
1818
"""
19+
1920
import numpy as np
2021
import keras.backend as K
2122
import tensorflow as tf
@@ -32,6 +33,7 @@
3233
from .MetaLayer import MetaLayer
3334
from contextlib import contextmanager
3435

36+
3537
# Custom activation functions
3638
def square_activation(x):
3739
"""Squares the input"""
@@ -79,17 +81,18 @@ def ReshapedLSTM(input_tensor):
7981

8082
return ReshapedLSTM
8183

84+
8285
class VBDense(Layer):
8386
def __init__(
84-
self,
85-
out_features: int,
86-
in_features: int,
87-
prior_prec: float = 0.001,
88-
map: bool = False,
89-
std_init: float = -9,
90-
lbound=-30,
91-
ubound=11,
92-
training = True
87+
self,
88+
out_features: int,
89+
in_features: int,
90+
prior_prec: float = 0.001,
91+
map: bool = False,
92+
std_init: float = -9,
93+
lbound=-30,
94+
ubound=11,
95+
training=True,
9396
):
9497
super().__init__()
9598
self.output_dim = out_features
@@ -105,79 +108,99 @@ def __init__(
105108

106109
def build(self, input_shape):
107110
self.bias = self.add_weight(
108-
name='bias',
109-
shape=(self.output_dim,),
110-
initializer='glorot_normal',
111-
trainable=True,
112-
dtype=tf.float64
113-
)
114-
111+
name='bias',
112+
shape=(self.output_dim,),
113+
initializer='glorot_normal',
114+
trainable=True,
115+
dtype=tf.float64,
116+
)
117+
115118
self.mu_w = self.add_weight(
116-
name='mu_w',
117-
shape=(self.output_dim, self.input_dim),
118-
initializer='glorot_normal',
119-
trainable=True,
120-
dtype=tf.float64
121-
)
122-
119+
name='mu_w',
120+
shape=(self.output_dim, self.input_dim),
121+
initializer='glorot_normal',
122+
trainable=True,
123+
dtype=tf.float64,
124+
)
125+
123126
self.logsig2_w = self.add_weight(
124-
name='logsig2_w',
125-
shape=(self.output_dim, self.input_dim),
126-
initializer='glorot_normal',
127+
name='logsig2_w',
128+
shape=(self.output_dim, self.input_dim),
129+
initializer='glorot_normal',
127130
trainable=True,
128131
dtype=tf.float64,
129-
)
130-
132+
)
133+
131134
self.reset_parameters()
132135

133136
def reset_parameters(self):
134137
stdv = 1.0 / tf.math.sqrt(tf.cast(self.input_dim, dtype=tf.float64))
135138
self.bias.assign(tf.zeros_like(self.bias))
136-
self.mu_w.assign(tf.random.normal(tf.shape(self.mu_w), mean=0, stddev=stdv, dtype=tf.float64))
137-
self.logsig2_w.assign(tf.random.normal(tf.shape(self.logsig2_w), mean=self.std_init, stddev=0.001, dtype=tf.float64))
138-
139-
#initial_logsig2 = tf.constant(self.std_init, dtype=tf.float64)
140-
#self.logsig2_w.assign(tf.fill(tf.shape(self.logsig2_w), initial_logsig2))
139+
self.mu_w.assign(
140+
tf.random.normal(tf.shape(self.mu_w), mean=0, stddev=stdv, dtype=tf.float64)
141+
)
142+
self.logsig2_w.assign(
143+
tf.random.normal(
144+
tf.shape(self.logsig2_w), mean=self.std_init, stddev=0.001, dtype=tf.float64
145+
)
146+
)
141147

142148
def reset_random(self):
143149
self.random = None
144150
self.map = False
145151

146152
def kl_loss(self) -> tf.Tensor:
147153
logsig2_w = tf.clip_by_value(self.logsig2_w, self.lbound, self.ubound)
148-
kl = 0.5 * tf.reduce_sum((self.prior_prec*(tf.math.pow(self.mu_w,2)+tf.math.exp(logsig2_w))
149-
- logsig2_w - tf.constant(1.0, dtype=tf.float64) - tf.math.log(self.prior_prec)))
154+
kl = 0.5 * tf.reduce_sum(
155+
(
156+
self.prior_prec * (tf.math.pow(self.mu_w, 2) + tf.math.exp(logsig2_w))
157+
- logsig2_w
158+
- tf.constant(1.0, dtype=tf.float64)
159+
- tf.math.log(self.prior_prec)
160+
)
161+
)
150162
return kl
151-
163+
164+
def train(self):
165+
self.training = True
166+
167+
def eval(self):
168+
self.training = False
169+
152170
def call(self, input: tf.Tensor) -> tf.Tensor:
153171
# Ensure input is tf.float64
154172
input = tf.cast(input, tf.float64)
155-
173+
156174
if self.training:
157-
mu_out = tf.matmul(input, tf.cast(self.mu_w, input.dtype), transpose_b=True) + tf.cast(self.bias, input.dtype)
175+
mu_out = tf.matmul(input, tf.cast(self.mu_w, input.dtype), transpose_b=True) + tf.cast(
176+
self.bias, input.dtype
177+
)
158178
logsig2_w = tf.clip_by_value(self.logsig2_w, self.lbound, self.ubound)
159179
s2_w = tf.math.exp(logsig2_w)
160180
input2 = tf.math.pow(input, 2)
161181
var_out = tf.matmul(input2, s2_w, transpose_b=True) + tf.cast(self.eps, input.dtype)
162-
163-
return mu_out + tf.math.sqrt(var_out) * tf.random.normal(shape=tf.shape(mu_out), dtype=input.dtype)
164-
182+
183+
return mu_out + tf.math.sqrt(var_out) * tf.random.normal(
184+
shape=tf.shape(mu_out), dtype=input.dtype
185+
)
186+
165187
else:
166188
# During inference, use MAP estimation (posterior mean) for deterministic output
167189
if self.map:
168190
mu_out = tf.matmul(input, self.mu_w, transpose_b=True) + self.bias
169191
return mu_out
170-
192+
171193
logsig2_w = tf.clip_by_value(self.logsig2_w, self.lbound, 11)
172194
if self.random is None:
173-
self.random = tf.Variable(tf.random.normal(shape=tf.shape(self.mu_w), dtype=tf.float64))
195+
self.random = tf.Variable(
196+
tf.random.normal(shape=tf.shape(self.mu_w), dtype=tf.float64)
197+
)
174198
s2_w = tf.math.exp(logsig2_w)
175199
# draw fresh samples instead of caching
176200
epsilon = tf.random.normal(shape=tf.shape(self.mu_w), dtype=tf.float64)
177-
weight = self.mu_w + tf.math.sqrt(s2_w) * epsilon #self.random #
178-
179-
return tf.matmul(input, weight, transpose_b=True) + self.bias
201+
weight = self.mu_w + tf.math.sqrt(s2_w) * epsilon # self.random #
180202

203+
return tf.matmul(input, weight, transpose_b=True) + self.bias
181204

182205

183206
class Dense(KerasDense, MetaLayer):
@@ -237,6 +260,7 @@ def apply_dense(xinput):
237260

238261
return apply_dense
239262

263+
240264
layers = {
241265
"dense": (
242266
Dense,
@@ -262,13 +286,7 @@ def apply_dense(xinput):
262286
LSTM_modified,
263287
{"kernel_initializer": "glorot_normal", "units": 5, "activation": "sigmoid"},
264288
),
265-
"VBDense": (
266-
VBDense,
267-
{
268-
"in_features" : 10, #hardcoded for now
269-
"out_features" : 8,
270-
},
271-
),
289+
"VBDense": (VBDense, {"in_features": 10, "out_features": 8}), # hardcoded for now
272290
"dropout": (Dropout, {"rate": 0.0}),
273291
"concatenate": (Concatenate, {}),
274292
}
@@ -344,4 +362,4 @@ def regularizer_selector(reg_name, **kwargs):
344362
if key in reg_args.keys():
345363
reg_args[key] = value
346364

347-
return reg_class(**reg_args)
365+
return reg_class(**reg_args)

0 commit comments

Comments
 (0)