@@ -52,7 +52,7 @@ void check_dequantize_per_tensor_args(
5252 ET_CHECK_MSG (
5353 input.scalar_type () == ScalarType::Byte ||
5454 input.scalar_type () == ScalarType::Char ||
55- input.scalar_type () == ScalarType::Bits16 ||
55+ input.scalar_type () == ScalarType::UInt16 ||
5656 input.scalar_type () == ScalarType::Short ||
5757 input.scalar_type () == (ScalarType)Ushort ||
5858 input.scalar_type () == (ScalarType)Bits4 ||
@@ -83,7 +83,7 @@ void check_dequantize_per_tensor_args(
8383} // namespace
8484
8585/* Local function which calls the kernels based on the input datatype */
86- void Dequantize_impl (
86+ void dequantize_impl (
8787 Tensor& out,
8888 const Tensor& input,
8989 float * scale_data,
@@ -211,7 +211,7 @@ void Dequantize_impl(
211211 break ;
212212 switch (input.scalar_type ()) {
213213 ET_FORALL_INT_TYPES (ASYM_CALCULATE_INT_TYPE_TENSOR);
214- ASYM_CALCULATE_INT_TYPE_TENSOR (uint16_t , Bits16 );
214+ ASYM_CALCULATE_INT_TYPE_TENSOR (uint16_t , UInt16 );
215215 default :
216216 ET_CHECK_MSG (
217217 false ,
@@ -302,7 +302,7 @@ void Dequantize_impl(
302302 break ;
303303 switch (input.scalar_type ()) {
304304 ET_FORALL_INT_TYPES (ASYM_CALCULATE_INT_TYPE_CHANNEL);
305- ASYM_CALCULATE_INT_TYPE_CHANNEL (uint16_t , Bits16 );
305+ ASYM_CALCULATE_INT_TYPE_CHANNEL (uint16_t , UInt16 );
306306 default :
307307 ET_CHECK_MSG (
308308 false ,
@@ -368,7 +368,7 @@ void Dequantize_impl(
368368 break ;
369369 switch (input.scalar_type ()) {
370370 ET_FORALL_INT_TYPES (SYM_CALCULATE_INT_TYPE_TENSOR);
371- SYM_CALCULATE_INT_TYPE_TENSOR (uint16_t , Bits16 );
371+ SYM_CALCULATE_INT_TYPE_TENSOR (uint16_t , UInt16 );
372372 default :
373373 ET_CHECK_MSG (
374374 false ,
@@ -459,7 +459,7 @@ void Dequantize_impl(
459459 break ;
460460 switch (input.scalar_type ()) {
461461 ET_FORALL_INT_TYPES (SYM_CALCULATE_INT_TYPE_CHANNEL);
462- SYM_CALCULATE_INT_TYPE_CHANNEL (uint16_t , Bits16 );
462+ SYM_CALCULATE_INT_TYPE_CHANNEL (uint16_t , UInt16 );
463463 default :
464464 ET_CHECK_MSG (
465465 false ,
@@ -502,7 +502,7 @@ Tensor& dequantize_per_tensor_out(
502502 float scale_data = (float )scale;
503503 int zero_point_data = (int )zero_point;
504504
505- Dequantize_impl (out, input, &scale_data, &zero_point_data, NULL , out_dtype);
505+ dequantize_impl (out, input, &scale_data, &zero_point_data, NULL , out_dtype);
506506
507507 return out;
508508}
@@ -620,7 +620,7 @@ Tensor& dequantize_per_channel_out(
620620 for (int i = 0 ; i < scale.numel (); i++) {
621621 scale_data[i] = (float )scale_dt[i];
622622 }
623- Dequantize_impl (out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype);
623+ dequantize_impl (out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype);
624624
625625 return out;
626626}
@@ -661,13 +661,19 @@ Tensor& dequantize_per_tensor_out(
661661 int64_t quant_min,
662662 int64_t quant_max,
663663 ScalarType dtype,
664- exec_aten::optional<ScalarType> out_dtype,
665664 Tensor& out) {
666665 // TODO(larryliu): Add a context arg to the real op function and remove this
667666 // wrapper
668667 (void )context;
669668 return dequantize_per_tensor_out (
670- input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
669+ input,
670+ scale,
671+ zero_point,
672+ quant_min,
673+ quant_max,
674+ dtype,
675+ out.scalar_type (),
676+ out);
671677}
672678
673679Tensor& dequantize_per_tensor_tensor_args_out (
@@ -764,4 +770,4 @@ Tensor& dequantize_per_token_out(
764770} // namespace native
765771} // namespace G3
766772} // namespace impl
767- } // namespace cadence
773+ } // namespace cadence
0 commit comments