Skip to content

Commit 87d31af

Browse files
authored
Support convolution
Differential Revision: D83536133 Pull Request resolved: #14680
1 parent f7c009e commit 87d31af

File tree

2 files changed

+322
-0
lines changed

2 files changed

+322
-0
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,51 @@ def quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
933933
def quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
934934

935935

936+
@impl(m, "convolution")
937+
def convolution(
938+
input_tensor: torch.Tensor,
939+
weight: torch.Tensor,
940+
bias: torch.Tensor,
941+
stride: tuple[int, int],
942+
padding: tuple[int, int],
943+
dilation: tuple[int, int],
944+
groups: int,
945+
channel_last: bool = False,
946+
) -> torch.Tensor:
947+
conv_is_1d = len(input_tensor.shape) == 3
948+
if channel_last:
949+
if conv_is_1d:
950+
input_tensor = input_tensor.movedim(-1, 1).contiguous()
951+
if len(weight.shape) != 3:
952+
raise ValueError("Weight tensor must be 3D if input is 3D")
953+
weight = weight.movedim(-1, 1).contiguous()
954+
else:
955+
input_tensor = input_tensor.movedim(-1, -3)
956+
if len(weight.shape) != 4:
957+
raise ValueError("Weight tensor must be 4D if input is nd > 3")
958+
weight = torch.permute(weight, (0, -1, 1, 2)).contiguous()
959+
960+
_stride: tuple[int, int] | int = stride
961+
_padding: tuple[int, int] | int = padding
962+
_dilation: tuple[int, int] | int = dilation
963+
if conv_is_1d:
964+
conv = torch.nn.functional.conv1d
965+
_stride = stride[0]
966+
_padding = padding[0]
967+
_dilation = dilation[0]
968+
else:
969+
conv = torch.nn.functional.conv2d
970+
971+
conv_out = conv(input_tensor, weight, bias, _stride, _padding, _dilation, groups)
972+
if channel_last:
973+
if conv_is_1d:
974+
conv_out = conv_out.movedim(1, -1).contiguous()
975+
else:
976+
conv_out = conv_out.movedim(-3, -1).contiguous()
977+
978+
return conv_out
979+
980+
936981
def quantized_relu_common(
937982
X: torch.Tensor,
938983
X_zero_point: torch.Tensor | int,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,3 +1256,280 @@ def test_rope(
12561256
torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4),
12571257
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
12581258
)
1259+
1260+
@expand(
1261+
[
1262+
# Test case 1: Basic 2D convolution (NCHW format)
1263+
(
1264+
"basic_2d_nchw",
1265+
torch.tensor(
1266+
[[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32
1267+
), # input: 1x1x2x2
1268+
torch.tensor(
1269+
[[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32
1270+
), # weight: 1x1x2x2 (identity-like filter)
1271+
torch.tensor([0.0], dtype=torch.float32), # bias
1272+
(1, 1), # stride
1273+
(0, 0), # padding
1274+
(1, 1), # dilation
1275+
1, # groups
1276+
False, # channel_last
1277+
torch.tensor(
1278+
[[[[5.0]]]], dtype=torch.float32
1279+
), # expected: 1*1 + 4*1 = 5
1280+
),
1281+
# Test case 2: Basic 2D convolution (NHWC format)
1282+
(
1283+
"basic_2d_nhwc",
1284+
torch.tensor(
1285+
[[[[1.0], [2.0]], [[3.0], [4.0]]]], dtype=torch.float32
1286+
), # input: 1x2x2x1 (NHWC)
1287+
torch.tensor(
1288+
[[[[1.0], [0.0]], [[0.0], [1.0]]]], dtype=torch.float32
1289+
), # weight: 1x2x2x1 (NHWC format)
1290+
torch.tensor([0.0], dtype=torch.float32), # bias
1291+
(1, 1), # stride
1292+
(0, 0), # padding
1293+
(1, 1), # dilation
1294+
1, # groups
1295+
True, # channel_last
1296+
torch.tensor(
1297+
[[[[5.0]]]], dtype=torch.float32
1298+
), # expected: 1*1 + 4*1 = 5
1299+
),
1300+
# Test case 3: 2D convolution with stride=2
1301+
(
1302+
"conv2d_stride2",
1303+
torch.tensor(
1304+
[
1305+
[
1306+
[
1307+
[1.0, 2.0, 3.0, 4.0],
1308+
[5.0, 6.0, 7.0, 8.0],
1309+
[9.0, 10.0, 11.0, 12.0],
1310+
[13.0, 14.0, 15.0, 16.0],
1311+
]
1312+
]
1313+
],
1314+
dtype=torch.float32,
1315+
), # input: 1x1x4x4
1316+
torch.tensor(
1317+
[[[[1.0, 1.0], [1.0, 1.0]]]], dtype=torch.float32
1318+
), # weight: 1x1x2x2 (sum filter)
1319+
torch.tensor([0.0], dtype=torch.float32), # bias
1320+
(2, 2), # stride=2
1321+
(0, 0), # padding
1322+
(1, 1), # dilation
1323+
1, # groups
1324+
False, # channel_last
1325+
torch.tensor([[[[14.0, 22.0], [46.0, 54.0]]]], dtype=torch.float32),
1326+
),
1327+
# Test case 4: 2D convolution with padding=1
1328+
(
1329+
"conv2d_padding1",
1330+
torch.tensor(
1331+
[[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32
1332+
), # input: 1x1x2x2
1333+
torch.tensor(
1334+
[[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32
1335+
), # weight: 1x1x2x2
1336+
torch.tensor([0.0], dtype=torch.float32), # bias
1337+
(1, 1), # stride
1338+
(1, 1), # padding=1
1339+
(1, 1), # dilation
1340+
1, # groups
1341+
False, # channel_last
1342+
torch.tensor(
1343+
[[[[1.0, 2.0, 0.0], [3.0, 5.0, 2.0], [0.0, 3.0, 4.0]]]],
1344+
dtype=torch.float32,
1345+
), # expected with padding
1346+
),
1347+
# Test case 5: 2D convolution with dilation=2
1348+
(
1349+
"conv2d_dilation2",
1350+
torch.tensor(
1351+
[
1352+
[
1353+
[
1354+
[1.0, 2.0, 3.0, 4.0],
1355+
[5.0, 6.0, 7.0, 8.0],
1356+
[9.0, 10.0, 11.0, 12.0],
1357+
[13.0, 14.0, 15.0, 16.0],
1358+
]
1359+
]
1360+
],
1361+
dtype=torch.float32,
1362+
), # input: 1x1x4x4
1363+
torch.tensor(
1364+
[[[[1.0, 1.0], [1.0, 1.0]]]], dtype=torch.float32
1365+
), # weight: 1x1x2x2
1366+
torch.tensor([0.0], dtype=torch.float32), # bias
1367+
(1, 1), # stride
1368+
(0, 0), # padding
1369+
(2, 2), # dilation=2
1370+
1, # groups
1371+
False, # channel_last
1372+
torch.tensor([[[[24.0, 28.0], [40.0, 44.0]]]], dtype=torch.float32),
1373+
),
1374+
# Test case 6: 2D grouped convolution (groups=2)
1375+
(
1376+
"conv2d_groups2",
1377+
torch.tensor(
1378+
[
1379+
[
1380+
[[1.0, 2.0], [3.0, 4.0]], # first input channel
1381+
[[5.0, 6.0], [7.0, 8.0]], # second input channel
1382+
]
1383+
],
1384+
dtype=torch.float32,
1385+
), # input: 1x2x2x2
1386+
torch.tensor(
1387+
[
1388+
[[[1.0, 1.0], [1.0, 1.0]]], # first group weight
1389+
[[[0.5, 0.5], [0.5, 0.5]]], # second group weight
1390+
],
1391+
dtype=torch.float32,
1392+
), # weight: 2x1x2x2
1393+
torch.tensor([0.0, 1.0], dtype=torch.float32), # bias
1394+
(1, 1), # stride
1395+
(0, 0), # padding
1396+
(1, 1), # dilation
1397+
2, # groups=2
1398+
False, # channel_last
1399+
torch.tensor([[[[10.0]], [[14.0]]]], dtype=torch.float32),
1400+
),
1401+
# Test case 7: 1D convolution (NCL format)
1402+
(
1403+
"conv1d_ncl",
1404+
torch.tensor(
1405+
[[[1.0, 2.0, 3.0, 4.0]]], dtype=torch.float32
1406+
), # input: 1x1x4
1407+
torch.tensor([[[1.0, 1.0]]], dtype=torch.float32), # weight: 1x1x2
1408+
torch.tensor([0.0], dtype=torch.float32), # bias
1409+
(1, 1), # stride (only stride[1] is used for 1D)
1410+
(0, 0), # padding (only padding[1] is used for 1D)
1411+
(1, 1), # dilation (only dilation[1] is used for 1D)
1412+
1, # groups
1413+
False, # channel_last
1414+
torch.tensor(
1415+
[[[3.0, 5.0, 7.0]]], dtype=torch.float32
1416+
), # expected: [1+2, 2+3, 3+4]
1417+
),
1418+
# Test case 8: 1D convolution (NLC format)
1419+
(
1420+
"conv1d_nlc",
1421+
torch.tensor(
1422+
[[[1.0], [2.0], [3.0], [4.0]]], dtype=torch.float32
1423+
), # input: 1x4x1 (NLC)
1424+
torch.tensor(
1425+
[[[1.0], [1.0]]], dtype=torch.float32
1426+
), # weight: 1x2x1 (NLC)
1427+
torch.tensor([0.0], dtype=torch.float32), # bias
1428+
(1, 1), # stride
1429+
(0, 0), # padding
1430+
(1, 1), # dilation
1431+
1, # groups
1432+
True, # channel_last
1433+
torch.tensor([[[3.0], [5.0], [7.0]]], dtype=torch.float32),
1434+
),
1435+
# Test case 9: Multi-channel input and output
1436+
(
1437+
"multi_channel",
1438+
torch.tensor(
1439+
[
1440+
[
1441+
[[1.0, 2.0], [3.0, 4.0]], # first input channel
1442+
[[0.5, 1.0], [1.5, 2.0]], # second input channel
1443+
]
1444+
],
1445+
dtype=torch.float32,
1446+
), # input: 1x2x2x2
1447+
torch.tensor(
1448+
[
1449+
[ # first output channel
1450+
[[1.0, 0.0], [0.0, 1.0]], # weights for first input channel
1451+
[
1452+
[2.0, 0.0],
1453+
[0.0, 2.0],
1454+
], # weights for second input channel
1455+
],
1456+
[ # second output channel
1457+
[[0.5, 0.5], [0.5, 0.5]], # weights for first input channel
1458+
[
1459+
[1.0, 1.0],
1460+
[1.0, 1.0],
1461+
], # weights for second input channel
1462+
],
1463+
],
1464+
dtype=torch.float32,
1465+
), # weight: 2x2x2x2
1466+
torch.tensor([0.0, 1.0], dtype=torch.float32), # bias
1467+
(1, 1), # stride
1468+
(0, 0), # padding
1469+
(1, 1), # dilation
1470+
1, # groups
1471+
False, # channel_last
1472+
torch.tensor([[[[10.0]], [[11.0]]]], dtype=torch.float32),
1473+
),
1474+
# Test case 10: Convolution with non-zero bias
1475+
(
1476+
"conv2d_with_bias",
1477+
torch.tensor(
1478+
[[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32
1479+
), # input: 1x1x2x2
1480+
torch.tensor(
1481+
[[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32
1482+
), # weight: 1x1x2x2
1483+
torch.tensor([10.0], dtype=torch.float32), # bias=10
1484+
(1, 1), # stride
1485+
(0, 0), # padding
1486+
(1, 1), # dilation
1487+
1, # groups
1488+
False, # channel_last
1489+
torch.tensor(
1490+
[[[[15.0]]]], dtype=torch.float32
1491+
), # expected: 5 + 10 = 15
1492+
),
1493+
]
1494+
)
1495+
def test_convolution(
1496+
self,
1497+
name: str,
1498+
input_tensor: torch.Tensor,
1499+
weight: torch.Tensor,
1500+
bias: torch.Tensor,
1501+
stride: tuple[int, int],
1502+
padding: tuple[int, int],
1503+
dilation: tuple[int, int],
1504+
groups: int,
1505+
channel_last: bool,
1506+
expected_output: torch.Tensor,
1507+
) -> None:
1508+
output = torch.ops.cadence.convolution(
1509+
input_tensor,
1510+
weight,
1511+
bias,
1512+
stride,
1513+
padding,
1514+
dilation,
1515+
groups,
1516+
channel_last,
1517+
)
1518+
1519+
# Verify output properties
1520+
self.assertEqual(
1521+
output.dtype,
1522+
input_tensor.dtype,
1523+
f"Output dtype should match input dtype in {name}",
1524+
)
1525+
self.assertEqual(
1526+
output.shape,
1527+
expected_output.shape,
1528+
f"Output shape should match expected shape in {name}",
1529+
)
1530+
1531+
# Verify output matches expected values
1532+
self.assertTrue(
1533+
torch.equal(output, expected_output),
1534+
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
1535+
)

0 commit comments

Comments
 (0)