Skip to content

Commit ccc52ee

Browse files
committed
calibration support for quantized flipout layers
1 parent b3d9980 commit ccc52ee

File tree

4 files changed

+202
-70
lines changed

4 files changed

+202
-70
lines changed

bayesian_torch/layers/flipout_layers/conv_flipout.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
import torch.nn as nn
3838
import torch.nn.functional as F
3939
from ..base_variational_layer import BaseVariationalLayer_, get_kernel_size
40+
from torch.quantization.observer import HistogramObserver, PerChannelMinMaxObserver, MinMaxObserver
41+
from torch.quantization.qconfig import QConfig
4042

4143
from torch.distributions.normal import Normal
4244
from torch.distributions.uniform import Uniform
@@ -136,6 +138,15 @@ def __init__(self,
136138
self.register_buffer('prior_bias_sigma', None, persistent=False)
137139

138140
self.init_parameters()
141+
self.quant_prepare=False
142+
143+
def prepare(self):
144+
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
145+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(4)])
146+
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
147+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(8)])
148+
self.dequant = torch.quantization.DeQuantStub()
149+
self.quant_prepare=True
139150

140151
def init_parameters(self):
141152
# prior values
@@ -303,6 +314,15 @@ def __init__(self,
303314
self.register_buffer('prior_bias_sigma', None, persistent=False)
304315

305316
self.init_parameters()
317+
self.quant_prepare=False
318+
319+
def prepare(self):
320+
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
321+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(4)])
322+
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
323+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(8)])
324+
self.dequant = torch.quantization.DeQuantStub()
325+
self.quant_prepare=True
306326

307327
def init_parameters(self):
308328
# prior values
@@ -365,18 +385,38 @@ def forward(self, x, return_kl=True):
365385
self.prior_bias_sigma)
366386

367387
# perturbed feedforward
368-
perturbed_outputs = F.conv2d(x * sign_input,
388+
x_tmp = x * sign_input
389+
perturbed_outputs_tmp = F.conv2d(x * sign_input,
369390
weight=delta_kernel,
370391
bias=bias,
371392
stride=self.stride,
372393
padding=self.padding,
373394
dilation=self.dilation,
374-
groups=self.groups) * sign_output
395+
groups=self.groups)
396+
perturbed_outputs = perturbed_outputs_tmp * sign_output
397+
out = outputs + perturbed_outputs
398+
399+
if self.quant_prepare:
400+
# quint8 quantstub
401+
input = self.quint_quant[0](input) # input
402+
outputs = self.quint_quant[1](outputs) # output
403+
sign_input = self.quint_quant[2](sign_input)
404+
sign_output = self.quint_quant[3](sign_output)
405+
x_tmp = self.quint_quant[4](x_tmp)
406+
perturbed_outputs_tmp = self.quint_quant[5](perturbed_outputs_tmp) # output
407+
perturbed_outputs = self.quint_quant[6](perturbed_outputs) # output
408+
out = self.quint_quant[7](out) # output
409+
410+
# qint8 quantstub
411+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
412+
mu_kernel = self.qint_quant[1](self.mu_kernel) # weight
413+
eps_kernel = self.qint_quant[2](eps_kernel) # random variable
414+
delta_kernel =self.qint_quant[3](delta_kernel) # multiply activation
375415

376416
# returning outputs + perturbations
377417
if return_kl:
378-
return outputs + perturbed_outputs, kl
379-
return outputs + perturbed_outputs
418+
return out, kl
419+
return out
380420

381421

382422
class Conv3dFlipout(BaseVariationalLayer_):

bayesian_torch/layers/flipout_layers/linear_flipout.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
from torch.distributions.normal import Normal
4141
from torch.distributions.uniform import Uniform
4242
from ..base_variational_layer import BaseVariationalLayer_
43+
from torch.quantization.observer import HistogramObserver, PerChannelMinMaxObserver, MinMaxObserver
44+
from torch.quantization.qconfig import QConfig
4345

4446
__all__ = ["LinearFlipout"]
4547

@@ -107,6 +109,15 @@ def __init__(self,
107109
self.register_buffer('eps_bias', None, persistent=False)
108110

109111
self.init_parameters()
112+
self.quant_prepare=False
113+
114+
def prepare(self):
115+
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
116+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(4)])
117+
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
118+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(8)])
119+
self.dequant = torch.quantization.DeQuantStub()
120+
self.quant_prepare=True
110121

