@@ -141,7 +141,7 @@ static inline int32_t weight_value(const unsigned char* w_data, int32_t index) {
141141 * them in out. Weight will always be uint8
142142 */
143143template <typename CTYPE_PARAMS, typename CTYPE_OUT>
144- void embedding_4bit_per_channel (
144+ void embedding_2bit_per_channel (
145145 const Tensor& weight,
146146 const Tensor& weight_scales,
147147 const optional<Tensor>& opt_weight_zero_points,
@@ -210,7 +210,7 @@ void resize_out_tensor(
210210 torch::executor::Error err = resize_tensor (out, output_size);
211211 ET_CHECK_MSG (
212212 err == torch::executor::Error::Ok,
213- " Failed to resize out Tensor in quantized_embedding_4bit_out " );
213+ " Failed to resize out Tensor in quantized_embedding_2bit_out " );
214214}
215215
216216} // namespace
@@ -220,7 +220,7 @@ void resize_out_tensor(
220220 * them in out. The weight is quantized per channel, with a scale and zero_point
221221 * for each embedding.
222222 *
223- * Corresponds as the out variant to torch.ops.quantized.embedding_4bit
223+ * Corresponds as the out variant to torch.ops.quantized.embedding_2bit
224224 *
225225 * NOTE: quant_min, quant_max, and Dtype are not used in computation, but rather
226226 * metadata that is passed around which can be useful for pattern matching. See
@@ -273,7 +273,7 @@ Tensor& quantized_embedding_2bit_out(
273273 // wrapper
274274 (void )context;
275275 resize_out_tensor (weight, indices, out);
276- return quantized_embedding_4bit_out (
276+ return quantized_embedding_2bit_out (
277277 weight,
278278 weight_scales,
279279 opt_weight_zero_points,
@@ -309,10 +309,10 @@ Tensor& quantized_embedding_2bit_dtype_out(
309309 ScalarType params_type = weight_scales.scalar_type ();
310310 ScalarType out_type = out.scalar_type ();
311311
312- constexpr auto name = " quantized_decomposed::embedding_4bit .dtype_out" ;
312+ constexpr auto name = " quantized_decomposed::embedding_2bit .dtype_out" ;
313313 ET_SWITCH_TWO_TYPES (Float, Half, params_type, ctx, name, CTYPE_P, [&]() {
314314 ET_SWITCH_TWO_TYPES (Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
315- embedding_4bit_per_channel <CTYPE_P, CTYPE_OUT>(
315+ embedding_2bit_per_channel <CTYPE_P, CTYPE_OUT>(
316316 weight, weight_scales, opt_weight_zero_points, indices, out);
317317 });
318318 });
0 commit comments