Skip to content

Commit 46d7591

Browse files
authored
Introduce strongly typed quant/dequant ops
Differential Revision: D82183474 Pull Request resolved: #14268
1 parent 18498bf commit 46d7591

13 files changed

+594
-23
lines changed

backends/cadence/aot/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ executorch_generated_lib(
144144
visibility = ["PUBLIC"],
145145
deps = [
146146
"//executorch/backends/cadence/generic/kernels:cadence_kernels",
147-
# Individual operator targets instead of combined cadence_generic_ops
148147
"//executorch/backends/cadence/generic/operators:op_requantize_out",
149148
"//executorch/backends/cadence/generic/operators:im2row_out",
150149
"//executorch/backends/cadence/generic/operators:dequantize_per_tensor",

backends/cadence/aot/functions.yaml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,60 @@
184184
- arg_meta: null
185185
kernel_name: impl::generic::quantize_per_tensor_out
186186

187+
- func: cadence::quantize_per_tensor_asym8s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
188+
variants: function
189+
kernels:
190+
- arg_meta: null
191+
kernel_name: impl::generic::quantize_per_tensor_asym8s_out
192+
193+
- func: cadence::quantize_per_tensor_asym8u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
194+
variants: function
195+
kernels:
196+
- arg_meta: null
197+
kernel_name: impl::generic::quantize_per_tensor_asym8u_out
198+
199+
- func: cadence::quantize_per_tensor_asym16s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
200+
variants: function
201+
kernels:
202+
- arg_meta: null
203+
kernel_name: impl::generic::quantize_per_tensor_asym16s_out
204+
205+
- func: cadence::quantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
206+
variants: function
207+
kernels:
208+
- arg_meta: null
209+
kernel_name: impl::generic::quantize_per_tensor_asym16u_out
210+
187211
- func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
188212
variants: function
189213
kernels:
190214
- arg_meta: null
191215
kernel_name: impl::generic::dequantize_per_tensor_out
192216

217+
- func: cadence::dequantize_per_tensor_asym8s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
218+
variants: function
219+
kernels:
220+
- arg_meta: null
221+
kernel_name: impl::generic::dequantize_per_tensor_asym8s_out
222+
223+
- func: cadence::dequantize_per_tensor_asym8u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
224+
variants: function
225+
kernels:
226+
- arg_meta: null
227+
kernel_name: impl::generic::dequantize_per_tensor_asym8u_out
228+
229+
- func: cadence::dequantize_per_tensor_asym16s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
230+
variants: function
231+
kernels:
232+
- arg_meta: null
233+
kernel_name: impl::generic::dequantize_per_tensor_asym16s_out
234+
235+
- func: cadence::dequantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
236+
variants: function
237+
kernels:
238+
- arg_meta: null
239+
kernel_name: impl::generic::dequantize_per_tensor_asym16u_out
240+
193241
- func: cadence::quantized_conv2d_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
194242
kernels:
195243
- arg_meta: null

backends/cadence/aot/functions_hifi.yaml

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,61 @@
284284
- arg_meta: null
285285
kernel_name: impl::HiFi::quantize_per_tensor_out
286286

287+
- func: cadence::quantize_per_tensor_asym8s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
288+
variants: function
289+
kernels:
290+
- arg_meta: null
291+
kernel_name: impl::HiFi::quantize_per_tensor_asym8s_out
292+
293+
- func: cadence::quantize_per_tensor_asym8u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
294+
variants: function
295+
kernels:
296+
- arg_meta: null
297+
kernel_name: impl::HiFi::quantize_per_tensor_asym8u_out
298+
299+
- func: cadence::quantize_per_tensor_asym16s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
300+
variants: function
301+
kernels:
302+
- arg_meta: null
303+
kernel_name: impl::HiFi::quantize_per_tensor_asym16s_out
304+
305+
- func: cadence::quantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
306+
variants: function
307+
kernels:
308+
- arg_meta: null
309+
kernel_name: impl::HiFi::quantize_per_tensor_asym16s_out
310+
311+
287312
- func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
288313
variants: function
289314
kernels:
290315
- arg_meta: null
291316
kernel_name: impl::HiFi::dequantize_per_tensor_out
292317

318+
- func: cadence::dequantize_per_tensor_asym8s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
319+
variants: function
320+
kernels:
321+
- arg_meta: null
322+
kernel_name: impl::HiFi::dequantize_per_tensor_asym8s_out
323+
324+
- func: cadence::dequantize_per_tensor_asym8u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
325+
variants: function
326+
kernels:
327+
- arg_meta: null
328+
kernel_name: impl::HiFi::dequantize_per_tensor_asym8u_out
329+
330+
- func: cadence::dequantize_per_tensor_asym16s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
331+
variants: function
332+
kernels:
333+
- arg_meta: null
334+
kernel_name: impl::HiFi::dequantize_per_tensor_asym16s_out
335+
336+
- func: cadence::dequantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
337+
variants: function
338+
kernels:
339+
- arg_meta: null
340+
kernel_name: impl::HiFi::dequantize_per_tensor_asym16u_out
341+
293342
- func: cadence::quantized_conv2d_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
294343
kernels:
295344
- arg_meta: null

backends/cadence/aot/ops_registrations.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,64 @@
2828
"quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
2929
)
3030

31+
lib.define(
32+
"quantize_per_tensor_asym8s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
33+
)
34+
lib.define(
35+
"quantize_per_tensor_asym8s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
36+
)
37+
38+
lib.define(
39+
"quantize_per_tensor_asym8u(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
40+
)
41+
lib.define(
42+
"quantize_per_tensor_asym8u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
43+
)
44+
45+
lib.define(
46+
"quantize_per_tensor_asym16s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
47+
)
48+
lib.define(
49+
"quantize_per_tensor_asym16s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
50+
)
51+
52+
lib.define(
53+
"quantize_per_tensor_asym16u(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
54+
)
55+
lib.define(
56+
"quantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
57+
)
58+
3159
lib.define(
3260
"dequantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
3361
)
3462
lib.define(
3563
"dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
3664
)
65+
lib.define(
66+
"dequantize_per_tensor_asym8s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
67+
)
68+
lib.define(
69+
"dequantize_per_tensor_asym8s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
70+
)
71+
lib.define(
72+
"dequantize_per_tensor_asym8u(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
73+
)
74+
lib.define(
75+
"dequantize_per_tensor_asym8u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
76+
)
77+
lib.define(
78+
"dequantize_per_tensor_asym16s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
79+
)
80+
lib.define(
81+
"dequantize_per_tensor_asym16s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
82+
)
83+
lib.define(
84+
"dequantize_per_tensor_asym16u(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
85+
)
86+
lib.define(
87+
"dequantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
88+
)
3789

3890
lib.define(
3991
"quantized_layer_norm(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)"
@@ -541,6 +593,54 @@ def quantize_per_tensor_meta(
541593
return input.new_empty(input.size(), dtype=dtype)
542594

543595

596+
@register_fake("cadence::quantize_per_tensor_asym8s")
597+
def quantize_per_tensor_asym8s_meta(
598+
input: torch.Tensor,
599+
scale: float,
600+
zero_point: int,
601+
quant_min: int,
602+
quant_max: int,
603+
dtype: torch.dtype,
604+
) -> torch.Tensor:
605+
return input.new_empty(input.size(), dtype=dtype)
606+
607+
608+
@register_fake("cadence::quantize_per_tensor_asym8u")
609+
def quantize_per_tensor_asym8u_meta(
610+
input: torch.Tensor,
611+
scale: float,
612+
zero_point: int,
613+
quant_min: int,
614+
quant_max: int,
615+
dtype: torch.dtype,
616+
) -> torch.Tensor:
617+
return input.new_empty(input.size(), dtype=dtype)
618+
619+
620+
@register_fake("cadence::quantize_per_tensor_asym16s")
621+
def quantize_per_tensor_asym16s_meta(
622+
input: torch.Tensor,
623+
scale: float,
624+
zero_point: int,
625+
quant_min: int,
626+
quant_max: int,
627+
dtype: torch.dtype,
628+
) -> torch.Tensor:
629+
return input.new_empty(input.size(), dtype=dtype)
630+
631+
632+
@register_fake("cadence::quantize_per_tensor_asym16u")
633+
def quantize_per_tensor_asym16u_meta(
634+
input: torch.Tensor,
635+
scale: float,
636+
zero_point: int,
637+
quant_min: int,
638+
quant_max: int,
639+
dtype: torch.dtype,
640+
) -> torch.Tensor:
641+
return input.new_empty(input.size(), dtype=dtype)
642+
643+
544644
@register_fake("cadence::dequantize_per_tensor")
545645
def dequantize_per_tensor_meta(
546646
input: torch.Tensor,
@@ -553,6 +653,54 @@ def dequantize_per_tensor_meta(
553653
return input.new_empty(input.size(), dtype=torch.float)
554654

555655

656+
@register_fake("cadence::dequantize_per_tensor_asym8s")
657+
def dequantize_per_tensor_asym8s_meta(
658+
input: torch.Tensor,
659+
scale: float,
660+
zero_point: int,
661+
quant_min: int,
662+
quant_max: int,
663+
dtype: torch.dtype,
664+
) -> torch.Tensor:
665+
return input.new_empty(input.size(), dtype=torch.float)
666+
667+
668+
@register_fake("cadence::dequantize_per_tensor_asym8u")
669+
def dequantize_per_tensor_asym8u_meta(
670+
input: torch.Tensor,
671+
scale: float,
672+
zero_point: int,
673+
quant_min: int,
674+
quant_max: int,
675+
dtype: torch.dtype,
676+
) -> torch.Tensor:
677+
return input.new_empty(input.size(), dtype=torch.float)
678+
679+
680+
@register_fake("cadence::dequantize_per_tensor_asym16s")
681+
def dequantize_per_tensor_asym16s_meta(
682+
input: torch.Tensor,
683+
scale: float,
684+
zero_point: int,
685+
quant_min: int,
686+
quant_max: int,
687+
dtype: torch.dtype,
688+
) -> torch.Tensor:
689+
return input.new_empty(input.size(), dtype=torch.float)
690+
691+
692+
@register_fake("cadence::dequantize_per_tensor_asym16u")
693+
def dequantize_per_tensor_asym16u_meta(
694+
input: torch.Tensor,
695+
scale: float,
696+
zero_point: int,
697+
quant_min: int,
698+
quant_max: int,
699+
dtype: torch.dtype,
700+
) -> torch.Tensor:
701+
return input.new_empty(input.size(), dtype=torch.float)
702+
703+
556704
@register_fake("cadence::quantized_add")
557705
def quantized_add_meta(
558706
X: torch.Tensor,

backends/cadence/aot/type_dispatch.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class OpConfig:
2727
base_name: str
2828
type_dispatch_suffixes: dict[tuple[torch.dtype, ...], str]
2929
weight_arg_idx: Optional[int] = None
30+
is_quant_op: bool = False
3031
variant: str = "per_tensor"
3132

3233

@@ -100,6 +101,27 @@ class CompileTimeTypeDispatchPass(ExportPass):
100101
},
101102
variant="default",
102103
),
104+
exir_ops.edge.cadence.quantize_per_tensor.default: OpConfig(
105+
"quantize_per_tensor",
106+
type_dispatch_suffixes={
107+
(torch.int8,): "asym8s",
108+
(torch.uint8,): "asym8u",
109+
(torch.int16,): "asym16s",
110+
(torch.uint16,): "asym16s",
111+
},
112+
variant="default",
113+
is_quant_op=True,
114+
),
115+
exir_ops.edge.cadence.dequantize_per_tensor.default: OpConfig(
116+
"dequantize_per_tensor",
117+
type_dispatch_suffixes={
118+
(torch.int8,): "asym8s",
119+
(torch.uint8,): "asym8u",
120+
(torch.int16,): "asym16s",
121+
(torch.uint16,): "asym16s",
122+
},
123+
variant="default",
124+
),
103125
}
104126

105127
def call_operator(
@@ -120,6 +142,8 @@ def call_operator(
120142
if config.weight_arg_idx is not None:
121143
weight_dtype = args[config.weight_arg_idx].to_tensor().dtype
122144
dtype_key = (input_dtype, weight_dtype)
145+
elif config.is_quant_op:
146+
dtype_key = (args[5],)
123147
else:
124148
dtype_key = (input_dtype,)
125149

0 commit comments

Comments
 (0)