Skip to content

Commit b3d9980

Browse files
committed
support load and store quantized models
1 parent 1e8bd69 commit b3d9980

File tree

2 files changed

+57
-38
lines changed

2 files changed

+57
-38
lines changed

bayesian_torch/examples/main_bayesian_imagenet_bnn2qbnn.py

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@
1010
import torch.optim
1111
import torch.utils.data
1212
from torch.utils.tensorboard import SummaryWriter
13+
import torchvision
1314
import torchvision.transforms as transforms
1415
import torchvision.datasets as datasets
1516

17+
import bayesian_torch
1618
import bayesian_torch.models.bayesian.resnet_variational_large as resnet
1719
import numpy as np
1820
from bayesian_torch.models.bnn_to_qbnn import bnn_to_qbnn
19-
import bayesian_torch.models.bayesian.quantized_resnet_variational_large as qresnet
20-
# import bayesian_torch.models.bayesian.quantized_resnet_flipout_large as qresnet
21+
from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn
22+
# import bayesian_torch.models.bayesian.quantized_resnet_variational_large as qresnet
23+
import bayesian_torch.models.bayesian.quantized_resnet_flipout_large as qresnet
2124

2225
torch.cuda.is_available = lambda : False
2326
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
@@ -68,7 +71,7 @@
6871
"--save-dir",
6972
dest="save_dir",
7073
help="The directory used to save the trained models",
71-
default="./checkpoint/bayesian",
74+
default="../../bayesian-torch-20221214/bayesian_torch/checkpoint/bayesian",
7275
type=str,
7376
)
7477
parser.add_argument(
@@ -134,7 +137,7 @@
134137
help="use tensorboard for logging and visualization of training progress",
135138
)
136139

137-
def evaluate(args, model, val_loader):
140+
def evaluate(args, model, val_loader, calibration=False):
138141
pred_probs_mc = []
139142
test_loss = 0
140143
correct = 0
@@ -159,6 +162,9 @@ def evaluate(args, model, val_loader):
159162
i+=1
160163
end = time.time()
161164
print("inference throughput: ", i*args.val_batch_size / (end - begin), " images/s")
165+
# break
166+
if calibration and i==3:
167+
break
162168

163169
output = torch.cat(output_list, 1)
164170
output = torch.nn.functional.softmax(output, dim=2)
@@ -232,7 +238,7 @@ def main():
232238

233239
tb_writer = None
234240

