@@ -52,7 +52,7 @@ void check_dequantize_per_tensor_args(
5252  ET_CHECK_MSG (
5353      input.scalar_type () == ScalarType::Byte ||
5454          input.scalar_type () == ScalarType::Char ||
55-           input.scalar_type () == ScalarType::Bits16  ||
55+           input.scalar_type () == ScalarType::UInt16  ||
5656          input.scalar_type () == ScalarType::Short ||
5757          input.scalar_type () == (ScalarType)Ushort ||
5858          input.scalar_type () == (ScalarType)Bits4 ||
@@ -83,7 +83,7 @@ void check_dequantize_per_tensor_args(
8383} //  namespace
8484
8585/*  Local function which calls the kernels based on the input datatype */ 
86- void  Dequantize_impl (
86+ void  dequantize_impl (
8787    Tensor& out,
8888    const  Tensor& input,
8989    float * scale_data,
@@ -211,7 +211,7 @@ void Dequantize_impl(
211211    break ;
212212        switch  (input.scalar_type ()) {
213213          ET_FORALL_INT_TYPES (ASYM_CALCULATE_INT_TYPE_TENSOR);
214-           ASYM_CALCULATE_INT_TYPE_TENSOR (uint16_t , Bits16 );
214+           ASYM_CALCULATE_INT_TYPE_TENSOR (uint16_t , UInt16 );
215215          default :
216216            ET_CHECK_MSG (
217217                false ,
@@ -302,7 +302,7 @@ void Dequantize_impl(
302302    break ;
303303        switch  (input.scalar_type ()) {
304304          ET_FORALL_INT_TYPES (ASYM_CALCULATE_INT_TYPE_CHANNEL);
305-           ASYM_CALCULATE_INT_TYPE_CHANNEL (uint16_t , Bits16 );
305+           ASYM_CALCULATE_INT_TYPE_CHANNEL (uint16_t , UInt16 );
306306          default :
307307            ET_CHECK_MSG (
308308                false ,
@@ -368,7 +368,7 @@ void Dequantize_impl(
368368    break ;
369369        switch  (input.scalar_type ()) {
370370          ET_FORALL_INT_TYPES (SYM_CALCULATE_INT_TYPE_TENSOR);
371-           SYM_CALCULATE_INT_TYPE_TENSOR (uint16_t , Bits16 );
371+           SYM_CALCULATE_INT_TYPE_TENSOR (uint16_t , UInt16 );
372372          default :
373373            ET_CHECK_MSG (
374374                false ,
@@ -459,7 +459,7 @@ void Dequantize_impl(
459459    break ;
460460        switch  (input.scalar_type ()) {
461461          ET_FORALL_INT_TYPES (SYM_CALCULATE_INT_TYPE_CHANNEL);
462-           SYM_CALCULATE_INT_TYPE_CHANNEL (uint16_t , Bits16 );
462+           SYM_CALCULATE_INT_TYPE_CHANNEL (uint16_t , UInt16 );
463463          default :
464464            ET_CHECK_MSG (
465465                false ,
@@ -502,7 +502,7 @@ Tensor& dequantize_per_tensor_out(
502502  float  scale_data = (float )scale;
503503  int  zero_point_data = (int )zero_point;
504504
505-   Dequantize_impl (out, input, &scale_data, &zero_point_data, NULL , out_dtype);
505+   dequantize_impl (out, input, &scale_data, &zero_point_data, NULL , out_dtype);
506506
507507  return  out;
508508}
@@ -620,7 +620,7 @@ Tensor& dequantize_per_channel_out(
620620  for  (int  i = 0 ; i < scale.numel (); i++) {
621621    scale_data[i] = (float )scale_dt[i];
622622  }
623-   Dequantize_impl (out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype);
623+   dequantize_impl (out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype);
624624
625625  return  out;
626626}
@@ -661,13 +661,19 @@ Tensor& dequantize_per_tensor_out(
661661    int64_t  quant_min,
662662    int64_t  quant_max,
663663    ScalarType dtype,
664-     exec_aten::optional<ScalarType> out_dtype,
665664    Tensor& out) {
666665  //  TODO(larryliu): Add a context arg to the real op function and remove this
667666  //  wrapper
668667  (void )context;
669668  return  dequantize_per_tensor_out (
670-       input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
669+       input,
670+       scale,
671+       zero_point,
672+       quant_min,
673+       quant_max,
674+       dtype,
675+       out.scalar_type (),
676+       out);
671677}
672678
673679Tensor& dequantize_per_tensor_tensor_args_out (
@@ -764,4 +770,4 @@ Tensor& dequantize_per_token_out(
764770} //  namespace native
765771} //  namespace G3
766772} //  namespace impl
767- } //  namespace cadence
773+ } //  namespace cadence
0 commit comments