@@ -1534,6 +1534,143 @@ def test_convolution(
1534
1534
f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1535
1535
)
1536
1536
1537
+ @expand (
1538
+ [
1539
+ # Basic 2D transposed convolution with stride=1 (current test case - corrected name)
1540
+ (
1541
+ "basic_2d_stride1" ,
1542
+ torch .tensor (
1543
+ [[[[1.0 , 2.0 ], [3.0 , 4.0 ]]]], dtype = torch .float32
1544
+ ), # input: 1x1x2x2
1545
+ torch .tensor (
1546
+ [[[[1.0 , 1.0 ], [1.0 , 1.0 ]]]], dtype = torch .float32
1547
+ ), # weight: 1x1x2x2
1548
+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1549
+ (1 , 1 ), # stride
1550
+ (0 , 0 ), # padding
1551
+ (1 , 1 ), # dilation
1552
+ 1 , # groups
1553
+ (0 , 0 ), # output_padding
1554
+ False , # channel_last
1555
+ torch .tensor (
1556
+ [[[[1.0 , 3.0 , 2.0 ], [4.0 , 10.0 , 6.0 ], [3.0 , 7.0 , 4.0 ]]]],
1557
+ dtype = torch .float32 ,
1558
+ ),
1559
+ ),
1560
+ # 2D transposed convolution with channel_last=True (NHWC format)
1561
+ (
1562
+ "channel_last_nhwc" ,
1563
+ torch .tensor (
1564
+ [[[[1.0 ], [2.0 ]], [[3.0 ], [4.0 ]]]], dtype = torch .float32
1565
+ ), # input: 1x2x2x1 (NHWC)
1566
+ torch .tensor (
1567
+ [[[[1.0 ], [1.0 ]], [[1.0 ], [1.0 ]]]], dtype = torch .float32
1568
+ ), # weight: 1x2x2x1 (NHWC)
1569
+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1570
+ (1 , 1 ), # stride
1571
+ (0 , 0 ), # padding
1572
+ (1 , 1 ), # dilation
1573
+ 1 , # groups
1574
+ (0 , 0 ), # output_padding
1575
+ True , # channel_last=True
1576
+ torch .tensor (
1577
+ [
1578
+ [
1579
+ [[1.0 ], [3.0 ], [2.0 ]],
1580
+ [[4.0 ], [10.0 ], [6.0 ]],
1581
+ [[3.0 ], [7.0 ], [4.0 ]],
1582
+ ]
1583
+ ],
1584
+ dtype = torch .float32 ,
1585
+ ),
1586
+ ),
1587
+ # 2D transposed convolution with non-zero bias
1588
+ (
1589
+ "with_bias" ,
1590
+ torch .tensor (
1591
+ [[[[1.0 , 2.0 ], [3.0 , 4.0 ]]]], dtype = torch .float32
1592
+ ), # input: 1x1x2x2
1593
+ torch .tensor (
1594
+ [[[[1.0 , 0.0 ], [0.0 , 1.0 ]]]], dtype = torch .float32
1595
+ ), # weight: 1x1x2x2
1596
+ torch .tensor ([5.0 ], dtype = torch .float32 ), # bias=5.0
1597
+ (1 , 1 ), # stride
1598
+ (0 , 0 ), # padding
1599
+ (1 , 1 ), # dilation
1600
+ 1 , # groups
1601
+ (0 , 0 ), # output_padding
1602
+ False , # channel_last
1603
+ torch .tensor (
1604
+ [[[[6.0 , 7.0 , 5.0 ], [8.0 , 10.0 , 7.0 ], [5.0 , 8.0 , 9.0 ]]]],
1605
+ dtype = torch .float32 ,
1606
+ ),
1607
+ ),
1608
+ # 1D transposed convolution (3D tensor, NLC format)
1609
+ (
1610
+ "conv1d_nlc" ,
1611
+ torch .tensor (
1612
+ [[[1.0 ], [2.0 ], [3.0 ]]], dtype = torch .float32
1613
+ ), # input: 1x3x1 (NLC)
1614
+ torch .tensor (
1615
+ [[[1.0 ], [0.5 ]]], dtype = torch .float32
1616
+ ), # weight: 1x2x1 (NLC)
1617
+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1618
+ (2 , 0 ), # stride
1619
+ (0 , 0 ), # padding
1620
+ (1 , 1 ), # dilation
1621
+ 1 , # groups
1622
+ (0 , 0 ), # output_padding
1623
+ True , # channel_last=True
1624
+ torch .tensor (
1625
+ [[[1.0 ], [0.5 ], [2.0 ], [1.0 ], [3.0 ], [1.5 ]]], dtype = torch .float32
1626
+ ),
1627
+ ),
1628
+ ]
1629
+ )
1630
+ def test_transposed_convolution (
1631
+ self ,
1632
+ name : str ,
1633
+ input_tensor : torch .Tensor ,
1634
+ weight : torch .Tensor ,
1635
+ bias : torch .Tensor ,
1636
+ stride : tuple [int , int ],
1637
+ padding : tuple [int , int ],
1638
+ dilation : tuple [int , int ],
1639
+ groups : int ,
1640
+ output_padding : tuple [int , int ],
1641
+ channel_last : bool ,
1642
+ expected_output : torch .Tensor ,
1643
+ ) -> None :
1644
+ output = torch .ops .cadence .transposed_convolution (
1645
+ input_tensor ,
1646
+ weight ,
1647
+ bias ,
1648
+ stride ,
1649
+ padding ,
1650
+ dilation ,
1651
+ output_padding ,
1652
+ groups ,
1653
+ channel_last ,
1654
+ )
1655
+
1656
+ # Verify output properties
1657
+ self .assertEqual (
1658
+ output .dtype ,
1659
+ input_tensor .dtype ,
1660
+ f"Output dtype should match input dtype in { name } " ,
1661
+ )
1662
+ self .assertEqual (
1663
+ output .shape ,
1664
+ expected_output .shape ,
1665
+ f"Output shape should match expected shape in { name } " ,
1666
+ )
1667
+
1668
+ # Verify output matches expected values
1669
+ self .assertTrue (
1670
+ torch .equal (output , expected_output ),
1671
+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1672
+ )
1673
+
1537
1674
@expand (
1538
1675
[
1539
1676
# Basic non-quantized average pooling
0 commit comments