Skip to content

Commit 1cdb872

Browse files
fix torchao quantizer for new torchao versions (#12901)
* fix torchao quantizer for new torchao versions Summary: `torchao==0.16.0` (not yet released) has some bc-breaking changes, this PR fixes the diffusers repo with those changes. Specifics on the changes: 1. `UInt4Tensor` is removed: pytorch/ao#3536 2. old float8 tensors v1 are removed: pytorch/ao#3510 In this PR: 1. move the logger variable up (not sure why it was in the middle of the file before) to get better error messages 2. gate the old torchao objects by torchao version Test Plan: import diffusers objects with new versions of torchao works: ```bash > python -c "import torchao; print(torchao.__version__); from diffusers import StableDiffusionPipeline" 0.16.0.dev20251229+cu129 ``` Reviewers: Subscribers: Tasks: Tags: * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent f6b6a71 commit 1cdb872

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
from ..base import DiffusersQuantizer
3737

3838

39+
logger = logging.get_logger(__name__)
40+
41+
3942
if TYPE_CHECKING:
4043
from ...models.modeling_utils import ModelMixin
4144

@@ -83,11 +86,19 @@ def _update_torch_safe_globals():
8386
]
8487
try:
8588
from torchao.dtypes import NF4Tensor
86-
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
87-
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
8889
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
8990

90-
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
91+
safe_globals.extend([UintxTensor, UintxAQTTensorImpl, NF4Tensor])
92+
93+
# note: is_torchao_version(">=", "0.16.0") does not work correctly
94+
# with torchao nightly, so using a ">" check which does work correctly
95+
if is_torchao_version(">", "0.15.0"):
96+
pass
97+
else:
98+
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
99+
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
100+
101+
safe_globals.extend([UInt4Tensor, Float8AQTTensorImpl])
91102

92103
except (ImportError, ModuleNotFoundError) as e:
93104
logger.warning(
@@ -123,9 +134,6 @@ def fuzzy_match_size(config_name: str) -> Optional[str]:
123134
return None
124135

125136

126-
logger = logging.get_logger(__name__)
127-
128-
129137
def _quantization_type(weight):
130138
from torchao.dtypes import AffineQuantizedTensor
131139
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor

0 commit comments

Comments
 (0)