Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
from ..base import DiffusersQuantizer


logger = logging.get_logger(__name__)


if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin

Expand Down Expand Up @@ -83,11 +86,18 @@ def _update_torch_safe_globals():
]
try:
from torchao.dtypes import NF4Tensor
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor

safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
safe_globals.extend([UintxTensor, UintxAQTTensorImpl, NF4Tensor])

# note: is_torchao_version(">=", "0.16.0") does not work correctly
# with torchao nightly, so using a ">" check which does work correctly
if is_torchao_version(">", "0.15.0"):
pass
else:
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
safe_globals.extend([UInt4Tensor, Float8AQTTensorImpl])

except (ImportError, ModuleNotFoundError) as e:
logger.warning(
Expand Down Expand Up @@ -123,9 +133,6 @@ def fuzzy_match_size(config_name: str) -> Optional[str]:
return None


logger = logging.get_logger(__name__)


def _quantization_type(weight):
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
Expand Down
Loading