@@ -47,7 +47,7 @@ def fuse_linear_bn(linear, bn):
4747 >>> b1 = nn.BatchNorm1d(20)
4848 >>> m2 = fuse_linear_bn(m1, b1)
4949 """
50- assert ( linear .training == bn .training ), \
50+ assert linear .training == bn .training , \
5151 "Linear and BN both must be in the same mode (train or eval)."
5252
5353 if linear .training :
@@ -59,7 +59,7 @@ def fuse_linear_bn(linear, bn):
5959
6060
6161def fuse_deconv_bn (deconv , bn ):
62- assert ( deconv .training == bn .training ), \
62+ assert deconv .training == bn .training , \
6363 'DeConv and BN must be in the same mode (train or eval)'
6464
6565 if deconv .training :
@@ -72,7 +72,7 @@ def fuse_deconv_bn(deconv, bn):
7272
7373
7474def fuse_deconv_bn_relu (deconv , bn , relu ):
75- assert ( deconv .training == bn .training == relu .training ), \
75+ assert deconv .training == bn .training == relu .training , \
7676 "DeConv and BN both must be in the same mode (train or eval)."
7777
7878 if deconv .training :
@@ -85,7 +85,7 @@ def fuse_deconv_bn_relu(deconv, bn, relu):
8585
8686
8787def fuse_conv_freezebn (conv , bn ):
88- assert ( bn .training is False ) , "Freezebn must be eval."
88+ assert bn .training is False , "Freezebn must be eval."
8989
9090 fused_module_class_map = {
9191 nn .Conv2d : qnni .ConvFreezebn2d ,
@@ -102,7 +102,7 @@ def fuse_conv_freezebn(conv, bn):
102102
103103
104104def fuse_conv_freezebn_relu (conv , bn , relu ):
105- assert ( conv .training == relu .training and bn .training is False ) , "Conv and relu both must be in the same mode (train or eval) and bn must be eval."
105+ assert conv .training == relu .training and bn .training is False , "Conv and relu both must be in the same mode (train or eval) and bn must be eval."
106106 fused_module : Optional [Type [nn .Sequential ]] = None
107107 if conv .training :
108108 map_to_fused_module_train = {
@@ -123,7 +123,7 @@ def fuse_conv_freezebn_relu(conv, bn, relu):
123123
124124
125125def fuse_deconv_freezebn (deconv , bn ):
126- assert ( bn .training is False ) , "Freezebn must be eval."
126+ assert bn .training is False , "Freezebn must be eval."
127127
128128 if deconv .training :
129129 assert bn .num_features == deconv .out_channels , 'Output channel of ConvTranspose2d must match num_features of BatchNorm2d'
@@ -135,7 +135,7 @@ def fuse_deconv_freezebn(deconv, bn):
135135
136136
137137def fuse_deconv_freezebn_relu (deconv , bn , relu ):
138- assert ( deconv .training == relu .training and bn .training is False ) , "Conv and relu both must be in the same mode (train or eval) and bn must be eval."
138+ assert deconv .training == relu .training and bn .training is False , "Conv and relu both must be in the same mode (train or eval) and bn must be eval."
139139
140140 if deconv .training :
141141 assert bn .num_features == deconv .out_channels , 'Output channel of ConvTranspose2d must match num_features of BatchNorm2d'
0 commit comments