Skip to content

Commit 342ca39

Browse files
committed
Add quant prepare functions
1 parent 39b41a5 commit 342ca39

File tree

2 files changed

+291
-39
lines changed

2 files changed

+291
-39
lines changed

bayesian_torch/layers/flipout_layers/conv_flipout.py

Lines changed: 166 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,33 @@ def forward(self, x, return_kl=True):
210210
self.prior_bias_sigma)
211211

212212
# perturbed feedforward
213-
perturbed_outputs = F.conv1d(x * sign_input,
214-
bias=bias,
213+
x_tmp = x * sign_input
214+
perturbed_outputs_tmp = F.conv1d(x * sign_input,
215215
weight=delta_kernel,
216+
bias=bias,
216217
stride=self.stride,
217218
padding=self.padding,
218219
dilation=self.dilation,
219-
groups=self.groups) * sign_output
220+
groups=self.groups)
221+
perturbed_outputs = perturbed_outputs_tmp * sign_output
222+
out = outputs + perturbed_outputs
223+
224+
if self.quant_prepare:
225+
# quint8 quantstub
226+
x = self.quint_quant[0](x) # input
227+
outputs = self.quint_quant[1](outputs) # output
228+
sign_input = self.quint_quant[2](sign_input)
229+
sign_output = self.quint_quant[3](sign_output)
230+
x_tmp = self.quint_quant[4](x_tmp)
231+
perturbed_outputs_tmp = self.quint_quant[5](perturbed_outputs_tmp) # output
232+
perturbed_outputs = self.quint_quant[6](perturbed_outputs) # output
233+
out = self.quint_quant[7](out) # output
234+
235+
# qint8 quantstub
236+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
237+
mu_kernel = self.qint_quant[1](self.mu_kernel) # weight
238+
eps_kernel = self.qint_quant[2](eps_kernel) # random variable
239+
delta_kernel =self.qint_quant[3](delta_kernel) # multiply activation
220240

221241
self.kl = kl
222242
# returning outputs + perturbations
@@ -513,6 +533,15 @@ def __init__(self,
513533
self.register_buffer('prior_bias_sigma', None, persistent=False)
514534

515535
self.init_parameters()
536+
self.quant_prepare=False
537+
538+
def prepare(self):
539+
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
540+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(4)])
541+
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
542+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(8)])
543+
self.dequant = torch.quantization.DeQuantStub()
544+
self.quant_prepare=True
516545

517546
def init_parameters(self):
518547
# prior values
@@ -575,13 +604,33 @@ def forward(self, x, return_kl=True):
575604
self.prior_bias_sigma)
576605

577606
# perturbed feedforward
578-
perturbed_outputs = F.conv3d(x * sign_input,
607+
x_tmp = x * sign_input
608+
perturbed_outputs_tmp = F.conv3d(x * sign_input,
579609
weight=delta_kernel,
580610
bias=bias,
581611
stride=self.stride,
582612
padding=self.padding,
583613
dilation=self.dilation,
584-
groups=self.groups) * sign_output
614+
groups=self.groups)
615+
perturbed_outputs = perturbed_outputs_tmp * sign_output
616+
out = outputs + perturbed_outputs
617+
618+
if self.quant_prepare:
619+
# quint8 quantstub
620+
x = self.quint_quant[0](x) # input
621+
outputs = self.quint_quant[1](outputs) # output
622+
sign_input = self.quint_quant[2](sign_input)
623+
sign_output = self.quint_quant[3](sign_output)
624+
x_tmp = self.quint_quant[4](x_tmp)
625+
perturbed_outputs_tmp = self.quint_quant[5](perturbed_outputs_tmp) # output
626+
perturbed_outputs = self.quint_quant[6](perturbed_outputs) # output
627+
out = self.quint_quant[7](out) # output
628+
629+
# qint8 quantstub
630+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
631+
mu_kernel = self.qint_quant[1](self.mu_kernel) # weight
632+
eps_kernel = self.qint_quant[2](eps_kernel) # random variable
633+
delta_kernel =self.qint_quant[3](delta_kernel) # multiply activation
585634

