diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index 694fab3dc6b..3c8191dc57d 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -52,7 +52,9 @@ def _derive_bias_qparams_fn( act_scale, weight_scale ) derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) - derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) + derived_zero = torch.zeros(derived_scale.size(), device=weight_zp.device).to( + torch.int32 + ) if isinstance(weight_obs_or_fq, PerBlockParamObserver): # keep maximum scale of each channel for bias derived_scale = (