Skip to content

Commit 6d6b4ba

Browse files
committed
refactor
Signed-off-by: weimingc <[email protected]>
1 parent 43cd4b0 commit 6d6b4ba

File tree

1 file changed

+6
-10
lines changed
  • modelopt/torch/quantization/plugins

1 file changed

+6
-10
lines changed

modelopt/torch/quantization/plugins/vllm.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,15 @@ def apply(
6161
x = layer.input_quantizer(x)
6262
if layer.weight_quantizer.is_enabled:
6363
original_weight = layer.weight
64-
_data = None
65-
# for parameter, we keep the original data, otherwise we modify the weight
6664
quantized_tensor = layer.weight_quantizer(layer.weight)
65+
# parameterize the quantized weight
6766
if isinstance(original_weight, torch.nn.Parameter):
68-
_data = original_weight.data
69-
layer.weight.data = quantized_tensor
70-
else:
71-
layer.weight = quantized_tensor
67+
quantized_tensor = torch.nn.Parameter(
68+
quantized_tensor, requires_grad=original_weight.requires_grad
69+
)
70+
layer.weight = quantized_tensor
7271
output = self.quant_method.apply(layer, x, bias)
73-
if _data is not None:
74-
layer.weight.data = _data
75-
else:
76-
layer.weight = original_weight
72+
layer.weight = original_weight
7773
else:
7874
output = self.quant_method.apply(layer, x, bias)
7975
output = layer.output_quantizer(output)

0 commit comments

Comments
 (0)