Skip to content

Commit 395e75b

Browse files
committed
addressed PR comments
1 parent 64d018c commit 395e75b

File tree

3 files changed

+19
-14
lines changed

3 files changed

+19
-14
lines changed

docs/source/en/quantization/modelopt.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Unless required by applicable law or agreed to in writing, software distributed
99
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
1010
specific language governing permissions and limitations under the License. -->
1111

12-
# Nvidia ModelOpt
12+
# NVIDIA ModelOpt
1313

1414
[nvidia_modelopt](https://github.com/NVIDIA/TensorRT-Model-Optimizer) is a unified library of state-of-the-art model optimization techniques like quantization, pruning, distillation, speculative decoding, etc. It compresses deep learning models for downstream deployment frameworks like TensorRT-LLM or TensorRT to optimize inference speed.
1515

@@ -19,7 +19,6 @@ Before you begin, make sure you have nvidia_modelopt installed.
1919
pip install -U "nvidia_modelopt[hf]"
2020
```
2121

22-
2322
Quantize a model by passing [`NVIDIAModelOptConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
2423

2524
The example below only quantizes the weights to FP8.

src/diffusers/quantizers/modelopt/modelopt_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class NVIDIAModelOptQuantizer(DiffusersQuantizer):
3232

3333
use_keep_in_fp32_modules = True
3434
requires_calibration = False
35-
required_packages = ["modelopt"]
35+
required_packages = ["nvidia_modelopt"]
3636

3737
def __init__(self, quantization_config, **kwargs):
3838
super().__init__(quantization_config, **kwargs)

src/diffusers/quantizers/quantization_config.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -767,9 +767,13 @@ class NVIDIAModelOptConfig(QuantizationConfigMixin):
767767
"FP8": (4, 3),
768768
"INT8": 8,
769769
"INT4": 4,
770-
# "NF4": 4, # TODO: enable this upon modelopt release https://github.com/NVIDIA/TensorRT-Model-Optimizer/issues/183
770+
"NF4": 4,
771771
"NVFP4": (2,1),
772772
}
773+
quanttype_to_scalingbits = {
774+
"NF4": 8,
775+
"NVFP4": (4, 3),
776+
}
773777

774778
def __init__(
775779
self,
@@ -884,15 +888,17 @@ def get_config_from_quant_type(self) -> Dict[str, Any]:
884888
quant_cfg["*input_quantizer"]["axis"] = self.channel_quantize
885889
quant_cfg["*input_quantizer"]["type"] = "dynamic"
886890

887-
# Only fixed sizes are supported for now in modelopt
888-
if "NF4" in w_type:
889-
quant_cfg["*weight_quantizer"]["block_sizes"].update({"scale_bits":8, "scale_block_sizes": {self.scale_channel_quantize: self.scale_block_quantize}})
890-
elif "NVFP4" in w_type:
891-
quant_cfg["*weight_quantizer"]["block_sizes"].update({"scale_bits":(4,3), "type": "dynamic"})
892-
if act_type:
893-
if "NF4" in act_type:
894-
quant_cfg["*input_quantizer"]["block_sizes"].update({"scale_bits":8, "scale_block_sizes": {self.scale_channel_quantize: self.scale_block_quantize}})
895-
elif "NVFP4" in act_type:
896-
quant_cfg["*input_quantizer"]["block_sizes"].update({"scale_bits":(4,3), "type": "dynamic"})
891+
# Only fixed scaling sizes are supported for now in modelopt
892+
if self.scale_channel_quantize is not None and self.scale_block_quantize is not None:
893+
if w_type in NVIDIAModelOptConfig.quanttype_to_scalingbits:
894+
quant_cfg["*weight_quantizer"]["block_sizes"].update({
895+
"scale_bits": NVIDIAModelOptConfig.quanttype_to_scalingbits[w_type],
896+
"scale_block_sizes": {self.scale_channel_quantize: self.scale_block_quantize}
897+
})
898+
if act_type and act_type in NVIDIAModelOptConfig.quanttype_to_scalingbits:
899+
quant_cfg["*input_quantizer"]["block_sizes"].update({
900+
"scale_bits": NVIDIAModelOptConfig.quanttype_to_scalingbits[act_type],
901+
"scale_block_sizes": {self.scale_channel_quantize: self.scale_block_quantize}
902+
})
897903

898904
return BASE_CONFIG

0 commit comments

Comments
 (0)