Skip to content

Commit 103c283

Browse files
committed
init
1 parent fbcd332 commit 103c283

File tree

6 files changed

+523
-0
lines changed

6 files changed

+523
-0
lines changed
Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/kernel/kernel_includes.h>
10+
#include <algorithm>
11+
#include <cinttypes>
12+
#include <cmath>
13+
14+
namespace torch {
15+
namespace executor {
16+
namespace native {
17+
18+
using Tensor = exec_aten::Tensor;
19+
using Scalar = exec_aten::Scalar;
20+
using ScalarType = exec_aten::ScalarType;
21+
22+
namespace {
23+
24+
/**
25+
* Asserts that the parameters are valid.
26+
*/
27+
void check_embedding_2bit_args(
28+
const Tensor& weight,
29+
const Tensor& weight_scales,
30+
const optional<Tensor>& opt_weight_zero_points,
31+
const int64_t weight_quant_min,
32+
const int64_t weight_quant_max,
33+
const Tensor& indices,
34+
exec_aten::optional<ScalarType> out_dtype,
35+
Tensor& out) {
36+
ET_CHECK_MSG(
37+
weight.dim() == 2, "weight must be 2D but got() %zd dims", weight.dim());
38+
39+
ET_CHECK_MSG(
40+
weight_scales.dim() == 1 || weight_scales.dim() == 2,
41+
"weight_scales must be 1D or 2D but got() %zd dims",
42+
weight_scales.dim());
43+
44+
ET_CHECK_MSG(
45+
weight_scales.size(0) == weight.size(0),
46+
"Number of scales must be == weight.size(0)=%zd"
47+
", but got %zd",
48+
weight_scales.size(0),
49+
weight.size(0));
50+
51+
if (weight_scales.dim() == 2) {
52+
auto num_groups = weight_scales.size(1);
53+
ET_CHECK_MSG(
54+
// each 8b uint8 column is 4 columns
55+
(4 * weight.size(1)) % num_groups == 0,
56+
"Number of groups must divide weight.size(1)=%zd"
57+
", but got # of groups = %zd",
58+
weight.size(1),
59+
num_groups);
60+
}
61+
62+
ET_CHECK_MSG(
63+
weight.scalar_type() == ScalarType::Byte,
64+
"weight.scalar_type() %" PRId8 " is not supported:",
65+
static_cast<int8_t>(weight.scalar_type()));
66+
67+
ET_CHECK_MSG(
68+
out.scalar_type() == ScalarType::Float ||
69+
out.scalar_type() == ScalarType::Half,
70+
"out.scalar_type() %" PRId8 " is not supported:",
71+
static_cast<int8_t>(out.scalar_type()));
72+
73+
ET_CHECK_MSG(
74+
weight_scales.scalar_type() == ScalarType::Float ||
75+
weight_scales.scalar_type() == ScalarType::Half,
76+
"weight_scales.scalar_type() %" PRId8 " is not supported:",
77+
static_cast<int8_t>(weight_scales.scalar_type()));
78+
79+
if (opt_weight_zero_points.has_value()) {
80+
ET_CHECK_MSG(
81+
opt_weight_zero_points.value().dim() == weight_scales.dim(),
82+
"weight_zero_points's rank match that of weight_scales. "
83+
"weight_zero_points rank: %" PRId8 ", weight_scales rank: %" PRId8,
84+
static_cast<int8_t>(opt_weight_zero_points.value().dim()),
85+
static_cast<int8_t>(weight_scales.dim()));
86+
87+
ET_CHECK_MSG(
88+
opt_weight_zero_points.value().scalar_type() == out.scalar_type(),
89+
"weight zero points scalar type %" PRId8
90+
" does not match out.scalar_type()",
91+
static_cast<int8_t>(opt_weight_zero_points.value().scalar_type()));
92+
93+
for (int32_t i = 0; i < weight_scales.dim(); ++i) {
94+
ET_CHECK_MSG(
95+
opt_weight_zero_points.value().size(i) == weight_scales.size(i),
96+
"Dimension size misatch at dim %" PRId8
97+
"Weight_zero_point size = %zd"
98+
", weight_scales size = %zd.",
99+
i,
100+
opt_weight_zero_points.value().size(i),
101+
weight_scales.size(i));
102+
}
103+
}
104+
105+
ET_CHECK_MSG(
106+
indices.scalar_type() == ScalarType::Long,
107+
"indices.scalar_type() %" PRId8 " is not Long only Long is supported:",
108+
static_cast<int8_t>(indices.scalar_type()));
109+
110+
ET_CHECK_MSG(
111+
weight_quant_min <= weight_quant_max,
112+
"weight quant min: %" PRId64
113+
" is greater than weight quant max: %" PRId64,
114+
weight_quant_min,
115+
weight_quant_max);
116+
117+
if (out_dtype.has_value()) {
118+
ET_CHECK_MSG(
119+
out.scalar_type() == out_dtype.value(),
120+
"output_dtype must match the dtype of the out tensor");
121+
}
122+
}
123+
124+
static inline int32_t weight_value(const unsigned char* w_data, int32_t index) {
125+
int32_t subbyte = index % 4;
126+
index >>= 2;
127+
switch (subbyte) {
128+
case 0:
129+
return (int32_t)(w_data[index] & 3) - 2;
130+
case 1:
131+
return (int32_t)((w_data[index] & 12) >> 2) - 2;
132+
case 2:
133+
return (int32_t)((w_data[index] & 48) >> 4) - 2;
134+
case 3:
135+
return (int32_t)((w_data[index] & 192) >> 6) - 2;
136+
}
137+
}
138+
139+
/**
140+
* Retrieves the embeddings specified by indices, dequantizes them, and stores
141+
* them in out. Weight will always be uint8
142+
*/
143+
template <typename CTYPE_PARAMS, typename CTYPE_OUT>
144+
void embedding_4bit_per_channel(
145+
const Tensor& weight,
146+
const Tensor& weight_scales,
147+
const optional<Tensor>& opt_weight_zero_points,
148+
const Tensor& indices,
149+
Tensor& out) {
150+
auto embedding_dim = weight.size(1) * 4;
151+
152+
int32_t num_groups_per_channel = 1;
153+
if (weight_scales.dim() == 2) {
154+
num_groups_per_channel = weight_scales.size(1);
155+
}
156+
int32_t group_size = embedding_dim / num_groups_per_channel;
157+
158+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
159+
const int64_t* indices_ptr = indices.const_data_ptr<int64_t>();
160+
161+
const CTYPE_PARAMS* scales = weight_scales.const_data_ptr<CTYPE_PARAMS>();
162+
const CTYPE_PARAMS* zero_points = nullptr;
163+
if (opt_weight_zero_points.has_value()) {
164+
zero_points = opt_weight_zero_points.value().const_data_ptr<CTYPE_PARAMS>();
165+
}
166+
167+
for (int i = 0; i < indices.numel(); i++) {
168+
int64_t index = indices_ptr[i];
169+
// If using groupwise embedding
170+
int32_t qparams_index = index * num_groups_per_channel;
171+
CTYPE_PARAMS zp = 0.0;
172+
const CTYPE_PARAMS* scale_ptr = scales + qparams_index;
173+
const CTYPE_PARAMS* zero_points_ptr = nullptr;
174+
if (opt_weight_zero_points.has_value()) {
175+
zero_points_ptr = zero_points + qparams_index;
176+
}
177+
178+
const uint8_t* w_data =
179+
weight.const_data_ptr<uint8_t>() + weight.size(1) * index;
180+
181+
for (int j = 0; j < embedding_dim; ++j) {
182+
int32_t group_id = j / group_size;
183+
const CTYPE_PARAMS scale = scale_ptr[group_id];
184+
if (opt_weight_zero_points.has_value()) {
185+
zp = zero_points_ptr[group_id];
186+
}
187+
out_data[j] = static_cast<CTYPE_OUT>(
188+
(static_cast<float>(weight_value(w_data, j)) -
189+
static_cast<float>(zp)) *
190+
static_cast<float>(scale));
191+
}
192+
out_data += embedding_dim;
193+
}
194+
}
195+
196+
void resize_out_tensor(
197+
const Tensor& weight,
198+
const Tensor& indices,
199+
Tensor& out) {
200+
exec_aten::SizesType expected_output_size[kTensorDimensionLimit];
201+
for (size_t i = 0; i < indices.dim(); i++) {
202+
expected_output_size[i] = indices.size(i);
203+
}
204+
const size_t embedding_dim = weight.size(1) * 4;
205+
expected_output_size[out.dim() - 1] = embedding_dim;
206+
207+
exec_aten::ArrayRef<exec_aten::SizesType> output_size{
208+
expected_output_size, static_cast<size_t>(out.dim())};
209+
210+
torch::executor::Error err = resize_tensor(out, output_size);
211+
ET_CHECK_MSG(
212+
err == torch::executor::Error::Ok,
213+
"Failed to resize out Tensor in quantized_embedding_4bit_out");
214+
}
215+
216+
} // namespace
217+
218+
/**
219+
* Retrieves the embeddings specified by indices, dequantizes them, and stores
220+
* them in out. The weight is quantized per channel, with a scale and zero_point
221+
* for each embedding.
222+
*
223+
* Corresponds as the out variant to torch.ops.quantized.embedding_4bit
224+
*
225+
* NOTE: quant_min, quant_max, and Dtype are not used in computation, but rather
226+
* metadata that is passed around which can be useful for pattern matching. See
227+
* https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more
228+
* info.
229+
*/
230+
Tensor& quantized_embedding_2bit_out(
231+
// TODO Evaluate whether this name is appropriate for an operator that takes
232+
// non quant input and returns fp output
233+
const Tensor& weight,
234+
const Tensor& weight_scales,
235+
const optional<Tensor>& opt_weight_zero_points,
236+
const int64_t weight_quant_min,
237+
const int64_t weight_quant_max,
238+
const Tensor& indices,
239+
Tensor& out) {
240+
ScalarType out_type = out.scalar_type();
241+
242+
// TODO (jakeszwe): improve these to account for the size of out in relation
243+
// to weight and indices accounting for a possible batch dimension
244+
check_embedding_2bit_args(
245+
weight,
246+
weight_scales,
247+
opt_weight_zero_points,
248+
weight_quant_min,
249+
weight_quant_max,
250+
indices,
251+
out_type,
252+
out);
253+
254+
constexpr auto name = "quantized_decomposed::embedding_2bit.out";
255+
ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
256+
embedding_2bit_per_channel<CTYPE_OUT, CTYPE_OUT>(
257+
weight, weight_scales, opt_weight_zero_points, indices, out);
258+
});
259+
260+
return out;
261+
}
262+
263+
Tensor& quantized_embedding_2bit_out(
264+
KernelRuntimeContext& context,
265+
const Tensor& weight,
266+
const Tensor& weight_scales,
267+
const optional<Tensor>& opt_weight_zero_points,
268+
int64_t weight_quant_min,
269+
int64_t weight_quant_max,
270+
const Tensor& indices,
271+
Tensor& out) {
272+
// TODO(larryliu): Add a context arg to the real op function and remove this
273+
// wrapper
274+
(void)context;
275+
resize_out_tensor(weight, indices, out);
276+
return quantized_embedding_4bit_out(
277+
weight,
278+
weight_scales,
279+
opt_weight_zero_points,
280+
weight_quant_min,
281+
weight_quant_max,
282+
indices,
283+
out);
284+
}
285+
286+
Tensor& quantized_embedding_2bit_dtype_out(
287+
// TODO Evaluate whether this name is appropriate for an operator that takes
288+
// non quant input and returns fp output
289+
const Tensor& weight,
290+
const Tensor& weight_scales,
291+
const optional<Tensor>& opt_weight_zero_points,
292+
const int64_t weight_quant_min,
293+
const int64_t weight_quant_max,
294+
const Tensor& indices,
295+
exec_aten::optional<ScalarType> out_dtype,
296+
Tensor& out) {
297+
// TODO (jakeszwe): improve these to account for the size of out in relation
298+
// to weight and indices accounting for a possible batch dimension
299+
check_embedding_2bit_args(
300+
weight,
301+
weight_scales,
302+
opt_weight_zero_points,
303+
weight_quant_min,
304+
weight_quant_max,
305+
indices,
306+
out_dtype,
307+
out);
308+
309+
ScalarType params_type = weight_scales.scalar_type();
310+
ScalarType out_type = out.scalar_type();
311+
312+
constexpr auto name = "quantized_decomposed::embedding_4bit.dtype_out";
313+
ET_SWITCH_TWO_TYPES(Float, Half, params_type, ctx, name, CTYPE_P, [&]() {
314+
ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
315+
embedding_4bit_per_channel<CTYPE_P, CTYPE_OUT>(
316+
weight, weight_scales, opt_weight_zero_points, indices, out);
317+
});
318+
});
319+
320+
return out;
321+
}
322+
323+
Tensor& quantized_embedding_2bit_dtype_out(
324+
KernelRuntimeContext& context,
325+
const Tensor& weight,
326+
const Tensor& weight_scales,
327+
const optional<Tensor>& opt_weight_zero_points,
328+
int64_t weight_quant_min,
329+
int64_t weight_quant_max,
330+
const Tensor& indices,
331+
exec_aten::optional<ScalarType> out_dtype,
332+
Tensor& out) {
333+
// TODO(larryliu): Add a context arg to the real op function and remove this
334+
// wrapper
335+
(void)context;
336+
resize_out_tensor(weight, indices, out);
337+
return quantized_embedding_2bit_dtype_out(
338+
weight,
339+
weight_scales,
340+
opt_weight_zero_points,
341+
weight_quant_min,
342+
weight_quant_max,
343+
indices,
344+
out_dtype,
345+
out);
346+
}
347+
348+
} // namespace native
349+
} // namespace executor
350+
} // namespace torch

