-
Couldn't load subscription status.
- Fork 6.5k
[Quantization] Add TRT-ModelOpt as a Backend #11173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
d7ca877
a016c56
eb73ab0
7fdb79e
a83bb98
9d9f0b9
7b09750
4fe06ee
71d8a7e
6c74c69
10fb9fe
6c65138
4b32567
915dbf0
3336a08
1c470f2
f823a2c
e78841e
8f88f29
212603f
24f1bcb
65097f1
97f94ae
752544f
415901f
482fe78
488282f
88259c9
e51be6a
d48835d
5c4a4ea
670202d
6dd903f
3f672d3
64d018c
395e75b
9034661
bbbc840
2076783
c53d251
1ddcc9c
5df6926
8439f01
b96da23
0bf90b0
b097f0f
cf054d2
0828f50
031298d
f345325
dd39595
d66709b
81f4785
8f60186
8daf21d
1a8806f
cb4e44b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .modelopt_quantizer import ModelOptQuantizer |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,162 @@ | ||||||||||||||||||||||||||||||||||||||||||||
| from typing import TYPE_CHECKING, Any, Dict, List, Union | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| from ...utils import ( | ||||||||||||||||||||||||||||||||||||||||||||
| get_module_from_name, | ||||||||||||||||||||||||||||||||||||||||||||
| is_accelerate_available, | ||||||||||||||||||||||||||||||||||||||||||||
| is_nvidia_modelopt_available, | ||||||||||||||||||||||||||||||||||||||||||||
| is_nvidia_modelopt_version, | ||||||||||||||||||||||||||||||||||||||||||||
| is_torch_available, | ||||||||||||||||||||||||||||||||||||||||||||
| logging, | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
| from ..base import DiffusersQuantizer | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| if TYPE_CHECKING: | ||||||||||||||||||||||||||||||||||||||||||||
| from ...models.modeling_utils import ModelMixin | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| if is_torch_available(): | ||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| if is_accelerate_available(): | ||||||||||||||||||||||||||||||||||||||||||||
| from accelerate.utils import set_module_tensor_to_device | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| logger = logging.get_logger(__name__) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| class ModelOptQuantizer(DiffusersQuantizer): | ||||||||||||||||||||||||||||||||||||||||||||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||
| r""" | ||||||||||||||||||||||||||||||||||||||||||||
| Diffusers Quantizer for TensorRT Model Optimizer | ||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| use_keep_in_fp32_modules = True | ||||||||||||||||||||||||||||||||||||||||||||
| requires_calibration = False | ||||||||||||||||||||||||||||||||||||||||||||
| required_packages = ["modelopt"] | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, quantization_config, **kwargs): | ||||||||||||||||||||||||||||||||||||||||||||
| super().__init__(quantization_config, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def validate_environment(self, *args, **kwargs): | ||||||||||||||||||||||||||||||||||||||||||||
| if not is_nvidia_modelopt_available(): | ||||||||||||||||||||||||||||||||||||||||||||
| raise ImportError( | ||||||||||||||||||||||||||||||||||||||||||||
| "Loading an nvidia-modelopt quantized model requires nvidia-modelopt library (`pip install nvidia-modelopt`)" | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
| if not is_nvidia_modelopt_version(">=", "0.25.0"): | ||||||||||||||||||||||||||||||||||||||||||||
| raise ImportError( | ||||||||||||||||||||||||||||||||||||||||||||
| "Loading an nvidia-modelopt quantized model requires `nvidia-modelopt>=0.25.0`. " | ||||||||||||||||||||||||||||||||||||||||||||
| "Please upgrade your installation with `pip install --upgrade nvidia-modelopt" | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| self.offload = False | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| device_map = kwargs.get("device_map", None) | ||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(device_map, dict): | ||||||||||||||||||||||||||||||||||||||||||||
| if "cpu" in device_map.values() or "disk" in device_map.values(): | ||||||||||||||||||||||||||||||||||||||||||||
| if self.pre_quantized: | ||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||
| "You are attempting to perform cpu/disk offload with a pre-quantized modelopt model " | ||||||||||||||||||||||||||||||||||||||||||||
| "This is not supported yet. Please remove the CPU or disk device from the `device_map` argument." | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||
| self.offload = True | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def check_if_quantized_param( | ||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||
| model: "ModelMixin", | ||||||||||||||||||||||||||||||||||||||||||||
| param_value: "torch.Tensor", | ||||||||||||||||||||||||||||||||||||||||||||
| param_name: str, | ||||||||||||||||||||||||||||||||||||||||||||
| state_dict: Dict[str, Any], | ||||||||||||||||||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||
| # ModelOpt imports diffusers internally. This is here to prevent circular imports | ||||||||||||||||||||||||||||||||||||||||||||
| from modelopt.torch.quantization.nn import QuantInputBase, SequentialQuantizer, TensorQuantizer | ||||||||||||||||||||||||||||||||||||||||||||
| from modelopt.torch.quantization.qtensor import BaseQuantizedTensor | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def is_param_quantized(module): | ||||||||||||||||||||||||||||||||||||||||||||
| for _module in module.modules(): | ||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(_module, TensorQuantizer) and not _module._dequantize: | ||||||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(_module, SequentialQuantizer): | ||||||||||||||||||||||||||||||||||||||||||||
| for q in _module: | ||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(q, TensorQuantizer) and not q._dequantize: | ||||||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| module, tensor_name = get_module_from_name(model, param_name) | ||||||||||||||||||||||||||||||||||||||||||||
| if self.pre_quantized and any(isinstance(module, t) for t in [BaseQuantizedTensor]): | ||||||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(module, QuantInputBase) and "weight" in tensor_name: | ||||||||||||||||||||||||||||||||||||||||||||
| return is_param_quantized(module) | ||||||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
| from modelopt.torch.quantization.nn import QuantInputBase, SequentialQuantizer, TensorQuantizer | |
| from modelopt.torch.quantization.qtensor import BaseQuantizedTensor | |
| def is_param_quantized(module): | |
| for _module in module.modules(): | |
| if isinstance(_module, TensorQuantizer) and not _module._dequantize: | |
| return True | |
| elif isinstance(_module, SequentialQuantizer): | |
| for q in _module: | |
| if isinstance(q, TensorQuantizer) and not q._dequantize: | |
| return True | |
| return False | |
| module, tensor_name = get_module_from_name(model, param_name) | |
| if self.pre_quantized and any(isinstance(module, t) for t in [BaseQuantizedTensor]): | |
| return True | |
| elif isinstance(module, QuantInputBase) and "weight" in tensor_name: | |
| return is_param_quantized(module) | |
| return False | |
| from modelopt.torch.quantization.utils import is_quantized | |
| return is_quantized(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using is_quantized would mean the following
- won't be able to check quantization for
SequentialQuantizer for real_quantizer we need to also check ifIGNORE THIS (applicable before mtq.compress functionality)._dequantizeis True (this means that it is actually compressed and not just replacing layers with their counterparts)
hence this function, let me know if this looks correct with above cases in mind
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
won't be able to check quantization for SequentialQuantizer
SequentialQuantizer is an nn.Sequential container for TensorQuantizer. See https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/e048fb239de4e4cb6a5f8f86ba70455478a1847e/modelopt/torch/quantization/nn/modules/tensor_quantizer.py#L1174
Hence, is_quantized will work even with SequentialQuantizer quantization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be able to verify if is_quantized(model) cuts the deal for us minimally. @ishan-modi WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I will confirm and make the change !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mtq.compress compresses the model weights into lower-bit representations, allowing users to leverage it directly at the Torch level. However, as previously mentioned, to achieve actual speed improvements, we need to utilize the TensorRT runtime rather than the Torch runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ishan-modi could we also mention this bit in the docs?
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.modules_to_not_convert doesn't seem to be used anywhere? Don't we need to update the quant config with these values?
Uh oh!
There was an error while loading. Please reload this page.