We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 580010c commit 57b89bfCopy full SHA for 57b89bf
bitsandbytes/backends/cpu/ops.py
@@ -236,6 +236,7 @@ def _(
236
if dtype != torch.bfloat16:
237
A = A.to(torch.bfloat16)
238
239
+ final_out_shape = (*A.shape[:-1], shapeB[0])
240
A = A.reshape(-1, A.shape[-1])
241
out_shape = (*A.shape[:-1], shapeB[0])
242
out = torch.empty(out_shape, dtype=A.dtype, device=A.device)
@@ -274,4 +275,4 @@ def _(
274
275
276
out = out.to(dtype)
277
- return out
278
+ return out.reshape(final_out_shape)
0 commit comments