Skip to content

Commit 685e795

Browse files
authored
Add int32 quant/dequant back
Differential Revision: D82282481 Pull Request resolved: #14269
1 parent 43b50ff commit 685e795

File tree

10 files changed

+143
-0
lines changed

10 files changed

+143
-0
lines changed

backends/cadence/aot/functions.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,12 @@
208208
- arg_meta: null
209209
kernel_name: impl::generic::quantize_per_tensor_asym16u_out
210210

211+
- func: cadence::quantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
212+
variants: function
213+
kernels:
214+
- arg_meta: null
215+
kernel_name: impl::generic::quantize_per_tensor_asym32s_out
216+
211217
- func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
212218
variants: function
213219
kernels:
@@ -238,6 +244,12 @@
238244
- arg_meta: null
239245
kernel_name: impl::generic::dequantize_per_tensor_asym16u_out
240246

247+
- func: cadence::dequantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
248+
variants: function
249+
kernels:
250+
- arg_meta: null
251+
kernel_name: impl::generic::dequantize_per_tensor_asym32s_out
252+
241253
- func: cadence::quantized_conv2d_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
242254
kernels:
243255
- arg_meta: null

backends/cadence/aot/functions_hifi.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,11 @@
308308
- arg_meta: null
309309
kernel_name: impl::HiFi::quantize_per_tensor_asym16s_out
310310

311+
- func: cadence::quantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
312+
variants: function
313+
kernels:
314+
- arg_meta: null
315+
kernel_name: impl::HiFi::quantize_per_tensor_asym32s_out
311316

312317
- func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
313318
variants: function
@@ -339,6 +344,12 @@
339344
- arg_meta: null
340345
kernel_name: impl::HiFi::dequantize_per_tensor_asym16u_out
341346

347+
- func: cadence::dequantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
348+
variants: function
349+
kernels:
350+
- arg_meta: null
351+
kernel_name: impl::HiFi::dequantize_per_tensor_asym16s_out
352+
342353
- func: cadence::quantized_conv2d_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
343354
kernels:
344355
- arg_meta: null

backends/cadence/aot/ops_registrations.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@
5656
"quantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
5757
)
5858

59+
lib.define(
60+
"quantize_per_tensor_asym32s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
61+
)
62+
lib.define(
63+
"quantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
64+
)
65+
5966
lib.define(
6067
"dequantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
6168
)
@@ -87,6 +94,13 @@
8794
"dequantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
8895
)
8996

97+
lib.define(
98+
"dequantize_per_tensor_asym32s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
99+
)
100+
lib.define(
101+
"dequantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
102+
)
103+
90104
lib.define(
91105
"quantized_layer_norm(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)"
92106
)
@@ -641,6 +655,18 @@ def quantize_per_tensor_asym16u_meta(
641655
return input.new_empty(input.size(), dtype=dtype)
642656

643657

658+
@register_fake("cadence::quantize_per_tensor_asym32s")
659+
def quantize_per_tensor_asym32s_meta(
660+
input: torch.Tensor,
661+
scale: float,
662+
zero_point: int,
663+
quant_min: int,
664+
quant_max: int,
665+
dtype: torch.dtype,
666+
) -> torch.Tensor:
667+
return input.new_empty(input.size(), dtype=dtype)
668+
669+
644670
@register_fake("cadence::dequantize_per_tensor")
645671
def dequantize_per_tensor_meta(
646672
input: torch.Tensor,
@@ -701,6 +727,18 @@ def dequantize_per_tensor_asym16u_meta(
701727
return input.new_empty(input.size(), dtype=torch.float)
702728

703729

730+
@register_fake("cadence::dequantize_per_tensor_asym32s")
731+
def dequantize_per_tensor_asym32s_meta(
732+
input: torch.Tensor,
733+
scale: float,
734+
zero_point: int,
735+
quant_min: int,
736+
quant_max: int,
737+
dtype: torch.dtype,
738+
) -> torch.Tensor:
739+
return input.new_empty(input.size(), dtype=torch.float)
740+
741+
704742
@register_fake("cadence::quantized_add")
705743
def quantized_add_meta(
706744
X: torch.Tensor,

backends/cadence/aot/type_dispatch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class CompileTimeTypeDispatchPass(ExportPass):
108108
(torch.uint8,): "asym8u",
109109
(torch.int16,): "asym16s",
110110
(torch.uint16,): "asym16s",
111+
(torch.int32,): "asym32s",
111112
},
112113
variant="default",
113114
is_quant_op=True,
@@ -119,6 +120,7 @@ class CompileTimeTypeDispatchPass(ExportPass):
119120
(torch.uint8,): "asym8u",
120121
(torch.int16,): "asym16s",
121122
(torch.uint16,): "asym16s",
123+
(torch.int32,): "asym32s",
122124
},
123125
variant="default",
124126
),

backends/cadence/generic/kernels/kernels.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ typed_quantize_val(int8_t);
7373
typed_quantize_val(uint8_t);
7474
typed_quantize_val(int16_t);
7575
typed_quantize_val(uint16_t);
76+
typed_quantize_val(int32_t);
7677
#undef typed_quantize_val
7778

7879
#define typed_quantize_vec(dtype) \
@@ -86,6 +87,7 @@ typed_quantize_vec(int8_t);
8687
typed_quantize_vec(uint8_t);
8788
typed_quantize_vec(int16_t);
8889
typed_quantize_vec(uint16_t);
90+
typed_quantize_vec(int32_t);
8991
#undef typed_quantize_vec
9092

9193
#define typed_dequantize_val(dtype) \
@@ -94,6 +96,7 @@ typed_dequantize_val(int8_t);
9496
typed_dequantize_val(uint8_t);
9597
typed_dequantize_val(int16_t);
9698
typed_dequantize_val(uint16_t);
99+
typed_dequantize_val(int32_t);
97100
#undef typed_dequantize_val
98101

99102
#define typed_dequantize_vec(dtype) \
@@ -107,6 +110,7 @@ typed_dequantize_vec(int8_t);
107110
typed_dequantize_vec(uint8_t);
108111
typed_dequantize_vec(int16_t);
109112
typed_dequantize_vec(uint16_t);
113+
typed_dequantize_vec(int32_t);
110114
#undef typed_dequantize_vec
111115

112116
} // namespace kernels