111122
def init_parameters(self):
112123
# init prior mu
@@ -136,7 +147,9 @@ def forward(self, x, return_kl=True):
136147
return_kl = False
137148
# sampling delta_W
138149
sigma_weight = torch.log1p(torch.exp(self.rho_weight))
139-
delta_weight = (sigma_weight * self.eps_weight.data.normal_())
150+
eps_weight = self.eps_weight.data.normal_()
151+
delta_weight = sigma_weight * eps_weight
152+
# delta_weight = (sigma_weight * self.eps_weight.data.normal_())
140153

141154
# get kl divergence
142155
if return_kl:
@@ -153,14 +166,32 @@ def forward(self, x, return_kl=True):
153166

154167
# linear outputs
155168
outputs = F.linear(x, self.mu_weight, self.mu_bias)
156-
157169
sign_input = x.clone().uniform_(-1, 1).sign()
158170
sign_output = outputs.clone().uniform_(-1, 1).sign()
159-
160-
perturbed_outputs = F.linear(x * sign_input, delta_weight,
161-
bias) * sign_output
171+
x_tmp = x * sign_input
172+
perturbed_outputs_tmp = F.linear(x_tmp, delta_weight, bias)
173+
perturbed_outputs = perturbed_outputs_tmp * sign_output
174+
out = outputs + perturbed_outputs
175+
176+
if self.quant_prepare:
177+
# quint8 quantstub
178+
input = self.quint_quant[0](input) # input
179+
outputs = self.quint_quant[1](outputs) # output
180+
sign_input = self.quint_quant[2](sign_input)
181+
sign_output = self.quint_quant[3](sign_output)
182+
x_tmp = self.quint_quant[4](x_tmp)
183+
perturbed_outputs_tmp = self.quint_quant[5](perturbed_outputs_tmp) # output
184+
perturbed_outputs = self.quint_quant[6](perturbed_outputs) # output
185+
out = self.quint_quant[7](out) # output
186+
187+
# qint8 quantstub
188+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
189+
mu_weight = self.qint_quant[1](self.mu_weight) # weight
190+
eps_weight = self.qint_quant[2](eps_weight) # random variable
191+
delta_weight =self.qint_quant[3](delta_weight) # multiply activation
192+
162193

163194
# returning outputs + perturbations
164195
if return_kl:
165-
return outputs + perturbed_outputs, kl
166-
return outputs + perturbed_outputs
196+
return out, kl
197+
return out

bayesian_torch/layers/flipout_layers/quantized_conv_flipout.py

