@@ -30,16 +30,38 @@ at::Tensor linear_weight_int4_reference_impl(
3030 const size_t ndim = original_x_size.size ();
3131 const int64_t out_features = weights_4x2.size (0 );
3232 const at::Tensor x_flattened = x.reshape ({-1 , original_x_size[ndim - 1 ]});
33- const at::Tensor packed_weights =
34- at::_convert_weight_to_int4pack (weights_4x2, inner_k_tiles);
35- at::Tensor out = at::_weight_int4pack_mm (
36- x_flattened, packed_weights, groupsize, scales_and_zeros);
33+ at::Tensor out = at::_weight_int4pack_mm_for_cpu (
34+ x_flattened, weights_4x2, groupsize, scales_and_zeros);
3735 std::vector<int64_t > out_shape (
3836 original_x_size.begin (), original_x_size.end ());
3937 out_shape.at (ndim - 1 ) = out_features;
4038 return out.reshape (out_shape);
4139}
4240
41+ at::Tensor unpack_weights_4x2 (const at::Tensor& weights_4x2) {
42+ std::vector<int64_t > weights_shape (weights_4x2.sizes ().vec ());
43+ weights_shape[1 ] *= 2 ;
44+
45+ at::Tensor weights_unpacked =
46+ at::empty (weights_shape, at::device (at::kCPU ).dtype (at::kInt ));
47+
48+ const int64_t N = weights_unpacked.size (0 );
49+ const int64_t K = weights_unpacked.size (1 );
50+
51+ for (int n = 0 ; n < N; n++) {
52+ for (int k = 0 ; k < K; k += 2 ) {
53+ const uint8_t packed_val = weights_4x2[n][k / 2 ].item ().to <uint8_t >();
54+ const uint8_t second_val = packed_val & 0x0F ;
55+ const uint8_t first_val = (packed_val & 0xF0 ) >> 4 ;
56+
57+ weights_unpacked[n][k] = int (first_val);
58+ weights_unpacked[n][k + 1 ] = int (second_val);
59+ }
60+ }
61+
62+ return weights_unpacked;
63+ }
64+
4365at::Tensor dequantize_and_linear (
4466 const at::Tensor& x,
4567 const at::Tensor& weights_4x2,
@@ -91,13 +113,18 @@ void test_reference_linear_int4(
91113 at::Tensor x = at::rand ({B, M, K}, at::device (at::kCPU ).dtype (at::kFloat ));
92114 at::Tensor weights_4x2 =
93115 at::randint (0 , 256 , {N, K / 2 }, at::device (at::kCPU ).dtype (at::kByte ));
116+ at::Tensor weights_int = unpack_weights_4x2 (weights_4x2);
94117
95118 const int k_groups = K / group_size;
96119 at::Tensor scales_and_zeros =
97120 at::rand ({k_groups, N, 2 }, at::device (at::kCPU ).dtype (at::kFloat ));
98121
99122 at::Tensor out = linear_weight_int4_reference_impl (
100- x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles);
123+ x,
124+ at::_convert_weight_to_int4pack_for_cpu (weights_int, group_size),
125+ group_size,
126+ scales_and_zeros,
127+ inner_k_tiles);
101128
102129 at::Tensor out_ref = dequantize_and_linear (
103130 x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles);
0 commit comments