@@ -42,20 +42,37 @@ def convert_qnniqat_linearbn(model, fused_node):
4242
4343@register_convert_function (qnniqat .ConvFreezebn2d )
4444@register_convert_function (nniqat .ConvBn2d )
45+ @register_convert_function (nniqat .ConvBn3d )
4546def convert_nniqat_convbn (model , fused_node ):
47+ """nniqat.ConvBn2d ----> nn.Conv2d ----> nniqat.Conv2d
48+ """
49+ fused_module_class_map = {
50+ qnniqat .ConvFreezebn2d : torch .nn .Conv2d ,
51+ qnniqat .ConvFreezebnReLU2d : torch .nn .Conv2d ,
52+ nniqat .ConvBn2d : torch .nn .Conv2d ,
53+ nniqat .ConvBnReLU2d : torch .nn .Conv2d ,
54+ nniqat .ConvBn3d : torch .nn .Conv3d ,
55+ nniqat .ConvBnReLU3d : torch .nn .Conv3d ,
56+ }
57+ fused_qat_module_class_map = {
58+ torch .nn .Conv2d : torch .nn .qat .Conv2d ,
59+ torch .nn .Conv3d : torch .nn .qat .Conv3d ,
60+ }
4661 modules = dict (model .named_modules ())
4762 fused_module = modules [fused_node .target ]
4863 # Create a Conv2d from FusedModule.
49- conv = torch .nn .Conv2d (fused_module .in_channels , fused_module .out_channels , fused_module .kernel_size ,
50- fused_module .stride , fused_module .padding , fused_module .dilation ,
51- fused_module .groups , fused_module .bias is not None , fused_module .padding_mode )
64+ conv = fused_module_class_map [type (fused_module )](fused_module .in_channels , fused_module .out_channels ,
65+ fused_module .kernel_size , fused_module .stride ,
66+ fused_module .padding , fused_module .dilation ,
67+ fused_module .groups , fused_module .bias is not None ,
68+ fused_module .padding_mode )
5269 conv .weight = fused_module .weight
5370 if fused_module .bias is not None :
5471 conv .bias = fused_module .bias
5572 fused_conv = fuse_conv_bn_eval (conv .eval (), fused_module .bn )
5673 # We need nn.qat.conv here to export weight quantize node.
5774 fused_conv .qconfig = fused_module .qconfig
58- fused_conv = torch . nn . qat . Conv2d .from_float (fused_conv )
75+ fused_conv = fused_qat_module_class_map [ type ( conv )] .from_float (fused_conv )
5976 # Attach weight fake quantize params.
6077 fused_conv .weight_fake_quant = fused_module .weight_fake_quant
6178 conv_parent_name , conv_name = _parent_name (fused_node .target )
@@ -64,7 +81,8 @@ def convert_nniqat_convbn(model, fused_node):
6481
6582@register_convert_function (qnniqat .ConvFreezebnReLU2d )
6683@register_convert_function (nniqat .ConvBnReLU2d )
67- def convert_nniqat_convbnrelu (model , fused_node ):
84+ @register_convert_function (nniqat .ConvBnReLU3d )
85+ def convert_nniqat_convbnrelu (model , fused_node ):
6886 convert_nniqat_convbn (model , fused_node )
6987 modules = dict (model .named_modules ())
7088 fused_module = modules [fused_node .target ]
@@ -196,6 +214,9 @@ def convert_qnniqat_deconvbnrelu(model, fused_node):
196214
197215@register_convert_function (qnniqat .ConvBn2d )
198216def convert_qnniqat_convbn (model , fused_node ):
217+ """mqbench.nn.intrinsic.qat module add bias quant.
218+ That is the difference between torch.nn.intrinsic.qat module.
219+ """
199220 modules = dict (model .named_modules ())
200221 fused_module = modules [fused_node .target ]
201222 # Create a Conv2d from FusedModule.
@@ -222,6 +243,9 @@ def convert_qnniqat_convbn(model, fused_node):
222243
223244@register_convert_function (qnniqat .ConvBnReLU2d )
224245def convert_qnniqat_convbnrelu (model , fused_node ):
246+ """mqbench.nn.intrinsic.qat module add bias quant.
247+ That is the difference between torch.nn.intrinsic.qat module.
248+ """
225249 convert_qnniqat_convbn (model , fused_node )
226250 modules = dict (model .named_modules ())
227251 fused_module = modules [fused_node .target ]
0 commit comments