From 6033747f14e54391f0edb28c2192f7d70349196f Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 28 Apr 2025 13:52:49 -0400 Subject: [PATCH] Support 4bit torch.compile fullgraph with PyTorch nightly --- bitsandbytes/nn/modules.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ea5451502..8fb61a7a6 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -290,6 +290,13 @@ def from_prequantized( return self + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + def _quantize(self, device): w = self.data.contiguous().to(device) w_4bit, quant_state = bnb.functional.quantize_4bit( @@ -486,7 +493,7 @@ def forward(self, x: torch.Tensor): bias = None if self.bias is None else self.bias.to(self.compute_dtype) - return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) + return bnb.matmul_4bit(x, self.weight.data.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) class LinearFP4(Linear4bit):