1- from typing import Optional , Type
2-
31import torch
42import torch .nn as nn
53from torch .quantization .fx .fusion_patterns import ConvBNReLUFusion , ModuleReLUFusion
1311from mqbench .nn .modules import FrozenBatchNorm2d
1412
1513
16- class ConvFreezebnReLUFusion (ConvBNReLUFusion ):
14+ class ConvExtendBnReLUFusion (ConvBNReLUFusion ):
1715 def __init__ (self , quantizer : QuantizerCls , node : Node ):
1816 super (ConvBNReLUFusion , self ).__init__ (quantizer , node )
1917 self .relu_node = None
@@ -87,39 +85,27 @@ def fuse_deconv_bn_relu(deconv, bn, relu):
8785def fuse_conv_freezebn (conv , bn ):
8886 assert bn .training is False , "Freezebn must be eval."
8987
90- fused_module_class_map = {
91- nn .Conv2d : qnni .ConvFreezebn2d ,
92- }
93-
9488 if conv .training :
9589 assert bn .num_features == conv .out_channels , 'Output channel of Conv2d must match num_features of BatchNorm2d'
9690 assert bn .affine , 'Only support fusing BatchNorm2d with affine set to True'
9791 assert bn .track_running_stats , 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
98- fused_module_class = fused_module_class_map .get ((type (conv )), None )
99- return fused_module_class (conv , bn )
92+ return qnni .ConvFreezebn2d (conv , bn )
10093 else :
10194 return nn .utils .fuse_conv_bn_eval (conv , bn )
10295
10396
10497def fuse_conv_freezebn_relu (conv , bn , relu ):
105- assert conv .training == relu .training and bn .training is False , "Conv and relu both must be in the same mode (train or eval) and bn must be eval."
106- fused_module : Optional [Type [nn .Sequential ]] = None
98+ assert conv .training == relu .training and bn .training is False , \
99+ "Conv and relu both must be in the same mode (train or eval) and bn must be eval."
100+
107101 if conv .training :
108- map_to_fused_module_train = {
109- nn .Conv2d : qnni .ConvFreezebnReLU2d ,
110- }
111102 assert bn .num_features == conv .out_channels , 'Output channel of Conv must match num_features of BatchNorm'
112103 assert bn .affine , 'Only support fusing BatchNorm with affine set to True'
113104 assert bn .track_running_stats , 'Only support fusing BatchNorm with tracking_running_stats set to True'
114- fused_module = map_to_fused_module_train .get (type (conv ), None )
115- return fused_module (conv , bn , relu )
105+ return qnni .ConvFreezebnReLU2d (conv , bn , relu )
116106 else :
117- map_to_fused_module_eval = {
118- nn .Conv2d : nn .intrinsic .ConvReLU2d ,
119- }
120- fused_module = map_to_fused_module_eval .get (type (conv ), None )
121- fused_conv = nn .utils .fusion .fuse_conv_bn_eval (conv , bn )
122- return fused_module (fused_conv , relu )
107+ fused_conv = nn .utils .fuse_conv_bn_eval (conv , bn )
108+ return nn .intrinsic .ConvReLU2d (fused_conv , relu )
123109
124110
125111def fuse_deconv_freezebn (deconv , bn ):
@@ -135,7 +121,8 @@ def fuse_deconv_freezebn(deconv, bn):
135121
136122
137123def fuse_deconv_freezebn_relu (deconv , bn , relu ):
138- assert deconv .training == relu .training and bn .training is False , "Conv and relu both must be in the same mode (train or eval) and bn must be eval."
124+ assert deconv .training == relu .training and bn .training is False , \
125+ "Conv and relu both must be in the same mode (train or eval) and bn must be eval."
139126
140127 if deconv .training :
141128 assert bn .num_features == deconv .out_channels , 'Output channel of ConvTranspose2d must match num_features of BatchNorm2d'
@@ -171,13 +158,13 @@ def fuse_deconv_freezebn_relu(deconv, bn, relu):
171158 (torch .nn .functional .relu , (torch .nn .BatchNorm2d , torch .nn .ConvTranspose2d )):
172159 ConvBNReLUFusion ,
173160 (torch .nn .ReLU , (FrozenBatchNorm2d , torch .nn .Conv2d )):
174- ConvFreezebnReLUFusion ,
161+ ConvExtendBnReLUFusion ,
175162 (FrozenBatchNorm2d , torch .nn .Conv2d ):
176- ConvFreezebnReLUFusion ,
163+ ConvExtendBnReLUFusion ,
177164 (torch .nn .ReLU , (FrozenBatchNorm2d , torch .nn .ConvTranspose2d )):
178- ConvFreezebnReLUFusion ,
165+ ConvExtendBnReLUFusion ,
179166 (FrozenBatchNorm2d , torch .nn .ConvTranspose2d ):
180- ConvFreezebnReLUFusion ,
167+ ConvExtendBnReLUFusion ,
181168 },
182169 "additional_qat_module_mappings" : {
183170 nn .ConvTranspose2d : qnn .qat .ConvTranspose2d ,
0 commit comments