@@ -54,7 +54,6 @@ def _check_input_dim(self, input):
54
54
input .dim ()))
55
55
56
56
def forward (self , input ):
57
- self ._check_input_dim (input [0 ])
58
57
exponential_average_factor = 0.0
59
58
if self .training and self .track_running_stats :
60
59
self .num_batches_tracked += 1
@@ -63,13 +62,21 @@ def forward(self, input):
63
62
else : # use exponential moving average
64
63
exponential_average_factor = self .momentum
65
64
66
- out = F .batch_norm (input [0 ], self .running_mean , self .running_var ,
67
- self .weight , self .bias , self .training
68
- or not self .track_running_stats ,
69
- exponential_average_factor , self .eps )
70
- kl = 0
71
- return out , kl
72
-
65
+ if len (input ) == 2 :
66
+ self ._check_input_dim (input [0 ])
67
+ out = F .batch_norm (input [0 ], self .running_mean , self .running_var ,
68
+ self .weight , self .bias , self .training
69
+ or not self .track_running_stats ,
70
+ exponential_average_factor , self .eps )
71
+ kl = 0
72
+ return out , kl
73
+ else :
74
+ out = F .batch_norm (input , self .running_mean , self .running_var ,
75
+ self .weight , self .bias , self .training
76
+ or not self .track_running_stats ,
77
+ exponential_average_factor , self .eps )
78
+ return out
79
+
73
80
74
81
class BatchNorm1dLayer (nn .Module ):
75
82
def __init__ (self ,
0 commit comments