@@ -20,11 +20,26 @@ namespace cadence {
2020namespace impl {
2121namespace HiFi {
2222namespace native {
23-
23+ namespace {
2424using ::executorch::aten::ScalarType;
2525using ::executorch::aten::Tensor;
2626using ::executorch::runtime::KernelRuntimeContext;
2727
28+ // Add checks for dtype quant min/max bounds.
29+ template <typename T>
30+ void check_quant_min_and_max (
31+ KernelRuntimeContext& ctx,
32+ const int64_t quant_min,
33+ const int64_t quant_max) {
34+ ET_KERNEL_CHECK (
35+ ctx,
36+ std::numeric_limits<T>::min () == quant_min &&
37+ std::numeric_limits<T>::max () == quant_max,
38+ InvalidArgument, );
39+ }
40+
41+ } // namespace
42+
2843// Quantize the input tensor (PT2 version). Note that quant_<min,max> are not
2944// used in any computation.
3045void quantize_per_tensor_out (
@@ -36,15 +51,43 @@ void quantize_per_tensor_out(
3651 __ET_UNUSED int64_t quant_max,
3752 const ScalarType dtype,
3853 Tensor& out) {
39- // Add checks for dtype quant min/max bounds.
40- ET_SWITCH_REALB_TYPES (
41- out.scalar_type (), ctx, " quantize_per_tensor" , OUT_DTYPE, [&]() {
42- ET_KERNEL_CHECK (
43- ctx,
44- std::numeric_limits<OUT_DTYPE>::min () == quant_min &&
45- std::numeric_limits<OUT_DTYPE>::max () == quant_max,
46- InvalidArgument, );
47- });
54+ // Check for input scalar type.
55+ ET_KERNEL_CHECK_MSG (
56+ ctx,
57+ input.scalar_type () == ScalarType::Float,
58+ InvalidType,
59+ ,
60+ " Input tensor for quantize_per_tensor.out should be type %s, but got %s" ,
61+ ::torch::executor::toString (ScalarType::Float),
62+ ::torch::executor::toString(input.scalar_type()));
63+
64+ // Check quant min/max for output types.
65+ switch (out.scalar_type ()) {
66+ case ScalarType::Byte:
67+ check_quant_min_and_max<uint8_t >(ctx, quant_min, quant_max);
68+ break ;
69+ case ScalarType::Char:
70+ check_quant_min_and_max<int8_t >(ctx, quant_min, quant_max);
71+ break ;
72+ case ScalarType::Short:
73+ check_quant_min_and_max<int16_t >(ctx, quant_min, quant_max);
74+ break ;
75+ case ScalarType::Bits16:
76+ case ScalarType::UInt16:
77+ check_quant_min_and_max<uint16_t >(ctx, quant_min, quant_max);
78+ break ;
79+ case ScalarType::Int:
80+ check_quant_min_and_max<int32_t >(ctx, quant_min, quant_max);
81+ break ;
82+ default :
83+ ET_KERNEL_CHECK_MSG (
84+ ctx,
85+ false ,
86+ InvalidType,
87+ ,
88+ " Unhandled output dtype %s" ,
89+ ::torch::executor::toString (out.scalar_type()));
90+ }
4891
4992 const float * input_data = input.const_data_ptr <float >();
5093 const size_t numel = out.numel ();
0 commit comments