Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
d7ca877
initial commit
ishan-modi Mar 30, 2025
a016c56
update
ishan-modi Mar 31, 2025
eb73ab0
updates
ishan-modi Apr 1, 2025
7fdb79e
update
ishan-modi Apr 8, 2025
a83bb98
update
ishan-modi Apr 8, 2025
9d9f0b9
update
ishan-modi Apr 10, 2025
7b09750
update
ishan-modi Apr 21, 2025
4fe06ee
Merge branch 'main' into add-trtquant-backend
ishan-modi Apr 23, 2025
71d8a7e
update
ishan-modi Apr 23, 2025
6c74c69
update
ishan-modi Apr 24, 2025
10fb9fe
Merge branch 'main' into add-trtquant-backend
sayakpaul Apr 29, 2025
6c65138
addressed PR comments
ishan-modi Apr 29, 2025
4b32567
Merge remote-tracking branch 'origin/add-trtquant-backend' into add-t…
ishan-modi Apr 29, 2025
915dbf0
update
ishan-modi Apr 30, 2025
3336a08
Merge branch 'main' into add-trtquant-backend
sayakpaul May 1, 2025
1c470f2
Merge branch 'main' into add-trtquant-backend
sayakpaul May 2, 2025
f823a2c
addressed PR comments
ishan-modi May 6, 2025
e78841e
Merge branch 'add-trtquant-backend' of https://github.com/ishan-modi/…
ishan-modi May 6, 2025
8f88f29
update
ishan-modi May 6, 2025
212603f
update
ishan-modi May 6, 2025
24f1bcb
update
ishan-modi May 6, 2025
65097f1
update
ishan-modi May 6, 2025
97f94ae
update
ishan-modi May 6, 2025
752544f
update
ishan-modi May 9, 2025
415901f
Merge branch 'main' into add-trtquant-backend
ishan-modi May 29, 2025
482fe78
updates
ishan-modi Jul 21, 2025
488282f
Merge branch 'main' into add-trtquant-backend
ishan-modi Jul 21, 2025
88259c9
Merge branch 'huggingface:main' into add-trtquant-backend
ishan-modi Aug 3, 2025
e51be6a
Merge branch 'main' into add-trtquant-backend
ishan-modi Aug 15, 2025
d48835d
update
ishan-modi Aug 16, 2025
5c4a4ea
Merge branch 'main' into add-trtquant-backend
ishan-modi Aug 16, 2025
670202d
update
ishan-modi Aug 16, 2025
6dd903f
Merge branch 'add-trtquant-backend' of https://github.com/ishan-modi/…
ishan-modi Aug 16, 2025
3f672d3
Merge branch 'main' into add-trtquant-backend
sayakpaul Aug 20, 2025
64d018c
Merge branch 'main' into add-trtquant-backend
sayakpaul Aug 20, 2025
395e75b
addressed PR comments
ishan-modi Aug 22, 2025
9034661
Merge branch 'main' into add-trtquant-backend
sayakpaul Aug 22, 2025
bbbc840
updates
ishan-modi Aug 22, 2025
2076783
Merge branch 'add-trtquant-backend' of https://github.com/ishan-modi/…
ishan-modi Aug 22, 2025
c53d251
code formatting
ishan-modi Aug 22, 2025
1ddcc9c
update
ishan-modi Aug 22, 2025
5df6926
addressed PR comments
ishan-modi Aug 22, 2025
8439f01
Merge branch 'main' into add-trtquant-backend
ishan-modi Aug 22, 2025
b96da23
Merge branch 'main' into add-trtquant-backend
ishan-modi Aug 26, 2025
0bf90b0
addressed PR comments
ishan-modi Aug 26, 2025
b097f0f
Merge branch 'add-trtquant-backend' of https://github.com/ishan-modi/…
ishan-modi Aug 26, 2025
cf054d2
addressed PR comments
ishan-modi Aug 26, 2025
0828f50
Merge branch 'main' into add-trtquant-backend
sayakpaul Aug 27, 2025
031298d
Merge branch 'main' into add-trtquant-backend
sayakpaul Aug 27, 2025
f345325
Merge branch 'main' into add-trtquant-backend
sayakpaul Aug 30, 2025
dd39595
addressed PR comments
ishan-modi Sep 1, 2025
d66709b
Merge branch 'add-trtquant-backend' of https://github.com/ishan-modi/…
ishan-modi Sep 1, 2025
81f4785
Merge branch 'main' into add-trtquant-backend
sayakpaul Sep 1, 2025
8f60186
fix docs and dependencies
ishan-modi Sep 1, 2025
8daf21d
Merge branch 'add-trtquant-backend' of https://github.com/ishan-modi/…
ishan-modi Sep 1, 2025
1a8806f
fixed dependency test
ishan-modi Sep 1, 2025
cb4e44b
Merge branch 'main' into add-trtquant-backend
sayakpaul Sep 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
"peft>=0.6.0",
"protobuf>=3.20.3,<4",
"pytest",
"pulp",
"pytest-timeout",
"pytest-xdist",
"python>=3.8.0",
Expand All @@ -128,10 +129,12 @@
"GitPython<3.1.19",
"scipy",
"onnx",
"torchprofile>=0.0.4",
"optimum_quanto>=0.2.6",
"gguf>=0.10.0",
"torchao>=0.7.0",
"bitsandbytes>=0.43.3",
"nvidia_modelopt>=0.27.0",
"regex!=2019.12.17",
"requests",
"tensorboard",
Expand Down Expand Up @@ -244,6 +247,7 @@ def run(self):
extras["gguf"] = deps_list("gguf", "accelerate")
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
extras["torchao"] = deps_list("torchao", "accelerate")
extras["nvidia_modelopt"] = deps_list("nvidia_modelopt", "onnx", "pulp", "torchprofile", "accelerate")

