10
10
import torch .optim
11
11
import torch .utils .data
12
12
from torch .utils .tensorboard import SummaryWriter
13
+ import torchvision
13
14
import torchvision .transforms as transforms
14
15
import torchvision .datasets as datasets
15
16
17
+ import bayesian_torch
16
18
import bayesian_torch .models .bayesian .resnet_variational_large as resnet
17
19
import numpy as np
18
20
from bayesian_torch .models .bnn_to_qbnn import bnn_to_qbnn
19
- import bayesian_torch .models .bayesian .quantized_resnet_variational_large as qresnet
20
- # import bayesian_torch.models.bayesian.quantized_resnet_flipout_large as qresnet
21
+ from bayesian_torch .models .dnn_to_bnn import dnn_to_bnn
22
+ # import bayesian_torch.models.bayesian.quantized_resnet_variational_large as qresnet
23
+ import bayesian_torch .models .bayesian .quantized_resnet_flipout_large as qresnet
21
24
22
25
torch .cuda .is_available = lambda : False
23
26
os .environ ["CUDA_VISIBLE_DEVICES" ] = "-1"
68
71
"--save-dir" ,
69
72
dest = "save_dir" ,
70
73
help = "The directory used to save the trained models" ,
71
- default = "./checkpoint/bayesian" ,
74
+ default = "../../bayesian-torch-20221214/bayesian_torch /checkpoint/bayesian" ,
72
75
type = str ,
73
76
)
74
77
parser .add_argument (
134
137
help = "use tensorboard for logging and visualization of training progress" ,
135
138
)
136
139
137
- def evaluate (args , model , val_loader ):
140
+ def evaluate (args , model , val_loader , calibration = False ):
138
141
pred_probs_mc = []
139
142
test_loss = 0
140
143
correct = 0
@@ -159,6 +162,9 @@ def evaluate(args, model, val_loader):
159
162
i += 1
160
163
end = time .time ()
161
164
print ("inference throughput: " , i * args .val_batch_size / (end - begin ), " images/s" )
165
+ # break
166
+ if calibration and i == 3 :
167
+ break
162
168
163
169
output = torch .cat (output_list , 1 )
164
170
output = torch .nn .functional .softmax (output , dim = 2 )
@@ -232,7 +238,7 @@ def main():
232
238
233
239
tb_writer = None
234
240
235
- valdir = os .path .join (args .data , 'Imagenet_2012Val ' )
241
+ valdir = os .path .join (args .data , 'val ' )
236
242
normalize = transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ],
237
243
std = [0.229 , 0.224 , 0.225 ])
238
244
val_dataset = datasets .ImageFolder (
@@ -256,45 +262,58 @@ def main():
256
262
os .makedirs (args .save_dir )
257
263
258
264
if args .mode == "test" :
265
+ const_bnn_prior_parameters = {
266
+ "prior_mu" : 0.0 ,
267
+ "prior_sigma" : 1.0 ,
268
+ "posterior_mu_init" : 0.0 ,
269
+ "posterior_rho_init" : args .bnn_rho_init ,
270
+ "type" : "Flipout" if args .use_flipout_layers else "Reparameterization" , # Flipout or Reparameterization
271
+ "moped_enable" : moped_enable , # initialize mu/sigma from the dnn weights
272
+ "moped_delta" : args .moped_delta_factor ,
273
+ }
274
+ quantizable_model = torchvision .models .quantization .resnet50 ()
275
+ dnn_to_bnn (quantizable_model , const_bnn_prior_parameters )
276
+ model = torch .nn .DataParallel (quantizable_model )
277
+
278
+
259
279
checkpoint_file = args .save_dir + "/bayesian_{}_imagenet.pth" .format (args .arch )
260
280
261
281
checkpoint = torch .load (checkpoint_file , map_location = torch .device ("cpu" ))
262
282
model .load_state_dict (checkpoint ["state_dict" ])
263
283
model .module = model .module .cpu ()
264
284
265
285
mp = bayesian_torch .quantization .prepare (model )
266
- evaluate (args , mp , val_loader ) # calibration
286
+ evaluate (args , mp , val_loader , calibration = True ) # calibration
267
287
qmodel = bayesian_torch .quantization .convert (mp )
268
288
evaluate (args , qmodel , val_loader )
269
289
290
+ # save weights
291
+ save_checkpoint (
292
+ {
293
+ 'epoch' : None ,
294
+ 'state_dict' : qmodel .state_dict (),
295
+ 'best_prec1' : None ,
296
+ },
297
+ True ,
298
+ filename = os .path .join (
299
+ args .save_dir ,
300
+ 'quantized_bayesian_{}_imagenetv2.pth' .format (args .arch )))
301
+
302
+ # reconstruct (no calibration)
303
+ quantizable_model = torchvision .models .quantization .resnet50 ()
304
+ dnn_to_bnn (quantizable_model , const_bnn_prior_parameters )
305
+ model = torch .nn .DataParallel (quantizable_model )
306
+ mp = bayesian_torch .quantization .prepare (model )
307
+ qmodel1 = bayesian_torch .quantization .convert (mp )
270
308
309
+ # load
310
+ checkpoint_file = args .save_dir + "/quantized_bayesian_{}_imagenetv2.pth" .format (args .arch )
311
+ checkpoint = torch .load (checkpoint_file , map_location = torch .device ("cpu" ))
312
+ qmodel1 .load_state_dict (checkpoint ["state_dict" ])
313
+ evaluate (args , qmodel1 , val_loader )
271
314
272
- # bnn_to_qbnn(model, fuse_conv_bn=False) # only replaces linear and conv layers
273
-
274
- # model = model.cpu()
275
315
276
- # save weights
277
- # save_checkpoint(
278
- # {
279
- # 'epoch': None,
280
- # 'state_dict': model.state_dict(),
281
- # 'best_prec1': None,
282
- # },
283
- # True,
284
- # filename=os.path.join(
285
- # args.save_dir,
286
- # 'quantized_bayesian_q{}_imagenet.pth'.format(args.arch)))
287
-
288
- # qmodel = torch.nn.DataParallel(qresnet.__dict__['q'+args.arch](bias=False)) # set bias=True to make qconv has bias
289
- # qmodel.module.quant_then_dequant(qmodel, fuse_conv_bn=False)
290
-
291
- # load weights
292
- # checkpoint_file = args.save_dir + "/quantized_bayesian_q{}_imagenet.pth".format(args.arch)
293
- # checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu"))
294
- # qmodel.load_state_dict(checkpoint["state_dict"])
295
-
296
- # qmodel.load_state_dict(model.state_dict())
297
- # evaluate(args, qmodel, val_loader)
316
+ return mp , qmodel , qmodel1
298
317
299
318
if __name__ == "__main__" :
300
- main ()
319
+ mp , qmodel , qmodel1 = main ()
0 commit comments