@@ -1256,3 +1256,280 @@ def test_rope(
12561256 torch .allclose (output , expected_output , rtol = 1e-4 , atol = 1e-4 ),
12571257 f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
12581258 )
1259+
1260+ @expand (
1261+ [
1262+ # Test case 1: Basic 2D convolution (NCHW format)
1263+ (
1264+ "basic_2d_nchw" ,
1265+ torch .tensor (
1266+ [[[[1.0 , 2.0 ], [3.0 , 4.0 ]]]], dtype = torch .float32
1267+ ), # input: 1x1x2x2
1268+ torch .tensor (
1269+ [[[[1.0 , 0.0 ], [0.0 , 1.0 ]]]], dtype = torch .float32
1270+ ), # weight: 1x1x2x2 (identity-like filter)
1271+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1272+ (1 , 1 ), # stride
1273+ (0 , 0 ), # padding
1274+ (1 , 1 ), # dilation
1275+ 1 , # groups
1276+ False , # channel_last
1277+ torch .tensor (
1278+ [[[[5.0 ]]]], dtype = torch .float32
1279+ ), # expected: 1*1 + 4*1 = 5
1280+ ),
1281+ # Test case 2: Basic 2D convolution (NHWC format)
1282+ (
1283+ "basic_2d_nhwc" ,
1284+ torch .tensor (
1285+ [[[[1.0 ], [2.0 ]], [[3.0 ], [4.0 ]]]], dtype = torch .float32
1286+ ), # input: 1x2x2x1 (NHWC)
1287+ torch .tensor (
1288+ [[[[1.0 ], [0.0 ]], [[0.0 ], [1.0 ]]]], dtype = torch .float32
1289+ ), # weight: 1x2x2x1 (NHWC format)
1290+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1291+ (1 , 1 ), # stride
1292+ (0 , 0 ), # padding
1293+ (1 , 1 ), # dilation
1294+ 1 , # groups
1295+ True , # channel_last
1296+ torch .tensor (
1297+ [[[[5.0 ]]]], dtype = torch .float32
1298+ ), # expected: 1*1 + 4*1 = 5
1299+ ),
1300+ # Test case 3: 2D convolution with stride=2
1301+ (
1302+ "conv2d_stride2" ,
1303+ torch .tensor (
1304+ [
1305+ [
1306+ [
1307+ [1.0 , 2.0 , 3.0 , 4.0 ],
1308+ [5.0 , 6.0 , 7.0 , 8.0 ],
1309+ [9.0 , 10.0 , 11.0 , 12.0 ],
1310+ [13.0 , 14.0 , 15.0 , 16.0 ],
1311+ ]
1312+ ]
1313+ ],
1314+ dtype = torch .float32 ,
1315+ ), # input: 1x1x4x4
1316+ torch .tensor (
1317+ [[[[1.0 , 1.0 ], [1.0 , 1.0 ]]]], dtype = torch .float32
1318+ ), # weight: 1x1x2x2 (sum filter)
1319+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1320+ (2 , 2 ), # stride=2
1321+ (0 , 0 ), # padding
1322+ (1 , 1 ), # dilation
1323+ 1 , # groups
1324+ False , # channel_last
1325+ torch .tensor ([[[[14.0 , 22.0 ], [46.0 , 54.0 ]]]], dtype = torch .float32 ),
1326+ ),
1327+ # Test case 4: 2D convolution with padding=1
1328+ (
1329+ "conv2d_padding1" ,
1330+ torch .tensor (
1331+ [[[[1.0 , 2.0 ], [3.0 , 4.0 ]]]], dtype = torch .float32
1332+ ), # input: 1x1x2x2
1333+ torch .tensor (
1334+ [[[[1.0 , 0.0 ], [0.0 , 1.0 ]]]], dtype = torch .float32
1335+ ), # weight: 1x1x2x2
1336+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1337+ (1 , 1 ), # stride
1338+ (1 , 1 ), # padding=1
1339+ (1 , 1 ), # dilation
1340+ 1 , # groups
1341+ False , # channel_last
1342+ torch .tensor (
1343+ [[[[1.0 , 2.0 , 0.0 ], [3.0 , 5.0 , 2.0 ], [0.0 , 3.0 , 4.0 ]]]],
1344+ dtype = torch .float32 ,
1345+ ), # expected with padding
1346+ ),
1347+ # Test case 5: 2D convolution with dilation=2
1348+ (
1349+ "conv2d_dilation2" ,
1350+ torch .tensor (
1351+ [
1352+ [
1353+ [
1354+ [1.0 , 2.0 , 3.0 , 4.0 ],
1355+ [5.0 , 6.0 , 7.0 , 8.0 ],
1356+ [9.0 , 10.0 , 11.0 , 12.0 ],
1357+ [13.0 , 14.0 , 15.0 , 16.0 ],
1358+ ]
1359+ ]
1360+ ],
1361+ dtype = torch .float32 ,
1362+ ), # input: 1x1x4x4
1363+ torch .tensor (
1364+ [[[[1.0 , 1.0 ], [1.0 , 1.0 ]]]], dtype = torch .float32
1365+ ), # weight: 1x1x2x2
1366+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1367+ (1 , 1 ), # stride
1368+ (0 , 0 ), # padding
1369+ (2 , 2 ), # dilation=2
1370+ 1 , # groups
1371+ False , # channel_last
1372+ torch .tensor ([[[[24.0 , 28.0 ], [40.0 , 44.0 ]]]], dtype = torch .float32 ),
1373+ ),
1374+ # Test case 6: 2D grouped convolution (groups=2)
1375+ (
1376+ "conv2d_groups2" ,
1377+ torch .tensor (
1378+ [
1379+ [
1380+ [[1.0 , 2.0 ], [3.0 , 4.0 ]], # first input channel
1381+ [[5.0 , 6.0 ], [7.0 , 8.0 ]], # second input channel
1382+ ]
1383+ ],
1384+ dtype = torch .float32 ,
1385+ ), # input: 1x2x2x2
1386+ torch .tensor (
1387+ [
1388+ [[[1.0 , 1.0 ], [1.0 , 1.0 ]]], # first group weight
1389+ [[[0.5 , 0.5 ], [0.5 , 0.5 ]]], # second group weight
1390+ ],
1391+ dtype = torch .float32 ,
1392+ ), # weight: 2x1x2x2
1393+ torch .tensor ([0.0 , 1.0 ], dtype = torch .float32 ), # bias
1394+ (1 , 1 ), # stride
1395+ (0 , 0 ), # padding
1396+ (1 , 1 ), # dilation
1397+ 2 , # groups=2
1398+ False , # channel_last
1399+ torch .tensor ([[[[10.0 ]], [[14.0 ]]]], dtype = torch .float32 ),
1400+ ),
1401+ # Test case 7: 1D convolution (NCL format)
1402+ (
1403+ "conv1d_ncl" ,
1404+ torch .tensor (
1405+ [[[1.0 , 2.0 , 3.0 , 4.0 ]]], dtype = torch .float32
1406+ ), # input: 1x1x4
1407+ torch .tensor ([[[1.0 , 1.0 ]]], dtype = torch .float32 ), # weight: 1x1x2
1408+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1409+ (1 , 1 ), # stride (only stride[1] is used for 1D)
1410+ (0 , 0 ), # padding (only padding[1] is used for 1D)
1411+ (1 , 1 ), # dilation (only dilation[1] is used for 1D)
1412+ 1 , # groups
1413+ False , # channel_last
1414+ torch .tensor (
1415+ [[[3.0 , 5.0 , 7.0 ]]], dtype = torch .float32
1416+ ), # expected: [1+2, 2+3, 3+4]
1417+ ),
1418+ # Test case 8: 1D convolution (NLC format)
1419+ (
1420+ "conv1d_nlc" ,
1421+ torch .tensor (
1422+ [[[1.0 ], [2.0 ], [3.0 ], [4.0 ]]], dtype = torch .float32
1423+ ), # input: 1x4x1 (NLC)
1424+ torch .tensor (
1425+ [[[1.0 ], [1.0 ]]], dtype = torch .float32
1426+ ), # weight: 1x2x1 (NLC)
1427+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1428+ (1 , 1 ), # stride
1429+ (0 , 0 ), # padding
1430+ (1 , 1 ), # dilation
1431+ 1 , # groups
1432+ True , # channel_last
1433+ torch .tensor ([[[3.0 ], [5.0 ], [7.0 ]]], dtype = torch .float32 ),
1434+ ),
1435+ # Test case 9: Multi-channel input and output
1436+ (
1437+ "multi_channel" ,
1438+ torch .tensor (
1439+ [
1440+ [
1441+ [[1.0 , 2.0 ], [3.0 , 4.0 ]], # first input channel
1442+ [[0.5 , 1.0 ], [1.5 , 2.0 ]], # second input channel
1443+ ]
1444+ ],
1445+ dtype = torch .float32 ,
1446+ ), # input: 1x2x2x2
1447+ torch .tensor (
1448+ [
1449+ [ # first output channel
1450+ [[1.0 , 0.0 ], [0.0 , 1.0 ]], # weights for first input channel
1451+ [
1452+ [2.0 , 0.0 ],
1453+ [0.0 , 2.0 ],
1454+ ], # weights for second input channel
1455+ ],
1456+ [ # second output channel
1457+ [[0.5 , 0.5 ], [0.5 , 0.5 ]], # weights for first input channel
1458+ [
1459+ [1.0 , 1.0 ],
1460+ [1.0 , 1.0 ],
1461+ ], # weights for second input channel
1462+ ],
1463+ ],
1464+ dtype = torch .float32 ,
1465+ ), # weight: 2x2x2x2
1466+ torch .tensor ([0.0 , 1.0 ], dtype = torch .float32 ), # bias
1467+ (1 , 1 ), # stride
1468+ (0 , 0 ), # padding
1469+ (1 , 1 ), # dilation
1470+ 1 , # groups
1471+ False , # channel_last
1472+ torch .tensor ([[[[10.0 ]], [[11.0 ]]]], dtype = torch .float32 ),
1473+ ),
1474+ # Test case 10: Convolution with non-zero bias
1475+ (
1476+ "conv2d_with_bias" ,
1477+ torch .tensor (
1478+ [[[[1.0 , 2.0 ], [3.0 , 4.0 ]]]], dtype = torch .float32
1479+ ), # input: 1x1x2x2
1480+ torch .tensor (
1481+ [[[[1.0 , 0.0 ], [0.0 , 1.0 ]]]], dtype = torch .float32
1482+ ), # weight: 1x1x2x2
1483+ torch .tensor ([10.0 ], dtype = torch .float32 ), # bias=10
1484+ (1 , 1 ), # stride
1485+ (0 , 0 ), # padding
1486+ (1 , 1 ), # dilation
1487+ 1 , # groups
1488+ False , # channel_last
1489+ torch .tensor (
1490+ [[[[15.0 ]]]], dtype = torch .float32
1491+ ), # expected: 5 + 10 = 15
1492+ ),
1493+ ]
1494+ )
1495+ def test_convolution (
1496+ self ,
1497+ name : str ,
1498+ input_tensor : torch .Tensor ,
1499+ weight : torch .Tensor ,
1500+ bias : torch .Tensor ,
1501+ stride : tuple [int , int ],
1502+ padding : tuple [int , int ],
1503+ dilation : tuple [int , int ],
1504+ groups : int ,
1505+ channel_last : bool ,
1506+ expected_output : torch .Tensor ,
1507+ ) -> None :
1508+ output = torch .ops .cadence .convolution (
1509+ input_tensor ,
1510+ weight ,
1511+ bias ,
1512+ stride ,
1513+ padding ,
1514+ dilation ,
1515+ groups ,
1516+ channel_last ,
1517+ )
1518+
1519+ # Verify output properties
1520+ self .assertEqual (
1521+ output .dtype ,
1522+ input_tensor .dtype ,
1523+ f"Output dtype should match input dtype in { name } " ,
1524+ )
1525+ self .assertEqual (
1526+ output .shape ,
1527+ expected_output .shape ,
1528+ f"Output shape should match expected shape in { name } " ,
1529+ )
1530+
1531+ # Verify output matches expected values
1532+ self .assertTrue (
1533+ torch .equal (output , expected_output ),
1534+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1535+ )
0 commit comments