Skip to content

Commit 2eb4f97

Browse files
authored
Add operators for LUT based low bit weight quantization
Differential Revision: D77618971 Pull Request resolved: #2577
1 parent 74808e2 commit 2eb4f97

File tree

1 file changed

+235
-0
lines changed

1 file changed

+235
-0
lines changed
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include <torchao/experimental/ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.h>
8+
9+
#include <torchao/experimental/ops/library.h>
10+
#include <torchao/experimental/ops/memory.h>
11+
#include <torchao/experimental/ops/parallel.h>
12+
#include <algorithm>
13+
#include <cassert>
14+
#include <vector>
15+
16+
namespace torchao::ops::groupwise_lowbit_weight_lut {
17+
18+
void pack_weights_operator(
19+
const UKernelConfig& uk,
20+
// Outputs
21+
void* packed_weights_ptr,
22+
// Inputs
23+
int n,
24+
int k,
25+
int scale_group_size,
26+
int lut_group_size,
27+
const uint8_t* weight_qval_indices,
28+
const float* weight_scales,
29+
const float* weight_luts,
30+
const float* bias) {
31+
TORCHAO_CHECK(
32+
lut_group_size % scale_group_size == 0,
33+
"scale_group_size must devide lut_group_size");
34+
TORCHAO_CHECK(k % scale_group_size == 0, "scale_group_size must divide k");
35+
TORCHAO_CHECK(
36+
lut_group_size % (k * uk.nr) == 0,
37+
"lut_group_size must be a multiple of k*nr");
38+
TORCHAO_CHECK(k % uk.kr == 0, "kr must divide k");
39+
40+
// 1. Define the block size for parallel work.
41+
int n_step = uk.n_step;
42+
int nc = std::min(n, n_step);
43+
const int num_nc_panels = (n + nc - 1) / nc;
44+
45+
torchao::parallel_1d(0, num_nc_panels, [&](int64_t idx) {
46+
const int n_idx = idx * nc;
47+
const int nc_tile_size = std::min(nc, n - n_idx);
48+
49+
auto packed_weights_offset = uk.packed_weights_offset(
50+
n_idx,
51+
k,
52+
uk.weight_nbit,
53+
scale_group_size,
54+
uk.has_scales,
55+
uk.has_bias,
56+
uk.nr,
57+
uk.kr,
58+
uk.sr);
59+
60+
// Calculate offsets for all input pointers
61+
int weight_qval_indices_offset = n_idx * k;
62+
// Scales are packed in groups of nr
63+
int scales_offset = weight_qval_indices_offset / scale_group_size;
64+
int luts_offset =
65+
(weight_qval_indices_offset / lut_group_size) * (1 << uk.weight_nbit);
66+
67+
// 2. Call pack_weights with chunk arguments
68+
uk.pack_weights(
69+
static_cast<uint8_t*>(packed_weights_ptr) + packed_weights_offset,
70+
weight_qval_indices + weight_qval_indices_offset,
71+
uk.has_scales ? weight_scales + scales_offset : nullptr,
72+
weight_luts + luts_offset,
73+
nc_tile_size,
74+
k,
75+
scale_group_size,
76+
lut_group_size,
77+
uk.has_scales,
78+
uk.has_bias,
79+
uk.has_bias ? bias + n_idx : nullptr,
80+
uk.nr,
81+
uk.kr,
82+
uk.sr);
83+
});
84+
}
85+
86+
GroupwiseTilingParams GroupwiseTilingParams::from_target_tiles_per_thread(
87+
int m,
88+
int m_step,
89+
int n,
90+
int n_step,
91+
int target_tiles_per_thread) {
92+
TORCHAO_CHECK(m >= 1, "m must be >= 1");
93+
TORCHAO_CHECK(m_step >= 1, "m_step must be >= 1");
94+
95+
TORCHAO_CHECK(n >= 1, "n must be >= 1");
96+
TORCHAO_CHECK(n_step >= 1, "n_step must be >= 1");
97+
TORCHAO_CHECK(
98+
target_tiles_per_thread >= 1, "target_tiles_per_thread must be >= 1");
99+
auto num_threads = torchao::get_num_threads();
100+
TORCHAO_CHECK(num_threads >= 1, "num_threads must be >= 1");
101+
102+
int mc = m_step;
103+
int num_mc_panels = (m + mc - 1) / mc;
104+
105+
int numerator = n * num_mc_panels;
106+
int denominator = num_threads * target_tiles_per_thread;
107+
108+
// Set nc = ceil(numerator / denominator)
109+
int nc = (numerator + denominator - 1) / denominator;
110+
assert(nc >= 1);
111+
112+
// Replace nc with next number n_step divides
113+
nc = ((nc + n_step - 1) / n_step) * n_step;
114+
115+
// Clamp mc, nc to be no larger than m, n
116+
mc = std::min(m, mc);
117+
nc = std::min(n, nc);
118+
119+
assert((mc == m) || (mc % m_step == 0));
120+
assert((nc == n) || (nc % n_step == 0));
121+
122+
GroupwiseTilingParams tiling_params;
123+
tiling_params.mc = mc;
124+
tiling_params.nc = nc;
125+
return tiling_params;
126+
}
127+
128+
void groupwise_lowbit_weight_lut_parallel_operator(
129+
const UKernelConfig& uk,
130+
const std::optional<GroupwiseTilingParams>& tiling_params,
131+
float* output,
132+
int m,
133+
int n,
134+
int k,
135+
int scale_group_size,
136+
int lut_group_size,
137+
const void* packed_weights,
138+
const float* activations,
139+
bool has_clamp,
140+
float clamp_min,
141+
float clamp_max) {
142+
TORCHAO_CHECK(
143+
lut_group_size % scale_group_size == 0,
144+
"scale_group_size must divide lut_group_size");
145+
TORCHAO_CHECK(k % scale_group_size == 0, "scale_group_size must divide k");
146+
TORCHAO_CHECK(
147+
lut_group_size % (k * uk.nr) == 0, "(k * nr) must divide lut_group_size");
148+
TORCHAO_CHECK(
149+
scale_group_size % uk.kr == 0, "kr must divide scale_group_size");
150+
int config_idx = uk.select_config_idx(m);
151+
auto& kernel_config = uk.configs[config_idx];
152+
int n_step = uk.n_step;
153+
int m_step = kernel_config.m_step;
154+
155+
int mc, nc;
156+
if (tiling_params.has_value()) {
157+
mc = tiling_params->mc;
158+
nc = tiling_params->nc;
159+
} else {
160+
// If no params are provided, calculate them to balance the workload.
161+
auto params = GroupwiseTilingParams::from_target_tiles_per_thread(
162+
m_step, m_step, n, n_step, /*target_tiles_per_thread=*/5);
163+
mc = params.mc;
164+
nc = params.nc;
165+
}
166+
TORCHAO_CHECK(mc >= 1, "mc must be >= 1");
167+
TORCHAO_CHECK(nc >= 1, "nc must be >= 1");
168+
TORCHAO_CHECK(
169+
(mc == m) || (mc % m_step == 0),
170+
"mc from tiling_params must be m or a multiple of m_step");
171+
TORCHAO_CHECK(
172+
(nc == n) || (nc % n_step == 0),
173+
"nc from tiling_params must be n or a multiple of n_step");
174+
175+
const int num_mc_tiles = (m + mc - 1) / mc;
176+
const int num_nc_tiles = (n + nc - 1) / nc;
177+
178+
const size_t packed_activations_size = kernel_config.packed_activations_size(
179+
mc, k, kernel_config.mr, uk.kr, uk.sr);
180+
auto packed_activations = torchao::make_aligned_byte_ptr(
181+
uk.preferred_alignment, packed_activations_size);
182+
183+
// Outer loop over M blocks
184+
for (int mc_tile_idx = 0; mc_tile_idx < num_mc_tiles; ++mc_tile_idx) {
185+
const int mc_tile_start = mc_tile_idx * mc;
186+
const int mc_tile_size = std::min(mc, m - mc_tile_start);
187+
const float* activation_row_ptr = activations + mc_tile_start * k;
188+
189+
kernel_config.pack_activations(
190+
(float*)packed_activations.get(),
191+
mc_tile_size,
192+
k,
193+
activation_row_ptr,
194+
kernel_config.mr,
195+
uk.kr,
196+
uk.sr);
197+
198+
// Parallelize the work over the larger NC-tiles
199+
torchao::parallel_1d(0, num_nc_tiles, [&](int64_t n_tile_idx) {
200+
const int nc_tile_start = n_tile_idx * nc;
201+
const int nc_tile_size = std::min(nc, n - nc_tile_start);
202+
float* output_tile_ptr = output + mc_tile_start * n + nc_tile_start;
203+
204+
const size_t packed_weights_offset = uk.packed_weights_offset(
205+
nc_tile_start,
206+
k,
207+
uk.weight_nbit,
208+
scale_group_size,
209+
uk.has_scales,
210+
uk.has_bias,
211+
uk.nr,
212+
uk.kr,
213+
uk.sr);
214+
const void* packed_weights_for_tile =
215+
static_cast<const uint8_t*>(packed_weights) + packed_weights_offset;
216+
217+
kernel_config.kernel(
218+
output_tile_ptr,
219+
/*output_m_stride=*/n,
220+
/*m=*/mc_tile_size,
221+
/*n=*/nc_tile_size,
222+
k,
223+
scale_group_size,
224+
lut_group_size,
225+
packed_weights_for_tile,
226+
packed_activations.get(),
227+
clamp_min,
228+
clamp_max,
229+
uk.has_bias,
230+
has_clamp);
231+
});
232+
}
233+
}
234+
235+
} // namespace torchao::ops::groupwise_lowbit_weight_lut

0 commit comments

Comments
 (0)