@@ -275,40 +275,3 @@ def _(
275275 out = out .to (dtype )
276276
277277 return out
278-
279- # def unpack_weight_packed_for_cpu(packed_qweight: torch.Tensor, block_n: int = 32):
280- # """
281- # Inverse of convert_weight_packed_for_cpu.
282- # packed_qweight: (N, K//2) uint8, each byte = (high<<4)|low, both 4-bit values in 0..15
283- # returns: qweight_final (N, K) uint8 with original 4-bit values (0..15)
284- # """
285- # assert packed_qweight.dtype == torch.uint8
286- # assert packed_qweight.dim() == 2
287- # N, K_half = packed_qweight.shape
288- # assert N % block_n == 0
289- # BIT_COUNT = block_n # 32
290- # # reshape to rows of 32 packed bytes
291- # qw = packed_qweight.reshape(-1, BIT_COUNT) # [(N//block_n)*K_half, 32]
292- # low = (qw & 0x0F)
293- # high = (qw >> 4) & 0x0F
294- # # restore 64 nibbles (low first then high, matching original pack order)
295- # restored = torch.cat([low, high], dim=1) # [..., 64]
296- # # reshape back (inverse of flatten)
297- # restored = restored.reshape(N // block_n, K_half, block_n, 2) # [N/block_n, K//2, block_n, 2]
298- # # inverse transpose
299- # restored = restored.transpose(-3, -2) # [N/block_n, block_n, K//2, 2]
300- # # final shape
301- # qweight_final = restored.reshape(N, K_half * 2).to(torch.uint8)
302- # return qweight_final
303-
304-
305- # _NF4_QUANT_TABLE = torch.tensor([ -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
306- # 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0 ], dtype=torch.float32)
307-
308- # def fused_matmul(x, packed_weight, scales, group_size):
309- # unpacked_weight = unpack_weight_packed_for_cpu(packed_weight)
310- # shape = unpacked_weight.shape
311- # # original_weight = _INT4_0_TO_15_TABLE[unpacked_weight.reshape(-1).int()].reshape(shape) * scales.T.repeat_interleave(group_size, dim=1)
312- # original_weight = _NF4_QUANT_TABLE[unpacked_weight.reshape(-1).int()].reshape(shape) * scales.T.repeat_interleave(group_size, dim=1)
313- # res = torch.matmul(x, original_weight.T.to(x.dtype))
314- # return res
0 commit comments