kernels/quantized/cpu/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ _QUANT_OPS = (
2323
op_target(
2424
name = "op_embedding",
2525
),
26+
op_target(
27+
name = "op_embedding2b",
28+
),
2629
op_target(
2730
name = "op_embedding4b",
2831
),

kernels/quantized/quantized.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,18 @@
4646
- arg_meta: null
4747
kernel_name: torch::executor::quantized_embedding_byte_dtype_out
4848

49+
- func: quantized_decomposed::embedding_2bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)
50+
variants: function
51+
kernels:
52+
- arg_meta: null
53+
kernel_name: torch::executor::quantized_embedding_2bit_out
54+
55+
- func: quantized_decomposed::embedding_2bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)
56+
variants: function
57+
kernels:
58+
- arg_meta: null
59+
kernel_name: torch::executor::quantized_embedding_2bit_dtype_out
60+
4961
- func: quantized_decomposed::embedding_4bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)
5062
variants: function
5163
kernels:

kernels/quantized/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ set(_kernels_quantized_test_sources
2525
op_add_test.cpp
2626
op_choose_qparams_test.cpp
2727
op_dequantize_test.cpp
28+
op_embedding2b_test.cpp
2829
op_embedding4b_test.cpp
2930
op_embedding_test.cpp
3031
op_mixed_linear_test.cpp

0 commit comments

Comments
 (0)