Lines changed: 59 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def __init__(self,
284284
self.bn_eps = None
285285

286286
self.is_dequant = False
287+
self.quant_dict = None
287288

288289
def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255):
289290
""" An implementation for symmetric quantization
@@ -425,40 +426,67 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1
425426
if self.dnn_to_bnn_flag:
426427
return_kl = False
427428

428-
if x.dtype!=torch.quint8:
429-
x = torch.quantize_per_tensor(x, default_scale, default_zero_point, torch.quint8)
430-
431-
bias = None
432-
if self.bias:
433-
bias = self.quantized_mu_bias
434-
435-
outputs = torch.nn.quantized.functional.conv2d(x, self.quantized_mu_weight, bias, self.stride, self.padding,
436-
self.dilation, self.groups, scale=default_scale, zero_point=default_zero_point) # input: quint8, weight: qint8, bias: fp32
437-
438-
# sampling perturbation signs
439-
sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign()
440-
sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign()
441-
sign_input = torch.quantize_per_tensor(sign_input, default_scale, default_zero_point, torch.quint8)
442-
sign_output = torch.quantize_per_tensor(sign_output, default_scale, default_zero_point, torch.quint8)
443-
444-
# getting perturbation weights
445-
eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8)
446-
new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale())
447-
delta_kernel = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0)
448-
449429
bias = None
450430
if self.bias:
451-
eps_bias = self.eps_bias.data.normal_()
452-
bias = (self.quantized_sigma_bias * eps_bias)
431+
bias = self.quantized_mu_bias # TODO: check correctness
432+
433+
if self.quant_dict is not None:
434+
# getting perturbation weights
435+
eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), self.quant_dict[0]['scale'], self.quant_dict[0]['zero_point'], torch.qint8)
436+
delta_kernel = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, self.quant_dict[1]['scale'], self.quant_dict[1]['zero_point'])
437+
438+
if x.dtype!=torch.quint8: # check if input has been quantized
439+
x = torch.quantize_per_tensor(x, self.quant_dict[2]['scale'], self.quant_dict[2]['zero_point'], torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format
440+
441+
outputs = torch.nn.quantized.functional.conv2d(x, self.quantized_mu_weight, bias, self.stride, self.padding,
442+
self.dilation, self.groups, scale=self.quant_dict[3]['scale'], zero_point=self.quant_dict[3]['zero_point']) # input: quint8, weight: qint8, bias: fp32
443+
444+
# sampling perturbation signs
445+
sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign()
446+
sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign()
447+
sign_input = torch.quantize_per_tensor(sign_input, self.quant_dict[4]['scale'], self.quant_dict[4]['zero_point'], torch.quint8)
448+
sign_output = torch.quantize_per_tensor(sign_output, self.quant_dict[5]['scale'], self.quant_dict[5]['zero_point'], torch.quint8)
449+
450+
# perturbed feedforward
451+
x = torch.ops.quantized.mul(x, sign_input, self.quant_dict[6]['scale'], self.quant_dict[6]['zero_point'])
452+
perturbed_outputs = torch.nn.quantized.functional.conv2d(x,
453+
weight=delta_kernel, bias=bias, stride=self.stride, padding=self.padding,
454+
dilation=self.dilation, groups=self.groups, scale=self.quant_dict[7]['scale'], zero_point=self.quant_dict[7]['zero_point'])
455+
perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, self.quant_dict[8]['scale'], self.quant_dict[8]['zero_point'])
456+
out = torch.ops.quantized.add(outputs, perturbed_outputs, self.quant_dict[9]['scale'], self.quant_dict[9]['zero_point'])
457+
out = out.dequantize()
453458

454-
# perturbed feedforward
455-
x = torch.ops.quantized.mul(x, sign_input, default_scale, default_zero_point)
456-
457-
perturbed_outputs = torch.nn.quantized.functional.conv2d(x,
458-
weight=delta_kernel, bias=bias, stride=self.stride, padding=self.padding,
459-
dilation=self.dilation, groups=self.groups, scale=default_scale, zero_point=default_zero_point)
460-
perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point)
461-
out = torch.ops.quantized.add(outputs, perturbed_outputs, default_scale, default_zero_point)
459+
else:
460+
if x.dtype!=torch.quint8:
461+
x = torch.quantize_per_tensor(x, default_scale, default_zero_point, torch.quint8)
462+
463+
outputs = torch.nn.quantized.functional.conv2d(x, self.quantized_mu_weight, bias, self.stride, self.padding,
464+
self.dilation, self.groups, scale=default_scale, zero_point=default_zero_point) # input: quint8, weight: qint8, bias: fp32
465+
466+
# sampling perturbation signs
467+
sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign()
468+
sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign()
469+
sign_input = torch.quantize_per_tensor(sign_input, default_scale, default_zero_point, torch.quint8)
470+
sign_output = torch.quantize_per_tensor(sign_output, default_scale, default_zero_point, torch.quint8)
471+
472+
# getting perturbation weights
473+
eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8)
474+
new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale())
475+
delta_kernel = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0)
476+
477+
bias = None
478+
if self.bias:
479+
eps_bias = self.eps_bias.data.normal_()
480+
bias = (self.quantized_sigma_bias * eps_bias)
481+
482+
# perturbed feedforward
483+
x = torch.ops.quantized.mul(x, sign_input, default_scale, default_zero_point)
484+
485+
perturbed_outputs = torch.nn.quantized.functional.conv2d(x,
486+
weight=delta_kernel, bias=bias, stride=self.stride, padding=self.padding,
487+
dilation=self.dilation, groups=self.groups, scale=default_scale, zero_point=default_zero_point)
488+
perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point)
489+
out = torch.ops.quantized.add(outputs, perturbed_outputs, default_scale, default_zero_point)
462490

463491
if return_kl:
464492
return out, 0

0 commit comments

Comments
 (0)