@@ -1534,6 +1534,143 @@ def test_convolution(
15341534 f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
15351535 )
15361536
1537+ @expand (
1538+ [
1539+ # Basic 2D transposed convolution with stride=1 (current test case - corrected name)
1540+ (
1541+ "basic_2d_stride1" ,
1542+ torch .tensor (
1543+ [[[[1.0 , 2.0 ], [3.0 , 4.0 ]]]], dtype = torch .float32
1544+ ), # input: 1x1x2x2
1545+ torch .tensor (
1546+ [[[[1.0 , 1.0 ], [1.0 , 1.0 ]]]], dtype = torch .float32
1547+ ), # weight: 1x1x2x2
1548+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1549+ (1 , 1 ), # stride
1550+ (0 , 0 ), # padding
1551+ (1 , 1 ), # dilation
1552+ 1 , # groups
1553+ (0 , 0 ), # output_padding
1554+ False , # channel_last
1555+ torch .tensor (
1556+ [[[[1.0 , 3.0 , 2.0 ], [4.0 , 10.0 , 6.0 ], [3.0 , 7.0 , 4.0 ]]]],
1557+ dtype = torch .float32 ,
1558+ ),
1559+ ),
1560+ # 2D transposed convolution with channel_last=True (NHWC format)
1561+ (
1562+ "channel_last_nhwc" ,
1563+ torch .tensor (
1564+ [[[[1.0 ], [2.0 ]], [[3.0 ], [4.0 ]]]], dtype = torch .float32
1565+ ), # input: 1x2x2x1 (NHWC)
1566+ torch .tensor (
1567+ [[[[1.0 ], [1.0 ]], [[1.0 ], [1.0 ]]]], dtype = torch .float32
1568+ ), # weight: 1x2x2x1 (NHWC)
1569+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1570+ (1 , 1 ), # stride
1571+ (0 , 0 ), # padding
1572+ (1 , 1 ), # dilation
1573+ 1 , # groups
1574+ (0 , 0 ), # output_padding
1575+ True , # channel_last=True
1576+ torch .tensor (
1577+ [
1578+ [
1579+ [[1.0 ], [3.0 ], [2.0 ]],
1580+ [[4.0 ], [10.0 ], [6.0 ]],
1581+ [[3.0 ], [7.0 ], [4.0 ]],
1582+ ]
1583+ ],
1584+ dtype = torch .float32 ,
1585+ ),
1586+ ),
1587+ # 2D transposed convolution with non-zero bias
1588+ (
1589+ "with_bias" ,
1590+ torch .tensor (
1591+ [[[[1.0 , 2.0 ], [3.0 , 4.0 ]]]], dtype = torch .float32
1592+ ), # input: 1x1x2x2
1593+ torch .tensor (
1594+ [[[[1.0 , 0.0 ], [0.0 , 1.0 ]]]], dtype = torch .float32
1595+ ), # weight: 1x1x2x2
1596+ torch .tensor ([5.0 ], dtype = torch .float32 ), # bias=5.0
1597+ (1 , 1 ), # stride
1598+ (0 , 0 ), # padding
1599+ (1 , 1 ), # dilation
1600+ 1 , # groups
1601+ (0 , 0 ), # output_padding
1602+ False , # channel_last
1603+ torch .tensor (
1604+ [[[[6.0 , 7.0 , 5.0 ], [8.0 , 10.0 , 7.0 ], [5.0 , 8.0 , 9.0 ]]]],
1605+ dtype = torch .float32 ,
1606+ ),
1607+ ),
1608+ # 1D transposed convolution (3D tensor, NLC format)
1609+ (
1610+ "conv1d_nlc" ,
1611+ torch .tensor (
1612+ [[[1.0 ], [2.0 ], [3.0 ]]], dtype = torch .float32
1613+ ), # input: 1x3x1 (NLC)
1614+ torch .tensor (
1615+ [[[1.0 ], [0.5 ]]], dtype = torch .float32
1616+ ), # weight: 1x2x1 (NLC)
1617+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1618+ (2 , 0 ), # stride
1619+ (0 , 0 ), # padding
1620+ (1 , 1 ), # dilation
1621+ 1 , # groups
1622+ (0 , 0 ), # output_padding
1623+ True , # channel_last=True
1624+ torch .tensor (
1625+ [[[1.0 ], [0.5 ], [2.0 ], [1.0 ], [3.0 ], [1.5 ]]], dtype = torch .float32
1626+ ),
1627+ ),
1628+ ]
1629+ )
1630+ def test_transposed_convolution (
1631+ self ,
1632+ name : str ,
1633+ input_tensor : torch .Tensor ,
1634+ weight : torch .Tensor ,
1635+ bias : torch .Tensor ,
1636+ stride : tuple [int , int ],
1637+ padding : tuple [int , int ],
1638+ dilation : tuple [int , int ],
1639+ groups : int ,
1640+ output_padding : tuple [int , int ],
1641+ channel_last : bool ,
1642+ expected_output : torch .Tensor ,
1643+ ) -> None :
1644+ output = torch .ops .cadence .transposed_convolution (
1645+ input_tensor ,
1646+ weight ,
1647+ bias ,
1648+ stride ,
1649+ padding ,
1650+ dilation ,
1651+ output_padding ,
1652+ groups ,
1653+ channel_last ,
1654+ )
1655+
1656+ # Verify output properties
1657+ self .assertEqual (
1658+ output .dtype ,
1659+ input_tensor .dtype ,
1660+ f"Output dtype should match input dtype in { name } " ,
1661+ )
1662+ self .assertEqual (
1663+ output .shape ,
1664+ expected_output .shape ,
1665+ f"Output shape should match expected shape in { name } " ,
1666+ )
1667+
1668+ # Verify output matches expected values
1669+ self .assertTrue (
1670+ torch .equal (output , expected_output ),
1671+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1672+ )
1673+
15371674 @expand (
15381675 [
15391676 # Basic non-quantized average pooling
0 commit comments