Skip to content

Commit b4370b8

Browse files
Try removing extra transpose ops in 4bit
1 parent 4e41a4b commit b4370b8

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]
327327

328328
# 1. Dequantize
329329
# 2. MatmulnN
330-
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
330+
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype), bias)
331331

332332
# 3. Save state
333333
ctx.state = quant_state
@@ -393,7 +393,7 @@ def matmul_4bit(
393393
)
394394
return MatMul4Bit.apply(A, B, out, bias, quant_state)
395395
else:
396-
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
396+
out = F.gemv_4bit(A, B, out, state=quant_state)
397397
if bias is not None:
398398
out += bias
399399
return out

bitsandbytes/nn/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def forward(self, x: torch.Tensor):
480480

481481
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
482482

483-
return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
483+
return bnb.matmul_4bit(x, self.weight.data, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
484484

485485

486486
class LinearFP4(Linear4bit):

0 commit comments

Comments
 (0)