@@ -119,6 +119,15 @@ def qbnn_conv_layer(d):
119
119
groups = d .groups ,
120
120
)
121
121
qbnn_layer .__dict__ .update (d .__dict__ )
122
+
123
+ if d .quant_prepare :
124
+ qbnn_layer .quant_dict = []
125
+ for qstub in d .qint_quant :
126
+ qbnn_layer .quant_dict .append ({'scale' :qstub .scale .item (), 'zero_point' :qstub .zero_point .item ()})
127
+ qbnn_layer .quant_dict = qbnn_layer .quant_dict [2 :]
128
+ for qstub in d .quint_quant :
129
+ qbnn_layer .quant_dict .append ({'scale' :qstub .scale .item (), 'zero_point' :qstub .zero_point .item ()})
130
+
122
131
qbnn_layer .quantize ()
123
132
if d .dnn_to_bnn_flag :
124
133
qbnn_layer .dnn_to_bnn_flag = True
@@ -180,7 +189,10 @@ def batch_norm_folding(conv, bn):
180
189
def bnn_to_qbnn (m , fuse_conv_bn = False ):
181
190
for name , value in list (m ._modules .items ()):
182
191
if m ._modules [name ]._modules :
183
- bnn_to_qbnn (m ._modules [name ], fuse_conv_bn = fuse_conv_bn )
192
+ if "Conv" in m ._modules [name ].__class__ .__name__ :
193
+ setattr (m , name , qbnn_conv_layer (m ._modules [name ]))
194
+ else :
195
+ bnn_to_qbnn (m ._modules [name ], fuse_conv_bn = fuse_conv_bn )
184
196
elif "Linear" in m ._modules [name ].__class__ .__name__ :
185
197
setattr (m , name , qbnn_linear_layer (m ._modules [name ]))
186
198
elif "LSTM" in m ._modules [name ].__class__ .__name__ :
0 commit comments