Skip to content

Commit 9b0118f

Browse files
committed
bnn to qbnn
1 parent 9b5a9dc commit 9b0118f

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

bayesian_torch/models/bnn_to_qbnn.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,15 @@ def qbnn_conv_layer(d):
119119
groups=d.groups,
120120
)
121121
qbnn_layer.__dict__.update(d.__dict__)
122+
123+
if d.quant_prepare:
124+
qbnn_layer.quant_dict = []
125+
for qstub in d.qint_quant:
126+
qbnn_layer.quant_dict.append({'scale':qstub.scale.item(), 'zero_point':qstub.zero_point.item()})
127+
qbnn_layer.quant_dict = qbnn_layer.quant_dict[2:]
128+
for qstub in d.quint_quant:
129+
qbnn_layer.quant_dict.append({'scale':qstub.scale.item(), 'zero_point':qstub.zero_point.item()})
130+
122131
qbnn_layer.quantize()
123132
if d.dnn_to_bnn_flag:
124133
qbnn_layer.dnn_to_bnn_flag = True
@@ -180,7 +189,10 @@ def batch_norm_folding(conv, bn):
180189
def bnn_to_qbnn(m, fuse_conv_bn=False):
181190
for name, value in list(m._modules.items()):
182191
if m._modules[name]._modules:
183-
bnn_to_qbnn(m._modules[name], fuse_conv_bn=fuse_conv_bn)
192+
if "Conv" in m._modules[name].__class__.__name__:
193+
setattr(m, name, qbnn_conv_layer(m._modules[name]))
194+
else:
195+
bnn_to_qbnn(m._modules[name], fuse_conv_bn=fuse_conv_bn)
184196
elif "Linear" in m._modules[name].__class__.__name__:
185197
setattr(m, name, qbnn_linear_layer(m._modules[name]))
186198
elif "LSTM" in m._modules[name].__class__.__name__:

0 commit comments

Comments
 (0)