@@ -1533,3 +1533,176 @@ def test_convolution(
15331533 torch .equal (output , expected_output ),
15341534 f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
15351535 )
1536+
1537+ @expand (
1538+ [
1539+ # Basic non-quantized average pooling
1540+ (
1541+ "basic_non_quantized" ,
1542+ torch .tensor (
1543+ [
1544+ [
1545+ [
1546+ [1.0 , 2.0 , 3.0 , 4.0 ],
1547+ [5.0 , 6.0 , 7.0 , 8.0 ],
1548+ [9.0 , 10.0 , 11.0 , 12.0 ],
1549+ [13.0 , 14.0 , 15.0 , 16.0 ],
1550+ ]
1551+ ]
1552+ ],
1553+ dtype = torch .float32 ,
1554+ ), # input: 1x1x4x4
1555+ (2 , 2 ), # kernel_size
1556+ (2 , 2 ), # stride
1557+ (0 , 0 ), # padding
1558+ False , # ceil_mode
1559+ False , # count_include_pad
1560+ None , # divisor_override
1561+ None , # in_zero_point (non-quantized)
1562+ False , # channel_last
1563+ torch .tensor (
1564+ [[[[3.5 , 5.5 ], [11.5 , 13.5 ]]]], dtype = torch .float32
1565+ ), # expected: average of 2x2 blocks
1566+ ),
1567+ # Non-quantized with count_include_pad=True and padding
1568+ (
1569+ "non_quantized_count_include_pad" ,
1570+ torch .tensor (
1571+ [[[[1.0 , 2.0 ], [3.0 , 4.0 ]]]], dtype = torch .float32
1572+ ), # input: 1x1x2x2
1573+ (3 , 3 ), # kernel_size (larger than input)
1574+ (1 , 1 ), # stride
1575+ (1 , 1 ), # padding
1576+ False , # ceil_mode
1577+ True , # count_include_pad=True
1578+ None , # divisor_override
1579+ None , # in_zero_point (non-quantized)
1580+ False , # channel_last
1581+ torch .tensor (
1582+ [[[[2.5 , 2.5 ], [2.5 , 2.5 ]]]],
1583+ dtype = torch .float32 ,
1584+ ),
1585+ ),
1586+ # Non-quantized with divisor_override
1587+ (
1588+ "non_quantized_divisor_override" ,
1589+ torch .tensor (
1590+ [[[[2.0 , 4.0 ], [6.0 , 8.0 ]]]], dtype = torch .float32
1591+ ), # input: 1x1x2x2
1592+ (2 , 2 ), # kernel_size
1593+ (1 , 1 ), # stride
1594+ (0 , 0 ), # padding
1595+ False , # ceil_mode
1596+ False , # count_include_pad
1597+ 2 , # divisor_override (instead of 4)
1598+ None , # in_zero_point (non-quantized)
1599+ False , # channel_last
1600+ torch .tensor (
1601+ [[[[10.0 ]]]], dtype = torch .float32
1602+ ), # expected: (2+4+6+8)/2 = 10
1603+ ),
1604+ # Quantized with non-zero zero_point and padding
1605+ (
1606+ "quantized_nonzero_zero_point" ,
1607+ torch .tensor (
1608+ [[[[130 , 132 ], [134 , 136 ]]]], dtype = torch .uint8
1609+ ), # input: 1x1x2x2, values around zero_point=128
1610+ (3 , 3 ), # kernel_size
1611+ (1 , 1 ), # stride
1612+ (1 , 1 ), # padding
1613+ False , # ceil_mode
1614+ True , # count_include_pad=True
1615+ None , # divisor_override
1616+ 128 , # in_zero_point=128 (padded areas will have this value)
1617+ False , # channel_last
1618+ torch .tensor (
1619+ [[[[130 , 130 ], [130 , 130 ]]]], dtype = torch .uint8
1620+ ), # expected: averages including padded zero_point values
1621+ ),
1622+ # Quantized with divisor_override
1623+ (
1624+ "quantized_divisor_override" ,
1625+ torch .tensor (
1626+ [[[[64 , 96 ], [128 , 160 ]]]], dtype = torch .float32
1627+ ), # input: 1x1x2x2
1628+ (2 , 2 ), # kernel_size
1629+ (1 , 1 ), # stride
1630+ (0 , 0 ), # padding
1631+ False , # ceil_mode
1632+ False , # count_include_pad
1633+ 2 , # divisor_override (instead of 4)
1634+ None , # in_zero_point=None
1635+ False , # channel_last
1636+ torch .tensor (
1637+ [[[[224 ]]]], dtype = torch .float32
1638+ ), # expected: (64+96+128+160)/2 = 224
1639+ ),
1640+ # Large values that need clamping
1641+ (
1642+ "quantized_clamping_test" ,
1643+ torch .tensor (
1644+ [[[[120 , 125 ], [125 , 127 ]]]], dtype = torch .int8
1645+ ), # input: 1x1x2x2, large values for int8
1646+ (2 , 2 ), # kernel_size
1647+ (1 , 1 ), # stride
1648+ (0 , 0 ), # padding
1649+ False , # ceil_mode
1650+ False , # count_include_pad
1651+ None , # divisor_override
1652+ 0 , # in_zero_point=0
1653+ False , # channel_last
1654+ torch .tensor (
1655+ [[[[124 ]]]], dtype = torch .int8
1656+ ), # expected: (120+125+125+127)/4 = 124.25 -> 124, within int8 range
1657+ ),
1658+ ]
1659+ )
1660+ def test_avg_pool2d (
1661+ self ,
1662+ name : str ,
1663+ input_tensor : torch .Tensor ,
1664+ kernel_size : tuple [int , int ],
1665+ stride : tuple [int , int ],
1666+ padding : tuple [int , int ],
1667+ ceil_mode : bool ,
1668+ count_include_pad : bool ,
1669+ divisor_override : int | None ,
1670+ in_zero_point : int | None ,
1671+ channel_last : bool ,
1672+ expected_output : torch .Tensor ,
1673+ ) -> None :
1674+ output = torch .ops .cadence .avg_pool2d (
1675+ input_tensor ,
1676+ kernel_size ,
1677+ stride ,
1678+ padding ,
1679+ ceil_mode ,
1680+ count_include_pad ,
1681+ divisor_override ,
1682+ in_zero_point if in_zero_point is None else torch .tensor ([in_zero_point ]),
1683+ channel_last ,
1684+ )
1685+
1686+ # Verify output properties
1687+ self .assertEqual (
1688+ output .dtype ,
1689+ input_tensor .dtype ,
1690+ f"Output dtype should match input dtype in { name } " ,
1691+ )
1692+ self .assertEqual (
1693+ output .shape ,
1694+ expected_output .shape ,
1695+ f"Output shape should match expected shape in { name } " ,
1696+ )
1697+
1698+ # Verify output matches expected values
1699+ if input_tensor .dtype .is_floating_point :
1700+ self .assertTrue (
1701+ torch .allclose (output , expected_output , rtol = 1e-4 , atol = 1e-4 ),
1702+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1703+ )
1704+ else :
1705+ self .assertTrue (
1706+ torch .equal (output , expected_output ),
1707+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1708+ )
0 commit comments