@@ -1533,3 +1533,176 @@ def test_convolution(
15331533            torch .equal (output , expected_output ),
15341534            f"Output values don't match expected in { name }  . Got { output }  , expected { expected_output }  " ,
15351535        )
1536+ 
1537+     @expand ( 
1538+         [ 
1539+             # Basic non-quantized average pooling  
1540+             ( 
1541+                 "basic_non_quantized" , 
1542+                 torch .tensor ( 
1543+                     [ 
1544+                         [ 
1545+                             [ 
1546+                                 [1.0 , 2.0 , 3.0 , 4.0 ], 
1547+                                 [5.0 , 6.0 , 7.0 , 8.0 ], 
1548+                                 [9.0 , 10.0 , 11.0 , 12.0 ], 
1549+                                 [13.0 , 14.0 , 15.0 , 16.0 ], 
1550+                             ] 
1551+                         ] 
1552+                     ], 
1553+                     dtype = torch .float32 , 
1554+                 ),  # input: 1x1x4x4  
1555+                 (2 , 2 ),  # kernel_size  
1556+                 (2 , 2 ),  # stride  
1557+                 (0 , 0 ),  # padding  
1558+                 False ,  # ceil_mode  
1559+                 False ,  # count_include_pad  
1560+                 None ,  # divisor_override  
1561+                 None ,  # in_zero_point (non-quantized)  
1562+                 False ,  # channel_last  
1563+                 torch .tensor ( 
1564+                     [[[[3.5 , 5.5 ], [11.5 , 13.5 ]]]], dtype = torch .float32  
1565+                 ),  # expected: average of 2x2 blocks  
1566+             ), 
1567+             # Non-quantized with count_include_pad=True and padding  
1568+             ( 
1569+                 "non_quantized_count_include_pad" , 
1570+                 torch .tensor ( 
1571+                     [[[[1.0 , 2.0 ], [3.0 , 4.0 ]]]], dtype = torch .float32  
1572+                 ),  # input: 1x1x2x2  
1573+                 (3 , 3 ),  # kernel_size (larger than input)  
1574+                 (1 , 1 ),  # stride  
1575+                 (1 , 1 ),  # padding  
1576+                 False ,  # ceil_mode  
1577+                 True ,  # count_include_pad=True  
1578+                 None ,  # divisor_override  
1579+                 None ,  # in_zero_point (non-quantized)  
1580+                 False ,  # channel_last  
1581+                 torch .tensor ( 
1582+                     [[[[2.5 , 2.5 ], [2.5 , 2.5 ]]]], 
1583+                     dtype = torch .float32 , 
1584+                 ), 
1585+             ), 
1586+             # Non-quantized with divisor_override  
1587+             ( 
1588+                 "non_quantized_divisor_override" , 
1589+                 torch .tensor ( 
1590+                     [[[[2.0 , 4.0 ], [6.0 , 8.0 ]]]], dtype = torch .float32  
1591+                 ),  # input: 1x1x2x2  
1592+                 (2 , 2 ),  # kernel_size  
1593+                 (1 , 1 ),  # stride  
1594+                 (0 , 0 ),  # padding  
1595+                 False ,  # ceil_mode  
1596+                 False ,  # count_include_pad  
1597+                 2 ,  # divisor_override (instead of 4)  
1598+                 None ,  # in_zero_point (non-quantized)  
1599+                 False ,  # channel_last  
1600+                 torch .tensor ( 
1601+                     [[[[10.0 ]]]], dtype = torch .float32  
1602+                 ),  # expected: (2+4+6+8)/2 = 10  
1603+             ), 
1604+             # Quantized with non-zero zero_point and padding  
1605+             ( 
1606+                 "quantized_nonzero_zero_point" , 
1607+                 torch .tensor ( 
1608+                     [[[[130 , 132 ], [134 , 136 ]]]], dtype = torch .uint8  
1609+                 ),  # input: 1x1x2x2, values around zero_point=128  
1610+                 (3 , 3 ),  # kernel_size  
1611+                 (1 , 1 ),  # stride  
1612+                 (1 , 1 ),  # padding  
1613+                 False ,  # ceil_mode  
1614+                 True ,  # count_include_pad=True  
1615+                 None ,  # divisor_override  
1616+                 128 ,  # in_zero_point=128 (padded areas will have this value)  
1617+                 False ,  # channel_last  
1618+                 torch .tensor ( 
1619+                     [[[[130 , 130 ], [130 , 130 ]]]], dtype = torch .uint8  
1620+                 ),  # expected: averages including padded zero_point values  
1621+             ), 
1622+             # Quantized with divisor_override  
1623+             ( 
1624+                 "quantized_divisor_override" , 
1625+                 torch .tensor ( 
1626+                     [[[[64 , 96 ], [128 , 160 ]]]], dtype = torch .float32  
1627+                 ),  # input: 1x1x2x2  
1628+                 (2 , 2 ),  # kernel_size  
1629+                 (1 , 1 ),  # stride  
1630+                 (0 , 0 ),  # padding  
1631+                 False ,  # ceil_mode  
1632+                 False ,  # count_include_pad  
1633+                 2 ,  # divisor_override (instead of 4)  
1634+                 None ,  # in_zero_point=None  
1635+                 False ,  # channel_last  
1636+                 torch .tensor ( 
1637+                     [[[[224 ]]]], dtype = torch .float32  
1638+                 ),  # expected: (64+96+128+160)/2 = 224  
1639+             ), 
1640+             # Large values that need clamping  
1641+             ( 
1642+                 "quantized_clamping_test" , 
1643+                 torch .tensor ( 
1644+                     [[[[120 , 125 ], [125 , 127 ]]]], dtype = torch .int8  
1645+                 ),  # input: 1x1x2x2, large values for int8  
1646+                 (2 , 2 ),  # kernel_size  
1647+                 (1 , 1 ),  # stride  
1648+                 (0 , 0 ),  # padding  
1649+                 False ,  # ceil_mode  
1650+                 False ,  # count_include_pad  
1651+                 None ,  # divisor_override  
1652+                 0 ,  # in_zero_point=0  
1653+                 False ,  # channel_last  
1654+                 torch .tensor ( 
1655+                     [[[[124 ]]]], dtype = torch .int8  
1656+                 ),  # expected: (120+125+125+127)/4 = 124.25 -> 124, within int8 range  
1657+             ), 
1658+         ] 
1659+     ) 
1660+     def  test_avg_pool2d (
1661+         self ,
1662+         name : str ,
1663+         input_tensor : torch .Tensor ,
1664+         kernel_size : tuple [int , int ],
1665+         stride : tuple [int , int ],
1666+         padding : tuple [int , int ],
1667+         ceil_mode : bool ,
1668+         count_include_pad : bool ,
1669+         divisor_override : int  |  None ,
1670+         in_zero_point : int  |  None ,
1671+         channel_last : bool ,
1672+         expected_output : torch .Tensor ,
1673+     ) ->  None :
1674+         output  =  torch .ops .cadence .avg_pool2d (
1675+             input_tensor ,
1676+             kernel_size ,
1677+             stride ,
1678+             padding ,
1679+             ceil_mode ,
1680+             count_include_pad ,
1681+             divisor_override ,
1682+             in_zero_point  if  in_zero_point  is  None  else  torch .tensor ([in_zero_point ]),
1683+             channel_last ,
1684+         )
1685+ 
1686+         # Verify output properties 
1687+         self .assertEqual (
1688+             output .dtype ,
1689+             input_tensor .dtype ,
1690+             f"Output dtype should match input dtype in { name }  " ,
1691+         )
1692+         self .assertEqual (
1693+             output .shape ,
1694+             expected_output .shape ,
1695+             f"Output shape should match expected shape in { name }  " ,
1696+         )
1697+ 
1698+         # Verify output matches expected values 
1699+         if  input_tensor .dtype .is_floating_point :
1700+             self .assertTrue (
1701+                 torch .allclose (output , expected_output , rtol = 1e-4 , atol = 1e-4 ),
1702+                 f"Output values don't match expected in { name }  . Got { output }  , expected { expected_output }  " ,
1703+             )
1704+         else :
1705+             self .assertTrue (
1706+                 torch .equal (output , expected_output ),
1707+                 f"Output values don't match expected in { name }  . Got { output }  , expected { expected_output }  " ,
1708+             )
0 commit comments