Skip to content

Commit 94fba5f

Browse files
committed
solved BatchNorm1d problem
1 parent f395db4 commit 94fba5f

File tree

2 files changed

+29
-30
lines changed

2 files changed

+29
-30
lines changed

src/INN/INN.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def forward(self, x, log_p=0, log_det_J=0):
171171
var = self.running_var # [dim]
172172
else:
173173
# if in training
174-
var = torch.var(x, dim=0, unbiased=False) # [dim]
174+
var = torch.var(x, dim=0, unbiased=False).detach() # [dim]
175175

176176
x = super(BatchNorm1d, self).forward(x)
177177

tests/quick_tests.ipynb

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,11 @@
209209
},
210210
{
211211
"cell_type": "code",
212-
"execution_count": 5,
212+
"execution_count": 14,
213213
"metadata": {
214214
"ExecuteTime": {
215-
"end_time": "2021-04-25T20:16:59.907029Z",
216-
"start_time": "2021-04-25T20:16:59.900497Z"
215+
"end_time": "2021-04-25T20:27:57.577833Z",
216+
"start_time": "2021-04-25T20:27:57.573240Z"
217217
}
218218
},
219219
"outputs": [],
@@ -223,21 +223,20 @@
223223
},
224224
{
225225
"cell_type": "code",
226-
"execution_count": 11,
226+
"execution_count": 15,
227227
"metadata": {
228228
"ExecuteTime": {
229-
"end_time": "2021-04-25T20:17:21.038359Z",
230-
"start_time": "2021-04-25T20:17:21.015209Z"
229+
"end_time": "2021-04-25T20:27:57.883569Z",
230+
"start_time": "2021-04-25T20:27:57.873327Z"
231231
}
232232
},
233233
"outputs": [
234234
{
235235
"name": "stdout",
236236
"output_type": "stream",
237237
"text": [
238-
"J_g=tensor([-1.5020e-05, -1.5020e-05, -1.5020e-05, -1.5020e-05, -1.5020e-05,\n",
239-
" -1.5020e-05]),\n",
240-
"J_c=-1.5020295904832892e-05\n"
238+
"J_g=tensor([ 1.9498, -0.0728, 1.3574, 0.7588, 1.3956, 0.7595]),\n",
239+
"J_c=tensor([ 1.9498, -0.0728, 1.3574, 0.7588, 1.3956, 0.7595])\n"
241240
]
242241
}
243242
],
@@ -251,23 +250,23 @@
251250
},
252251
{
253252
"cell_type": "code",
254-
"execution_count": 12,
253+
"execution_count": 16,
255254
"metadata": {
256255
"ExecuteTime": {
257-
"end_time": "2021-04-25T20:17:23.285413Z",
258-
"start_time": "2021-04-25T20:17:23.279627Z"
256+
"end_time": "2021-04-25T20:28:05.686690Z",
257+
"start_time": "2021-04-25T20:28:05.681378Z"
259258
}
260259
},
261260
"outputs": [
262261
{
263262
"data": {
264263
"text/plain": [
265-
"tensor([[1.0000, 0.0000, 0.0000],\n",
266-
" [0.0000, 1.0000, 0.0000],\n",
267-
" [0.0000, 0.0000, 1.0000]])"
264+
"tensor([[ 1.9760, -1.3065, 1.2205],\n",
265+
" [ 0.5567, -0.8171, -1.6935],\n",
266+
" [ 0.0597, 2.1638, 2.1138]])"
268267
]
269268
},
270-
"execution_count": 12,
269+
"execution_count": 16,
271270
"metadata": {},
272271
"output_type": "execute_result"
273272
}
@@ -420,40 +419,40 @@
420419
},
421420
{
422421
"cell_type": "code",
423-
"execution_count": 19,
422+
"execution_count": 17,
424423
"metadata": {
425424
"ExecuteTime": {
426-
"end_time": "2021-04-25T20:25:00.696600Z",
427-
"start_time": "2021-04-25T20:25:00.688777Z"
425+
"end_time": "2021-04-25T20:51:57.518139Z",
426+
"start_time": "2021-04-25T20:51:57.513012Z"
428427
}
429428
},
430429
"outputs": [],
431430
"source": [
432431
"x = torch.randn((5, 3))\n",
433-
"bn = nn.BatchNorm1d(3)"
432+
"bn = nn.BatchNorm1d(3, affine=False)"
434433
]
435434
},
436435
{
437436
"cell_type": "code",
438-
"execution_count": 20,
437+
"execution_count": 18,
439438
"metadata": {
440439
"ExecuteTime": {
441-
"end_time": "2021-04-25T20:25:01.168809Z",
442-
"start_time": "2021-04-25T20:25:01.158170Z"
440+
"end_time": "2021-04-25T20:51:58.028429Z",
441+
"start_time": "2021-04-25T20:51:58.009478Z"
443442
}
444443
},
445444
"outputs": [
446445
{
447446
"data": {
448447
"text/plain": [
449-
"tensor([[-1.1245, 1.3747, -0.5232],\n",
450-
" [ 1.1031, 0.5880, 0.0641],\n",
451-
" [-0.8741, 0.0202, -1.3274],\n",
452-
" [-0.3750, -0.3597, 0.0675],\n",
453-
" [ 1.2705, -1.6232, 1.7191]], grad_fn=<NativeBatchNormBackward>)"
448+
"tensor([[-1.6941, 0.2933, -0.2451],\n",
449+
" [-0.1313, -0.2711, 1.4740],\n",
450+
" [ 0.2754, -0.2282, 0.4445],\n",
451+
" [ 0.1287, -1.4409, -0.0721],\n",
452+
" [ 1.4213, 1.6469, -1.6014]])"
454453
]
455454
},
456-
"execution_count": 20,
455+
"execution_count": 18,
457456
"metadata": {},
458457
"output_type": "execute_result"
459458
}

0 commit comments

Comments
 (0)