Skip to content

Commit 0f08e27

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Add transposed convolution
Summary: Continued support of cadence python references Differential Revision: D83602808
1 parent 965d1bc commit 0f08e27

File tree

2 files changed

+136
-11
lines changed

2 files changed

+136
-11
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,6 @@ def quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
932932
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8, is_1d=True)
933933
def quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
934934

935-
936935
@impl(m, "convolution")
937936
def convolution(
938937
input_tensor: torch.Tensor,
@@ -960,6 +959,7 @@ def convolution(
960959
_stride: tuple[int, int] | int = stride
961960
_padding: tuple[int, int] | int = padding
962961
_dilation: tuple[int, int] | int = dilation
962+
963963
if conv_is_1d:
964964
conv = torch.nn.functional.conv1d
965965
_stride = stride[0]
@@ -977,6 +977,52 @@ def convolution(
977977

978978
return conv_out
979979

980+
@impl(m, "transposed_convolution")
981+
def transposed_convolution(
982+
input_tensor: torch.Tensor,
983+
weight: torch.Tensor,
984+
bias: torch.Tensor,
985+
stride: tuple[int, int],
986+
padding: tuple[int, int],
987+
dilation: tuple[int, int],
988+
output_padding: tuple[int, int],
989+
groups: int,
990+
channel_last: bool = False,
991+
) -> torch.Tensor:
992+
993+
conv_is_1d = len(input_tensor.shape) == 3
994+
if channel_last:
995+
if conv_is_1d:
996+
input_tensor = input_tensor.movedim(-1, 1).contiguous()
997+
if len(weight.shape) != 3:
998+
raise ValueError("Weight tensor must be 3D if input is 3D")
999+
weight = weight.movedim(-1, 1).contiguous()
1000+
else:
1001+
input_tensor = input_tensor.movedim(-1, -3)
1002+
if len(weight.shape) != 4:
1003+
raise ValueError("Weight tensor must be 4D if input is nd > 3")
1004+
weight = torch.permute(weight, (0, -1, 1, 2)).contiguous()
1005+
1006+
_stride: tuple[int, int] | int = stride
1007+
_padding: tuple[int, int] | int = padding
1008+
_dilation: tuple[int, int] | int = dilation
1009+
1010+
if conv_is_1d:
1011+
conv = torch.nn.functional.conv_transpose1d
1012+
_stride = stride[0]
1013+
_padding = padding[0]
1014+
_dilation = dilation[0]
1015+
else:
1016+
conv = torch.nn.functional.conv_transpose2d
1017+
1018+
conv_out = conv(input_tensor, weight, bias, _stride, _padding, output_padding, groups, _dilation)
1019+
if channel_last:
1020+
if conv_is_1d:
1021+
conv_out = conv_out.movedim(1, -1).contiguous()
1022+
else:
1023+
conv_out = conv_out.movedim(-3, -1).contiguous()
1024+
1025+
return conv_out
9801026

9811027
@impl(m, "avg_pool2d")
9821028
def avg_pool2d(

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 89 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,7 +1259,7 @@ def test_rope(
12591259

12601260
@expand(
12611261
[
1262-
# Test case 1: Basic 2D convolution (NCHW format)
1262+
# Basic 2D convolution (NCHW format)
12631263
(
12641264
"basic_2d_nchw",
12651265
torch.tensor(
@@ -1274,11 +1274,12 @@ def test_rope(
12741274
(1, 1), # dilation
12751275
1, # groups
12761276
False, # channel_last
1277+
12771278
torch.tensor(
12781279
[[[[5.0]]]], dtype=torch.float32
12791280
), # expected: 1*1 + 4*1 = 5
12801281
),
1281-
# Test case 2: Basic 2D convolution (NHWC format)
1282+
# Basic 2D convolution (NHWC format)
12821283
(
12831284
"basic_2d_nhwc",
12841285
torch.tensor(
@@ -1293,11 +1294,12 @@ def test_rope(
12931294
(1, 1), # dilation
12941295
1, # groups
12951296
True, # channel_last
1297+
12961298
torch.tensor(
12971299
[[[[5.0]]]], dtype=torch.float32
12981300
), # expected: 1*1 + 4*1 = 5
12991301
),
1300-
# Test case 3: 2D convolution with stride=2
1302+
# 2D convolution with stride=2
13011303
(
13021304
"conv2d_stride2",
13031305
torch.tensor(
@@ -1322,9 +1324,10 @@ def test_rope(
13221324
(1, 1), # dilation
13231325
1, # groups
13241326
False, # channel_last
1327+
13251328
torch.tensor([[[[14.0, 22.0], [46.0, 54.0]]]], dtype=torch.float32),
13261329
),
1327-
# Test case 4: 2D convolution with padding=1
1330+
# 2D convolution with padding=1
13281331
(
13291332
"conv2d_padding1",
13301333
torch.tensor(
@@ -1339,12 +1342,13 @@ def test_rope(
13391342
(1, 1), # dilation
13401343
1, # groups
13411344
False, # channel_last
1345+
13421346
torch.tensor(
13431347
[[[[1.0, 2.0, 0.0], [3.0, 5.0, 2.0], [0.0, 3.0, 4.0]]]],
13441348
dtype=torch.float32,
13451349
), # expected with padding
13461350
),
1347-
# Test case 5: 2D convolution with dilation=2
1351+
# 2D convolution with dilation=2
13481352
(
13491353
"conv2d_dilation2",
13501354
torch.tensor(
@@ -1369,9 +1373,10 @@ def test_rope(
13691373
(2, 2), # dilation=2
13701374
1, # groups
13711375
False, # channel_last
1376+
13721377
torch.tensor([[[[24.0, 28.0], [40.0, 44.0]]]], dtype=torch.float32),
13731378
),
1374-
# Test case 6: 2D grouped convolution (groups=2)
1379+
# 2D grouped convolution (groups=2)
13751380
(
13761381
"conv2d_groups2",
13771382
torch.tensor(
@@ -1396,9 +1401,10 @@ def test_rope(
13961401
(1, 1), # dilation
13971402
2, # groups=2
13981403
False, # channel_last
1404+
13991405
torch.tensor([[[[10.0]], [[14.0]]]], dtype=torch.float32),
14001406
),
1401-
# Test case 7: 1D convolution (NCL format)
1407+
# 1D convolution (NCL format)
14021408
(
14031409
"conv1d_ncl",
14041410
torch.tensor(
@@ -1411,11 +1417,12 @@ def test_rope(
14111417
(1, 1), # dilation (only dilation[1] is used for 1D)
14121418
1, # groups
14131419
False, # channel_last
1420+
14141421
torch.tensor(
14151422
[[[3.0, 5.0, 7.0]]], dtype=torch.float32
14161423
), # expected: [1+2, 2+3, 3+4]
14171424
),
1418-
# Test case 8: 1D convolution (NLC format)
1425+
# 1D convolution (NLC format)
14191426
(
14201427
"conv1d_nlc",
14211428
torch.tensor(
@@ -1430,9 +1437,10 @@ def test_rope(
14301437
(1, 1), # dilation
14311438
1, # groups
14321439
True, # channel_last
1440+
14331441
torch.tensor([[[3.0], [5.0], [7.0]]], dtype=torch.float32),
14341442
),
1435-
# Test case 9: Multi-channel input and output
1443+
# Multi-channel input and output
14361444
(
14371445
"multi_channel",
14381446
torch.tensor(
@@ -1469,9 +1477,10 @@ def test_rope(
14691477
(1, 1), # dilation
14701478
1, # groups
14711479
False, # channel_last
1480+
14721481
torch.tensor([[[[10.0]], [[11.0]]]], dtype=torch.float32),
14731482
),
1474-
# Test case 10: Convolution with non-zero bias
1483+
# Convolution with non-zero bias
14751484
(
14761485
"conv2d_with_bias",
14771486
torch.tensor(
@@ -1486,6 +1495,7 @@ def test_rope(
14861495
(1, 1), # dilation
14871496
1, # groups
14881497
False, # channel_last
1498+
14891499
torch.tensor(
14901500
[[[[15.0]]]], dtype=torch.float32
14911501
), # expected: 5 + 10 = 15
@@ -1534,6 +1544,75 @@ def test_convolution(
15341544
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
15351545
)
15361546

1547+
@expand(
1548+
[
1549+
(
1550+
"conv2d_transposed_stride2",
1551+
torch.tensor(
1552+
[[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32
1553+
), # input: 1x1x2x2
1554+
torch.tensor(
1555+
[[[[1.0, 1.0], [1.0, 1.0]]]], dtype=torch.float32
1556+
), # weight: 1x1x2x2
1557+
torch.tensor([0.0], dtype=torch.float32), # bias
1558+
(1, 1), # stride=2
1559+
(0, 0), # padding
1560+
(1, 1), # dilation
1561+
1, # groups
1562+
(0, 0), # output_padding
1563+
False, # channel_last
1564+
torch.tensor(
1565+
[[[[1.0, 3.0, 2.0],
1566+
[4.0, 10.0, 6.0],
1567+
[3.0, 7.0, 4.0]]]], dtype=torch.float32
1568+
),
1569+
),
1570+
]
1571+
)
1572+
def test_transposed_convolution(
1573+
self,
1574+
name: str,
1575+
input_tensor: torch.Tensor,
1576+
weight: torch.Tensor,
1577+
bias: torch.Tensor,
1578+
stride: tuple[int, int],
1579+
padding: tuple[int, int],
1580+
dilation: tuple[int, int],
1581+
groups: int,
1582+
output_padding: tuple[int, int],
1583+
channel_last: bool,
1584+
expected_output: torch.Tensor,
1585+
) -> None:
1586+
output = torch.ops.cadence.transposed_convolution(
1587+
input_tensor,
1588+
weight,
1589+
bias,
1590+
stride,
1591+
padding,
1592+
dilation,
1593+
output_padding,
1594+
groups,
1595+
channel_last,
1596+
)
1597+
1598+
# Verify output properties
1599+
self.assertEqual(
1600+
output.dtype,
1601+
input_tensor.dtype,
1602+
f"Output dtype should match input dtype in {name}",
1603+
)
1604+
self.assertEqual(
1605+
output.shape,
1606+
expected_output.shape,
1607+
f"Output shape should match expected shape in {name}",
1608+
)
1609+
1610+
# Verify output matches expected values
1611+
self.assertTrue(
1612+
torch.equal(output, expected_output),
1613+
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
1614+
)
1615+
15371616
@expand(
15381617
[
15391618
# Basic non-quantized average pooling

0 commit comments

Comments
 (0)