Skip to content

Commit 03bdf88

Browse files
jiqing-fengrsshaik1
authored andcommitted
Signed-off-by: jiqing-feng <[email protected]>
1 parent c0b1a62 commit 03bdf88

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,8 @@ def matmul_4bit(
579579
assert quant_state is not None
580580
if A.device.type in ("cpu", "xpu") and A.requires_grad == False:
581581
if getattr(quant_state, "ipex", False):
582-
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
582+
B = B.t() if len(B.shape) == 2 else B
583+
out = F.gemv_4bit(A, B, out, state=quant_state)
583584
if bias is not None:
584585
out += bias
585586
return out

bitsandbytes/nn/modules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,8 @@ def forward(self, x: torch.Tensor):
508508
x = x.to(self.compute_dtype)
509509

510510
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
511-
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
511+
weight = self.weight.t() if len(self.weight.shape) == 2 else self.weight
512+
out = bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state)
512513

513514
out = out.to(inp_dtype)
514515

0 commit comments

Comments
 (0)