Skip to content

Commit 782820a

Browse files
update MOPED layer example utility function
Signed-off-by: Ranganath Krishnan <[email protected]>
1 parent c7ff3e7 commit 782820a

File tree

2 files changed

+38
-26
lines changed

2 files changed

+38
-26
lines changed

bayesian_torch/examples/main_bayesian_flipout_imagenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def MOPED_layer(layer, det_layer, delta):
212212
print(str(layer))
213213
layer.weight.data = det_layer.weight.data
214214
if layer.bias is not None:
215-
layer.bias.data = det_layer.bias.data2
215+
layer.bias.data = det_layer.bias.data
216216

217217
elif (str(layer) == 'LinearFlipout()'
218218
or str(layer) == 'LinearReparameterization()'):

bayesian_torch/utils/util.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ def mutual_information(mc_preds):
5555
Compute the difference between the entropy of the mean of the
5656
predictive distribution and the mean of the entropy.
5757
"""
58-
MI = entropy(np.mean(mc_preds, axis=0)) - np.mean(entropy(mc_preds),
59-
axis=0)
60-
return MI
58+
mutual_info = entropy(np.mean(mc_preds, axis=0)) - np.mean(entropy(mc_preds),
59+
axis=0)
60+
return mutual_info
6161

6262

6363
def get_rho(sigma, delta):
@@ -86,39 +86,51 @@ def MOPED(model, det_model, det_checkpoint, delta):
8686
for (idx, layer), (det_idx,
8787
det_layer) in zip(enumerate(model.modules()),
8888
enumerate(det_model.modules())):
89-
if (str(layer) == 'Conv1dVariational()'
90-
or str(layer) == 'Conv2dVariational()'
91-
or str(layer) == 'Conv3dVariational()'
92-
or str(layer) == 'ConvTranspose1dVariational()'
93-
or str(layer) == 'ConvTranspose2dVariational()'
94-
or str(layer) == 'ConvTranspose3dVariational()'):
89+
if (str(layer) == 'Conv1dReparametrization()'
90+
or str(layer) == 'Conv2dReparameterization()'
91+
or str(layer) == 'Conv3dReparameterization()'
92+
or str(layer) == 'ConvTranspose1dReparameterization()'
93+
or str(layer) == 'ConvTranspose2dReparameterization()'
94+
or str(layer) == 'ConvTranspose3dReparameterization()'
95+
or str(layer) == 'Conv1dFlipout()'
96+
or str(layer) == 'Conv2dFlipout()'
97+
or str(layer) == 'Conv3dFlipout()'
98+
or str(layer) == 'ConvTranspose1dFlipout()'
99+
or str(layer) == 'ConvTranspose2dFlipout()'
100+
or str(layer) == 'ConvTranspose3dFlipout()'):
95101
#set the priors
96-
layer.prior_weight_mu.data = det_layer.weight
97-
layer.prior_bias_mu.data = det_layer.bias
102+
layer.prior_weight_mu = det_layer.weight.data
103+
if layer.prior_bias_mu is not None:
104+
layer.prior_bias_mu = det_layer.bias.data
98105

99106
#initialize surrogate posteriors
100-
layer.mu_kernel.data = det_layer.weight
107+
layer.mu_kernel.data = det_layer.weight.data
101108
layer.rho_kernel.data = get_rho(det_layer.weight.data, delta)
102-
layer.mu_bias.data = det_layer.bias
103-
layer.rho_bias.data = get_rho(det_layer.bias.data, delta)
104-
elif (str(layer) == 'LinearVariational()'):
109+
if layer.mu_bias is not None:
110+
layer.mu_bias.data = det_layer.bias.data
111+
layer.rho_bias.data = get_rho(det_layer.bias.data, delta)
112+
elif (str(layer) == 'LinearReparameterization()'
113+
or str(layer) == 'LinearFlipout()'):
105114
#set the priors
106-
layer.prior_weight_mu.data = det_layer.weight
107-
layer.prior_bias_mu.data = det_layer.bias
115+
layer.prior_weight_mu = det_layer.weight.data
116+
if layer.prior_bias_mu is not None:
117+
layer.prior_bias_mu.data = det_layer.bias
108118

109119
#initialize the surrogate posteriors
110-
layer.mu_weight.data = det_layer.weight
120+
layer.mu_weight.data = det_layer.weight.data
111121
layer.rho_weight.data = get_rho(det_layer.weight.data, delta)
112-
layer.mu_bias.data = det_layer.bias
113-
layer.rho_bias.data = get_rho(det_layer.bias.data, delta)
122+
if layer.mu_bias is not None:
123+
layer.mu_bias.data = det_layer.bias.data
124+
layer.rho_bias.data = get_rho(det_layer.bias.data, delta)
114125

115126
elif str(layer).startswith('Batch'):
116127
#initialize parameters
117-
layer.weight.data = det_layer.weight
118-
layer.bias.data = det_layer.bias
119-
layer.running_mean.data = det_layer.running_mean
120-
layer.running_var.data = det_layer.running_var
121-
layer.num_batches_tracked.data = det_layer.num_batches_tracked
128+
layer.weight.data = det_layer.weight.data
129+
if layer.bias is not None:
130+
layer.bias.data = det_layer.bias
131+
layer.running_mean.data = det_layer.running_mean.data
132+
layer.running_var.data = det_layer.running_var.data
133+
layer.num_batches_tracked.data = det_layer.num_batches_tracked.data
122134

123135
model.state_dict()
124136
return model

0 commit comments

Comments
 (0)