Skip to content

Commit 57b89bf

Browse files
committed
fix out shape
Signed-off-by: jiqing-feng <[email protected]>
1 parent 580010c commit 57b89bf

File tree

1 file changed

+2
-1
lines changed
  • bitsandbytes/backends/cpu

1 file changed

+2
-1
lines changed

bitsandbytes/backends/cpu/ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def _(
236236
if dtype != torch.bfloat16:
237237
A = A.to(torch.bfloat16)
238238

239+
final_out_shape = (*A.shape[:-1], shapeB[0])
239240
A = A.reshape(-1, A.shape[-1])
240241
out_shape = (*A.shape[:-1], shapeB[0])
241242
out = torch.empty(out_shape, dtype=A.dtype, device=A.device)
@@ -274,4 +275,4 @@ def _(
274275
if dtype != torch.bfloat16:
275276
out = out.to(dtype)
276277

277-
return out
278+
return out.reshape(final_out_shape)

0 commit comments

Comments
 (0)