@@ -210,13 +210,33 @@ def forward(self, x, return_kl=True):
210
210
self .prior_bias_sigma )
211
211
212
212
# perturbed feedforward
213
- perturbed_outputs = F . conv1d ( x * sign_input ,
214
- bias = bias ,
213
+ x_tmp = x * sign_input
214
+ perturbed_outputs_tmp = F . conv1d ( x * sign_input ,
215
215
weight = delta_kernel ,
216
+ bias = bias ,
216
217
stride = self .stride ,
217
218
padding = self .padding ,
218
219
dilation = self .dilation ,
219
- groups = self .groups ) * sign_output
220
+ groups = self .groups )
221
+ perturbed_outputs = perturbed_outputs_tmp * sign_output
222
+ out = outputs + perturbed_outputs
223
+
224
+ if self .quant_prepare :
225
+ # quint8 quantstub
226
+ x = self .quint_quant [0 ](x ) # input
227
+ outputs = self .quint_quant [1 ](outputs ) # output
228
+ sign_input = self .quint_quant [2 ](sign_input )
229
+ sign_output = self .quint_quant [3 ](sign_output )
230
+ x_tmp = self .quint_quant [4 ](x_tmp )
231
+ perturbed_outputs_tmp = self .quint_quant [5 ](perturbed_outputs_tmp ) # output
232
+ perturbed_outputs = self .quint_quant [6 ](perturbed_outputs ) # output
233
+ out = self .quint_quant [7 ](out ) # output
234
+
235
+ # qint8 quantstub
236
+ sigma_weight = self .qint_quant [0 ](sigma_weight ) # weight
237
+ mu_kernel = self .qint_quant [1 ](self .mu_kernel ) # weight
238
+ eps_kernel = self .qint_quant [2 ](eps_kernel ) # random variable
239
+ delta_kernel = self .qint_quant [3 ](delta_kernel ) # multiply activation
220
240
221
241
self .kl = kl
222
242
# returning outputs + perturbations
@@ -513,6 +533,15 @@ def __init__(self,
513
533
self .register_buffer ('prior_bias_sigma' , None , persistent = False )
514
534
515
535
self .init_parameters ()
536
+ self .quant_prepare = False
537
+
538
+ def prepare (self ):
539
+ self .qint_quant = nn .ModuleList ([torch .quantization .QuantStub (
540
+ QConfig (weight = MinMaxObserver .with_args (dtype = torch .qint8 , qscheme = torch .per_tensor_symmetric ), activation = MinMaxObserver .with_args (dtype = torch .qint8 ,qscheme = torch .per_tensor_symmetric ))) for _ in range (4 )])
541
+ self .quint_quant = nn .ModuleList ([torch .quantization .QuantStub (
542
+ QConfig (weight = MinMaxObserver .with_args (dtype = torch .quint8 ), activation = MinMaxObserver .with_args (dtype = torch .quint8 ))) for _ in range (8 )])
543
+ self .dequant = torch .quantization .DeQuantStub ()
544
+ self .quant_prepare = True
516
545
517
546
def init_parameters (self ):
518
547
# prior values
@@ -575,13 +604,33 @@ def forward(self, x, return_kl=True):
575
604
self .prior_bias_sigma )
576
605
577
606
# perturbed feedforward
578
- perturbed_outputs = F .conv3d (x * sign_input ,
607
+ x_tmp = x * sign_input
608
+ perturbed_outputs_tmp = F .conv3d (x * sign_input ,
579
609
weight = delta_kernel ,
580
610
bias = bias ,
581
611
stride = self .stride ,
582
612
padding = self .padding ,
583
613
dilation = self .dilation ,
584
- groups = self .groups ) * sign_output
614
+ groups = self .groups )
615
+ perturbed_outputs = perturbed_outputs_tmp * sign_output
616
+ out = outputs + perturbed_outputs
617
+
618
+ if self .quant_prepare :
619
+ # quint8 quantstub
620
+ x = self .quint_quant [0 ](x ) # input
621
+ outputs = self .quint_quant [1 ](outputs ) # output
622
+ sign_input = self .quint_quant [2 ](sign_input )
623
+ sign_output = self .quint_quant [3 ](sign_output )
624
+ x_tmp = self .quint_quant [4 ](x_tmp )
625
+ perturbed_outputs_tmp = self .quint_quant [5 ](perturbed_outputs_tmp ) # output
626
+ perturbed_outputs = self .quint_quant [6 ](perturbed_outputs ) # output
627
+ out = self .quint_quant [7 ](out ) # output
628
+
629
+ # qint8 quantstub
630
+ sigma_weight = self .qint_quant [0 ](sigma_weight ) # weight
631
+ mu_kernel = self .qint_quant [1 ](self .mu_kernel ) # weight
632
+ eps_kernel = self .qint_quant [2 ](eps_kernel ) # random variable
633
+ delta_kernel = self .qint_quant [3 ](delta_kernel ) # multiply activation
585
634
586
635
self .kl = kl
587
636
# returning outputs + perturbations
@@ -677,12 +726,20 @@ def __init__(self,
677
726
self .register_buffer ('prior_bias_sigma' , None , persistent = False )
678
727
679
728
self .init_parameters ()
729
+ self .quant_prepare = False
730
+
731
+ def prepare (self ):
732
+ self .qint_quant = nn .ModuleList ([torch .quantization .QuantStub (
733
+ QConfig (weight = MinMaxObserver .with_args (dtype = torch .qint8 , qscheme = torch .per_tensor_symmetric ), activation = MinMaxObserver .with_args (dtype = torch .qint8 ,qscheme = torch .per_tensor_symmetric ))) for _ in range (4 )])
734
+ self .quint_quant = nn .ModuleList ([torch .quantization .QuantStub (
735
+ QConfig (weight = MinMaxObserver .with_args (dtype = torch .quint8 ), activation = MinMaxObserver .with_args (dtype = torch .quint8 ))) for _ in range (8 )])
736
+ self .dequant = torch .quantization .DeQuantStub ()
737
+ self .quant_prepare = True
680
738
681
739
def init_parameters (self ):
682
740
# prior values
683
741
self .prior_weight_mu .data .fill_ (self .prior_mean )
684
- self .prior_weight_sigma .data .fill_
685
- (self .prior_variance )
742
+ self .prior_weight_sigma .data .fill_ (self .prior_variance )
686
743
687
744
# init our weights for the deterministic and perturbated weights
688
745
self .mu_kernel .data .normal_ (mean = self .posterior_mu_init , std = .1 )
@@ -741,15 +798,34 @@ def forward(self, x, return_kl=True):
741
798
self .prior_bias_sigma )
742
799
743
800
# perturbed feedforward
744
- perturbed_outputs = F .conv_transpose1d (
745
- x * sign_input ,
746
- weight = delta_kernel ,
747
- bias = bias ,
748
- stride = self .stride ,
749
- padding = self .padding ,
750
- output_padding = self .output_padding ,
751
- dilation = self .dilation ,
752
- groups = self .groups ) * sign_output
801
+ x_tmp = x * sign_input
802
+ perturbed_outputs_tmp = F .conv_transpose1d (x * sign_input ,
803
+ weight = delta_kernel ,
804
+ bias = bias ,
805
+ stride = self .stride ,
806
+ padding = self .padding ,
807
+ output_padding = self .output_padding ,
808
+ dilation = self .dilation ,
809
+ groups = self .groups )
810
+ perturbed_outputs = perturbed_outputs_tmp * sign_output
811
+ out = outputs + perturbed_outputs
812
+
813
+ if self .quant_prepare :
814
+ # quint8 quantstub
815
+ x = self .quint_quant [0 ](x ) # input
816
+ outputs = self .quint_quant [1 ](outputs ) # output
817
+ sign_input = self .quint_quant [2 ](sign_input )
818
+ sign_output = self .quint_quant [3 ](sign_output )
819
+ x_tmp = self .quint_quant [4 ](x_tmp )
820
+ perturbed_outputs_tmp = self .quint_quant [5 ](perturbed_outputs_tmp ) # output
821
+ perturbed_outputs = self .quint_quant [6 ](perturbed_outputs ) # output
822
+ out = self .quint_quant [7 ](out ) # output
823
+
824
+ # qint8 quantstub
825
+ sigma_weight = self .qint_quant [0 ](sigma_weight ) # weight
826
+ mu_kernel = self .qint_quant [1 ](self .mu_kernel ) # weight
827
+ eps_kernel = self .qint_quant [2 ](eps_kernel ) # random variable
828
+ delta_kernel = self .qint_quant [3 ](delta_kernel ) # multiply activation
753
829
754
830
self .kl = kl
755
831
# returning outputs + perturbations
@@ -850,6 +926,15 @@ def __init__(self,
850
926
self .register_buffer ('prior_bias_sigma' , None , persistent = False )
851
927
852
928
self .init_parameters ()
929
+ self .quant_prepare = False
930
+
931
+ def prepare (self ):
932
+ self .qint_quant = nn .ModuleList ([torch .quantization .QuantStub (
933
+ QConfig (weight = MinMaxObserver .with_args (dtype = torch .qint8 , qscheme = torch .per_tensor_symmetric ), activation = MinMaxObserver .with_args (dtype = torch .qint8 ,qscheme = torch .per_tensor_symmetric ))) for _ in range (4 )])
934
+ self .quint_quant = nn .ModuleList ([torch .quantization .QuantStub (
935
+ QConfig (weight = MinMaxObserver .with_args (dtype = torch .quint8 ), activation = MinMaxObserver .with_args (dtype = torch .quint8 ))) for _ in range (8 )])
936
+ self .dequant = torch .quantization .DeQuantStub ()
937
+ self .quant_prepare = True
853
938
854
939
def init_parameters (self ):
855
940
# prior values
@@ -913,15 +998,34 @@ def forward(self, x, return_kl=True):
913
998
self .prior_bias_sigma )
914
999
915
1000
# perturbed feedforward
916
- perturbed_outputs = F .conv_transpose2d (
917
- x * sign_input ,
918
- bias = bias ,
919
- weight = delta_kernel ,
920
- stride = self .stride ,
921
- padding = self .padding ,
922
- output_padding = self .output_padding ,
923
- dilation = self .dilation ,
924
- groups = self .groups ) * sign_output
1001
+ x_tmp = x * sign_input
1002
+ perturbed_outputs_tmp = F .conv_transpose2d (x * sign_input ,
1003
+ weight = delta_kernel ,
1004
+ bias = bias ,
1005
+ stride = self .stride ,
1006
+ padding = self .padding ,
1007
+ output_padding = self .output_padding ,
1008
+ dilation = self .dilation ,
1009
+ groups = self .groups )
1010
+ perturbed_outputs = perturbed_outputs_tmp * sign_output
1011
+ out = outputs + perturbed_outputs
1012
+
1013
+ if self .quant_prepare :
1014
+ # quint8 quantstub
1015
+ x = self .quint_quant [0 ](x ) # input
1016
+ outputs = self .quint_quant [1 ](outputs ) # output
1017
+ sign_input = self .quint_quant [2 ](sign_input )
1018
+ sign_output = self .quint_quant [3 ](sign_output )
1019
+ x_tmp = self .quint_quant [4 ](x_tmp )
1020
+ perturbed_outputs_tmp = self .quint_quant [5 ](perturbed_outputs_tmp ) # output
1021
+ perturbed_outputs = self .quint_quant [6 ](perturbed_outputs ) # output
1022
+ out = self .quint_quant [7 ](out ) # output
1023
+
1024
+ # qint8 quantstub
1025
+ sigma_weight = self .qint_quant [0 ](sigma_weight ) # weight
1026
+ mu_kernel = self .qint_quant [1 ](self .mu_kernel ) # weight
1027
+ eps_kernel = self .qint_quant [2 ](eps_kernel ) # random variable
1028
+ delta_kernel = self .qint_quant [3 ](delta_kernel ) # multiply activation
925
1029
926
1030
self .kl = kl
927
1031
# returning outputs + perturbations
@@ -1022,6 +1126,15 @@ def __init__(self,
1022
1126
self .register_buffer ('prior_bias_sigma' , None , persistent = False )
1023
1127
1024
1128
self .init_parameters ()
1129
+ self .quant_prepare = False
1130
+
1131
+ def prepare (self ):
1132
+ self .qint_quant = nn .ModuleList ([torch .quantization .QuantStub (
1133
+ QConfig (weight = MinMaxObserver .with_args (dtype = torch .qint8 , qscheme = torch .per_tensor_symmetric ), activation = MinMaxObserver .with_args (dtype = torch .qint8 ,qscheme = torch .per_tensor_symmetric ))) for _ in range (4 )])
1134
+ self .quint_quant = nn .ModuleList ([torch .quantization .QuantStub (
1135
+ QConfig (weight = MinMaxObserver .with_args (dtype = torch .quint8 ), activation = MinMaxObserver .with_args (dtype = torch .quint8 ))) for _ in range (8 )])
1136
+ self .dequant = torch .quantization .DeQuantStub ()
1137
+ self .quant_prepare = True
1025
1138
1026
1139
def init_parameters (self ):
1027
1140
# prior values
@@ -1084,15 +1197,34 @@ def forward(self, x, return_kl=True):
1084
1197
self .prior_bias_sigma )
1085
1198
1086
1199
# perturbed feedforward
1087
- perturbed_outputs = F .conv_transpose3d (
1088
- x * sign_input ,
1089
- weight = delta_kernel ,
1090
- bias = bias ,
1091
- stride = self .stride ,
1092
- padding = self .padding ,
1093
- output_padding = self .output_padding ,
1094
- dilation = self .dilation ,
1095
- groups = self .groups ) * sign_output
1200
+ x_tmp = x * sign_input
1201
+ perturbed_outputs_tmp = F .conv_transpose3d (x * sign_input ,
1202
+ weight = delta_kernel ,
1203
+ bias = bias ,
1204
+ stride = self .stride ,
1205
+ padding = self .padding ,
1206
+ output_padding = self .output_padding ,
1207
+ dilation = self .dilation ,
1208
+ groups = self .groups )
1209
+ perturbed_outputs = perturbed_outputs_tmp * sign_output
1210
+ out = outputs + perturbed_outputs
1211
+
1212
+ if self .quant_prepare :
1213
+ # quint8 quantstub
1214
+ x = self .quint_quant [0 ](x ) # input
1215
+ outputs = self .quint_quant [1 ](outputs ) # output
1216
+ sign_input = self .quint_quant [2 ](sign_input )
1217
+ sign_output = self .quint_quant [3 ](sign_output )
1218
+ x_tmp = self .quint_quant [4 ](x_tmp )
1219
+ perturbed_outputs_tmp = self .quint_quant [5 ](perturbed_outputs_tmp ) # output
1220
+ perturbed_outputs = self .quint_quant [6 ](perturbed_outputs ) # output
1221
+ out = self .quint_quant [7 ](out ) # output
1222
+
1223
+ # qint8 quantstub
1224
+ sigma_weight = self .qint_quant [0 ](sigma_weight ) # weight
1225
+ mu_kernel = self .qint_quant [1 ](self .mu_kernel ) # weight
1226
+ eps_kernel = self .qint_quant [2 ](eps_kernel ) # random variable
1227
+ delta_kernel = self .qint_quant [3 ](delta_kernel ) # multiply activation
1096
1228
1097
1229
self .kl = kl
1098
1230
# returning outputs + perturbations
0 commit comments