@@ -21,23 +21,22 @@ using ScalarType = exec_aten::ScalarType;
2121
2222namespace {
2323
24-
25- static inline int32_t weight_value (const unsigned char * w_data, int32_t index, int32_t weight_nbit) {
24+ static inline int32_t
25+ weight_value (const unsigned char * w_data, int32_t index, int32_t weight_nbit) {
2626 if (weight_nbit == 2 ) {
2727 int32_t subbyte = index % 4 ;
28- index >>= 2 ;
29- switch (subbyte) {
30- case 0 :
31- return (int32_t )(w_data[index] & 3 ) - 2 ;
32- case 1 :
33- return (int32_t )((w_data[index] & 12 ) >> 2 ) - 2 ;
34- case 2 :
35- return (int32_t )((w_data[index] & 48 ) >> 4 ) - 2 ;
36- case 3 :
37- return (int32_t )((w_data[index] & 192 ) >> 6 ) - 2 ;
38- }
39- }
40- else if (weight_nbit == 4 ) {
28+ index >>= 2 ;
29+ switch (subbyte) {
30+ case 0 :
31+ return (int32_t )(w_data[index] & 3 ) - 2 ;
32+ case 1 :
33+ return (int32_t )((w_data[index] & 12 ) >> 2 ) - 2 ;
34+ case 2 :
35+ return (int32_t )((w_data[index] & 48 ) >> 4 ) - 2 ;
36+ case 3 :
37+ return (int32_t )((w_data[index] & 192 ) >> 6 ) - 2 ;
38+ }
39+ } else if (weight_nbit == 4 ) {
4140 int32_t odd = index & 1 ;
4241 index >>= 1 ;
4342 if (odd) {
@@ -46,10 +45,11 @@ static inline int32_t weight_value(const unsigned char* w_data, int32_t index, i
4645 return (int32_t )((w_data[index] >> 4 ) & 0x0F ) - 8 ;
4746 }
4847 }
49-
5048}
5149
52- static inline int32_t get_embedding_dim (int32_t packed_dim, int32_t weight_nbit) {
50+ static inline int32_t get_embedding_dim (
51+ int32_t packed_dim,
52+ int32_t weight_nbit) {
5353 assert (8 % weight_nbit == 0 );
5454 int packed_values_per_byte = 8 / weight_nbit;
5555 return packed_dim * packed_values_per_byte;
@@ -68,7 +68,7 @@ void check_embedding_xbit_args(
6868 exec_aten::optional<ScalarType> out_dtype,
6969 Tensor& out,
7070 int weight_nbit) {
71- ET_CHECK_MSG (8 % weight_nbit == 0 , " nbit must divide 8" );
71+ ET_CHECK_MSG (8 % weight_nbit == 0 , " nbit must divide 8" );
7272
7373 ET_CHECK_MSG (
7474 weight.dim () == 2 , " weight must be 2D but got() %zd dims" , weight.dim ());
@@ -158,8 +158,6 @@ void check_embedding_xbit_args(
158158 }
159159}
160160
161-
162-
163161/* *
164162 * Retrieves the embeddings specified by indices, dequantizes them, and stores
165163 * them in out. Weight will always be uint8
@@ -172,7 +170,6 @@ void embedding_xbit_per_channel(
172170 const Tensor& indices,
173171 Tensor& out,
174172 int weight_nbit) {
175-
176173 auto embedding_dim = get_embedding_dim (weight.size (1 ), weight_nbit);
177174
178175 int32_t num_groups_per_channel = 1 ;
@@ -283,7 +280,12 @@ Tensor& quantized_embedding_xbit_out(
283280 constexpr auto name = " quantized_decomposed::embedding_xbit.out" ;
284281 ET_SWITCH_TWO_TYPES (Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
285282 embedding_xbit_per_channel<CTYPE_OUT, CTYPE_OUT>(
286- weight, weight_scales, opt_weight_zero_points, indices, out, weight_nbit);
283+ weight,
284+ weight_scales,
285+ opt_weight_zero_points,
286+ indices,
287+ out,
288+ weight_nbit);
287289 });
288290
289291 return out;
@@ -346,7 +348,12 @@ Tensor& quantized_embedding_xbit_dtype_out(
346348 ET_SWITCH_TWO_TYPES (Float, Half, params_type, ctx, name, CTYPE_P, [&]() {
347349 ET_SWITCH_TWO_TYPES (Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
348350 embedding_xbit_per_channel<CTYPE_P, CTYPE_OUT>(
349- weight, weight_scales, opt_weight_zero_points, indices, out, weight_nbit);
351+ weight,
352+ weight_scales,
353+ opt_weight_zero_points,
354+ indices,
355+ out,
356+ weight_nbit);
350357 });
351358 });
352359
0 commit comments