Skip to content

Commit 9ad7b33

Browse files
Improvement for torch.compile support on Params4bit (#1673)
(cherry picked from commit d9333aa)
1 parent bbb257d commit 9ad7b33

File tree

2 files changed

+1
-11
lines changed

2 files changed

+1
-11
lines changed

bitsandbytes/nn/modules.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,6 @@ def from_prequantized(
290290

291291
return self
292292

293-
@classmethod
294-
def __torch_function__(cls, func, types, args=(), kwargs=None):
295-
if kwargs is None:
296-
kwargs = {}
297-
with torch._C.DisableTorchFunctionSubclass():
298-
return func(*args, **kwargs)
299-
300293
def _quantize(self, device):
301294
w = self.data.contiguous().to(device)
302295
w_4bit, quant_state = bnb.functional.quantize_4bit(

tests/test_linear4bit.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,7 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s
294294
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
295295
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
296296
def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode):
297-
if device == "cpu" and quant_type == "fp4":
298-
pytest.skip("FP4 is not supported for CPU")
299-
300-
if fullgraph and torch.__version__ < (2, 8):
297+
if fullgraph and torch.__version__ < (2, 8, 0, "dev"):
301298
pytest.skip("fullgraph mode requires torch 2.8 or higher")
302299

303300
if device == "cuda" and platform.system() == "Windows":

0 commit comments

Comments
 (0)