backends/cadence/generic/operators/dequantize_per_tensor.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ Tensor& dequantize_per_tensor_out(
4444
} else if (input.scalar_type() == ScalarType::Short) {
4545
const int16_t* input_data = input.const_data_ptr<int16_t>();
4646
dequantize<int16_t>(out_data, input_data, scale, zero_point, numel);
47+
} else if (input.scalar_type() == ScalarType::Int) {
48+
const int32_t* input_data = input.const_data_ptr<int32_t>();
49+
dequantize<int32_t>(out_data, input_data, scale, zero_point, numel);
4750
} else {
4851
ET_CHECK_MSG(
4952
false,
@@ -117,6 +120,22 @@ Tensor& dequantize_per_tensor_asym16u_out(
117120
return out;
118121
}
119122

123+
Tensor& dequantize_per_tensor_asym32s_out(
124+
KernelRuntimeContext& context,
125+
const Tensor& input,
126+
double scale,
127+
int64_t zero_point,
128+
int64_t quant_min,
129+
int64_t quant_max,
130+
ScalarType dtype,
131+
Tensor& out) {
132+
float* out_data = out.mutable_data_ptr<float>();
133+
size_t numel = out.numel();
134+
const int32_t* input_data = input.const_data_ptr<int32_t>();
135+
dequantize<int32_t>(out_data, input_data, scale, zero_point, numel);
136+
return out;
137+
}
138+
120139
} // namespace native
121140
} // namespace generic
122141
} // namespace impl

backends/cadence/generic/operators/quantize_per_tensor.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ Tensor& quantize_per_tensor_out(
4646
} else if (out.scalar_type() == ScalarType::Short) {
4747
int16_t* out_data = out.mutable_data_ptr<int16_t>();
4848
quantize<int16_t>(out_data, input_data, 1. / scale, zero_point, numel);
49+
} else if (out.scalar_type() == ScalarType::Int) {
50+
int32_t* out_data = out.mutable_data_ptr<int32_t>();
51+
quantize<int32_t>(out_data, input_data, 1. / scale, zero_point, numel);
4952
} else {
5053
ET_CHECK_MSG(
5154
false,
@@ -119,6 +122,22 @@ Tensor& quantize_per_tensor_asym16u_out(
119122
return out;
120123
}
121124

125+
Tensor& quantize_per_tensor_asym32s_out(
126+
KernelRuntimeContext& context,
127+
const Tensor& input,
128+
double scale,
129+
int64_t zero_point,
130+
int64_t quant_min,
131+
int64_t quant_max,
132+
ScalarType dtype,
133+
Tensor& out) {
134+
const float* input_data = input.const_data_ptr<float>();
135+
size_t numel = out.numel();
136+
int32_t* out_data = out.mutable_data_ptr<int32_t>();
137+
quantize<int32_t>(out_data, input_data, 1. / scale, zero_point, numel);
138+
return out;
139+
}
140+
122141
}; // namespace native
123142
}; // namespace generic
124143
}; // namespace impl

