Skip to content

Commit 568fbc1

Browse files
authored
[Release/2.6] port some improvements of loading INT4 checkpoint for WOQ (#3448)
* WOQ INT4: pack GPTQ/AWQ weight directly without converting to plain format (#3435) * WOQ INT4: pack GPTQ/AWQ weight directly without converting to plain format * refine code * Enable ref kernel * Avoid unnecessary data copy * clange-format * Fix IpexWoqLinearAllreduce._init_cls * fix UT * Load auto-round low precision checkpoint (#3443)
1 parent 55853e6 commit 568fbc1

17 files changed

+714
-230
lines changed

csrc/cpu/aten/Linear.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,8 @@ at::Tensor woq_linear_pack_weight(
380380
int64_t weight_dtype,
381381
std::vector<int64_t>& weight_shape,
382382
int64_t group_size,
383-
int64_t lowp_mode) {
383+
int64_t lowp_mode,
384+
int64_t weight_format) {
384385
// TPP kernel does not support edge cases
385386
// It generates packed weight in 4d (Nc, Kc, block_k, block_n)
386387
auto N = weight_shape[0], K = weight_shape[1];
@@ -402,16 +403,34 @@ at::Tensor woq_linear_pack_weight(
402403
at::Tensor weight_int4 =
403404
at::pad(weight, {0, 0, 0, N_int4 - N}, "constant", 0);
404405
return woq_tpp_gemm_packB_stub(
405-
kCPU, weight_int4, weight_dtype, block_n, block_k, lowp_mode);
406+
kCPU,
407+
weight_int4,
408+
weight_dtype,
409+
block_n,
410+
block_k,
411+
lowp_mode,
412+
weight_format);
406413
}
407414
if (N % block_n) {
408415
at::Tensor weight_padded =
409416
at::pad(weight, {0, 0, 0, block_n - N % block_n}, "constant", 0);
410417
return woq_tpp_gemm_packB_stub(
411-
kCPU, weight_padded, weight_dtype, block_n, block_k, lowp_mode);
418+
kCPU,
419+
weight_padded,
420+
weight_dtype,
421+
block_n,
422+
block_k,
423+
lowp_mode,
424+
weight_format);
412425
} else {
413426
return woq_tpp_gemm_packB_stub(
414-
kCPU, weight, weight_dtype, block_n, block_k, lowp_mode);
427+
kCPU,
428+
weight,
429+
weight_dtype,
430+
block_n,
431+
block_k,
432+
lowp_mode,
433+
weight_format);
415434
}
416435
}
417436
return weight;

csrc/cpu/aten/Linear.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ at::Tensor woq_linear_pack_weight(
134134
int64_t weight_dtype,
135135
std::vector<int64_t>& weight_shape,
136136
int64_t group_size,
137-
int64_t lowp_mode);
137+
int64_t lowp_mode,
138+
int64_t weight_format);
138139

