Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,21 @@ def quantized_conv_nhwc_per_tensor(
- out_shift (int): Unused
"""

if not input_tensor.is_contiguous(memory_format=torch.channels_last):
raise ValueError("Input tensor must be in NHWC format")
# Convert to NCHW format to reuse the existing implementation
conv_is_1d = False
if len(input_tensor.shape) == 3:
conv_is_1d = True
input_tensor = input_tensor.movedim(-1, 1).contiguous()
if len(weight.shape) != 3:
raise ValueError("Weight tensor must be 3D if input is 3D")
weight = weight.movedim(-1, 1).contiguous()
else:
input_tensor = input_tensor.movedim(-1, -3)
if len(weight.shape) != 4:
raise ValueError("Weight tensor must be 4D if input is nd > 3")
weight = torch.permute(weight, (0, -1, 1, 2)).contiguous()

return quantized_conv_per_tensor(
nchw_out = quantized_conv_per_tensor(
input_tensor,
weight,
bias,
Expand All @@ -478,6 +489,11 @@ def quantized_conv_nhwc_per_tensor(
out_shift,
)

if conv_is_1d:
return nchw_out.movedim(1, -1).contiguous()
else:
return nchw_out.movedim(-3, -1).contiguous()


def quantized_conv_variant(
layout: str,
Expand Down
25 changes: 17 additions & 8 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def test_quantized_layer_norm_per_tensor(
), # expected_output: [1+2, 2+3, 3+4] / 0.5 = [6, 10, 14]
memory_format,
)
for memory_format in [torch.contiguous_format]
for memory_format in [torch.contiguous_format, torch.channels_last]
],
# Test case 5: Multiple output channels
*[
Expand Down Expand Up @@ -686,10 +686,13 @@ def test_quantized_conv_per_tensor(
) -> None:
assert memory_format in [torch.contiguous_format, torch.channels_last]

if len(input_tensor.shape) == 3 and memory_format == torch.channels_last:
self.fail("Channels last format is not supported for 3D input tensors")

input_tensor = input_tensor.to(memory_format=memory_format)
if memory_format == torch.channels_last:
if input_tensor.ndim == 3:
input_tensor = input_tensor.movedim(1, -1)
weight = weight.movedim(1, -1)
else:
input_tensor = input_tensor.movedim(-3, -1)
weight = weight.movedim(-3, -1)

convs = [
(
Expand All @@ -701,7 +704,7 @@ def test_quantized_conv_per_tensor(

optimized_convs = []
if input_tensor.dtype == torch.int8 and weight.dtype == torch.int8:
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
if memory_format == torch.contiguous_format:
optimized_convs = [
torch.ops.cadence.quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor,
torch.ops.cadence.quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor,
Expand All @@ -715,7 +718,7 @@ def test_quantized_conv_per_tensor(
torch.ops.cadence.quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor,
]
elif input_tensor.dtype == torch.uint8 and weight.dtype == torch.uint8:
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
if memory_format == torch.contiguous_format:
optimized_convs = [
torch.ops.cadence.quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor,
torch.ops.cadence.quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor,
Expand Down Expand Up @@ -746,7 +749,13 @@ def test_quantized_conv_per_tensor(
output_zero_point,
out_multiplier,
out_shift,
).to(memory_format=torch.contiguous_format)
)

if memory_format == torch.channels_last:
if input_tensor.ndim == 3:
output = output.movedim(-1, 1)
else:
output = output.movedim(-1, -3)

# Verify output properties
self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}")
Expand Down
Loading