Skip to content

Commit 18b296f

Browse files
committed
modify the bias to a torch.Parameter to allow for JIT tracing
1 parent 17480e6 commit 18b296f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

bayesian_torch/layers/variational_layers/quantize_linear_variational.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def quantize(self):
119119
delattr(self, "rho_weight")
120120

121121
self.quantized_mu_bias = self.mu_bias#Parameter(self.get_quantized_tensor(self.mu_bias), requires_grad=False)
122-
self.quantized_sigma_bias = torch.log1p(torch.exp(self.rho_bias))#Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_bias))), requires_grad=False)
122+
self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False)#Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_bias))), requires_grad=False)
123123
delattr(self, "mu_bias")
124124
delattr(self, "rho_bias")
125125

@@ -131,7 +131,7 @@ def dequantize(self): # Deprecated
131131
self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias)
132132
return
133133

134-
def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True):
134+
def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_scale=0.2, default_zero_point=128, return_kl=True):
135135
""" Forward pass
136136
137137
Parameters

0 commit comments

Comments
 (0)