586635
self.kl = kl
587636
# returning outputs + perturbations
@@ -677,12 +726,20 @@ def __init__(self,
677726
self.register_buffer('prior_bias_sigma', None, persistent=False)
678727

679728
self.init_parameters()
729+
self.quant_prepare=False
730+
731+
def prepare(self):
732+
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
733+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(4)])
734+
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
735+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(8)])
736+
self.dequant = torch.quantization.DeQuantStub()
737+
self.quant_prepare=True
680738

681739
def init_parameters(self):
682740
# prior values
683741
self.prior_weight_mu.data.fill_(self.prior_mean)
684-
self.prior_weight_sigma.data.fill_
685-
(self.prior_variance)
742+
self.prior_weight_sigma.data.fill_(self.prior_variance)
686743

687744
# init our weights for the deterministic and perturbated weights
688745
self.mu_kernel.data.normal_(mean=self.posterior_mu_init, std=.1)
@@ -741,15 +798,34 @@ def forward(self, x, return_kl=True):
741798
self.prior_bias_sigma)
742799

743800
# perturbed feedforward
744-
perturbed_outputs = F.conv_transpose1d(
745-
x * sign_input,
746-
weight=delta_kernel,
747-
bias=bias,
748-
stride=self.stride,
749-
padding=self.padding,
750-
output_padding=self.output_padding,
751-
dilation=self.dilation,
752-
groups=self.groups) * sign_output
801+
x_tmp = x * sign_input
802+
perturbed_outputs_tmp = F.conv_transpose1d(x * sign_input,
803+
weight=delta_kernel,
804+
bias=bias,
805+
stride=self.stride,
806+
padding=self.padding,
807+
output_padding=self.output_padding,
808+
dilation=self.dilation,
809+
groups=self.groups)
810+
perturbed_outputs = perturbed_outputs_tmp * sign_output
811+
out = outputs + perturbed_outputs
812+
813+
if self.quant_prepare:
814+
# quint8 quantstub
815+
x = self.quint_quant[0](x) # input
816+
outputs = self.quint_quant[1](outputs) # output
817+
sign_input = self.quint_quant[2](sign_input)
818+
sign_output = self.quint_quant[3](sign_output)
819+
x_tmp = self.quint_quant[4](x_tmp)
820+
perturbed_outputs_tmp = self.quint_quant[5](perturbed_outputs_tmp) # output
821+
perturbed_outputs = self.quint_quant[6](perturbed_outputs) # output
822+
out = self.quint_quant[7](out) # output
823+
824+
# qint8 quantstub
825+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
826+
mu_kernel = self.qint_quant[1](self.mu_kernel) # weight
827+
eps_kernel = self.qint_quant[2](eps_kernel) # random variable
828+
delta_kernel =self.qint_quant[3](delta_kernel) # multiply activation
753829

754830
self.kl = kl
755831
# returning outputs + perturbations
@@ -850,6 +926,15 @@ def __init__(self,
850926
self.register_buffer('prior_bias_sigma', None, persistent=False)
851927

852928
self.init_parameters()
929+
self.quant_prepare=False
930+
931+
def prepare(self):
932+
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
933+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(4)])
934+
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
935+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(8)])
936+
self.dequant = torch.quantization.DeQuantStub()
937+
self.quant_prepare=True
853938

854939
def init_parameters(self):
855940
# prior values
@@ -913,15 +998,34 @@ def forward(self, x, return_kl=True):
913998
self.prior_bias_sigma)
914999

