@@ -34,12 +34,22 @@ at::Tensor qlinear_woq_pack(
3434 int qw_type,
3535 size_t block_n,
3636 size_t block_k,
37- int64_t lowp_mode) {
37+ int64_t lowp_mode,
38+ int64_t weight_format) {
3839 TLA_ASSERT (qw.is_contiguous (), " qw must be contiguous" );
3940 bool is_4bit_flag = is_4bit (qw_type);
4041 auto sizes = qw.sizes ();
4142 auto N = sizes[0 ];
4243 auto K = is_4bit_flag ? sizes[1 ] * 2 : sizes[1 ];
44+ if (weight_format == GPTQ_WEIGHT_FORMAT) {
45+ // weight shape = [K / 8, N] in int32
46+ N = sizes[1 ];
47+ K = sizes[0 ] * 8 ;
48+ } else if (weight_format == AWQ_WEIGHT_FORMAT) {
49+ // weight shape = [K, N / 8] in int32
50+ N = sizes[1 ] * 8 ;
51+ K = sizes[0 ];
52+ }
4353 TLA_ASSERT (N % block_n == 0 , " N must be multiple of block_n" );
4454 TLA_ASSERT (K % block_k == 0 , " K must be multiple of block_k" );
4555 if (is_4bit_flag) {
@@ -56,8 +66,8 @@ at::Tensor qlinear_woq_pack(
5666 const int Nc = N / block_n;
5767 const int Kc = K / block_k;
5868 if (is_4bit_flag) {
59- // TODO(jgong5): support lowp_mode == LOWP_MODE_INT8
60- auto result = at::empty ( {Nc, Kc, block_k, block_n / 2 }, qw.options ());
69+ auto result = at::empty (
70+ {Nc, Kc, block_k, block_n / 2 }, qw.options (). dtype (at:: kByte ));
6171 // Pack weight in [N,K] to [N/block_n, K/block_k, block_k, block_n]
6272 // And then, pre-shuffle per 32 or 64 4-bit values to save shuffle at
6373 // runtime Take 32 4-bit values as an example below: x0 x1 x2 x3 x4 x5 x6 x7
@@ -67,32 +77,144 @@ at::Tensor qlinear_woq_pack(
6777 // 4-bit values.
6878 uint8_t * src_data = (uint8_t *)qw.data_ptr ();
6979 uint8_t * dst_data = (uint8_t *)result.data_ptr ();
70- auto psrc = GetVLAPtr<uint8_t >(src_data, {block_n, Kc, block_k / 2 });
7180 auto pdst = GetVLAPtr<uint8_t >(dst_data, {Kc, block_k, block_n / 2 });
7281 auto pdst_4vnni =
7382 GetVLAPtr<uint8_t >(dst_data, {Kc, block_k / 4 , block_n / 2 , 4 });
74- auto pack_loop =
75- ThreadedLoop<3 >({{Nc}, {Kc}, {0 , block_n, N_GROUP_SIZE, false }}, " ABc" );
76- pack_loop ([&](int * idx) {
77- int nc = idx[0 ];
78- int kc = idx[1 ];
79- int nb = idx[2 ];
80- for (int i = 0 ; i < N_GROUP_SIZE / 2 ; i++) {
83+ if (weight_format == PLAIN_WEIGHT_FORMAT) {
84+ // weight shape = [N, K / 2] in uint8
85+ auto psrc = GetVLAPtr<uint8_t >(src_data, {block_n, Kc, block_k / 2 });
86+ auto pack_loop = ThreadedLoop<3 >(
87+ {{Nc}, {Kc}, {0 , block_n, N_GROUP_SIZE, false }}, " ABc" );
88+ pack_loop ([&](int * idx) {
89+ int nc = idx[0 ];
90+ int kc = idx[1 ];
91+ int nb = idx[2 ];
92+ for (int i = 0 ; i < N_GROUP_SIZE / 2 ; i++) {
93+ for (int kb = 0 ; kb < block_k; kb += 2 ) {
94+ auto src0 = psrc[nc][nb + i][kc][kb / 2 ];
95+ auto src1 = psrc[nc][nb + i + N_GROUP_SIZE / 2 ][kc][kb / 2 ];
96+ auto dst0 = (src0 & 0xf ) | ((src1 & 0xf ) << 4 );
97+ auto dst1 = (src0 >> 4 ) | ((src1 >> 4 ) << 4 );
98+ if (lowp_mode != LOWP_MODE_INT8) {
99+ pdst[nc][kc][kb][nb / 2 + i] = dst0;
100+ pdst[nc][kc][kb + 1 ][nb / 2 + i] = dst1;
101+ } else {
102+ pdst_4vnni[nc][kc][kb / 4 ][nb / 2 + i][kb % 4 ] = dst0;
103+ pdst_4vnni[nc][kc][(kb + 1 ) / 4 ][nb / 2 + i][(kb + 1 ) % 4 ] = dst1;
104+ }
105+ }
106+ }
107+ });
108+ } else if (weight_format == GPTQ_WEIGHT_FORMAT) {
109+ // weight shape = [K / 8, N] in int32
110+ // weight shape = [K / 8, N, 4] in uint8
111+ // view as [K / 8, Nc, block_n, 4]
112+ auto psrc = GetVLAPtr<uint8_t >(src_data, {Nc, block_n, 4 });
113+ auto pack_loop = ThreadedLoop<3 >(
114+ {{Nc}, {Kc}, {0 , block_n, N_GROUP_SIZE, false }}, " ABc" );
115+ pack_loop ([&](int * idx) {
116+ int nc = idx[0 ];
117+ int kc = idx[1 ];
118+ int nb = idx[2 ];
119+ int k_start = kc * block_k;
120+ for (int i = 0 ; i < N_GROUP_SIZE / 2 ; i++) {
121+ for (int kb = 0 ; kb < block_k; kb += 2 ) {
122+ int k = k_start + kb;
123+ int k8_idx = k / 8 ;
124+ int k8_off = k % 8 ;
125+ auto src0 = psrc[k8_idx][nc][nb + i][k8_off / 2 ];
126+ auto src1 = psrc[k8_idx][nc][nb + i + N_GROUP_SIZE / 2 ][k8_off / 2 ];
127+ auto dst0 = (src0 & 0xf ) | ((src1 & 0xf ) << 4 );
128+ auto dst1 = (src0 >> 4 ) | ((src1 >> 4 ) << 4 );
129+ if (lowp_mode != LOWP_MODE_INT8) {
130+ pdst[nc][kc][kb][nb / 2 + i] = dst0;
131+ pdst[nc][kc][kb + 1 ][nb / 2 + i] = dst1;
132+ } else {
133+ pdst_4vnni[nc][kc][kb / 4 ][nb / 2 + i][kb % 4 ] = dst0;
134+ pdst_4vnni[nc][kc][(kb + 1 ) / 4 ][nb / 2 + i][(kb + 1 ) % 4 ] = dst1;
135+ }
136+ }
137+ }
138+ });
139+ } else { // AWQ_WEIGHT_FORMAT
140+ TORCH_CHECK (
141+ weight_format == AWQ_WEIGHT_FORMAT,
142+ " Unsupported weight format: " ,
143+ weight_format);
144+ // weight shape = [K, N / 8] in int32
145+ // Every 8 int4 data along N are shuffled from [0, 1, 2, 3, 4, 5, 6, 7] to
146+ // [0, 2, 4, 6, 1, 3, 5, 7] and they are packed as one int32 element.
147+ // weight shape = [K, N / 2] in uint8
148+ // view as [Kc, block_k, Nc, block_n / 2]
149+ auto psrc = GetVLAPtr<uint8_t >(src_data, {block_k, Nc, block_n / 2 });
150+ auto pack_loop = ThreadedLoop<3 >(
151+ {{Nc}, {Kc}, {0 , block_n, N_GROUP_SIZE, false }}, " ABc" );
152+ TORCH_CHECK (
153+ (N_GROUP_SIZE / 2 ) % 8 == 0 , " N_GROUP_SIZE must be multiple of 16" );
154+ pack_loop ([&](int * idx) {
155+ int nc = idx[0 ];
156+ int kc = idx[1 ];
157+ int nb = idx[2 ];
81158 for (int kb = 0 ; kb < block_k; kb += 2 ) {
82- auto src0 = psrc[nc][nb + i][kc][kb / 2 ];
83- auto src1 = psrc[nc][nb + i + N_GROUP_SIZE / 2 ][kc][kb / 2 ];
84- auto dst0 = (src0 & 0xf ) | ((src1 & 0xf ) << 4 );
85- auto dst1 = (src0 >> 4 ) | ((src1 >> 4 ) << 4 );
86- if (lowp_mode != LOWP_MODE_INT8) {
87- pdst[nc][kc][kb][nb / 2 + i] = dst0;
88- pdst[nc][kc][kb + 1 ][nb / 2 + i] = dst1;
89- } else {
90- pdst_4vnni[nc][kc][kb / 4 ][nb / 2 + i][kb % 4 ] = dst0;
91- pdst_4vnni[nc][kc][(kb + 1 ) / 4 ][nb / 2 + i][(kb + 1 ) % 4 ] = dst1;
159+ for (int i = 0 ; i < N_GROUP_SIZE / 2 ; i += 8 ) {
160+ int n_base = (nb + i) / 2 ;
161+ uint8_t src0_low[4 ] = {
162+ psrc[kc][kb][nc][n_base],
163+ psrc[kc][kb][nc][n_base + 1 ],
164+ psrc[kc][kb][nc][n_base + 2 ],
165+ psrc[kc][kb][nc][n_base + 3 ]};
166+ uint8_t src0_high[4 ] = {
167+ psrc[kc][kb + 1 ][nc][n_base],
168+ psrc[kc][kb + 1 ][nc][n_base + 1 ],
169+ psrc[kc][kb + 1 ][nc][n_base + 2 ],
170+ psrc[kc][kb + 1 ][nc][n_base + 3 ]};
171+
172+ n_base += N_GROUP_SIZE / 2 / 2 ;
173+ uint8_t src1_low[4 ] = {
174+ psrc[kc][kb][nc][n_base],
175+ psrc[kc][kb][nc][n_base + 1 ],
176+ psrc[kc][kb][nc][n_base + 2 ],
177+ psrc[kc][kb][nc][n_base + 3 ]};
178+ uint8_t src1_high[4 ] = {
179+ psrc[kc][kb + 1 ][nc][n_base],
180+ psrc[kc][kb + 1 ][nc][n_base + 1 ],
181+ psrc[kc][kb + 1 ][nc][n_base + 2 ],
182+ psrc[kc][kb + 1 ][nc][n_base + 3 ]};
183+
184+ uint8_t dst0[8 ] = {
185+ (src0_low[0 ] & 0xf ) | ((src1_low[0 ] & 0xf ) << 4 ),
186+ (src0_low[2 ] & 0xf ) | ((src1_low[2 ] & 0xf ) << 4 ),
187+ (src0_low[0 ] >> 4 ) | ((src1_low[0 ] >> 4 ) << 4 ),
188+ (src0_low[2 ] >> 4 ) | ((src1_low[2 ] >> 4 ) << 4 ),
189+ (src0_low[1 ] & 0xf ) | ((src1_low[1 ] & 0xf ) << 4 ),
190+ (src0_low[3 ] & 0xf ) | ((src1_low[3 ] & 0xf ) << 4 ),
191+ (src0_low[1 ] >> 4 ) | ((src1_low[1 ] >> 4 ) << 4 ),
192+ (src0_low[3 ] >> 4 ) | ((src1_low[3 ] >> 4 ) << 4 )};
193+ uint8_t dst1[8 ] = {
194+ (src0_high[0 ] & 0xf ) | ((src1_high[0 ] & 0xf ) << 4 ),
195+ (src0_high[2 ] & 0xf ) | ((src1_high[2 ] & 0xf ) << 4 ),
196+ (src0_high[0 ] >> 4 ) | ((src1_high[0 ] >> 4 ) << 4 ),
197+ (src0_high[2 ] >> 4 ) | ((src1_high[2 ] >> 4 ) << 4 ),
198+ (src0_high[1 ] & 0xf ) | ((src1_high[1 ] & 0xf ) << 4 ),
199+ (src0_high[3 ] & 0xf ) | ((src1_high[3 ] & 0xf ) << 4 ),
200+ (src0_high[1 ] >> 4 ) | ((src1_high[1 ] >> 4 ) << 4 ),
201+ (src0_high[3 ] >> 4 ) | ((src1_high[3 ] >> 4 ) << 4 )};
202+ if (lowp_mode != LOWP_MODE_INT8) {
203+ for (int j = 0 ; j < 8 ; j++) {
204+ pdst[nc][kc][kb][nb / 2 + i + j] = dst0[j];
205+ pdst[nc][kc][kb + 1 ][nb / 2 + i + j] = dst1[j];
206+ }
207+ } else {
208+ for (int j = 0 ; j < 8 ; j++) {
209+ pdst_4vnni[nc][kc][kb / 4 ][nb / 2 + i + j][kb % 4 ] = dst0[j];
210+ pdst_4vnni[nc][kc][(kb + 1 ) / 4 ][nb / 2 + i + j][(kb + 1 ) % 4 ] =
211+ dst1[j];
212+ }
213+ }
92214 }
93215 }
94- }
95- });
216+ });
217+ }
96218 return result;
97219 } else {
98220 if (lowp_mode == LOWP_MODE_INT8) {
@@ -427,7 +549,58 @@ at::Tensor qlinear_woq_pack(
427549 int qw_type,
428550 size_t block_n,
429551 size_t block_k,
430- int64_t lowp_mode) {
552+ int64_t lowp_mode,
553+ int64_t weight_format) {
554+ if (weight_format == GPTQ_WEIGHT_FORMAT) {
555+ // weight shape = [K / 8, N] in int32
556+ TORCH_CHECK (
557+ qw.scalar_type () == at::kInt ,
558+ " Unsupported weight type: " ,
559+ qw.scalar_type ());
560+ auto qw_t = qw.t ().contiguous ();
561+ auto qw_uint8 = at::empty (
562+ {qw_t .size (0 ), qw_t .size (1 ) * 8 }, qw_t .options ().dtype (at::kByte ));
563+ using namespace at ::indexing;
564+ for (int i = 0 ; i < 8 ; ++i) {
565+ qw_uint8.index_put_ (
566+ {Slice (), Slice (i, None, 8 )},
567+ (qw_t .bitwise_right_shift (4 * i)).bitwise_and (0xf ).to (at::kByte ));
568+ }
569+ auto new_qw =
570+ qw_uint8.index ({Slice (), Slice (1 , None, 2 )})
571+ .bitwise_left_shift (4 )
572+ .bitwise_or_ (qw_uint8.index ({Slice (), Slice (None, None, 2 )})
573+ .bitwise_and (0xF ));
574+ return new_qw;
575+ } else if (weight_format == AWQ_WEIGHT_FORMAT) {
576+ // weight shape = [K, N / 8] in int32
577+ using namespace at ::indexing;
578+ auto qw_uint8 =
579+ at::empty ({qw.size (0 ), qw.size (1 ) * 8 }, qw.options ().dtype (at::kByte ));
580+ // logic for unpacking:
581+ // for i in range(8):
582+ // unpacked[:, i::8] = (qw >> (4 * i)) & 0xf
583+ for (int i = 0 ; i < 8 ; ++i) {
584+ qw_uint8.index_put_ (
585+ {Slice (), Slice (i, None, 8 )},
586+ qw.bitwise_right_shift (4 * i).bitwise_and (0xf ).to (at::kByte ));
587+ }
588+ // Shuffling along N from [0, 2, 4, 6, 1, 3, 5, 7] to [0, 1, 2, 3, 4, 5, 6,
589+ // 7]
590+ auto qw_uint8_view =
591+ qw_uint8.view ({qw_uint8.size (0 ), qw_uint8.size (1 ) / 8 , 8 });
592+ auto qw_uint8_shuffled = at::index_select (
593+ qw_uint8_view, /* dim */ 2 , at::tensor ({0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 }));
594+ qw_uint8_shuffled =
595+ qw_uint8_shuffled.view ({qw_uint8.size (0 ), qw_uint8.size (1 )});
596+ auto qw_uint8_t = qw_uint8_shuffled.t ().contiguous ();
597+ auto new_qw =
598+ qw_uint8_t .index ({Slice (), Slice (1 , None, 2 )})
599+ .bitwise_left_shift (4 )
600+ .bitwise_or_ (qw_uint8_t .index ({Slice (), Slice (None, None, 2 )})
601+ .bitwise_and (0xF ));
602+ return new_qw;
603+ }
431604 return qw;
432605}
433606
0 commit comments