if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
Expand Down
21 changes: 21 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
is_k_diffusion_available,
is_librosa_available,
is_note_seq_available,
is_nvidia_modelopt_available,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
Expand Down Expand Up @@ -108,6 +109,18 @@
else:
_import_structure["quantizers.quantization_config"].append("QuantoConfig")

try:
if not is_torch_available() and not is_accelerate_available() and not is_nvidia_modelopt_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_nvidia_modelopt_objects

_import_structure["utils.dummy_nvidia_modelopt_objects"] = [
name for name in dir(dummy_nvidia_modelopt_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("NVIDIAModelOptConfig")

try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -714,6 +727,14 @@
else:
from .quantizers.quantization_config import QuantoConfig

try:
if not is_nvidia_modelopt_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_nvidia_modelopt_objects import *
else:
from .quantizers.quantization_config import NVIDIAModelOptConfig

try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"peft": "peft>=0.6.0",
"protobuf": "protobuf>=3.20.3,<4",
"pytest": "pytest",
"pulp": "pulp",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.8.0",
Expand All @@ -35,10 +36,12 @@
"GitPython": "GitPython<3.1.19",
"scipy": "scipy",
"onnx": "onnx",
"torchprofile": "torchprofile>=0.0.4",
"optimum_quanto": "optimum_quanto>=0.2.6",
"gguf": "gguf>=0.10.0",
"torchao": "torchao>=0.7.0",
"bitsandbytes": "bitsandbytes>=0.43.3",
"nvidia_modelopt": "nvidia_modelopt>=0.27.0",
"regex": "regex!=2019.12.17",
"requests": "requests",
"tensorboard": "tensorboard",
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@

from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .gguf import GGUFQuantizer
from .modelopt import NVIDIAModelOptQuantizer
from .quantization_config import (
BitsAndBytesConfig,
GGUFQuantizationConfig,
NVIDIAModelOptConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
Expand All @@ -39,6 +41,7 @@
"gguf": GGUFQuantizer,
"quanto": QuantoQuantizer,
"torchao": TorchAoHfQuantizer,
"modelopt": NVIDIAModelOptQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
Expand All @@ -47,6 +50,7 @@
"gguf": GGUFQuantizationConfig,
"quanto": QuantoConfig,
"torchao": TorchAoConfig,
"modelopt": NVIDIAModelOptConfig,
}


Expand Down
1 change: 1 addition & 0 deletions src/diffusers/quantizers/modelopt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .modelopt_quantizer import NVIDIAModelOptQuantizer
163 changes: 163 additions & 0 deletions src/diffusers/quantizers/modelopt/modelopt_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from typing import TYPE_CHECKING, Any, Dict, List, Union

from ...utils import (
get_module_from_name,
is_accelerate_available,
is_nvidia_modelopt_available,
is_nvidia_modelopt_version,
is_torch_available,
logging,
)
from ..base import DiffusersQuantizer


if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin


if is_torch_available():
import torch

if is_accelerate_available():
from accelerate.utils import set_module_tensor_to_device


logger = logging.get_logger(__name__)


class NVIDIAModelOptQuantizer(DiffusersQuantizer):
r"""
Diffusers Quantizer for TensorRT Model Optimizer
"""

use_keep_in_fp32_modules = True
requires_calibration = False
required_packages = ["modelopt"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be nvidia_modelopt no?


def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)

def validate_environment(self, *args, **kwargs):
if not is_nvidia_modelopt_available():
raise ImportError(
"Loading an nvidia-modelopt quantized model requires nvidia-modelopt library (`pip install nvidia-modelopt`)"
)
if not is_nvidia_modelopt_version(">=", "0.25.0"):
raise ImportError(
"Loading an nvidia-modelopt quantized model requires `nvidia-modelopt>=0.25.0`. "
"Please upgrade your installation with `pip install --upgrade nvidia-modelopt"
)

self.offload = False

device_map = kwargs.get("device_map", None)
if isinstance(device_map, dict):
if "cpu" in device_map.values() or "disk" in device_map.values():
if self.pre_quantized:
raise ValueError(
"You are attempting to perform cpu/disk offload with a pre-quantized modelopt model "
"This is not supported yet. Please remove the CPU or disk device from the `device_map` argument."
)
else:
self.offload = True

def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
):
# ModelOpt imports diffusers internally. This is here to prevent circular imports
from modelopt.torch.quantization.nn import QuantInputBase, SequentialQuantizer, TensorQuantizer
from modelopt.torch.quantization.qtensor import BaseQuantizedTensor

def is_param_quantized(module):
for _module in module.modules():
if isinstance(_module, TensorQuantizer) and not _module._dequantize:
return True
elif isinstance(_module, SequentialQuantizer):
for q in _module:
if isinstance(q, TensorQuantizer) and not q._dequantize:
return True
return False

module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized and any(isinstance(module, t) for t in [BaseQuantizedTensor]):
return True
elif isinstance(module, QuantInputBase) and "weight" in tensor_name:
return is_param_quantized(module)
return False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why dont we simply check if the model has any TensorQuantizer? is_quantized API from modelopt does exactly this.

Suggested change
from modelopt.torch.quantization.nn import QuantInputBase, SequentialQuantizer, TensorQuantizer
from modelopt.torch.quantization.qtensor import BaseQuantizedTensor
def is_param_quantized(module):
for _module in module.modules():
if isinstance(_module, TensorQuantizer) and not _module._dequantize:
return True
elif isinstance(_module, SequentialQuantizer):
for q in _module:
if isinstance(q, TensorQuantizer) and not q._dequantize:
return True
return False
module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized and any(isinstance(module, t) for t in [BaseQuantizedTensor]):
return True
elif isinstance(module, QuantInputBase) and "weight" in tensor_name:
return is_param_quantized(module)
return False
from modelopt.torch.quantization.utils import is_quantized
return is_quantized(model)

Copy link
Contributor Author

@ishan-modi ishan-modi Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using is_quantized would mean the following

  • won't be able to check quantization for SequentialQuantizer
  • for real_quantizer we need to also check if _dequantize is True (this means that it is actually compressed and not just replacing layers with their counterparts) IGNORE THIS (applicable before mtq.compress functionality).

hence this function, let me know if this looks correct with above cases in mind

Copy link

@realAsma realAsma May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't be able to check quantization for SequentialQuantizer

SequentialQuantizer is an nn.Sequential container for TensorQuantizer. See https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/e048fb239de4e4cb6a5f8f86ba70455478a1847e/modelopt/torch/quantization/nn/modules/tensor_quantizer.py#L1174

Hence, is_quantized will work even with SequentialQuantizer quantization.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to verify if is_quantized(model) cuts the deal for us minimally. @ishan-modi WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I will confirm and make the change !


def create_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
*args,
**kwargs,
):
"""
Create the quantized parameter by calling .calibrate() after setting it to the module.
"""
# ModelOpt imports diffusers internally. This is here to prevent circular imports
import modelopt.torch.quantization as mtq

