@@ -55,9 +55,9 @@ def mutual_information(mc_preds):
55
55
Compute the difference between the entropy of the mean of the
56
56
predictive distribution and the mean of the entropy.
57
57
"""
58
- MI = entropy (np .mean (mc_preds , axis = 0 )) - np .mean (entropy (mc_preds ),
59
- axis = 0 )
60
- return MI
58
+ mutual_info = entropy (np .mean (mc_preds , axis = 0 )) - np .mean (entropy (mc_preds ),
59
+ axis = 0 )
60
+ return mutual_info
61
61
62
62
63
63
def get_rho (sigma , delta ):
@@ -86,39 +86,51 @@ def MOPED(model, det_model, det_checkpoint, delta):
86
86
for (idx , layer ), (det_idx ,
87
87
det_layer ) in zip (enumerate (model .modules ()),
88
88
enumerate (det_model .modules ())):
89
- if (str (layer ) == 'Conv1dVariational()'
90
- or str (layer ) == 'Conv2dVariational()'
91
- or str (layer ) == 'Conv3dVariational()'
92
- or str (layer ) == 'ConvTranspose1dVariational()'
93
- or str (layer ) == 'ConvTranspose2dVariational()'
94
- or str (layer ) == 'ConvTranspose3dVariational()' ):
89
+ if (str (layer ) == 'Conv1dReparametrization()'
90
+ or str (layer ) == 'Conv2dReparameterization()'
91
+ or str (layer ) == 'Conv3dReparameterization()'
92
+ or str (layer ) == 'ConvTranspose1dReparameterization()'
93
+ or str (layer ) == 'ConvTranspose2dReparameterization()'
94
+ or str (layer ) == 'ConvTranspose3dReparameterization()'
95
+ or str (layer ) == 'Conv1dFlipout()'
96
+ or str (layer ) == 'Conv2dFlipout()'
97
+ or str (layer ) == 'Conv3dFlipout()'
98
+ or str (layer ) == 'ConvTranspose1dFlipout()'
99
+ or str (layer ) == 'ConvTranspose2dFlipout()'
100
+ or str (layer ) == 'ConvTranspose3dFlipout()' ):
95
101
#set the priors
96
- layer .prior_weight_mu .data = det_layer .weight
97
- layer .prior_bias_mu .data = det_layer .bias
102
+ layer .prior_weight_mu = det_layer .weight .data
103
+ if layer .prior_bias_mu is not None :
104
+ layer .prior_bias_mu = det_layer .bias .data
98
105
99
106
#initialize surrogate posteriors
100
- layer .mu_kernel .data = det_layer .weight
107
+ layer .mu_kernel .data = det_layer .weight . data
101
108
layer .rho_kernel .data = get_rho (det_layer .weight .data , delta )
102
- layer .mu_bias .data = det_layer .bias
103
- layer .rho_bias .data = get_rho (det_layer .bias .data , delta )
104
- elif (str (layer ) == 'LinearVariational()' ):
109
+ if layer .mu_bias is not None :
110
+ layer .mu_bias .data = det_layer .bias .data
111
+ layer .rho_bias .data = get_rho (det_layer .bias .data , delta )
112
+ elif (str (layer ) == 'LinearReparameterization()'
113
+ or str (layer ) == 'LinearFlipout()' ):
105
114
#set the priors
106
- layer .prior_weight_mu .data = det_layer .weight
107
- layer .prior_bias_mu .data = det_layer .bias
115
+ layer .prior_weight_mu = det_layer .weight .data
116
+ if layer .prior_bias_mu is not None :
117
+ layer .prior_bias_mu .data = det_layer .bias
108
118
109
119
#initialize the surrogate posteriors
110
- layer .mu_weight .data = det_layer .weight
120
+ layer .mu_weight .data = det_layer .weight . data
111
121
layer .rho_weight .data = get_rho (det_layer .weight .data , delta )
112
- layer .mu_bias .data = det_layer .bias
113
- layer .rho_bias .data = get_rho (det_layer .bias .data , delta )
122
+ if layer .mu_bias is not None :
123
+ layer .mu_bias .data = det_layer .bias .data
124
+ layer .rho_bias .data = get_rho (det_layer .bias .data , delta )
114
125
115
126
elif str (layer ).startswith ('Batch' ):
116
127
#initialize parameters
117
- layer .weight .data = det_layer .weight
118
- layer .bias .data = det_layer .bias
119
- layer .running_mean .data = det_layer .running_mean
120
- layer .running_var .data = det_layer .running_var
121
- layer .num_batches_tracked .data = det_layer .num_batches_tracked
128
+ layer .weight .data = det_layer .weight .data
129
+ if layer .bias is not None :
130
+ layer .bias .data = det_layer .bias
131
+ layer .running_mean .data = det_layer .running_mean .data
132
+ layer .running_var .data = det_layer .running_var .data
133
+ layer .num_batches_tracked .data = det_layer .num_batches_tracked .data
122
134
123
135
model .state_dict ()
124
136
return model
0 commit comments