Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2244,10 +2244,10 @@ def avg_pool2d_meta(
kernel_size: Tuple[int],
stride: Tuple[int],
padding: Tuple[int],
ceil_mode: bool,
count_include_pad: Optional[bool] = True,
ceil_mode: bool = False,
count_include_pad: bool = True,
divisor_override: Optional[int] = None,
in_zero_point: Optional[int] = None,
in_zero_point: Optional[torch.Tensor] = None,
channel_last: bool = False,
) -> torch.Tensor:
# Use torch native meta kernels when operator semantics are similar
Expand Down
52 changes: 52 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,58 @@ def convolution(
return conv_out


@impl(m, "avg_pool2d")
def avg_pool2d(
input_tensor: torch.Tensor,
kernel_size: tuple[int, int],
stride: tuple[int, int],
padding: tuple[int, int],
ceil_mode: bool = False,
count_include_pad: bool = False,
divisor_override: int | None = None,
in_zero_point: torch.Tensor | None = None,
channel_last: bool = False,
) -> torch.Tensor:
if channel_last:
raise NotImplementedError("Channel last is not yet supported for avg_pool2d")

in_dtype = input_tensor.dtype
pad_h, pad_w = padding
if in_zero_point is not None:
# Avg pool2d does not allow non-0 padding,
# so we manually pad the input
pad_value = in_zero_point.item()
if not count_include_pad:
# To simulate this, just pad with 0s
pad_value = 0

input_tensor = torch.nn.functional.pad(
input_tensor,
(pad_w, pad_w, pad_h, pad_h),
mode="constant",
value=pad_value,
).float()

padding = (0, 0)

out = torch.nn.functional.avg_pool2d(
input_tensor,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
)

if in_zero_point is not None:
min_val = torch.iinfo(in_dtype).min
max_val = torch.iinfo(in_dtype).max
out = torch.clamp(torch.round(out), min_val, max_val)

return out.to(in_dtype)


def quantized_relu_common(
X: torch.Tensor,
X_zero_point: torch.Tensor | int,
Expand Down
173 changes: 173 additions & 0 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,3 +1533,176 @@ def test_convolution(
torch.equal(output, expected_output),
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
)

@expand(
[
# Basic non-quantized average pooling
(
"basic_non_quantized",
torch.tensor(
[
[
[
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0],
]
]
],
dtype=torch.float32,
), # input: 1x1x4x4
(2, 2), # kernel_size
(2, 2), # stride
(0, 0), # padding
False, # ceil_mode
False, # count_include_pad
None, # divisor_override
None, # in_zero_point (non-quantized)
False, # channel_last
torch.tensor(
[[[[3.5, 5.5], [11.5, 13.5]]]], dtype=torch.float32
), # expected: average of 2x2 blocks
),
# Non-quantized with count_include_pad=True and padding
(
"non_quantized_count_include_pad",
torch.tensor(
[[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32
), # input: 1x1x2x2
(3, 3), # kernel_size (larger than input)
(1, 1), # stride
(1, 1), # padding
False, # ceil_mode
True, # count_include_pad=True
None, # divisor_override
None, # in_zero_point (non-quantized)
False, # channel_last
torch.tensor(
[[[[2.5, 2.5], [2.5, 2.5]]]],
dtype=torch.float32,
),
),
# Non-quantized with divisor_override
(
"non_quantized_divisor_override",
torch.tensor(
[[[[2.0, 4.0], [6.0, 8.0]]]], dtype=torch.float32
), # input: 1x1x2x2
(2, 2), # kernel_size
(1, 1), # stride
(0, 0), # padding
False, # ceil_mode
False, # count_include_pad
2, # divisor_override (instead of 4)
None, # in_zero_point (non-quantized)
False, # channel_last
torch.tensor(
[[[[10.0]]]], dtype=torch.float32
), # expected: (2+4+6+8)/2 = 10
),
# Quantized with non-zero zero_point and padding
(
"quantized_nonzero_zero_point",
torch.tensor(
[[[[130, 132], [134, 136]]]], dtype=torch.uint8
), # input: 1x1x2x2, values around zero_point=128
(3, 3), # kernel_size
(1, 1), # stride
(1, 1), # padding
False, # ceil_mode
True, # count_include_pad=True
None, # divisor_override
128, # in_zero_point=128 (padded areas will have this value)
False, # channel_last
torch.tensor(
[[[[130, 130], [130, 130]]]], dtype=torch.uint8
), # expected: averages including padded zero_point values
),
# Quantized with divisor_override
(
"quantized_divisor_override",
torch.tensor(
[[[[64, 96], [128, 160]]]], dtype=torch.float32
), # input: 1x1x2x2
(2, 2), # kernel_size
(1, 1), # stride
(0, 0), # padding
False, # ceil_mode
False, # count_include_pad
2, # divisor_override (instead of 4)
None, # in_zero_point=None
False, # channel_last
torch.tensor(
[[[[224]]]], dtype=torch.float32
), # expected: (64+96+128+160)/2 = 224
),
# Large values that need clamping
(
"quantized_clamping_test",
torch.tensor(
[[[[120, 125], [125, 127]]]], dtype=torch.int8
), # input: 1x1x2x2, large values for int8
(2, 2), # kernel_size
(1, 1), # stride
(0, 0), # padding
False, # ceil_mode
False, # count_include_pad
None, # divisor_override
0, # in_zero_point=0
False, # channel_last
torch.tensor(
[[[[124]]]], dtype=torch.int8
), # expected: (120+125+125+127)/4 = 124.25 -> 124, within int8 range
),
]
)
def test_avg_pool2d(
self,
name: str,
input_tensor: torch.Tensor,
kernel_size: tuple[int, int],
stride: tuple[int, int],
padding: tuple[int, int],
ceil_mode: bool,
count_include_pad: bool,
divisor_override: int | None,
in_zero_point: int | None,
channel_last: bool,
expected_output: torch.Tensor,
) -> None:
output = torch.ops.cadence.avg_pool2d(
input_tensor,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
in_zero_point if in_zero_point is None else torch.tensor([in_zero_point]),
channel_last,
)

# Verify output properties
self.assertEqual(
output.dtype,
input_tensor.dtype,
f"Output dtype should match input dtype in {name}",
)
self.assertEqual(
output.shape,
expected_output.shape,
f"Output shape should match expected shape in {name}",
)

# Verify output matches expected values
if input_tensor.dtype.is_floating_point:
self.assertTrue(
torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4),
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
)
else:
self.assertTrue(
torch.equal(output, expected_output),
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
)
Loading