Skip to content

Commit 91650f6

Browse files
Merge pull request #2 from y0ast/main
Fix KL computation
2 parents 09b8d0a + 69955b7 commit 91650f6

File tree

4 files changed

+22
-54
lines changed

4 files changed

+22
-54
lines changed

bayesian_torch/models/bayesian/resnet_flipout.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,13 @@ def forward(self, x):
132132
out = F.relu(out)
133133
for l in self.layer1:
134134
out, kl = l(out)
135-
kl_sum += kl
135+
kl_sum += kl
136136
for l in self.layer2:
137137
out, kl = l(out)
138-
kl_sum += kl
138+
kl_sum += kl
139139
for l in self.layer3:
140140
out, kl = l(out)
141-
kl_sum += kl
141+
kl_sum += kl
142142

143143
out = F.avg_pool2d(out, out.size()[3])
144144
out = out.view(out.size(0), -1)

bayesian_torch/models/bayesian/resnet_flipout_large.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -215,36 +215,20 @@ def forward(self, x):
215215
x = self.maxpool(x)
216216

217217
for layer in self.layer1:
218-
if 'Flipout' in str(layer):
219-
x, kl = layer(x)
220-
if kl is None:
221-
kl_sum += kl
222-
else:
223-
x = layer(x)
218+
x, kl = layer(x)
219+
kl_sum += kl
224220

225221
for layer in self.layer2:
226-
if 'Flipout' in str(layer):
227-
x, kl = layer(x)
228-
if kl is None:
229-
kl_sum += kl
230-
else:
231-
x = layer(x)
222+
x, kl = layer(x)
223+
kl_sum += kl
232224

233225
for layer in self.layer3:
234-
if 'Flipout' in str(layer):
235-
x, kl = layer(x)
236-
if kl is None:
237-
kl_sum += kl
238-
else:
239-
x = layer(x)
226+
x, kl = layer(x)
227+
kl_sum += kl
240228

241229
for layer in self.layer4:
242-
if 'Flipout' in str(layer):
243-
x, kl = layer(x)
244-
if kl is None:
245-
kl_sum += kl
246-
else:
247-
x = layer(x)
230+
x, kl = layer(x)
231+
kl_sum += kl
248232

249233
x = self.avgpool(x)
250234
x = x.view(x.size(0), -1)

bayesian_torch/models/bayesian/resnet_variational.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,13 @@ def forward(self, x):
152152
out = F.relu(out)
153153
for l in self.layer1:
154154
out, kl = l(out)
155-
kl_sum += kl
155+
kl_sum += kl
156156
for l in self.layer2:
157157
out, kl = l(out)
158-
kl_sum += kl
158+
kl_sum += kl
159159
for l in self.layer3:
160160
out, kl = l(out)
161-
kl_sum += kl
161+
kl_sum += kl
162162

163163
out = F.avg_pool2d(out, out.size()[3])
164164
out = out.view(out.size(0), -1)

bayesian_torch/models/bayesian/resnet_variational_large.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -220,36 +220,20 @@ def forward(self, x):
220220
x = self.maxpool(x)
221221

222222
for layer in self.layer1:
223-
if 'Reparameterization' in str(layer):
224-
x, kl = layer(x)
225-
if kl is None:
226-
kl_sum += kl
227-
else:
228-
x = layer(x)
223+
x, kl = layer(x)
224+
kl_sum += kl
229225

230226
for layer in self.layer2:
231-
if 'Reparameterization' in str(layer):
232-
x, kl = layer(x)
233-
if kl is None:
234-
kl_sum += kl
235-
else:
236-
x = layer(x)
227+
x, kl = layer(x)
228+
kl_sum += kl
237229

238230
for layer in self.layer3:
239-
if 'Reparameterization' in str(layer):
240-
x, kl = layer(x)
241-
if kl is None:
242-
kl_sum += kl
243-
else:
244-
x = layer(x)
231+
x, kl = layer(x)
232+
kl_sum += kl
245233

246234
for layer in self.layer4:
247-
if 'Reparameterization' in str(layer):
248-
x, kl = layer(x)
249-
if kl is None:
250-
kl_sum += kl
251-
else:
252-
x = layer(x)
235+
x, kl = layer(x)
236+
kl_sum += kl
253237

254238
x = self.avgpool(x)
255239
x = x.view(x.size(0), -1)

0 commit comments

Comments
 (0)