9151000
# perturbed feedforward
916-
perturbed_outputs = F.conv_transpose2d(
917-
x * sign_input,
918-
bias=bias,
919-
weight=delta_kernel,
920-
stride=self.stride,
921-
padding=self.padding,
922-
output_padding=self.output_padding,
923-
dilation=self.dilation,
924-
groups=self.groups) * sign_output
1001+
x_tmp = x * sign_input
1002+
perturbed_outputs_tmp = F.conv_transpose2d(x * sign_input,
1003+
weight=delta_kernel,
1004+
bias=bias,
1005+
stride=self.stride,
1006+
padding=self.padding,
1007+
output_padding=self.output_padding,
1008+
dilation=self.dilation,
1009+
groups=self.groups)
1010+
perturbed_outputs = perturbed_outputs_tmp * sign_output
1011+
out = outputs + perturbed_outputs
1012+
1013+
if self.quant_prepare:
1014+
# quint8 quantstub
1015+
x = self.quint_quant[0](x) # input
1016+
outputs = self.quint_quant[1](outputs) # output
1017+
sign_input = self.quint_quant[2](sign_input)
1018+
sign_output = self.quint_quant[3](sign_output)
1019+
x_tmp = self.quint_quant[4](x_tmp)
1020+
perturbed_outputs_tmp = self.quint_quant[5](perturbed_outputs_tmp) # output
1021+
perturbed_outputs = self.quint_quant[6](perturbed_outputs) # output
1022+
out = self.quint_quant[7](out) # output
1023+
1024+
# qint8 quantstub
1025+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
1026+
mu_kernel = self.qint_quant[1](self.mu_kernel) # weight
1027+
eps_kernel = self.qint_quant[2](eps_kernel) # random variable
1028+
delta_kernel =self.qint_quant[3](delta_kernel) # multiply activation
9251029

9261030
self.kl = kl
9271031
# returning outputs + perturbations
@@ -1022,6 +1126,15 @@ def __init__(self,
10221126
self.register_buffer('prior_bias_sigma', None, persistent=False)
10231127

10241128
self.init_parameters()
1129+
self.quant_prepare=False
1130+
1131+
def prepare(self):
1132+
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
1133+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(4)])
1134+
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
1135+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(8)])
1136+
self.dequant = torch.quantization.DeQuantStub()
1137+
self.quant_prepare=True
10251138

10261139
def init_parameters(self):
10271140
# prior values
@@ -1084,15 +1197,34 @@ def forward(self, x, return_kl=True):
10841197
self.prior_bias_sigma)
10851198

10861199
# perturbed feedforward
1087-
perturbed_outputs = F.conv_transpose3d(
1088-
x * sign_input,
1089-
weight=delta_kernel,
1090-
bias=bias,
1091-
stride=self.stride,
1092-
padding=self.padding,
1093-
output_padding=self.output_padding,
1094-
dilation=self.dilation,
1095-
groups=self.groups) * sign_output
1200+
x_tmp = x * sign_input
1201+
perturbed_outputs_tmp = F.conv_transpose3d(x * sign_input,
1202+
weight=delta_kernel,
1203+
bias=bias,
1204+
stride=self.stride,
1205+
padding=self.padding,
1206+
output_padding=self.output_padding,
1207+
dilation=self.dilation,
1208+
groups=self.groups)
1209+
perturbed_outputs = perturbed_outputs_tmp * sign_output
1210+
out = outputs + perturbed_outputs
1211+
1212+
if self.quant_prepare:
1213+
# quint8 quantstub
1214+
x = self.quint_quant[0](x) # input
1215+
outputs = self.quint_quant[1](outputs) # output
1216+
sign_input = self.quint_quant[2](sign_input)
1217+
sign_output = self.quint_quant[3](sign_output)
1218+
x_tmp = self.quint_quant[4](x_tmp)
1219+
perturbed_outputs_tmp = self.quint_quant[5](perturbed_outputs_tmp) # output
1220+
perturbed_outputs = self.quint_quant[6](perturbed_outputs) # output
1221+
out = self.quint_quant[7](out) # output
1222+
1223+
# qint8 quantstub
1224+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
1225+
mu_kernel = self.qint_quant[1](self.mu_kernel) # weight
1226+
eps_kernel = self.qint_quant[2](eps_kernel) # random variable
1227+
delta_kernel =self.qint_quant[3](delta_kernel) # multiply activation
10961228

10971229
self.kl = kl
10981230
# returning outputs + perturbations

0 commit comments

Comments
 (0)