|
41 | 41 | is_gguf_available, |
42 | 42 | is_torch_available, |
43 | 43 | is_torch_version, |
44 | | - is_torchao_available, |
45 | | - is_torchao_version, |
46 | 44 | logging, |
47 | 45 | ) |
48 | 46 |
|
|
61 | 59 | from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device |
62 | 60 |
|
63 | 61 |
|
64 | | -def _update_torch_safe_globals(): |
65 | | - safe_globals = [ |
66 | | - (torch.uint1, "torch.uint1"), |
67 | | - (torch.uint2, "torch.uint2"), |
68 | | - (torch.uint3, "torch.uint3"), |
69 | | - (torch.uint4, "torch.uint4"), |
70 | | - (torch.uint5, "torch.uint5"), |
71 | | - (torch.uint6, "torch.uint6"), |
72 | | - (torch.uint7, "torch.uint7"), |
73 | | - ] |
74 | | - try: |
75 | | - from torchao.dtypes import NF4Tensor |
76 | | - from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl |
77 | | - from torchao.dtypes.uintx.uint4_layout import UInt4Tensor |
78 | | - from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor |
79 | | - |
80 | | - safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor]) |
81 | | - |
82 | | - except (ImportError, ModuleNotFoundError) as e: |
83 | | - logger.warning( |
84 | | - "Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`" |
85 | | - ) |
86 | | - logger.debug(e) |
87 | | - |
88 | | - finally: |
89 | | - torch.serialization.add_safe_globals(safe_globals=safe_globals) |
90 | | - |
91 | | - |
92 | | -if is_torch_version(">=", "2.6") and is_torchao_available() and is_torchao_version(">=", "0.7.0"): |
93 | | - _update_torch_safe_globals() |
94 | | - |
95 | | - |
96 | 62 | # Adapted from `transformers` (see modeling_utils.py) |
97 | 63 | def _determine_device_map( |
98 | 64 | model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None |
|
0 commit comments