Skip to content

Commit 794710d

Browse files
Guard for torch < 2.5
1 parent 0d5cda7 commit 794710d

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

bitsandbytes/nn/parametrize.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,13 +137,14 @@ def _enable_parametrization_cache(module: nn.Module, inputs: tuple[Any, ...]):
137137

138138

139139
def _register_parametrization_hooks(module: nn.Module, param_name: str):
140-
# Register a state dict hook for saving.
141-
module.register_state_dict_post_hook(
142-
partial(
143-
_parametrized_state_dict_post_hook,
144-
param_name=param_name,
140+
# Register a state dict hook for saving. Note that this requires torch >= 2.5.0.
141+
if torch.__version__ >= (2, 5):
142+
module.register_state_dict_post_hook(
143+
partial(
144+
_parametrized_state_dict_post_hook,
145+
param_name=param_name,
146+
)
145147
)
146-
)
147148

148149
# Register hooks to enable caching for the dequantization parametrization.
149150
# This helps preserve time and memory when the same quantized parameter

tests/test_parametrize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def test_prequantized_replacement(device, dtype, quant_type):
161161
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
162162
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
163163
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
164+
@pytest.mark.skipif(torch.__version__ < (2, 5), reason="state dict hook requires torch >= 2.5.0")
164165
def test_state_dict_functionality(device, dtype, quant_type, compress_statistics):
165166
"""Test that state dict saving works with quantized parameters."""
166167
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):

0 commit comments

Comments
 (0)