Skip to content

Commit 4057a23

Browse files
committed
minor refactor
Signed-off-by: Suguna Velury <[email protected]>
1 parent 5beed13 commit 4057a23

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,7 @@ def to_quantized_weight(
727727
quantization: str,
728728
weights_scaling_factor2: torch.Tensor | None = None,
729729
block_size: int | None = None,
730+
dtype: torch.dtype | None = None,
730731
):
731732
"""Converts the weight to the quantized (packed) format."""
732733
if weights_scaling_factor is not None:
@@ -739,6 +740,9 @@ def to_quantized_weight(
739740
if isinstance(weight, QTensorWrapper):
740741
return weight.data
741742

743+
if dtype:
744+
weight = weight.to(dtype)
745+
742746
if quantization == QUANTIZATION_FP8:
743747
# Fix RuntimeError: Promotion for Float8 Types is not supported, attempted to promote Float8_e4m3fn and Float
744748
# in speculative decoding fp8 model export

modelopt/torch/export/unified_export_hf.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from modelopt.torch.quantization import set_quantizer_by_cfg_context
3131
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
32-
from modelopt.torch.quantization.qtensor import NVFP4QTensor, QTensorWrapper
32+
from modelopt.torch.quantization.qtensor import NVFP4QTensor
3333
from modelopt.torch.quantization.utils import quantizer_attr_names
3434

3535
from .convert_hf_config import convert_hf_quant_config_format
@@ -314,23 +314,25 @@ def _export_quantized_weight(
314314
)[0]
315315

316316
quantized_weight = to_quantized_weight(
317-
weight.to(dtype) if not isinstance(weight, QTensorWrapper) else weight,
317+
weight,
318318
weight_scale,
319319
quantization_format,
320320
weight_scale_2,
321321
block_size,
322+
dtype,
322323
)
323324

324325
quantized_weight, weight_scale = maybe_transpose_expert_weight_dimensions(
325326
quantized_weight, weight_scale, is_bmm_expert_weight=is_bmm_expert_weight
326327
)
327328
else:
328329
quantized_weight = to_quantized_weight(
329-
weight.to(dtype) if not isinstance(weight, QTensorWrapper) else weight,
330+
weight,
330331
weight_scale,
331332
quantization_format,
332333
weight_scale_2,
333334
block_size,
335+
dtype,
334336
)
335337

336338
setattr(sub_module, weight_name, nn.Parameter(quantized_weight, requires_grad=False))

0 commit comments

Comments
 (0)