139140
at::Tensor woq_linear_compute_compensation(
140141
const at::Tensor& weight,
@@ -266,7 +267,7 @@ using woq_int8_gemm_kernel_fn = at::Tensor (*)(
266267
const c10::optional<at::Tensor>&);
267268

268269
using woq_tpp_gemm_packB_fn =
269-
at::Tensor (*)(const at::Tensor&, int, size_t, size_t, int64_t);
270+
at::Tensor (*)(const at::Tensor&, int, size_t, size_t, int64_t, int64_t);
270271

271272
using woq_tpp_gemm_unpackB_fn = at::Tensor (*)(const at::Tensor&, int, int64_t);
272273

csrc/cpu/aten/kernels/WoqUtilKrnl.cpp

Lines changed: 197 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,22 @@ at::Tensor qlinear_woq_pack(
3434
int qw_type,
3535
size_t block_n,
3636
size_t block_k,
37-
int64_t lowp_mode) {
37+
int64_t lowp_mode,
38+
int64_t weight_format) {
3839
TLA_ASSERT(qw.is_contiguous(), "qw must be contiguous");
3940
bool is_4bit_flag = is_4bit(qw_type);
4041
auto sizes = qw.sizes();
4142
auto N = sizes[0];
4243
auto K = is_4bit_flag ? sizes[1] * 2 : sizes[1];
44+
if (weight_format == GPTQ_WEIGHT_FORMAT) {
45+
// weight shape = [K / 8, N] in int32
46+
N = sizes[1];
47+
K = sizes[0] * 8;
48+
} else if (weight_format == AWQ_WEIGHT_FORMAT) {
49+
// weight shape = [K, N / 8] in int32
50+
N = sizes[1] * 8;
51+
K = sizes[0];
52+
}
4353
TLA_ASSERT(N % block_n == 0, "N must be multiple of block_n");
4454
TLA_ASSERT(K % block_k == 0, "K must be multiple of block_k");
4555
if (is_4bit_flag) {
@@ -56,8 +66,8 @@ at::Tensor qlinear_woq_pack(
5666
const int Nc = N / block_n;
5767
const int Kc = K / block_k;
5868
if (is_4bit_flag) {
59-
// TODO(jgong5): support lowp_mode == LOWP_MODE_INT8
60-
auto result = at::empty({Nc, Kc, block_k, block_n / 2}, qw.options());
69+
auto result = at::empty(
70+
{Nc, Kc, block_k, block_n / 2}, qw.options().dtype(at::kByte));
6171
// Pack weight in [N,K] to [N/block_n, K/block_k, block_k, block_n]
6272
// And then, pre-shuffle per 32 or 64 4-bit values to save shuffle at
6373
// runtime Take 32 4-bit values as an example below: x0 x1 x2 x3 x4 x5 x6 x7
@@ -67,32 +77,144 @@ at::Tensor qlinear_woq_pack(
6777
// 4-bit values.
6878
uint8_t* src_data = (uint8_t*)qw.data_ptr();
6979
uint8_t* dst_data = (uint8_t*)result.data_ptr();
70-
auto psrc = GetVLAPtr<uint8_t>(src_data, {block_n, Kc, block_k / 2});
7180
auto pdst = GetVLAPtr<uint8_t>(dst_data, {Kc, block_k, block_n / 2});
7281
auto pdst_4vnni =
7382
GetVLAPtr<uint8_t>(dst_data, {Kc, block_k / 4, block_n / 2, 4});
74-
auto pack_loop =
75-
ThreadedLoop<3>({{Nc}, {Kc}, {0, block_n, N_GROUP_SIZE, false}}, "ABc");
76-
pack_loop([&](int* idx) {
77-
int nc = idx[0];
78-
int kc = idx[1];
79-
int nb = idx[2];
80-
for (int i = 0; i < N_GROUP_SIZE / 2; i++) {
83+
if (weight_format == PLAIN_WEIGHT_FORMAT) {
84+
// weight shape = [N, K / 2] in uint8
85+
auto psrc = GetVLAPtr<uint8_t>(src_data, {block_n, Kc, block_k / 2});
86+
auto pack_loop = ThreadedLoop<3>(
87+
{{Nc}, {Kc}, {0, block_n, N_GROUP_SIZE, false}}, "ABc");
88+
pack_loop([&](int* idx) {
89+
int nc = idx[0];
90+
int kc = idx[1];
91+
int nb = idx[2];
92+
for (int i = 0; i < N_GROUP_SIZE / 2; i++) {
93+
for (int kb = 0; kb < block_k; kb += 2) {
94+
auto src0 = psrc[nc][nb + i][kc][kb / 2];
95+
auto src1 = psrc[nc][nb + i + N_GROUP_SIZE / 2][kc][kb / 2];
96+
auto dst0 = (src0 & 0xf) | ((src1 & 0xf) << 4);
97+
auto dst1 = (src0 >> 4) | ((src1 >> 4) << 4);
98+
if (lowp_mode != LOWP_MODE_INT8) {
99+
pdst[nc][kc][kb][nb / 2 + i] = dst0;
100+
pdst[nc][kc][kb + 1][nb / 2 + i] = dst1;
101+
} else {
102+
pdst_4vnni[nc][kc][kb / 4][nb / 2 + i][kb % 4] = dst0;
103+
pdst_4vnni[nc][kc][(kb + 1) / 4][nb / 2 + i][(kb + 1) % 4] = dst1;
104+
}
105+
}
106+
}
107+
});
108+
} else if (weight_format == GPTQ_WEIGHT_FORMAT) {
109+
// weight shape = [K / 8, N] in int32
110+
// weight shape = [K / 8, N, 4] in uint8
111+
// view as [K / 8, Nc, block_n, 4]
112+
auto psrc = GetVLAPtr<uint8_t>(src_data, {Nc, block_n, 4});
113+
auto pack_loop = ThreadedLoop<3>(
114+
{{Nc}, {Kc}, {0, block_n, N_GROUP_SIZE, false}}, "ABc");
115+
pack_loop([&](int* idx) {
116+
int nc = idx[0];
117+
int kc = idx[1];
118+
int nb = idx[2];
119+
int k_start = kc * block_k;
120+
for (int i = 0; i < N_GROUP_SIZE / 2; i++) {
121+
for (int kb = 0; kb < block_k; kb += 2) {
122+
int k = k_start + kb;
123+
int k8_idx = k / 8;
124+
int k8_off = k % 8;
125+
auto src0 = psrc[k8_idx][nc][nb + i][k8_off / 2];
126+
auto src1 = psrc[k8_idx][nc][nb + i + N_GROUP_SIZE / 2][k8_off / 2];
127+
auto dst0 = (src0 & 0xf) | ((src1 & 0xf) << 4);
128+
auto dst1 = (src0 >> 4) | ((src1 >> 4) << 4);
129+
if (lowp_mode != LOWP_MODE_INT8) {
130+
pdst[nc][kc][kb][nb / 2 + i] = dst0;
131+
pdst[nc][kc][kb + 1][nb / 2 + i] = dst1;
132+
} else {
133+
pdst_4vnni[nc][kc][kb / 4][nb / 2 + i][kb % 4] = dst0;
134+
pdst_4vnni[nc][kc][(kb + 1) / 4][nb / 2 + i][(kb + 1) % 4] = dst1;
135+
}
136+
}
137+
}
138+
});
139+
} else { // AWQ_WEIGHT_FORMAT
140+
TORCH_CHECK(
141+
weight_format == AWQ_WEIGHT_FORMAT,
142+
"Unsupported weight format: ",
143+
weight_format);
144+
// weight shape = [K, N / 8] in int32
145+
// Every 8 int4 data along N are shuffled from [0, 1, 2, 3, 4, 5, 6, 7] to
146+
// [0, 2, 4, 6, 1, 3, 5, 7] and they are packed as one int32 element.
147+
// weight shape = [K, N / 2] in uint8
148+
// view as [Kc, block_k, Nc, block_n / 2]
149+
auto psrc = GetVLAPtr<uint8_t>(src_data, {block_k, Nc, block_n / 2});
150+
auto pack_loop = ThreadedLoop<3>(
151+
{{Nc}, {Kc}, {0, block_n, N_GROUP_SIZE, false}}, "ABc");
152+
TORCH_CHECK(
153+
(N_GROUP_SIZE / 2) % 8 == 0, "N_GROUP_SIZE must be multiple of 16");
154+
pack_loop([&](int* idx) {
155+
int nc = idx[0];
156+
int kc = idx[1];
157+
int nb = idx[2];
81158
for (int kb = 0; kb < block_k; kb += 2) {
82-
auto src0 = psrc[nc][nb + i][kc][kb / 2];
83-
auto src1 = psrc[nc][nb + i + N_GROUP_SIZE / 2][kc][kb / 2];
84-
auto dst0 = (src0 & 0xf) | ((src1 & 0xf) << 4);
85-
auto dst1 = (src0 >> 4) | ((src1 >> 4) << 4);
86-
if (lowp_mode != LOWP_MODE_INT8) {
87-
pdst[nc][kc][kb][nb / 2 + i] = dst0;
88-
pdst[nc][kc][kb + 1][nb / 2 + i] = dst1;
89-
} else {
90-
pdst_4vnni[nc][kc][kb / 4][nb / 2 + i][kb % 4] = dst0;
91-
pdst_4vnni[nc][kc][(kb + 1) / 4][nb / 2 + i][(kb + 1) % 4] = dst1;
159+
for (int i = 0; i < N_GROUP_SIZE / 2; i += 8) {
160+
int n_base = (nb + i) / 2;
161+
uint8_t src0_low[4] = {
162+
psrc[kc][kb][nc][n_base],
163+
psrc[kc][kb][nc][n_base + 1],
164+
psrc[kc][kb][nc][n_base + 2],
165+
psrc[kc][kb][nc][n_base + 3]};
166+
uint8_t src0_high[4] = {
167+
psrc[kc][kb + 1][nc][n_base],
168+
psrc[kc][kb + 1][nc][n_base + 1],
169+
psrc[kc][kb + 1][nc][n_base + 2],
170+
psrc[kc][kb + 1][nc][n_base + 3]};
171+
172+
n_base += N_GROUP_SIZE / 2 / 2;
173+
uint8_t src1_low[4] = {
174+
psrc[kc][kb][nc][n_base],
175+
psrc[kc][kb][nc][n_base + 1],
176+
psrc[kc][kb][nc][n_base + 2],
177+
psrc[kc][kb][nc][n_base + 3]};
178+
uint8_t src1_high[4] = {
179+
psrc[kc][kb + 1][nc][n_base],
180+
psrc[kc][kb + 1][nc][n_base + 1],
181+
psrc[kc][kb + 1][nc][n_base + 2],
182+
psrc[kc][kb + 1][nc][n_base + 3]};
183+
184+
uint8_t dst0[8] = {
185+
(src0_low[0] & 0xf) | ((src1_low[0] & 0xf) << 4),
186+
(src0_low[2] & 0xf) | ((src1_low[2] & 0xf) << 4),
187+
(src0_low[0] >> 4) | ((src1_low[0] >> 4) << 4),
188+
(src0_low[2] >> 4) | ((src1_low[2] >> 4) << 4),
189+
(src0_low[1] & 0xf) | ((src1_low[1] & 0xf) << 4),
190+
(src0_low[3] & 0xf) | ((src1_low[3] & 0xf) << 4),
191+
(src0_low[1] >> 4) | ((src1_low[1] >> 4) << 4),
192+
(src0_low[3] >> 4) | ((src1_low[3] >> 4) << 4)};
193+
uint8_t dst1[8] = {
194+
(src0_high[0] & 0xf) | ((src1_high[0] & 0xf) << 4),
195+
(src0_high[2] & 0xf) | ((src1_high[2] & 0xf) << 4),
196+
(src0_high[0] >> 4) | ((src1_high[0] >> 4) << 4),
197+
(src0_high[2] >> 4) | ((src1_high[2] >> 4) << 4),
198+
(src0_high[1] & 0xf) | ((src1_high[1] & 0xf) << 4),
199+
(src0_high[3] & 0xf) | ((src1_high[3] & 0xf) << 4),
200+
(src0_high[1] >> 4) | ((src1_high[1] >> 4) << 4),
201+
(src0_high[3] >> 4) | ((src1_high[3] >> 4) << 4)};
202+
if (lowp_mode != LOWP_MODE_INT8) {
203+
for (int j = 0; j < 8; j++) {
204+
pdst[nc][kc][kb][nb / 2 + i + j] = dst0[j];
205+
pdst[nc][kc][kb + 1][nb / 2 + i + j] = dst1[j];
206+
}
207+
} else {
208+
for (int j = 0; j < 8; j++) {
209+
pdst_4vnni[nc][kc][kb / 4][nb / 2 + i + j][kb % 4] = dst0[j];
210+
pdst_4vnni[nc][kc][(kb + 1) / 4][nb / 2 + i + j][(kb + 1) % 4] =
211+
dst1[j];
212+
}
213+
}
92214
}
93215
}
94-
}
95-
});
216+
});
217+
}
96218
return result;
97219
} else {
98220
if (lowp_mode == LOWP_MODE_INT8) {
@@ -427,7 +549,58 @@ at::Tensor qlinear_woq_pack(
427549
int qw_type,
428550
size_t block_n,
429551
size_t block_k,
430-
int64_t lowp_mode) {
552+
int64_t lowp_mode,
553+
int64_t weight_format) {
554+
if (weight_format == GPTQ_WEIGHT_FORMAT) {
555+
// weight shape = [K / 8, N] in int32
556+
TORCH_CHECK(
557+
qw.scalar_type() == at::kInt,
558+
"Unsupported weight type: ",
559+
qw.scalar_type());
560+
auto qw_t = qw.t().contiguous();
561+
auto qw_uint8 = at::empty(
562+
{qw_t.size(0), qw_t.size(1) * 8}, qw_t.options().dtype(at::kByte));
563+
using namespace at::indexing;
564+
for (int i = 0; i < 8; ++i) {
565+
qw_uint8.index_put_(
566+
{Slice(), Slice(i, None, 8)},
567+
(qw_t.bitwise_right_shift(4 * i)).bitwise_and(0xf).to(at::kByte));
568+
}
569+
auto new_qw =
570+
qw_uint8.index({Slice(), Slice(1, None, 2)})
571+
.bitwise_left_shift(4)
572+
.bitwise_or_(qw_uint8.index({Slice(), Slice(None, None, 2)})
573+
.bitwise_and(0xF));
574+
return new_qw;
575+
} else if (weight_format == AWQ_WEIGHT_FORMAT) {
576+
// weight shape = [K, N / 8] in int32
577+
using namespace at::indexing;
578+
auto qw_uint8 =
579+
at::empty({qw.size(0), qw.size(1) * 8}, qw.options().dtype(at::kByte));
580+
// logic for unpacking:
581+
// for i in range(8):
582+
// unpacked[:, i::8] = (qw >> (4 * i)) & 0xf
583+
for (int i = 0; i < 8; ++i) {
584+
qw_uint8.index_put_(
585+
{Slice(), Slice(i, None, 8)},
586+
qw.bitwise_right_shift(4 * i).bitwise_and(0xf).to(at::kByte));
587+
}
588+
// Shuffling along N from [0, 2, 4, 6, 1, 3, 5, 7] to [0, 1, 2, 3, 4, 5, 6,
589+
// 7]
590+
auto qw_uint8_view =
591+
qw_uint8.view({qw_uint8.size(0), qw_uint8.size(1) / 8, 8});
592+
auto qw_uint8_shuffled = at::index_select(
593+
qw_uint8_view, /* dim */ 2, at::tensor({0, 4, 1, 5, 2, 6, 3, 7}));
594+
qw_uint8_shuffled =
595+
qw_uint8_shuffled.view({qw_uint8.size(0), qw_uint8.size(1)});
596+
auto qw_uint8_t = qw_uint8_shuffled.t().contiguous();
597+
auto new_qw =
598+
qw_uint8_t.index({Slice(), Slice(1, None, 2)})
599+
.bitwise_left_shift(4)
600+
.bitwise_or_(qw_uint8_t.index({Slice(), Slice(None, None, 2)})
601+
.bitwise_and(0xF));
602+
return new_qw;
603+
}
431604
return qw;
432605
}
433606

csrc/cpu/aten/utils/woq_defines.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@
4242

4343
#define WOQ_N_BLOCK_SIZE 32
4444

45+
// INT4 weight format before packing
46+
// plain: [N, K / 2] in uint8
47+
// gptq: [K / 8, N] in int32
48+
// awq: [K, N / 8] in int32
49+
#define PLAIN_WEIGHT_FORMAT 0
50+
#define GPTQ_WEIGHT_FORMAT 1
51+
#define AWQ_WEIGHT_FORMAT 2
52+
4553
constexpr bool is_asymmetric_quant_a(const int quant_a_mode) {
4654
return quant_a_mode <= QUANT_A_PER_M_K_BLOCK;
4755
}

0 commit comments

Comments
 (0)