@@ -1533,3 +1533,176 @@ def test_convolution(
1533
1533
torch .equal (output , expected_output ),
1534
1534
f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1535
1535
)
1536
+
1537
+ @expand (
1538
+ [
1539
+ # Basic non-quantized average pooling
1540
+ (
1541
+ "basic_non_quantized" ,
1542
+ torch .tensor (
1543
+ [
1544
+ [
1545
+ [
1546
+ [1.0 , 2.0 , 3.0 , 4.0 ],
1547
+ [5.0 , 6.0 , 7.0 , 8.0 ],
1548
+ [9.0 , 10.0 , 11.0 , 12.0 ],
1549
+ [13.0 , 14.0 , 15.0 , 16.0 ],
1550
+ ]
1551
+ ]
1552
+ ],
1553
+ dtype = torch .float32 ,
1554
+ ), # input: 1x1x4x4
1555
+ (2 , 2 ), # kernel_size
1556
+ (2 , 2 ), # stride
1557
+ (0 , 0 ), # padding
1558
+ False , # ceil_mode
1559
+ False , # count_include_pad
1560
+ None , # divisor_override
1561
+ None , # in_zero_point (non-quantized)
1562
+ False , # channel_last
1563
+ torch .tensor (
1564
+ [[[[3.5 , 5.5 ], [11.5 , 13.5 ]]]], dtype = torch .float32
1565
+ ), # expected: average of 2x2 blocks
1566
+ ),
1567
+ # Non-quantized with count_include_pad=True and padding
1568
+ (
1569
+ "non_quantized_count_include_pad" ,
1570
+ torch .tensor (
1571
+ [[[[1.0 , 2.0 ], [3.0 , 4.0 ]]]], dtype = torch .float32
1572
+ ), # input: 1x1x2x2
1573
+ (3 , 3 ), # kernel_size (larger than input)
1574
+ (1 , 1 ), # stride
1575
+ (1 , 1 ), # padding
1576
+ False , # ceil_mode
1577
+ True , # count_include_pad=True
1578
+ None , # divisor_override
1579
+ None , # in_zero_point (non-quantized)
1580
+ False , # channel_last
1581
+ torch .tensor (
1582
+ [[[[2.5 , 2.5 ], [2.5 , 2.5 ]]]],
1583
+ dtype = torch .float32 ,
1584
+ ),
1585
+ ),
1586
+ # Non-quantized with divisor_override
1587
+ (
1588
+ "non_quantized_divisor_override" ,
1589
+ torch .tensor (
1590
+ [[[[2.0 , 4.0 ], [6.0 , 8.0 ]]]], dtype = torch .float32
1591
+ ), # input: 1x1x2x2
1592
+ (2 , 2 ), # kernel_size
1593
+ (1 , 1 ), # stride
1594
+ (0 , 0 ), # padding
1595
+ False , # ceil_mode
1596
+ False , # count_include_pad
1597
+ 2 , # divisor_override (instead of 4)
1598
+ None , # in_zero_point (non-quantized)
1599
+ False , # channel_last
1600
+ torch .tensor (
1601
+ [[[[10.0 ]]]], dtype = torch .float32
1602
+ ), # expected: (2+4+6+8)/2 = 10
1603
+ ),
1604
+ # Quantized with non-zero zero_point and padding
1605
+ (
1606
+ "quantized_nonzero_zero_point" ,
1607
+ torch .tensor (
1608
+ [[[[130 , 132 ], [134 , 136 ]]]], dtype = torch .uint8
1609
+ ), # input: 1x1x2x2, values around zero_point=128
1610
+ (3 , 3 ), # kernel_size
1611
+ (1 , 1 ), # stride
1612
+ (1 , 1 ), # padding
1613
+ False , # ceil_mode
1614
+ True , # count_include_pad=True
1615
+ None , # divisor_override
1616
+ 128 , # in_zero_point=128 (padded areas will have this value)
1617
+ False , # channel_last
1618
+ torch .tensor (
1619
+ [[[[130 , 130 ], [130 , 130 ]]]], dtype = torch .uint8
1620
+ ), # expected: averages including padded zero_point values
1621
+ ),
1622
+ # Quantized with divisor_override
1623
+ (
1624
+ "quantized_divisor_override" ,
1625
+ torch .tensor (
1626
+ [[[[64 , 96 ], [128 , 160 ]]]], dtype = torch .float32
1627
+ ), # input: 1x1x2x2
1628
+ (2 , 2 ), # kernel_size
1629
+ (1 , 1 ), # stride
1630
+ (0 , 0 ), # padding
1631
+ False , # ceil_mode
1632
+ False , # count_include_pad
1633
+ 2 , # divisor_override (instead of 4)
1634
+ None , # in_zero_point=None
1635
+ False , # channel_last
1636
+ torch .tensor (
1637
+ [[[[224 ]]]], dtype = torch .float32
1638
+ ), # expected: (64+96+128+160)/2 = 224
1639
+ ),
1640
+ # Large values that need clamping
1641
+ (
1642
+ "quantized_clamping_test" ,
1643
+ torch .tensor (
1644
+ [[[[120 , 125 ], [125 , 127 ]]]], dtype = torch .int8
1645
+ ), # input: 1x1x2x2, large values for int8
1646
+ (2 , 2 ), # kernel_size
1647
+ (1 , 1 ), # stride
1648
+ (0 , 0 ), # padding
1649
+ False , # ceil_mode
1650
+ False , # count_include_pad
1651
+ None , # divisor_override
1652
+ 0 , # in_zero_point=0
1653
+ False , # channel_last
1654
+ torch .tensor (
1655
+ [[[[124 ]]]], dtype = torch .int8
1656
+ ), # expected: (120+125+125+127)/4 = 124.25 -> 124, within int8 range
1657
+ ),
1658
+ ]
1659
+ )
1660
+ def test_avg_pool2d (
1661
+ self ,
1662
+ name : str ,
1663
+ input_tensor : torch .Tensor ,
1664
+ kernel_size : tuple [int , int ],
1665
+ stride : tuple [int , int ],
1666
+ padding : tuple [int , int ],
1667
+ ceil_mode : bool ,
1668
+ count_include_pad : bool ,
1669
+ divisor_override : int | None ,
1670
+ in_zero_point : int | None ,
1671
+ channel_last : bool ,
1672
+ expected_output : torch .Tensor ,
1673
+ ) -> None :
1674
+ output = torch .ops .cadence .avg_pool2d (
1675
+ input_tensor ,
1676
+ kernel_size ,
1677
+ stride ,
1678
+ padding ,
1679
+ ceil_mode ,
1680
+ count_include_pad ,
1681
+ divisor_override ,
1682
+ in_zero_point if in_zero_point is None else torch .tensor ([in_zero_point ]),
1683
+ channel_last ,
1684
+ )
1685
+
1686
+ # Verify output properties
1687
+ self .assertEqual (
1688
+ output .dtype ,
1689
+ input_tensor .dtype ,
1690
+ f"Output dtype should match input dtype in { name } " ,
1691
+ )
1692
+ self .assertEqual (
1693
+ output .shape ,
1694
+ expected_output .shape ,
1695
+ f"Output shape should match expected shape in { name } " ,
1696
+ )
1697
+
1698
+ # Verify output matches expected values
1699
+ if input_tensor .dtype .is_floating_point :
1700
+ self .assertTrue (
1701
+ torch .allclose (output , expected_output , rtol = 1e-4 , atol = 1e-4 ),
1702
+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1703
+ )
1704
+ else :
1705
+ self .assertTrue (
1706
+ torch .equal (output , expected_output ),
1707
+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1708
+ )
0 commit comments