@@ -284,6 +284,7 @@ def __init__(self,
284
284
self .bn_eps = None
285
285
286
286
self .is_dequant = False
287
+ self .quant_dict = None
287
288
288
289
def get_scale_and_zero_point (self , x , upper_bound = 100 , target_range = 255 ):
289
290
""" An implementation for symmetric quantization
@@ -425,40 +426,67 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1
425
426
if self .dnn_to_bnn_flag :
426
427
return_kl = False
427
428
428
- if x .dtype != torch .quint8 :
429
- x = torch .quantize_per_tensor (x , default_scale , default_zero_point , torch .quint8 )
430
-
431
- bias = None
432
- if self .bias :
433
- bias = self .quantized_mu_bias
434
-
435
- outputs = torch .nn .quantized .functional .conv2d (x , self .quantized_mu_weight , bias , self .stride , self .padding ,
436
- self .dilation , self .groups , scale = default_scale , zero_point = default_zero_point ) # input: quint8, weight: qint8, bias: fp32
437
-
438
- # sampling perturbation signs
439
- sign_input = torch .zeros (x .shape ).uniform_ (- 1 , 1 ).sign ()
440
- sign_output = torch .zeros (outputs .shape ).uniform_ (- 1 , 1 ).sign ()
441
- sign_input = torch .quantize_per_tensor (sign_input , default_scale , default_zero_point , torch .quint8 )
442
- sign_output = torch .quantize_per_tensor (sign_output , default_scale , default_zero_point , torch .quint8 )
443
-
444
- # getting perturbation weights
445
- eps_kernel = torch .quantize_per_tensor (self .eps_kernel .data .normal_ (), normal_scale , 0 , torch .qint8 )
446
- new_scale = (self .quantized_sigma_weight .q_scale ())* (eps_kernel .q_scale ())
447
- delta_kernel = torch .ops .quantized .mul (self .quantized_sigma_weight , eps_kernel , new_scale , 0 )
448
-
449
429
bias = None
450
430
if self .bias :
451
- eps_bias = self .eps_bias .data .normal_ ()
452
- bias = (self .quantized_sigma_bias * eps_bias )
431
+ bias = self .quantized_mu_bias # TODO: check correctness
432
+
433
+ if self .quant_dict is not None :
434
+ # getting perturbation weights
435
+ eps_kernel = torch .quantize_per_tensor (self .eps_kernel .data .normal_ (), self .quant_dict [0 ]['scale' ], self .quant_dict [0 ]['zero_point' ], torch .qint8 )
436
+ delta_kernel = torch .ops .quantized .mul (self .quantized_sigma_weight , eps_kernel , self .quant_dict [1 ]['scale' ], self .quant_dict [1 ]['zero_point' ])
437
+
438
+ if x .dtype != torch .quint8 : # check if input has been quantized
439
+ x = torch .quantize_per_tensor (x , self .quant_dict [2 ]['scale' ], self .quant_dict [2 ]['zero_point' ], torch .quint8 ) # scale=0.1 by grid search; zero_point=128 for uint8 format
440
+
441
+ outputs = torch .nn .quantized .functional .conv2d (x , self .quantized_mu_weight , bias , self .stride , self .padding ,
442
+ self .dilation , self .groups , scale = self .quant_dict [3 ]['scale' ], zero_point = self .quant_dict [3 ]['zero_point' ]) # input: quint8, weight: qint8, bias: fp32
443
+
444
+ # sampling perturbation signs
445
+ sign_input = torch .zeros (x .shape ).uniform_ (- 1 , 1 ).sign ()
446
+ sign_output = torch .zeros (outputs .shape ).uniform_ (- 1 , 1 ).sign ()
447
+ sign_input = torch .quantize_per_tensor (sign_input , self .quant_dict [4 ]['scale' ], self .quant_dict [4 ]['zero_point' ], torch .quint8 )
448
+ sign_output = torch .quantize_per_tensor (sign_output , self .quant_dict [5 ]['scale' ], self .quant_dict [5 ]['zero_point' ], torch .quint8 )
449
+
450
+ # perturbed feedforward
451
+ x = torch .ops .quantized .mul (x , sign_input , self .quant_dict [6 ]['scale' ], self .quant_dict [6 ]['zero_point' ])
452
+ perturbed_outputs = torch .nn .quantized .functional .conv2d (x ,
453
+ weight = delta_kernel , bias = bias , stride = self .stride , padding = self .padding ,
454
+ dilation = self .dilation , groups = self .groups , scale = self .quant_dict [7 ]['scale' ], zero_point = self .quant_dict [7 ]['zero_point' ])
455
+ perturbed_outputs = torch .ops .quantized .mul (perturbed_outputs , sign_output , self .quant_dict [8 ]['scale' ], self .quant_dict [8 ]['zero_point' ])
456
+ out = torch .ops .quantized .add (outputs , perturbed_outputs , self .quant_dict [9 ]['scale' ], self .quant_dict [9 ]['zero_point' ])
457
+ out = out .dequantize ()
453
458
454
- # perturbed feedforward
455
- x = torch .ops .quantized .mul (x , sign_input , default_scale , default_zero_point )
456
-
457
- perturbed_outputs = torch .nn .quantized .functional .conv2d (x ,
458
- weight = delta_kernel , bias = bias , stride = self .stride , padding = self .padding ,
459
- dilation = self .dilation , groups = self .groups , scale = default_scale , zero_point = default_zero_point )
460
- perturbed_outputs = torch .ops .quantized .mul (perturbed_outputs , sign_output , default_scale , default_zero_point )
461
- out = torch .ops .quantized .add (outputs , perturbed_outputs , default_scale , default_zero_point )
459
+ else :
460
+ if x .dtype != torch .quint8 :
461
+ x = torch .quantize_per_tensor (x , default_scale , default_zero_point , torch .quint8 )
462
+
463
+ outputs = torch .nn .quantized .functional .conv2d (x , self .quantized_mu_weight , bias , self .stride , self .padding ,
464
+ self .dilation , self .groups , scale = default_scale , zero_point = default_zero_point ) # input: quint8, weight: qint8, bias: fp32
465
+
466
+ # sampling perturbation signs
467
+ sign_input = torch .zeros (x .shape ).uniform_ (- 1 , 1 ).sign ()
468
+ sign_output = torch .zeros (outputs .shape ).uniform_ (- 1 , 1 ).sign ()
469
+ sign_input = torch .quantize_per_tensor (sign_input , default_scale , default_zero_point , torch .quint8 )
470
+ sign_output = torch .quantize_per_tensor (sign_output , default_scale , default_zero_point , torch .quint8 )
471
+
472
+ # getting perturbation weights
473
+ eps_kernel = torch .quantize_per_tensor (self .eps_kernel .data .normal_ (), normal_scale , 0 , torch .qint8 )
474
+ new_scale = (self .quantized_sigma_weight .q_scale ())* (eps_kernel .q_scale ())
475
+ delta_kernel = torch .ops .quantized .mul (self .quantized_sigma_weight , eps_kernel , new_scale , 0 )
476
+
477
+ bias = None
478
+ if self .bias :
479
+ eps_bias = self .eps_bias .data .normal_ ()
480
+ bias = (self .quantized_sigma_bias * eps_bias )
481
+
482
+ # perturbed feedforward
483
+ x = torch .ops .quantized .mul (x , sign_input , default_scale , default_zero_point )
484
+
485
+ perturbed_outputs = torch .nn .quantized .functional .conv2d (x ,
486
+ weight = delta_kernel , bias = bias , stride = self .stride , padding = self .padding ,
487
+ dilation = self .dilation , groups = self .groups , scale = default_scale , zero_point = default_zero_point )
488
+ perturbed_outputs = torch .ops .quantized .mul (perturbed_outputs , sign_output , default_scale , default_zero_point )
489
+ out = torch .ops .quantized .add (outputs , perturbed_outputs , default_scale , default_zero_point )
462
490
463
491
if return_kl :
464
492
return out , 0
0 commit comments