Skip to content

Commit 3a6fa8c

Browse files
committed
use shared code in 4bit
1 parent 5e7a100 commit 3a6fa8c

File tree

4 files changed

+47
-261
lines changed

4 files changed

+47
-261
lines changed

kernels/quantized/cpu/embeddingxb.cpp

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,22 @@ using ScalarType = exec_aten::ScalarType;
2121

2222
namespace {
2323

24-
25-
static inline int32_t weight_value(const unsigned char* w_data, int32_t index, int32_t weight_nbit) {
24+
static inline int32_t
25+
weight_value(const unsigned char* w_data, int32_t index, int32_t weight_nbit) {
2626
if (weight_nbit == 2) {
2727
int32_t subbyte = index % 4;
28-
index >>= 2;
29-
switch (subbyte) {
30-
case 0:
31-
return (int32_t)(w_data[index] & 3) - 2;
32-
case 1:
33-
return (int32_t)((w_data[index] & 12) >> 2) - 2;
34-
case 2:
35-
return (int32_t)((w_data[index] & 48) >> 4) - 2;
36-
case 3:
37-
return (int32_t)((w_data[index] & 192) >> 6) - 2;
38-
}
39-
}
40-
else if (weight_nbit == 4) {
28+
index >>= 2;
29+
switch (subbyte) {
30+
case 0:
31+
return (int32_t)(w_data[index] & 3) - 2;
32+
case 1:
33+
return (int32_t)((w_data[index] & 12) >> 2) - 2;
34+
case 2:
35+
return (int32_t)((w_data[index] & 48) >> 4) - 2;
36+
case 3:
37+
return (int32_t)((w_data[index] & 192) >> 6) - 2;
38+
}
39+
} else if (weight_nbit == 4) {
4140
int32_t odd = index & 1;
4241
index >>= 1;
4342
if (odd) {
@@ -46,10 +45,11 @@ static inline int32_t weight_value(const unsigned char* w_data, int32_t index, i
4645
return (int32_t)((w_data[index] >> 4) & 0x0F) - 8;
4746
}
4847
}
49-
5048
}
5149

52-
static inline int32_t get_embedding_dim(int32_t packed_dim, int32_t weight_nbit) {
50+
static inline int32_t get_embedding_dim(
51+
int32_t packed_dim,
52+
int32_t weight_nbit) {
5353
assert(8 % weight_nbit == 0);
5454
int packed_values_per_byte = 8 / weight_nbit;
5555
return packed_dim * packed_values_per_byte;
@@ -68,7 +68,7 @@ void check_embedding_xbit_args(
6868
exec_aten::optional<ScalarType> out_dtype,
6969
Tensor& out,
7070
int weight_nbit) {
71-
ET_CHECK_MSG(8 % weight_nbit == 0, "nbit must divide 8");
71+
ET_CHECK_MSG(8 % weight_nbit == 0, "nbit must divide 8");
7272

7373
ET_CHECK_MSG(
7474
weight.dim() == 2, "weight must be 2D but got() %zd dims", weight.dim());
@@ -158,8 +158,6 @@ void check_embedding_xbit_args(
158158
}
159159
}
160160

161-
162-
163161
/**
164162
* Retrieves the embeddings specified by indices, dequantizes them, and stores
165163
* them in out. Weight will always be uint8
@@ -172,7 +170,6 @@ void embedding_xbit_per_channel(
172170
const Tensor& indices,
173171
Tensor& out,
174172
int weight_nbit) {
175-
176173
auto embedding_dim = get_embedding_dim(weight.size(1), weight_nbit);
177174

178175
int32_t num_groups_per_channel = 1;
@@ -283,7 +280,12 @@ Tensor& quantized_embedding_xbit_out(
283280
constexpr auto name = "quantized_decomposed::embedding_xbit.out";
284281
ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
285282
embedding_xbit_per_channel<CTYPE_OUT, CTYPE_OUT>(
286-
weight, weight_scales, opt_weight_zero_points, indices, out, weight_nbit);
283+
weight,
284+
weight_scales,
285+
opt_weight_zero_points,
286+
indices,
287+
out,
288+
weight_nbit);
287289
});
288290

289291
return out;
@@ -346,7 +348,12 @@ Tensor& quantized_embedding_xbit_dtype_out(
346348
ET_SWITCH_TWO_TYPES(Float, Half, params_type, ctx, name, CTYPE_P, [&]() {
347349
ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
348350
embedding_xbit_per_channel<CTYPE_P, CTYPE_OUT>(
349-
weight, weight_scales, opt_weight_zero_points, indices, out, weight_nbit);
351+
weight,
352+
weight_scales,
353+
opt_weight_zero_points,
354+
indices,
355+
out,
356+
weight_nbit);
350357
});
351358
});
352359

kernels/quantized/cpu/embeddingxb.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ Tensor& quantized_embedding_xbit_dtype_out(
5555
Tensor& out,
5656
int weight_nbit);
5757

58-
59-
Tensor& quantized_embedding_xbit_dtype_out(
58+
Tensor& quantized_embedding_xbit_dtype_out(
6059
KernelRuntimeContext& context,
6160
const Tensor& weight,
6261
const Tensor& weight_scales,
@@ -68,7 +67,6 @@ Tensor& quantized_embedding_xbit_dtype_out(
6867
Tensor& out,
6968
int weight_nbit);
7069

71-
7270
} // namespace native
7371
} // namespace executor
7472
} // namespace torch

0 commit comments

Comments
 (0)