@register_convert_function(qnniqat.LinearBn1d)
def convert_qnniqat_linearbn(model, fused_node):
modules = dict(model.named_modules())
fused_module = modules[fused_node.target]
# Create a Linear from FusedModule.
linear = torch.nn.Linear(fused_module.in_features, fused_module.out_features, fused_module.bias is not None)
linear.weight = fused_module.weight
if fused_module.bias is not None:
linear.bias = fused_module.bias
# Merge Linear + BN
fused_linear = fuse_linear_bn_eval(linear.eval(), fused_module.bn.eval())
# We need nn.qat.linear here to export weight quantize node.
linear.qconfig = fused_module.qconfig
linear = torch.nn.qat.Linear.from_float(linear)
# Attach weight fake quantize params.
linear.weight_fake_quant = fused_module.weight_fake_quant
linear_parent_name, linear_name = _parent_name(fused_node.target)
setattr(modules[linear_parent_name], linear_name, fused_linear)
The last line "setattr(modules[linear_parent_name], linear_name, fused_linear )" should be "setattr(modules[linear_parent_name], linear_name, linear)"