Skip to content

Commit b0d495f

Browse files
committed
add aten target to bazel
1 parent 3c3f47c commit b0d495f

File tree

3 files changed

+67
-0
lines changed

3 files changed

+67
-0
lines changed

kernels/quantized/cpu/op_embedding2b.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,27 @@ Tensor& quantized_embedding_2bit_out(
7474
2);
7575
}
7676

77+
Tensor& quantized_embedding_2bit_dtype_out(
78+
const Tensor& weight,
79+
const Tensor& weight_scales,
80+
const optional<Tensor>& opt_weight_zero_points,
81+
int64_t weight_quant_min,
82+
int64_t weight_quant_max,
83+
const Tensor& indices,
84+
exec_aten::optional<ScalarType> out_dtype,
85+
Tensor& out) {
86+
return quantized_embedding_xbit_dtype_out(
87+
weight,
88+
weight_scales,
89+
opt_weight_zero_points,
90+
weight_quant_min,
91+
weight_quant_max,
92+
indices,
93+
out_dtype,
94+
out,
95+
2);
96+
}
97+
7798
Tensor& quantized_embedding_2bit_dtype_out(
7899
KernelRuntimeContext& context,
79100
const Tensor& weight,

kernels/quantized/cpu/targets.bzl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,16 @@ _QUANT_OPS = (
2626
op_target(
2727
name = "op_embedding2b",
2828
deps = ["//executorch/kernels/quantized/cpu:embeddingxb"],
29+
_aten_mode_deps = [
30+
"//executorch/kernels/quantized/cpu:embeddingxb_aten",
31+
],
2932
),
3033
op_target(
3134
name = "op_embedding4b",
3235
deps = ["//executorch/kernels/quantized/cpu:embeddingxb"],
36+
_aten_mode_deps = [
37+
"//executorch/kernels/quantized/cpu:embeddingxb_aten",
38+
],
3339
),
3440
op_target(
3541
name = "op_mixed_mm",
@@ -80,6 +86,16 @@ def define_common_targets():
8086
deps = ["//executorch/runtime/kernel:kernel_includes"],
8187
)
8288

89+
runtime.cxx_library(
90+
name = "embeddingxb_aten",
91+
srcs = ["embeddingxb.cpp"],
92+
exported_headers = ["embeddingxb.h"],
93+
visibility = [
94+
"//executorch/kernels/quantized/...",
95+
],
96+
deps = ["//executorch/runtime/kernel:kernel_includes_aten"],
97+
)
98+
8399
runtime.cxx_library(
84100
name = "quantized_cpu_aten",
85101
srcs = [],

kernels/quantized/test/op_embedding2b_test.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,33 @@ TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingDeath2) {
159159
out),
160160
"");
161161
}
162+
163+
TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingDeath3) {
164+
et_pal_init();
165+
TensorFactory<ScalarType::Byte> tfb;
166+
TensorFactory<ScalarType::Float> tf;
167+
TensorFactory<ScalarType::Long> tfl;
168+
169+
int64_t quant_min = -2;
170+
int64_t quant_max = 1;
171+
172+
Tensor weight_scales =
173+
tf.make({2, 3}, {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
174+
Tensor weight_zero_points = tf.make({2, 3}, {0, 0, 0, 0, 0, 0, 0, 0, 0});
175+
Tensor qweight = tfb.make({2, 1}, {236, 134, 228});
176+
Tensor indices = tfl.make({2}, {0, 2});
177+
Tensor out = tf.zeros({2, 8});
178+
179+
// scales/zeros imply groupsize 3, which does not divide embed dimension from
180+
// qvals (8)
181+
ET_EXPECT_DEATH(
182+
quantized_embedding_2bit_out(
183+
qweight,
184+
weight_scales,
185+
weight_zero_points,
186+
quant_min,
187+
quant_max,
188+
indices,
189+
out),
190+
"");
191+
}

0 commit comments

Comments
 (0)