backends/cadence/hifi/kernels/kernels.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ typed_quantize_val(int8_t);
127127
typed_quantize_val(uint8_t);
128128
typed_quantize_val(int16_t);
129129
typed_quantize_val(uint16_t);
130+
typed_quantize_val(int32_t);
130131
#undef typed_quantize_val
131132

132133
#define typed_quantize_vec(dtype) \
@@ -150,6 +151,7 @@ typed_dequantize_val(int8_t);
150151
typed_dequantize_val(uint8_t);
151152
typed_dequantize_val(int16_t);
152153
typed_dequantize_val(uint16_t);
154+
typed_dequantize_val(int32_t);
153155
#undef typed_dequantize_val
154156

155157
#define typed_dequantize_vec(dtype) \

backends/cadence/hifi/operators/op_dequantize_per_tensor.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ void dequantize_per_tensor_out(
4545
input.scalar_type() == ScalarType::UInt16) {
4646
const uint16_t* input_data = input.const_data_ptr<uint16_t>();
4747
dequantize<uint16_t>(out_data, input_data, scale, zero_point, numel);
48+
} else if (input.scalar_type() == ScalarType::Int) {
49+
const int32_t* input_data = input.const_data_ptr<int32_t>();
50+
dequantize<int32_t>(out_data, input_data, scale, zero_point, numel);
4851
} else {
4952
ET_CHECK_MSG(
5053
false,
@@ -98,6 +101,21 @@ void dequantize_per_tensor_asym16u_out(
98101
dequantize<uint16_t>(out_data, input_data, scale, zero_point, numel);
99102
}
100103

104+
void dequantize_per_tensor_asym32s_out(
105+
KernelRuntimeContext& context,
106+
const Tensor& input,
107+
double scale,
108+
int64_t zero_point,
109+
int64_t quant_min,
110+
int64_t quant_max,
111+
ScalarType dtype,
112+
Tensor& out) {
113+
float* out_data = out.mutable_data_ptr<float>();
114+
size_t numel = out.numel();
115+
const int32_t* input_data = input.const_data_ptr<int32_t>();
116+
dequantize<int32_t>(out_data, input_data, scale, zero_point, numel);
117+
}
118+
101119
} // namespace native
102120
} // namespace HiFi
103121
} // namespace impl

backends/cadence/hifi/operators/op_quantize_per_tensor.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ void quantize_per_tensor_out(
108108
out.scalar_type() == ScalarType::UInt16) {
109109
uint16_t* out_data = out.mutable_data_ptr<uint16_t>();
110110
quantize<uint16_t>(out_data, input_data, 1. / scale, zero_point, numel);
111+
} else if (out.scalar_type() == ScalarType::Int) {
112+
int32_t* out_data = out.mutable_data_ptr<int32_t>();
113+
quantize<int32_t>(out_data, input_data, 1. / scale, zero_point, numel);
111114
} else {
112115
ET_KERNEL_CHECK_MSG(
113116
ctx,
@@ -164,6 +167,21 @@ void quantize_per_tensor_asym16u_out(
164167
quantize<uint16_t>(out_data, input_data, 1. / scale, zero_point, numel);
165168
}
166169

170+
void quantize_per_tensor_asym32s_out(
171+
KernelRuntimeContext& context,
172+
const Tensor& input,
173+
double scale,
174+
int64_t zero_point,
175+
int64_t quant_min,
176+
int64_t quant_max,
177+
ScalarType dtype,
178+
Tensor& out) {
179+
const float* input_data = input.const_data_ptr<float>();
180+
size_t numel = out.numel();
181+
int32_t* out_data = out.mutable_data_ptr<int32_t>();
182+
quantize<int32_t>(out_data, input_data, 1. / scale, zero_point, numel);
183+
}
184+
167185
}; // namespace native
168186
}; // namespace HiFi
169187
}; // namespace impl

0 commit comments

Comments
 (0)