@@ -1256,3 +1256,280 @@ def test_rope(
1256
1256
torch .allclose (output , expected_output , rtol = 1e-4 , atol = 1e-4 ),
1257
1257
f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1258
1258
)
1259
+
1260
+ @expand (
1261
+ [
1262
+ # Test case 1: Basic 2D convolution (NCHW format)
1263
+ (
1264
+ "basic_2d_nchw" ,
1265
+ torch .tensor (
1266
+ [[[[1.0 , 2.0 ], [3.0 , 4.0 ]]]], dtype = torch .float32
1267
+ ), # input: 1x1x2x2
1268
+ torch .tensor (
1269
+ [[[[1.0 , 0.0 ], [0.0 , 1.0 ]]]], dtype = torch .float32
1270
+ ), # weight: 1x1x2x2 (identity-like filter)
1271
+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1272
+ (1 , 1 ), # stride
1273
+ (0 , 0 ), # padding
1274
+ (1 , 1 ), # dilation
1275
+ 1 , # groups
1276
+ False , # channel_last
1277
+ torch .tensor (
1278
+ [[[[5.0 ]]]], dtype = torch .float32
1279
+ ), # expected: 1*1 + 4*1 = 5
1280
+ ),
1281
+ # Test case 2: Basic 2D convolution (NHWC format)
1282
+ (
1283
+ "basic_2d_nhwc" ,
1284
+ torch .tensor (
1285
+ [[[[1.0 ], [2.0 ]], [[3.0 ], [4.0 ]]]], dtype = torch .float32
1286
+ ), # input: 1x2x2x1 (NHWC)
1287
+ torch .tensor (
1288
+ [[[[1.0 ], [0.0 ]], [[0.0 ], [1.0 ]]]], dtype = torch .float32
1289
+ ), # weight: 1x2x2x1 (NHWC format)
1290
+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1291
+ (1 , 1 ), # stride
1292
+ (0 , 0 ), # padding
1293
+ (1 , 1 ), # dilation
1294
+ 1 , # groups
1295
+ True , # channel_last
1296
+ torch .tensor (
1297
+ [[[[5.0 ]]]], dtype = torch .float32
1298
+ ), # expected: 1*1 + 4*1 = 5
1299
+ ),
1300
+ # Test case 3: 2D convolution with stride=2
1301
+ (
1302
+ "conv2d_stride2" ,
1303
+ torch .tensor (
1304
+ [
1305
+ [
1306
+ [
1307
+ [1.0 , 2.0 , 3.0 , 4.0 ],
1308
+ [5.0 , 6.0 , 7.0 , 8.0 ],
1309
+ [9.0 , 10.0 , 11.0 , 12.0 ],
1310
+ [13.0 , 14.0 , 15.0 , 16.0 ],
1311
+ ]
1312
+ ]
1313
+ ],
1314
+ dtype = torch .float32 ,
1315
+ ), # input: 1x1x4x4
1316
+ torch .tensor (
1317
+ [[[[1.0 , 1.0 ], [1.0 , 1.0 ]]]], dtype = torch .float32
1318
+ ), # weight: 1x1x2x2 (sum filter)
1319
+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1320
+ (2 , 2 ), # stride=2
1321
+ (0 , 0 ), # padding
1322
+ (1 , 1 ), # dilation
1323
+ 1 , # groups
1324
+ False , # channel_last
1325
+ torch .tensor ([[[[14.0 , 22.0 ], [46.0 , 54.0 ]]]], dtype = torch .float32 ),
1326
+ ),
1327
+ # Test case 4: 2D convolution with padding=1
1328
+ (
1329
+ "conv2d_padding1" ,
1330
+ torch .tensor (
1331
+ [[[[1.0 , 2.0 ], [3.0 , 4.0 ]]]], dtype = torch .float32
1332
+ ), # input: 1x1x2x2
1333
+ torch .tensor (
1334
+ [[[[1.0 , 0.0 ], [0.0 , 1.0 ]]]], dtype = torch .float32
1335
+ ), # weight: 1x1x2x2
1336
+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1337
+ (1 , 1 ), # stride
1338
+ (1 , 1 ), # padding=1
1339
+ (1 , 1 ), # dilation
1340
+ 1 , # groups
1341
+ False , # channel_last
1342
+ torch .tensor (
1343
+ [[[[1.0 , 2.0 , 0.0 ], [3.0 , 5.0 , 2.0 ], [0.0 , 3.0 , 4.0 ]]]],
1344
+ dtype = torch .float32 ,
1345
+ ), # expected with padding
1346
+ ),
1347
+ # Test case 5: 2D convolution with dilation=2
1348
+ (
1349
+ "conv2d_dilation2" ,
1350
+ torch .tensor (
1351
+ [
1352
+ [
1353
+ [
1354
+ [1.0 , 2.0 , 3.0 , 4.0 ],
1355
+ [5.0 , 6.0 , 7.0 , 8.0 ],
1356
+ [9.0 , 10.0 , 11.0 , 12.0 ],
1357
+ [13.0 , 14.0 , 15.0 , 16.0 ],
1358
+ ]
1359
+ ]
1360
+ ],
1361
+ dtype = torch .float32 ,
1362
+ ), # input: 1x1x4x4
1363
+ torch .tensor (
1364
+ [[[[1.0 , 1.0 ], [1.0 , 1.0 ]]]], dtype = torch .float32
1365
+ ), # weight: 1x1x2x2
1366
+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1367
+ (1 , 1 ), # stride
1368
+ (0 , 0 ), # padding
1369
+ (2 , 2 ), # dilation=2
1370
+ 1 , # groups
1371
+ False , # channel_last
1372
+ torch .tensor ([[[[24.0 , 28.0 ], [40.0 , 44.0 ]]]], dtype = torch .float32 ),
1373
+ ),
1374
+ # Test case 6: 2D grouped convolution (groups=2)
1375
+ (
1376
+ "conv2d_groups2" ,
1377
+ torch .tensor (
1378
+ [
1379
+ [
1380
+ [[1.0 , 2.0 ], [3.0 , 4.0 ]], # first input channel
1381
+ [[5.0 , 6.0 ], [7.0 , 8.0 ]], # second input channel
1382
+ ]
1383
+ ],
1384
+ dtype = torch .float32 ,
1385
+ ), # input: 1x2x2x2
1386
+ torch .tensor (
1387
+ [
1388
+ [[[1.0 , 1.0 ], [1.0 , 1.0 ]]], # first group weight
1389
+ [[[0.5 , 0.5 ], [0.5 , 0.5 ]]], # second group weight
1390
+ ],
1391
+ dtype = torch .float32 ,
1392
+ ), # weight: 2x1x2x2
1393
+ torch .tensor ([0.0 , 1.0 ], dtype = torch .float32 ), # bias
1394
+ (1 , 1 ), # stride
1395
+ (0 , 0 ), # padding
1396
+ (1 , 1 ), # dilation
1397
+ 2 , # groups=2
1398
+ False , # channel_last
1399
+ torch .tensor ([[[[10.0 ]], [[14.0 ]]]], dtype = torch .float32 ),
1400
+ ),
1401
+ # Test case 7: 1D convolution (NCL format)
1402
+ (
1403
+ "conv1d_ncl" ,
1404
+ torch .tensor (
1405
+ [[[1.0 , 2.0 , 3.0 , 4.0 ]]], dtype = torch .float32
1406
+ ), # input: 1x1x4
1407
+ torch .tensor ([[[1.0 , 1.0 ]]], dtype = torch .float32 ), # weight: 1x1x2
1408
+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1409
+ (1 , 1 ), # stride (only stride[1] is used for 1D)
1410
+ (0 , 0 ), # padding (only padding[1] is used for 1D)
1411
+ (1 , 1 ), # dilation (only dilation[1] is used for 1D)
1412
+ 1 , # groups
1413
+ False , # channel_last
1414
+ torch .tensor (
1415
+ [[[3.0 , 5.0 , 7.0 ]]], dtype = torch .float32
1416
+ ), # expected: [1+2, 2+3, 3+4]
1417
+ ),
1418
+ # Test case 8: 1D convolution (NLC format)
1419
+ (
1420
+ "conv1d_nlc" ,
1421
+ torch .tensor (
1422
+ [[[1.0 ], [2.0 ], [3.0 ], [4.0 ]]], dtype = torch .float32
1423
+ ), # input: 1x4x1 (NLC)
1424
+ torch .tensor (
1425
+ [[[1.0 ], [1.0 ]]], dtype = torch .float32
1426
+ ), # weight: 1x2x1 (NLC)
1427
+ torch .tensor ([0.0 ], dtype = torch .float32 ), # bias
1428
+ (1 , 1 ), # stride
1429
+ (0 , 0 ), # padding
1430
+ (1 , 1 ), # dilation
1431
+ 1 , # groups
1432
+ True , # channel_last
1433
+ torch .tensor ([[[3.0 ], [5.0 ], [7.0 ]]], dtype = torch .float32 ),
1434
+ ),
1435
+ # Test case 9: Multi-channel input and output
1436
+ (
1437
+ "multi_channel" ,
1438
+ torch .tensor (
1439
+ [
1440
+ [
1441
+ [[1.0 , 2.0 ], [3.0 , 4.0 ]], # first input channel
1442
+ [[0.5 , 1.0 ], [1.5 , 2.0 ]], # second input channel
1443
+ ]
1444
+ ],
1445
+ dtype = torch .float32 ,
1446
+ ), # input: 1x2x2x2
1447
+ torch .tensor (
1448
+ [
1449
+ [ # first output channel
1450
+ [[1.0 , 0.0 ], [0.0 , 1.0 ]], # weights for first input channel
1451
+ [
1452
+ [2.0 , 0.0 ],
1453
+ [0.0 , 2.0 ],
1454
+ ], # weights for second input channel
1455
+ ],
1456
+ [ # second output channel
1457
+ [[0.5 , 0.5 ], [0.5 , 0.5 ]], # weights for first input channel
1458
+ [
1459
+ [1.0 , 1.0 ],
1460
+ [1.0 , 1.0 ],
1461
+ ], # weights for second input channel
1462
+ ],
1463
+ ],
1464
+ dtype = torch .float32 ,
1465
+ ), # weight: 2x2x2x2
1466
+ torch .tensor ([0.0 , 1.0 ], dtype = torch .float32 ), # bias
1467
+ (1 , 1 ), # stride
1468
+ (0 , 0 ), # padding
1469
+ (1 , 1 ), # dilation
1470
+ 1 , # groups
1471
+ False , # channel_last
1472
+ torch .tensor ([[[[10.0 ]], [[11.0 ]]]], dtype = torch .float32 ),
1473
+ ),
1474
+ # Test case 10: Convolution with non-zero bias
1475
+ (
1476
+ "conv2d_with_bias" ,
1477
+ torch .tensor (
1478
+ [[[[1.0 , 2.0 ], [3.0 , 4.0 ]]]], dtype = torch .float32
1479
+ ), # input: 1x1x2x2
1480
+ torch .tensor (
1481
+ [[[[1.0 , 0.0 ], [0.0 , 1.0 ]]]], dtype = torch .float32
1482
+ ), # weight: 1x1x2x2
1483
+ torch .tensor ([10.0 ], dtype = torch .float32 ), # bias=10
1484
+ (1 , 1 ), # stride
1485
+ (0 , 0 ), # padding
1486
+ (1 , 1 ), # dilation
1487
+ 1 , # groups
1488
+ False , # channel_last
1489
+ torch .tensor (
1490
+ [[[[15.0 ]]]], dtype = torch .float32
1491
+ ), # expected: 5 + 10 = 15
1492
+ ),
1493
+ ]
1494
+ )
1495
+ def test_convolution (
1496
+ self ,
1497
+ name : str ,
1498
+ input_tensor : torch .Tensor ,
1499
+ weight : torch .Tensor ,
1500
+ bias : torch .Tensor ,
1501
+ stride : tuple [int , int ],
1502
+ padding : tuple [int , int ],
1503
+ dilation : tuple [int , int ],
1504
+ groups : int ,
1505
+ channel_last : bool ,
1506
+ expected_output : torch .Tensor ,
1507
+ ) -> None :
1508
+ output = torch .ops .cadence .convolution (
1509
+ input_tensor ,
1510
+ weight ,
1511
+ bias ,
1512
+ stride ,
1513
+ padding ,
1514
+ dilation ,
1515
+ groups ,
1516
+ channel_last ,
1517
+ )
1518
+
1519
+ # Verify output properties
1520
+ self .assertEqual (
1521
+ output .dtype ,
1522
+ input_tensor .dtype ,
1523
+ f"Output dtype should match input dtype in { name } " ,
1524
+ )
1525
+ self .assertEqual (
1526
+ output .shape ,
1527
+ expected_output .shape ,
1528
+ f"Output shape should match expected shape in { name } " ,
1529
+ )
1530
+
1531
+ # Verify output matches expected values
1532
+ self .assertTrue (
1533
+ torch .equal (output , expected_output ),
1534
+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1535
+ )
0 commit comments