Skip to content

Commit 0b3227f

Browse files
authored
Add support for conv1d
Differential Revision: D82160616 Pull Request resolved: #14189
1 parent 897b0d5 commit 0b3227f

13 files changed

+1144
-10
lines changed

backends/cadence/aot/functions.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,26 @@
359359
- arg_meta: null
360360
kernel_name: impl::reference::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out
361361

362+
- func: cadence::quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
363+
kernels:
364+
- arg_meta: null
365+
kernel_name: impl::reference::quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor_out
366+
367+
- func: cadence::quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
368+
kernels:
369+
- arg_meta: null
370+
kernel_name: impl::reference::quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor_out
371+
372+
- func: cadence::quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
373+
kernels:
374+
- arg_meta: null
375+
kernel_name: impl::reference::quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor_out
376+
377+
- func: cadence::quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
378+
kernels:
379+
- arg_meta: null
380+
kernel_name: impl::reference::quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor_out
381+
362382
- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
363383
kernels:
364384
- arg_meta: null

backends/cadence/aot/functions_hifi.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,26 @@
370370
- arg_meta: null
371371
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out
372372

373+
- func: cadence::quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
374+
kernels:
375+
- arg_meta: null
376+
kernel_name: cadence::impl::HiFi::quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor_out
377+
378+
- func: cadence::quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
379+
kernels:
380+
- arg_meta: null
381+
kernel_name: cadence::impl::HiFi::quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor_out
382+
383+
- func: cadence::quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
384+
kernels:
385+
- arg_meta: null
386+
kernel_name: cadence::impl::HiFi::quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor_out
387+
388+
- func: cadence::quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
389+
kernels:
390+
- arg_meta: null
391+
kernel_name: cadence::impl::HiFi::quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor_out
392+
373393
- func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
374394
kernels:
375395
- arg_meta: null