dtype = kwargs.get("dtype", torch.float32)
module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized:
setattr(module, tensor_name, param_value)
else:
set_module_tensor_to_device(model, param_name, target_device, param_value, dtype)
mtq.calibrate(module, self.quantization_config.modelopt_config["algorithm"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would mention in the docstring that if we're doing activation calibration the kwargs for the forward pass need to be passed in as a dict.
https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/e048fb239de4e4cb6a5f8f86ba70455478a1847e/modelopt/torch/quantization/model_quant.py#L62-L64

mtq.compress(module)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mtq.compress compresses the model weights into lower-bit representations, allowing users to leverage it directly at the Torch level. However, as previously mentioned, to achieve actual speed improvements, we need to utilize the TensorRT runtime rather than the Torch runtime.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ishan-modi could we also mention this bit in the docs?

module.weight.requires_grad = False

def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
return max_memory

def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
if self.quantization_config.quant_type == "FP8":
target_dtype = torch.float8_e4m3fn
return target_dtype

def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype":
if torch_dtype is None:
logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.")
torch_dtype = torch.float32
return torch_dtype

def _process_model_before_weight_loading(
self,
model: "ModelMixin",
device_map,
keep_in_fp32_modules: List[str] = [],
**kwargs,
):
# ModelOpt imports diffusers internally. This is here to prevent circular imports
import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq

self.modules_to_not_convert = self.quantization_config.modules_to_not_convert

if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]

self.modules_to_not_convert.extend(keep_in_fp32_modules)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.modules_to_not_convert doesn't seem to be used anywhere? Don't we need to update the quant config with these values?


mto.apply_mode(model, mode=[("quantize", self.quantization_config.modelopt_config)], registry=mtq.mode.QuantizeModeRegistry)
model.config.quantization_config = self.quantization_config

def _process_model_after_weight_loading(self, model, **kwargs):
return model

@property
def is_trainable(self):
return True

@property
def is_serializable(self):
return True
Loading