Skip to content

Commit a731206

Browse files
authored
Adding avgpool2d
Differential Revision: D83579062 Pull Request resolved: #14703
1 parent 8f7d045 commit a731206

File tree

3 files changed

+228
-3
lines changed

3 files changed

+228
-3
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,10 +2244,10 @@ def avg_pool2d_meta(
22442244
kernel_size: Tuple[int],
22452245
stride: Tuple[int],
22462246
padding: Tuple[int],
2247-
ceil_mode: bool,
2248-
count_include_pad: Optional[bool] = True,
2247+
ceil_mode: bool = False,
2248+
count_include_pad: bool = True,
22492249
divisor_override: Optional[int] = None,
2250-
in_zero_point: Optional[int] = None,
2250+
in_zero_point: Optional[torch.Tensor] = None,
22512251
channel_last: bool = False,
22522252
) -> torch.Tensor:
22532253
# Use torch native meta kernels when operator semantics are similar

backends/cadence/aot/ref_implementations.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,58 @@ def convolution(
978978
return conv_out
979979

980980

981+
@impl(m, "avg_pool2d")
982+
def avg_pool2d(
983+
input_tensor: torch.Tensor,
984+
kernel_size: tuple[int, int],
985+
stride: tuple[int, int],
986+
padding: tuple[int, int],
987+
ceil_mode: bool = False,
988+
count_include_pad: bool = False,
989+
divisor_override: int | None = None,
990+
in_zero_point: torch.Tensor | None = None,
991+
channel_last: bool = False,
992+
) -> torch.Tensor:
993+
if channel_last:
994+
raise NotImplementedError("Channel last is not yet supported for avg_pool2d")
995+
996+
in_dtype = input_tensor.dtype
997+
pad_h, pad_w = padding
998+
if in_zero_point is not None:
999+
# Avg pool2d does not allow non-0 padding,
1000+
# so we manually pad the input
1001+
pad_value = in_zero_point.item()
1002+
if not count_include_pad:
1003+
# To simulate this, just pad with 0s
1004+
pad_value = 0
1005+
1006+
input_tensor = torch.nn.functional.pad(
1007+
input_tensor,
1008+
(pad_w, pad_w, pad_h, pad_h),
1009+
mode="constant",
1010+
value=pad_value,
1011+
).float()
1012+
1013+
padding = (0, 0)
1014+
1015+
out = torch.nn.functional.avg_pool2d(
1016+
input_tensor,
1017+
kernel_size,
1018+
stride,
1019+
padding,
1020+
ceil_mode,
1021+
count_include_pad,
1022+
divisor_override,
1023+
)
1024+
1025+
if in_zero_point is not None:
1026+
min_val = torch.iinfo(in_dtype).min
1027+
max_val = torch.iinfo(in_dtype).max
1028+
out = torch.clamp(torch.round(out), min_val, max_val)
1029+
1030+
return out.to(in_dtype)
1031+
1032+
9811033
def quantized_relu_common(
9821034
X: torch.Tensor,
9831035
X_zero_point: torch.Tensor | int,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,3 +1533,176 @@ def test_convolution(
15331533
torch.equal(output, expected_output),
15341534
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
15351535
)
1536+
1537+
@expand(
1538+
[
1539+
# Basic non-quantized average pooling
1540+
(
1541+
"basic_non_quantized",
1542+
torch.tensor(
1543+
[
1544+
[
1545+
[
1546+
[1.0, 2.0, 3.0, 4.0],
1547+
[5.0, 6.0, 7.0, 8.0],
1548+
[9.0, 10.0, 11.0, 12.0],
1549+
[13.0, 14.0, 15.0, 16.0],
1550+
]
1551+
]
1552+
],
1553+
dtype=torch.float32,
1554+
), # input: 1x1x4x4
1555+
(2, 2), # kernel_size
1556+
(2, 2), # stride
1557+
(0, 0), # padding
1558+
False, # ceil_mode
1559+
False, # count_include_pad
1560+
None, # divisor_override
1561+
None, # in_zero_point (non-quantized)
1562+
False, # channel_last
1563+
torch.tensor(
1564+
[[[[3.5, 5.5], [11.5, 13.5]]]], dtype=torch.float32
1565+
), # expected: average of 2x2 blocks
1566+
),
1567+
# Non-quantized with count_include_pad=True and padding
1568+
(
1569+
"non_quantized_count_include_pad",
1570+
torch.tensor(
1571+
[[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32
1572+
), # input: 1x1x2x2
1573+
(3, 3), # kernel_size (larger than input)
1574+
(1, 1), # stride
1575+
(1, 1), # padding
1576+
False, # ceil_mode
1577+
True, # count_include_pad=True
1578+
None, # divisor_override
1579+
None, # in_zero_point (non-quantized)
1580+
False, # channel_last
1581+
torch.tensor(
1582+
[[[[2.5, 2.5], [2.5, 2.5]]]],
1583+
dtype=torch.float32,
1584+
),
1585+
),
1586+
# Non-quantized with divisor_override
1587+
(
1588+
"non_quantized_divisor_override",
1589+
torch.tensor(
1590+
[[[[2.0, 4.0], [6.0, 8.0]]]], dtype=torch.float32
1591+
), # input: 1x1x2x2
1592+
(2, 2), # kernel_size
1593+
(1, 1), # stride
1594+
(0, 0), # padding
1595+
False, # ceil_mode
1596+
False, # count_include_pad
1597+
2, # divisor_override (instead of 4)
1598+
None, # in_zero_point (non-quantized)
1599+
False, # channel_last
1600+
torch.tensor(
1601+
[[[[10.0]]]], dtype=torch.float32
1602+
), # expected: (2+4+6+8)/2 = 10
1603+
),
1604+
# Quantized with non-zero zero_point and padding
1605+
(
1606+
"quantized_nonzero_zero_point",
1607+
torch.tensor(
1608+
[[[[130, 132], [134, 136]]]], dtype=torch.uint8
1609+
), # input: 1x1x2x2, values around zero_point=128
1610+
(3, 3), # kernel_size
1611+
(1, 1), # stride
1612+
(1, 1), # padding
1613+
False, # ceil_mode
1614+
True, # count_include_pad=True
1615+
None, # divisor_override
1616+
128, # in_zero_point=128 (padded areas will have this value)
1617+
False, # channel_last
1618+
torch.tensor(
1619+
[[[[130, 130], [130, 130]]]], dtype=torch.uint8
1620+
), # expected: averages including padded zero_point values
1621+
),
1622+
# Quantized with divisor_override
1623+
(
1624+
"quantized_divisor_override",
1625+
torch.tensor(
1626+
[[[[64, 96], [128, 160]]]], dtype=torch.float32
1627+
), # input: 1x1x2x2
1628+
(2, 2), # kernel_size
1629+
(1, 1), # stride
1630+
(0, 0), # padding
1631+
False, # ceil_mode
1632+
False, # count_include_pad
1633+
2, # divisor_override (instead of 4)
1634+
None, # in_zero_point=None
1635+
False, # channel_last
1636+
torch.tensor(
1637+
[[[[224]]]], dtype=torch.float32
1638+
), # expected: (64+96+128+160)/2 = 224
1639+
),
1640+
# Large values that need clamping
1641+
(
1642+
"quantized_clamping_test",
1643+
torch.tensor(
1644+
[[[[120, 125], [125, 127]]]], dtype=torch.int8
1645+
), # input: 1x1x2x2, large values for int8
1646+
(2, 2), # kernel_size
1647+
(1, 1), # stride
1648+
(0, 0), # padding
1649+
False, # ceil_mode
1650+
False, # count_include_pad
1651+
None, # divisor_override
1652+
0, # in_zero_point=0
1653+
False, # channel_last
1654+
torch.tensor(
1655+
[[[[124]]]], dtype=torch.int8
1656+
), # expected: (120+125+125+127)/4 = 124.25 -> 124, within int8 range
1657+
),
1658+
]
1659+
)
1660+
def test_avg_pool2d(
1661+
self,
1662+
name: str,
1663+
input_tensor: torch.Tensor,
1664+
kernel_size: tuple[int, int],
1665+
stride: tuple[int, int],
1666+
padding: tuple[int, int],
1667+
ceil_mode: bool,
1668+
count_include_pad: bool,
1669+
divisor_override: int | None,
1670+
in_zero_point: int | None,
1671+
channel_last: bool,
1672+
expected_output: torch.Tensor,
1673+
) -> None:
1674+
output = torch.ops.cadence.avg_pool2d(
1675+
input_tensor,
1676+
kernel_size,
1677+
stride,
1678+
padding,
1679+
ceil_mode,
1680+
count_include_pad,
1681+
divisor_override,
1682+
in_zero_point if in_zero_point is None else torch.tensor([in_zero_point]),
1683+
channel_last,
1684+
)
1685+
1686+
# Verify output properties
1687+
self.assertEqual(
1688+
output.dtype,
1689+
input_tensor.dtype,
1690+
f"Output dtype should match input dtype in {name}",
1691+
)
1692+
self.assertEqual(
1693+
output.shape,
1694+
expected_output.shape,
1695+
f"Output shape should match expected shape in {name}",
1696+
)
1697+
1698+
# Verify output matches expected values
1699+
if input_tensor.dtype.is_floating_point:
1700+
self.assertTrue(
1701+
torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4),
1702+
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
1703+
)
1704+
else:
1705+
self.assertTrue(
1706+
torch.equal(output, expected_output),
1707+
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
1708+
)

0 commit comments

Comments
 (0)