Skip to content

Commit 6d58352

Browse files
committed
minor
Signed-off-by: realAsma <[email protected]>
1 parent a13c5d4 commit 6d58352

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,10 @@ def __init__(self, module, name):
451451
self.num_search_steps = 0
452452
self.block_size = _get_awq_quantizer_block_size(module.weight, module.weight_quantizer)
453453
self.weight_scale = get_weight_scale(module.weight, self.block_size)
454-
self.loss = {k.item(): 0.0 for k in torch.arange(0, 1.0 + alpha_step, alpha_step)}
454+
self.loss = {
455+
k.item(): torch.zeros((), device=module.weight.device, dtype=torch.float32)
456+
for k in torch.arange(0, 1.0 + alpha_step, alpha_step)
457+
}
455458
self.best_scale = None
456459
self.best_alpha = None
457460
self.is_input_quantized = module.input_quantizer.is_enabled

0 commit comments

Comments
 (0)