@@ -1259,7 +1259,7 @@ def test_rope(
12591259
12601260 @expand (
12611261 [
1262- # Test case 1: Basic 2D convolution (NCHW format)
1262+ # Basic 2D convolution (NCHW format)
12631263 (
12641264 "basic_2d_nchw" ,
12651265 torch .tensor (
@@ -1274,11 +1274,12 @@ def test_rope(
12741274 (1 , 1 ), # dilation
12751275 1 , # groups
12761276 False , # channel_last
1277+
12771278 torch .tensor (
12781279 [[[[5.0 ]]]], dtype = torch .float32
12791280 ), # expected: 1*1 + 4*1 = 5
12801281 ),
1281- # Test case 2: Basic 2D convolution (NHWC format)
1282+ # Basic 2D convolution (NHWC format)
12821283 (
12831284 "basic_2d_nhwc" ,
12841285 torch .tensor (
@@ -1293,11 +1294,12 @@ def test_rope(
12931294 (1 , 1 ), # dilation
12941295 1 , # groups
12951296 True , # channel_last
1297+
12961298 torch .tensor (
12971299 [[[[5.0 ]]]], dtype = torch .float32
12981300 ), # expected: 1*1 + 4*1 = 5
12991301 ),
1300- # Test case 3: 2D convolution with stride=2
1302+ # 2D convolution with stride=2
13011303 (
13021304 "conv2d_stride2" ,
13031305 torch .tensor (
@@ -1322,9 +1324,10 @@ def test_rope(
13221324 (1 , 1 ), # dilation
13231325 1 , # groups
13241326 False , # channel_last
1327+
13251328 torch .tensor ([[[[14.0 , 22.0 ], [46.0 , 54.0 ]]]], dtype = torch .float32 ),
13261329 ),
1327- # Test case 4: 2D convolution with padding=1
1330+ # 2D convolution with padding=1
13281331 (
13291332 "conv2d_padding1" ,
13301333 torch .tensor (
@@ -1339,12 +1342,13 @@ def test_rope(
13391342 (1 , 1 ), # dilation
13401343 1 , # groups
13411344 False , # channel_last
1345+
13421346 torch .tensor (
13431347 [[[[1.0 , 2.0 , 0.0 ], [3.0 , 5.0 , 2.0 ], [0.0 , 3.0 , 4.0 ]]]],
13441348 dtype = torch .float32 ,
13451349 ), # expected with padding
13461350 ),
1347- # Test case 5: 2D convolution with dilation=2
1351+ # 2D convolution with dilation=2
13481352 (
13491353 "conv2d_dilation2" ,
13501354 torch .tensor (
@@ -1369,9 +1373,10 @@ def test_rope(
13691373 (2 , 2 ), # dilation=2
13701374 1 , # groups
13711375 False , # channel_last
1376+
13721377 torch .tensor ([[[[24.0 , 28.0 ], [40.0 , 44.0 ]]]], dtype = torch .float32 ),
13731378 ),
1374- # Test case 6: 2D grouped convolution (groups=2)
1379+ # 2D grouped convolution (groups=2)
13751380 (
13761381 "conv2d_groups2" ,
13771382 torch .tensor (
@@ -1396,9 +1401,10 @@ def test_rope(
13961401 (1 , 1 ), # dilation
13971402 2 , # groups=2
13981403 False , # channel_last
1404+
13991405 torch .tensor ([[[[10.0 ]], [[14.0 ]]]], dtype = torch .float32 ),
14001406 ),
1401- # Test case 7: 1D convolution (NCL format)
1407+ # 1D convolution (NCL format)
14021408 (
14031409 "conv1d_ncl" ,
14041410 torch .tensor (
@@ -1411,11 +1417,12 @@ def test_rope(
14111417 (1 , 1 ), # dilation (only dilation[1] is used for 1D)
14121418 1 , # groups
14131419 False , # channel_last
1420+
14141421 torch .tensor (
14151422 [[[3.0 , 5.0 , 7.0 ]]], dtype = torch .float32
14161423 ), # expected: [1+2, 2+3, 3+4]
14171424 ),
1418- # Test case 8: 1D convolution (NLC format)
1425+ # 1D convolution (NLC format)
14191426 (
14201427 "conv1d_nlc" ,
14211428 torch .tensor (
@@ -1430,9 +1437,10 @@ def test_rope(
14301437 (1 , 1 ), # dilation
14311438 1 , # groups
14321439 True , # channel_last
1440+
14331441 torch .tensor ([[[3.0 ], [5.0 ], [7.0 ]]], dtype = torch .float32 ),
14341442 ),
1435- # Test case 9: Multi-channel input and output
1443+ # Multi-channel input and output
14361444 (
14371445 "multi_channel" ,
14381446 torch .tensor (
@@ -1469,9 +1477,10 @@ def test_rope(
14691477 (1 , 1 ), # dilation
14701478 1 , # groups
14711479 False , # channel_last
1480+
14721481 torch .tensor ([[[[10.0 ]], [[11.0 ]]]], dtype = torch .float32 ),
14731482 ),
1474- # Test case 10: Convolution with non-zero bias
1483+ # Convolution with non-zero bias
14751484 (
14761485 "conv2d_with_bias" ,
14771486 torch .tensor (
@@ -1486,6 +1495,7 @@ def test_rope(
14861495 (1 , 1 ), # dilation
14871496 1 , # groups
14881497 False , # channel_last
1498+
14891499 torch .tensor (
14901500 [[[[15.0 ]]]], dtype = torch .float32
14911501 ), # expected: 5 + 10 = 15
@@ -1534,6 +1544,75 @@ def test_convolution(
15341544 f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
15351545 )
15361546
1547+ @expand (
1548+ [
1549+ (
1550+ "conv2d_transposed_stride2" ,
1551+ torch .tensor (
1552+ [[[[1.0 , 2.0 ], [3.0 , 4.0 ]]]], dtype = torch .float32
1553+ ), # input: 1x1x2x2
1554+ torch .tensor (
1555+ [[[[1.0 , 1.0 ], [1.0 , 1.0 ]]]], dtype = torch .float32
1556+ ), # weight: 1x1x2x2
1557+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1558+ (1 , 1 ), # stride=2
1559+ (0 , 0 ), # padding
1560+ (1 , 1 ), # dilation
1561+ 1 , # groups
1562+ (0 , 0 ), # output_padding
1563+ False , # channel_last
1564+ torch .tensor (
1565+ [[[[1.0 , 3.0 , 2.0 ],
1566+ [4.0 , 10.0 , 6.0 ],
1567+ [3.0 , 7.0 , 4.0 ]]]], dtype = torch .float32
1568+ ),
1569+ ),
1570+ ]
1571+ )
1572+ def test_transposed_convolution (
1573+ self ,
1574+ name : str ,
1575+ input_tensor : torch .Tensor ,
1576+ weight : torch .Tensor ,
1577+ bias : torch .Tensor ,
1578+ stride : tuple [int , int ],
1579+ padding : tuple [int , int ],
1580+ dilation : tuple [int , int ],
1581+ groups : int ,
1582+ output_padding : tuple [int , int ],
1583+ channel_last : bool ,
1584+ expected_output : torch .Tensor ,
1585+ ) -> None :
1586+ output = torch .ops .cadence .transposed_convolution (
1587+ input_tensor ,
1588+ weight ,
1589+ bias ,
1590+ stride ,
1591+ padding ,
1592+ dilation ,
1593+ output_padding ,
1594+ groups ,
1595+ channel_last ,
1596+ )
1597+
1598+ # Verify output properties
1599+ self .assertEqual (
1600+ output .dtype ,
1601+ input_tensor .dtype ,
1602+ f"Output dtype should match input dtype in { name } " ,
1603+ )
1604+ self .assertEqual (
1605+ output .shape ,
1606+ expected_output .shape ,
1607+ f"Output shape should match expected shape in { name } " ,
1608+ )
1609+
1610+ # Verify output matches expected values
1611+ self .assertTrue (
1612+ torch .equal (output , expected_output ),
1613+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1614+ )
1615+
15371616 @expand (
15381617 [
15391618 # Basic non-quantized average pooling
0 commit comments