@@ -83,7 +83,7 @@ void check_dequantize_per_tensor_args(
8383} // namespace
8484
8585/* Local function which calls the kernels based on the input datatype */
86- void Dequantize_impl (
86+ void dequantize_impl (
8787 Tensor& out,
8888 const Tensor& input,
8989 float * scale_data,
@@ -502,7 +502,7 @@ Tensor& dequantize_per_tensor_out(
502502 float scale_data = (float )scale;
503503 int zero_point_data = (int )zero_point;
504504
505- Dequantize_impl (out, input, &scale_data, &zero_point_data, NULL , out_dtype);
505+ dequantize_impl (out, input, &scale_data, &zero_point_data, NULL , out_dtype);
506506
507507 return out;
508508}
@@ -620,7 +620,7 @@ Tensor& dequantize_per_channel_out(
620620 for (int i = 0 ; i < scale.numel (); i++) {
621621 scale_data[i] = (float )scale_dt[i];
622622 }
623- Dequantize_impl (out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype);
623+ dequantize_impl (out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype);
624624
625625 return out;
626626}
@@ -661,13 +661,19 @@ Tensor& dequantize_per_tensor_out(
661661 int64_t quant_min,
662662 int64_t quant_max,
663663 ScalarType dtype,
664- exec_aten::optional<ScalarType> out_dtype,
665664 Tensor& out) {
666665 // TODO(larryliu): Add a context arg to the real op function and remove this
667666 // wrapper
668667 (void )context;
669668 return dequantize_per_tensor_out (
670- input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
669+ input,
670+ scale,
671+ zero_point,
672+ quant_min,
673+ quant_max,
674+ dtype,
675+ out.scalar_type (),
676+ out);
671677}
672678
673679Tensor& dequantize_per_tensor_tensor_args_out (
@@ -764,4 +770,4 @@ Tensor& dequantize_per_token_out(
764770} // namespace native
765771} // namespace G3
766772} // namespace impl
767- } // namespace cadence
773+ } // namespace cadence
0 commit comments