backends/cadence/aot/ops_registrations.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,30 @@
169169
lib.define(
170170
"quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
171171
)
172+
lib.define(
173+
"quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
174+
)
175+
lib.define(
176+
"quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
177+
)
178+
lib.define(
179+
"quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
180+
)
181+
lib.define(
182+
"quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
183+
)
184+
lib.define(
185+
"quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
186+
)
187+
lib.define(
188+
"quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
189+
)
190+
lib.define(
191+
"quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
192+
)
193+
lib.define(
194+
"quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
195+
)
172196
lib.define(
173197
"quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
174198
)
@@ -2153,6 +2177,150 @@ def roi_align_box_processor_meta(
21532177
return rois.new_empty((rois.shape[0], 80), dtype=torch.uint8)
21542178

21552179

2180+
@register_fake("cadence::quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor")
2181+
def quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor_meta(
2182+
input: torch.Tensor,
2183+
weight: torch.Tensor,
2184+
bias: torch.Tensor,
2185+
stride: Tuple[int],
2186+
padding: Tuple[int],
2187+
dilation: Tuple[int],
2188+
groups: int,
2189+
in_zero_point: int,
2190+
weight_zero_point: int,
2191+
bias_scale: float,
2192+
output_scale: float,
2193+
output_zero_point: int,
2194+
out_multiplier: int,
2195+
out_shift: int,
2196+
) -> torch.Tensor:
2197+
assert input.dim() == 3 and weight.dim() == 3
2198+
assert (
2199+
input.dtype == torch.int8
2200+
and weight.dtype == torch.int8
2201+
and bias.dtype == torch.int32
2202+
)
2203+
out_channels, _, kernel_size = weight.shape
2204+
output_size = get_conv1d_output_size(
2205+
input.shape,
2206+
out_channels,
2207+
stride[1],
2208+
padding[1],
2209+
dilation[1],
2210+
kernel_size,
2211+
False,
2212+
)
2213+
return input.new_empty(output_size, dtype=input.dtype)
2214+
2215+
2216+
@register_fake("cadence::quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor")
2217+
def quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor_meta(
2218+
input: torch.Tensor,
2219+
weight: torch.Tensor,
2220+
bias: torch.Tensor,
2221+
stride: Tuple[int],
2222+
padding: Tuple[int],
2223+
dilation: Tuple[int],
2224+
groups: int,
2225+
in_zero_point: int,
2226+
weight_zero_point: int,
2227+
bias_scale: float,
2228+
output_scale: float,
2229+
output_zero_point: int,
2230+
out_multiplier: int,
2231+
out_shift: int,
2232+
) -> torch.Tensor:
2233+
assert input.dim() == 3 and weight.dim() == 3
2234+
assert (
2235+
input.dtype == torch.uint8
2236+
and weight.dtype == torch.uint8
2237+
and bias.dtype == torch.int32
2238+
)
2239+
out_channels, _, kernel_size = weight.shape
2240+
output_size = get_conv1d_output_size(
2241+
input.shape,
2242+
out_channels,
2243+
stride[1],
2244+
padding[1],
2245+
dilation[1],
2246+
kernel_size,
2247+
False,
2248+
)
2249+
return input.new_empty(output_size, dtype=input.dtype)
2250+
2251+
2252+
@register_fake("cadence::quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor")
2253+
def quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor_meta(
2254+
input: torch.Tensor,
2255+
weight: torch.Tensor,
2256+
bias: torch.Tensor,
2257+
stride: Tuple[int],
2258+
padding: Tuple[int],
2259+
dilation: Tuple[int],
2260+
groups: int,
2261+
in_zero_point: int,
2262+
weight_zero_point: int,
2263+
bias_scale: float,
2264+
output_scale: float,
2265+
output_zero_point: int,
2266+
out_multiplier: int,
2267+
out_shift: int,
2268+
) -> torch.Tensor:
2269+
assert input.dim() == 3 and weight.dim() == 3
2270+
assert (
2271+
input.dtype == torch.int8
2272+
and weight.dtype == torch.int8
2273+
and bias.dtype == torch.int32
2274+
)
2275+
out_channels, kernel_size, _ = weight.shape
2276+
output_size = get_conv1d_output_size(
2277+
input.shape,
2278+
out_channels,
2279+
stride[1],
2280+
padding[1],
2281+
dilation[1],
2282+
kernel_size,
2283+
True,
2284+
)
2285+
return input.new_empty(output_size, dtype=input.dtype)
2286+
2287+
2288+
@register_fake("cadence::quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor")
2289+
def quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor_meta(
2290+
input: torch.Tensor,
2291+
weight: torch.Tensor,
2292+
bias: torch.Tensor,
2293+
stride: Tuple[int],
2294+
padding: Tuple[int],
2295+
dilation: Tuple[int],
2296+
groups: int,
2297+
in_zero_point: int,
2298+
weight_zero_point: int,
2299+
bias_scale: float,
2300+
output_scale: float,
2301+
output_zero_point: int,
2302+
out_multiplier: int,
2303+
out_shift: int,
2304+
) -> torch.Tensor:
2305+
assert input.dim() == 3 and weight.dim() == 3
2306+
assert (
2307+
input.dtype == torch.uint8
2308+
and weight.dtype == torch.uint8
2309+
and bias.dtype == torch.int32
2310+
)
2311+
out_channels, kernel_size, _ = weight.shape
2312+
output_size = get_conv1d_output_size(
2313+
input.shape,
2314+
out_channels,
2315+
stride[1],
2316+
padding[1],
2317+
dilation[1],
2318+
kernel_size,
2319+
True,
2320+
)
2321+
return input.new_empty(output_size, dtype=input.dtype)
2322+
2323+
21562324
@register_fake("cadence::_softmax_f32_f32")
21572325
def softmax_f32_f32_meta(
21582326
self: torch.Tensor,

backends/cadence/aot/ref_implementations.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,7 @@ def quantized_conv_variant(
613613
layout: str,
614614
input_dtype: torch.dtype,
615615
weight_dtype: torch.dtype,
616+
is_1d: bool = False,
616617
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
617618
"""Create a quantized conv variant with type checking."""
618619

@@ -644,6 +645,14 @@ def variant(
644645
bias.dtype == torch.int32
645646
), f"Expected bias dtype int32, got {bias.dtype}"
646647

648+
if is_1d:
649+
assert (
650+
len(input_tensor.shape) == 3
651+
), f"1D convolution requires 3D input tensor, got {len(input_tensor.shape)}D"
652+
assert (
653+
len(weight.shape) == 3
654+
), f"1D convolution requires 3D weight tensor, got {len(weight.shape)}D"
655+
647656
# Call the appropriate base function
648657
match layout:
649658
case "nchw":
@@ -748,6 +757,26 @@ def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tens
748757
def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
749758

750759

760+
@impl(m, "quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor")
761+
@quantized_conv_variant("nchw", torch.int8, torch.int8, is_1d=True)
762+
def quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
763+
764+
765+
@impl(m, "quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor")
766+
@quantized_conv_variant("nchw", torch.uint8, torch.uint8, is_1d=True)
767+
def quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
768+
769+
770+
@impl(m, "quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor")
771+
@quantized_conv_variant("nhwc", torch.int8, torch.int8, is_1d=True)
772+
def quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
773+
774+
775+
@impl(m, "quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor")
776+
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8, is_1d=True)
777+
def quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
778+
779+
751780
def quantized_relu_common(
752781
X: torch.Tensor,
753782
X_zero_point: torch.Tensor | int,

backends/cadence/aot/tests/test_type_dispatch_passes.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,110 @@ def test_uint8_dispatch_quantized_conv_nhwc_dilated(self) -> None:
446446
1,
447447
)
448448

449+
def test_int8_dispatch_quantized_conv_nchw_1d(self) -> None:
450+
"""Test int8 x int8 inputs for 1D conv should dispatch to 1d_asym8sxasym8s_asym8s variant for quantized_conv_nchw"""
451+
x = torch.randint(-128, 127, (1, 3, 8), dtype=torch.int8)
452+
w = torch.randint(-128, 127, (16, 3, 3), dtype=torch.int8)
453+
b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32)
454+
gm = single_op_builder(
455+
placeholders=(x, w, b),
456+
op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor,
457+
args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1),
458+
)
459+
p = CompileTimeTypeDispatchPass()
460+
gm = cast(PassResult, p(gm)).graph_module
461+
# Original op should be replaced
462+
self.assertEqual(
463+
count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor),
464+
0,
465+
)
466+
# Should be replaced with 1D int8 specific variant
467+
self.assertEqual(
468+
count_node(
469+
gm,
470+
exir_ops.edge.cadence.quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor,
471+
),
472+
1,
473+
)
474+
475+
def test_uint8_dispatch_quantized_conv_nchw_1d(self) -> None:
476+
"""Test uint8 x uint8 inputs for 1D conv should dispatch to 1d_asym8uxasym8u_asym8u variant for quantized_conv_nchw"""
477+
x = torch.randint(0, 255, (1, 3, 8), dtype=torch.uint8)
478+
w = torch.randint(0, 255, (16, 3, 3), dtype=torch.uint8)
479+
b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32)
480+
gm = single_op_builder(
481+
placeholders=(x, w, b),
482+
op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor,
483+
args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1),
484+
)
485+
p = CompileTimeTypeDispatchPass()
486+
gm = cast(PassResult, p(gm)).graph_module
487+
# Original op should be replaced
488+
self.assertEqual(
489+
count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor),
490+
0,
491+
)
492+
# Should be replaced with 1D uint8 specific variant
493+
self.assertEqual(
494+
count_node(
495+
gm,
496+
exir_ops.edge.cadence.quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor,
497+
),
498+
1,
499+
)
500+
501+
def test_int8_dispatch_quantized_conv_nhwc_1d(self) -> None:
502+
"""Test int8 x int8 inputs for 1D conv should dispatch to 1d_asym8sxasym8s_asym8s variant for quantized_conv_nhwc"""
503+
x = torch.randint(-128, 127, (1, 8, 3), dtype=torch.int8)
504+
w = torch.randint(-128, 127, (16, 3, 3), dtype=torch.int8)
505+
b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32)
506+
gm = single_op_builder(
507+
placeholders=(x, w, b),
508+
op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor,
509+
args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1),
510+
)
511+
p = CompileTimeTypeDispatchPass()
512+
gm = cast(PassResult, p(gm)).graph_module
513+
# Original op should be replaced
514+
self.assertEqual(
515+
count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor),
516+
0,
517+
)
518+
# Should be replaced with 1D int8 specific variant
519+
self.assertEqual(
520+
count_node(
521+
gm,
522+
exir_ops.edge.cadence.quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor,
523+
),
524+
1,
525+
)
526+
527+
def test_uint8_dispatch_quantized_conv_nhwc_1d(self) -> None:
528+
"""Test uint8 x uint8 inputs for 1D conv should dispatch to 1d_asym8uxasym8u_asym8u variant for quantized_conv_nhwc"""
529+
x = torch.randint(0, 255, (1, 8, 3), dtype=torch.uint8)
530+
w = torch.randint(0, 255, (16, 3, 3), dtype=torch.uint8)
531+
b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32)
532+
gm = single_op_builder(
533+
placeholders=(x, w, b),
534+
op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor,
535+
args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1),
536+
)
537+
p = CompileTimeTypeDispatchPass()
538+
gm = cast(PassResult, p(gm)).graph_module
539+
# Original op should be replaced
540+
self.assertEqual(
541+
count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor),
542+
0,
543+
)
544+
# Should be replaced with 1D uint8 specific variant
545+
self.assertEqual(
546+
count_node(
547+
gm,
548+
exir_ops.edge.cadence.quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor,
549+
),
550+
1,
551+
)
552+
449553
def test_int8_dispatch_quantized_add(self) -> None:
450554
"""Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_add"""
451555
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)

0 commit comments

Comments
 (0)