@@ -2109,6 +2109,10 @@ def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantStat
21092109 return: packed_weight
21102110 """
21112111 assert qweight .dtype == torch .uint8 , "qweight must be uint8"
2112+ quant_state .original_dtype = quant_state .dtype
2113+ quant_state .original_nested = quant_state .nested
2114+ quant_state .original_qshape = qweight .shape
2115+
21122116 qweight = qweight .reshape (- 1 )
21132117 unpacked_w = torch .empty (qweight .shape [0 ] * 2 , dtype = torch .int32 , device = qweight .device )
21142118 unpacked_w [1 ::2 ] = qweight & 0xF
@@ -2145,9 +2149,73 @@ def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantStat
21452149 delattr (quant_state , "state2" )
21462150
21472151 quant_state .dtype = torch .bfloat16
2152+ quant_state .packing_format_for_cpu = True
21482153 return final_qweight , quant_state
21492154
21502155
2156+ def _convert_weight_packed_for_cpu_inverse (
2157+ packed_weight : torch .Tensor ,
2158+ quant_state : QuantState ,
2159+ block_n : int = 32 ,
2160+ ) -> tuple [torch .Tensor , QuantState ]:
2161+ """
2162+ packed_weight: [N, K/2] uint8, output of `_convert_weight_packed_for_cpu` (final_qweight)
2163+ quant_state: QuantState that was modified by `_convert_weight_packed_for_cpu`
2164+ Returns:
2165+ qweight: [*, N, K] uint8, original qweight shape (quant_state.shape)
2166+ recovered_state: QuantState with partially restored fields (best-effort inverse)
2167+ """
2168+ assert quant_state .packing_format_for_cpu , "only for packing format"
2169+ assert packed_weight .dtype == torch .uint8
2170+ assert len (packed_weight .shape ) == 2 , "packed_weight should be [N, K/2]"
2171+ N , K_half = packed_weight .shape
2172+ K = K_half * 2
2173+
2174+ # 1) packed [N, K/2] -> [N//BLOCK_N, BLOCK_N, K/2, 2]
2175+ BLOCK_N = block_n
2176+ BIT_COUNT = 32 # (=32 low + 32 high)
2177+
2178+ assert N % BLOCK_N == 0 , "N must be divisible by block_n"
2179+ assert K % 2 == 0 , "K must be even"
2180+
2181+ # [N, K/2] -> [-1, 64] (32 low + 32 high)
2182+ packed = packed_weight .reshape (- 1 , BIT_COUNT ) # [-1, 64]
2183+ # split high/low nibbles
2184+ high = (packed >> 4 ) & 0xF
2185+ low = packed & 0xF
2186+ # concatenate to [..., 64], first 32 are low, last 32 are high
2187+ qw = torch .cat ([low , high ], dim = - 1 ).to (torch .uint8 ) # [..., 64]
2188+
2189+ # -> [N/BLOCK_N, K/2, BLOCK_N, 2] -> [N, K]
2190+ qw = qw .reshape (N // BLOCK_N , K_half , BLOCK_N , 2 ) # [N/B, K/2, B, 2]
2191+ qw = qw .transpose (- 3 , - 2 ).contiguous () # [N/B, B, K/2, 2]
2192+ qw = qw .reshape (N , K ) # [N, K]
2193+
2194+ qweight = qw # [N, K]
2195+
2196+ unpacked_w = qweight .reshape (- 1 ).to (torch .int32 ) # [K*N]
2197+ high4 = (unpacked_w [::2 ] & 0xF ).to (torch .uint8 )
2198+ low4 = (unpacked_w [1 ::2 ] & 0xF ).to (torch .uint8 )
2199+ qweight = (high4 << 4 ) | low4 # [K*N/2]
2200+
2201+ # 2) Best-effort restore of quant_state fields (absmax / dtype / nested flags, etc.)
2202+ recovered_state = quant_state
2203+
2204+ # quantize absmax
2205+ if recovered_state .original_nested :
2206+ absmax = recovered_state .absmax .T .reshape (- 1 ).to (recovered_state .original_dtype )
2207+ offset = absmax .mean ()
2208+ qabsmax , state2 = quantize_blockwise (absmax - offset , blocksize = 256 )
2209+ recovered_state .absmax = qabsmax
2210+ recovered_state .offset = offset
2211+ recovered_state .state2 = state2
2212+
2213+ recovered_state .dtype = recovered_state .original_dtype
2214+ recovered_state .packing_format_for_cpu = False
2215+
2216+ return qweight .to (torch .uint8 ).reshape (recovered_state .original_qshape ), recovered_state
2217+
2218+
21512219def has_avx512bf16 ():
21522220 if hasattr (lib , "has_avx512bf16_cpu" ) and lib .has_avx512bf16_cpu ():
21532221 return True
0 commit comments