Skip to content

Commit fa273fd

Browse files
committed
update
1 parent 6a0ae75 commit fa273fd

File tree

2 files changed

+40
-35
lines changed

2 files changed

+40
-35
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@
4141
is_gguf_available,
4242
is_torch_available,
4343
is_torch_version,
44-
is_torchao_available,
45-
is_torchao_version,
4644
logging,
4745
)
4846

@@ -61,38 +59,6 @@
6159
from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device
6260

6361

64-
def _update_torch_safe_globals():
65-
safe_globals = [
66-
(torch.uint1, "torch.uint1"),
67-
(torch.uint2, "torch.uint2"),
68-
(torch.uint3, "torch.uint3"),
69-
(torch.uint4, "torch.uint4"),
70-
(torch.uint5, "torch.uint5"),
71-
(torch.uint6, "torch.uint6"),
72-
(torch.uint7, "torch.uint7"),
73-
]
74-
try:
75-
from torchao.dtypes import NF4Tensor
76-
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
77-
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
78-
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
79-
80-
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
81-
82-
except (ImportError, ModuleNotFoundError) as e:
83-
logger.warning(
84-
"Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`"
85-
)
86-
logger.debug(e)
87-
88-
finally:
89-
torch.serialization.add_safe_globals(safe_globals=safe_globals)
90-
91-
92-
if is_torch_version(">=", "2.6") and is_torchao_available() and is_torchao_version(">=", "0.7.0"):
93-
_update_torch_safe_globals()
94-
95-
9662
# Adapted from `transformers` (see modeling_utils.py)
9763
def _determine_device_map(
9864
model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@
2323

2424
from packaging import version
2525

26-
from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging
26+
from ...utils import (
27+
get_module_from_name,
28+
is_torch_available,
29+
is_torch_version,
30+
is_torchao_available,
31+
is_torchao_version,
32+
logging,
33+
)
2734
from ..base import DiffusersQuantizer
2835

2936

@@ -62,6 +69,38 @@
6269
from torchao.quantization import quantize_
6370

6471

72+
def _update_torch_safe_globals():
73+
safe_globals = [
74+
(torch.uint1, "torch.uint1"),
75+
(torch.uint2, "torch.uint2"),
76+
(torch.uint3, "torch.uint3"),
77+
(torch.uint4, "torch.uint4"),
78+
(torch.uint5, "torch.uint5"),
79+
(torch.uint6, "torch.uint6"),
80+
(torch.uint7, "torch.uint7"),
81+
]
82+
try:
83+
from torchao.dtypes import NF4Tensor
84+
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
85+
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
86+
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
87+
88+
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
89+
90+
except (ImportError, ModuleNotFoundError) as e:
91+
logger.warning(
92+
"Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`"
93+
)
94+
logger.debug(e)
95+
96+
finally:
97+
torch.serialization.add_safe_globals(safe_globals=safe_globals)
98+
99+
100+
if is_torch_version(">=", "2.6") and is_torchao_available() and is_torchao_version(">=", "0.7.0"):
101+
_update_torch_safe_globals()
102+
103+
65104
logger = logging.get_logger(__name__)
66105

67106

0 commit comments

Comments
 (0)