Skip to content

Commit 0c88d43

Browse files
committed
fix absmax shhape
Signed-off-by: jiqing-feng <[email protected]>
1 parent 3179b42 commit 0c88d43

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

bitsandbytes/functional.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2142,14 +2142,16 @@ def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantStat
21422142
if absmax.dtype != torch.float32:
21432143
absmax = absmax.float()
21442144

2145-
quant_state.absmax = (
2146-
absmax.reshape(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
2147-
.T.to(torch.bfloat16)
2148-
.contiguous()
2149-
)
2145+
quant_state.absmax = absmax
21502146
quant_state.nested = False
21512147
delattr(quant_state, "state2")
21522148

2149+
quant_state.absmax = (
2150+
quant_state.absmax.reshape(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
2151+
.T.to(torch.bfloat16)
2152+
.contiguous()
2153+
)
2154+
21532155
quant_state.dtype = torch.bfloat16
21542156
quant_state.packing_format_for_cpu = True
21552157
return final_qweight, quant_state

0 commit comments

Comments
 (0)