235-
valdir = os.path.join(args.data, 'Imagenet_2012Val')
241+
valdir = os.path.join(args.data, 'val')
236242
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
237243
std=[0.229, 0.224, 0.225])
238244
val_dataset = datasets.ImageFolder(
@@ -256,45 +262,58 @@ def main():
256262
os.makedirs(args.save_dir)
257263

258264
if args.mode == "test":
265+
const_bnn_prior_parameters = {
266+
"prior_mu": 0.0,
267+
"prior_sigma": 1.0,
268+
"posterior_mu_init": 0.0,
269+
"posterior_rho_init": args.bnn_rho_init,
270+
"type": "Flipout" if args.use_flipout_layers else "Reparameterization", # Flipout or Reparameterization
271+
"moped_enable": moped_enable, # initialize mu/sigma from the dnn weights
272+
"moped_delta": args.moped_delta_factor,
273+
}
274+
quantizable_model = torchvision.models.quantization.resnet50()
275+
dnn_to_bnn(quantizable_model, const_bnn_prior_parameters)
276+
model = torch.nn.DataParallel(quantizable_model)
277+
278+
259279
checkpoint_file = args.save_dir + "/bayesian_{}_imagenet.pth".format(args.arch)
260280

261281
checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu"))
262282
model.load_state_dict(checkpoint["state_dict"])
263283
model.module = model.module.cpu()
264284

265285
mp = bayesian_torch.quantization.prepare(model)
266-
evaluate(args, mp, val_loader) # calibration
286+
evaluate(args, mp, val_loader, calibration=True) # calibration
267287
qmodel = bayesian_torch.quantization.convert(mp)
268288
evaluate(args, qmodel, val_loader)
269289

290+
# save weights
291+
save_checkpoint(
292+
{
293+
'epoch': None,
294+
'state_dict': qmodel.state_dict(),
295+
'best_prec1': None,
296+
},
297+
True,
298+
filename=os.path.join(
299+
args.save_dir,
300+
'quantized_bayesian_{}_imagenetv2.pth'.format(args.arch)))
301+
302+
# reconstruct (no calibration)
303+
quantizable_model = torchvision.models.quantization.resnet50()
304+
dnn_to_bnn(quantizable_model, const_bnn_prior_parameters)
305+
model = torch.nn.DataParallel(quantizable_model)
306+
mp = bayesian_torch.quantization.prepare(model)
307+
qmodel1 = bayesian_torch.quantization.convert(mp)
270308

309+
# load
310+
checkpoint_file = args.save_dir + "/quantized_bayesian_{}_imagenetv2.pth".format(args.arch)
311+
checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu"))
312+
qmodel1.load_state_dict(checkpoint["state_dict"])
313+
evaluate(args, qmodel1, val_loader)
271314

272-
# bnn_to_qbnn(model, fuse_conv_bn=False) # only replaces linear and conv layers
273-
274-
# model = model.cpu()
275315

276-
# save weights
277-
# save_checkpoint(
278-
# {
279-
# 'epoch': None,
280-
# 'state_dict': model.state_dict(),
281-
# 'best_prec1': None,
282-
# },
283-
# True,
284-
# filename=os.path.join(
285-
# args.save_dir,
286-
# 'quantized_bayesian_q{}_imagenet.pth'.format(args.arch)))
287-
288-
# qmodel = torch.nn.DataParallel(qresnet.__dict__['q'+args.arch](bias=False)) # set bias=True to make qconv has bias
289-
# qmodel.module.quant_then_dequant(qmodel, fuse_conv_bn=False)
290-
291-
# load weights
292-
# checkpoint_file = args.save_dir + "/quantized_bayesian_q{}_imagenet.pth".format(args.arch)
293-
# checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu"))
294-
# qmodel.load_state_dict(checkpoint["state_dict"])
295-
296-
# qmodel.load_state_dict(model.state_dict())
297-
# evaluate(args, qmodel, val_loader)
316+
return mp, qmodel, qmodel1
298317

299318
if __name__ == "__main__":
300-
main()
319+
mp, qmodel, qmodel1 = main()

bayesian_torch/models/bnn_to_qbnn.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,12 @@ def qbnn_linear_layer(d):
103103
qbnn_layer.__dict__.update(d.__dict__)
104104

105105
if d.quant_prepare:
106-
qbnn_layer.quant_dict = []
106+
qbnn_layer.quant_dict = nn.ModuleList()
107107
for qstub in d.qint_quant:
108-
qbnn_layer.quant_dict.append({'scale':qstub.scale.item(), 'zero_point':qstub.zero_point.item()})
108+
qbnn_layer.quant_dict.append(nn.ParameterDict({'scale': torch.nn.Parameter(qstub.scale.float()), 'zero_point': torch.nn.Parameter(qstub.zero_point.float())}))
109109
qbnn_layer.quant_dict = qbnn_layer.quant_dict[2:]
110110
for qstub in d.quint_quant:
111-
qbnn_layer.quant_dict.append({'scale':qstub.scale.item(), 'zero_point':qstub.zero_point.item()})
111+
qbnn_layer.quant_dict.append(nn.ParameterDict({'scale': torch.nn.Parameter(qstub.scale.float()), 'zero_point': torch.nn.Parameter(qstub.zero_point.float())}))
112112

113113
qbnn_layer.quantize()
114114
if d.dnn_to_bnn_flag:
@@ -130,12 +130,12 @@ def qbnn_conv_layer(d):
130130
qbnn_layer.__dict__.update(d.__dict__)
131131

132132
if d.quant_prepare:
133-
qbnn_layer.quant_dict = []
133+
qbnn_layer.quant_dict = nn.ModuleList()
134134
for qstub in d.qint_quant:
135-
qbnn_layer.quant_dict.append({'scale':qstub.scale.item(), 'zero_point':qstub.zero_point.item()})
135+
qbnn_layer.quant_dict.append(nn.ParameterDict({'scale': torch.nn.Parameter(qstub.scale.float()), 'zero_point': torch.nn.Parameter(qstub.zero_point.float())}))
136136
qbnn_layer.quant_dict = qbnn_layer.quant_dict[2:]
137137
for qstub in d.quint_quant:
138-
qbnn_layer.quant_dict.append({'scale':qstub.scale.item(), 'zero_point':qstub.zero_point.item()})
138+
qbnn_layer.quant_dict.append(nn.ParameterDict({'scale': torch.nn.Parameter(qstub.scale.float()), 'zero_point': torch.nn.Parameter(qstub.zero_point.float())}))
139139

140140
qbnn_layer.quantize()
141141
if d.dnn_to_bnn_flag:

0 commit comments

Comments
 (0)