Skip to content

Commit baaaa86

Browse files
authored
Add transposed convolution
Differential Revision: D83602808 Pull Request resolved: #14708
1 parent fb66fb3 commit baaaa86

File tree

2 files changed

+196
-0
lines changed

2 files changed

+196
-0
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,7 @@ def convolution(
960960
_stride: tuple[int, int] | int = stride
961961
_padding: tuple[int, int] | int = padding
962962
_dilation: tuple[int, int] | int = dilation
963+
963964
if conv_is_1d:
964965
conv = torch.nn.functional.conv1d
965966
_stride = stride[0]
@@ -978,6 +979,64 @@ def convolution(
978979
return conv_out
979980

980981

982+
@impl(m, "transposed_convolution")
983+
def transposed_convolution(
984+
input_tensor: torch.Tensor,
985+
weight: torch.Tensor,
986+
bias: torch.Tensor,
987+
stride: tuple[int, int],
988+
padding: tuple[int, int],
989+
dilation: tuple[int, int],
990+
output_padding: tuple[int, int],
991+
groups: int,
992+
channel_last: bool = False,
993+
) -> torch.Tensor:
994+
995+
conv_is_1d = len(input_tensor.shape) == 3
996+
if channel_last:
997+
if conv_is_1d:
998+
input_tensor = input_tensor.movedim(-1, 1).contiguous()
999+
if len(weight.shape) != 3:
1000+
raise ValueError("Weight tensor must be 3D if input is 3D")
1001+
weight = weight.movedim(-1, 1).contiguous()
1002+
else:
1003+
input_tensor = input_tensor.movedim(-1, -3)
1004+
if len(weight.shape) != 4:
1005+
raise ValueError("Weight tensor must be 4D if input is nd > 3")
1006+
weight = torch.permute(weight, (0, -1, 1, 2)).contiguous()
1007+
1008+
_stride: tuple[int, int] | int = stride
1009+
_padding: tuple[int, int] | int = padding
1010+
_dilation: tuple[int, int] | int = dilation
1011+
_output_padding: tuple[int, int] | int = output_padding
1012+
if conv_is_1d:
1013+
conv = torch.nn.functional.conv_transpose1d
1014+
_stride = stride[0]
1015+
_padding = padding[0]
1016+
_dilation = dilation[0]
1017+
_output_padding = output_padding[0]
1018+
else:
1019+
conv = torch.nn.functional.conv_transpose2d
1020+
1021+
conv_out = conv(
1022+
input_tensor,
1023+
weight,
1024+
bias,
1025+
_stride,
1026+
_padding,
1027+
_output_padding,
1028+
groups,
1029+
_dilation,
1030+
)
1031+
if channel_last:
1032+
if conv_is_1d:
1033+
conv_out = conv_out.movedim(1, -1).contiguous()
1034+
else:
1035+
conv_out = conv_out.movedim(-3, -1).contiguous()
1036+
1037+
return conv_out
1038+
1039+
9811040
@impl(m, "avg_pool2d")
9821041
def avg_pool2d(
9831042
input_tensor: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,6 +1534,143 @@ def test_convolution(
15341534
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
15351535
)
15361536

1537+
@expand(
1538+
[
1539+
# Basic 2D transposed convolution with stride=1 (current test case - corrected name)
1540+
(
1541+
"basic_2d_stride1",
1542+
torch.tensor(
1543+
[[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32
1544+
), # input: 1x1x2x2
1545+
torch.tensor(
1546+
[[[[1.0, 1.0], [1.0, 1.0]]]], dtype=torch.float32
1547+
), # weight: 1x1x2x2
1548+
torch.tensor([0.0], dtype=torch.float32), # bias
1549+
(1, 1), # stride
1550+
(0, 0), # padding
1551+
(1, 1), # dilation
1552+
1, # groups
1553+
(0, 0), # output_padding
1554+
False, # channel_last
1555+
torch.tensor(
1556+
[[[[1.0, 3.0, 2.0], [4.0, 10.0, 6.0], [3.0, 7.0, 4.0]]]],
1557+
dtype=torch.float32,
1558+
),
1559+
),
1560+
# 2D transposed convolution with channel_last=True (NHWC format)
1561+
(
1562+
"channel_last_nhwc",
1563+
torch.tensor(
1564+
[[[[1.0], [2.0]], [[3.0], [4.0]]]], dtype=torch.float32
1565+
), # input: 1x2x2x1 (NHWC)
1566+
torch.tensor(
1567+
[[[[1.0], [1.0]], [[1.0], [1.0]]]], dtype=torch.float32
1568+
), # weight: 1x2x2x1 (NHWC)
1569+
torch.tensor([0.0], dtype=torch.float32), # bias
1570+
(1, 1), # stride
1571+
(0, 0), # padding
1572+
(1, 1), # dilation
1573+
1, # groups
1574+
(0, 0), # output_padding
1575+
True, # channel_last=True
1576+
torch.tensor(
1577+
[
1578+
[
1579+
[[1.0], [3.0], [2.0]],
1580+
[[4.0], [10.0], [6.0]],
1581+
[[3.0], [7.0], [4.0]],
1582+
]
1583+
],
1584+
dtype=torch.float32,
1585+
),
1586+
),
1587+
# 2D transposed convolution with non-zero bias
1588+
(
1589+
"with_bias",
1590+
torch.tensor(
1591+
[[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32
1592+
), # input: 1x1x2x2
1593+
torch.tensor(
1594+
[[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32
1595+
), # weight: 1x1x2x2
1596+
torch.tensor([5.0], dtype=torch.float32), # bias=5.0
1597+
(1, 1), # stride
1598+
(0, 0), # padding
1599+
(1, 1), # dilation
1600+
1, # groups
1601+
(0, 0), # output_padding
1602+
False, # channel_last
1603+
torch.tensor(
1604+
[[[[6.0, 7.0, 5.0], [8.0, 10.0, 7.0], [5.0, 8.0, 9.0]]]],
1605+
dtype=torch.float32,
1606+
),
1607+
),
1608+
# 1D transposed convolution (3D tensor, NLC format)
1609+
(
1610+
"conv1d_nlc",
1611+
torch.tensor(
1612+
[[[1.0], [2.0], [3.0]]], dtype=torch.float32
1613+
), # input: 1x3x1 (NLC)
1614+
torch.tensor(
1615+
[[[1.0], [0.5]]], dtype=torch.float32
1616+
), # weight: 1x2x1 (NLC)
1617+
torch.tensor([0.0], dtype=torch.float32), # bias
1618+
(2, 0), # stride
1619+
(0, 0), # padding
1620+
(1, 1), # dilation
1621+
1, # groups
1622+
(0, 0), # output_padding
1623+
True, # channel_last=True
1624+
torch.tensor(
1625+
[[[1.0], [0.5], [2.0], [1.0], [3.0], [1.5]]], dtype=torch.float32
1626+
),
1627+
),
1628+
]
1629+
)
1630+
def test_transposed_convolution(
1631+
self,
1632+
name: str,
1633+
input_tensor: torch.Tensor,
1634+
weight: torch.Tensor,
1635+
bias: torch.Tensor,
1636+
stride: tuple[int, int],
1637+
padding: tuple[int, int],
1638+
dilation: tuple[int, int],
1639+
groups: int,
1640+
output_padding: tuple[int, int],
1641+
channel_last: bool,
1642+
expected_output: torch.Tensor,
1643+
) -> None:
1644+
output = torch.ops.cadence.transposed_convolution(
1645+
input_tensor,
1646+
weight,
1647+
bias,
1648+
stride,
1649+
padding,
1650+
dilation,
1651+
output_padding,
1652+
groups,
1653+
channel_last,
1654+
)
1655+
1656+
# Verify output properties
1657+
self.assertEqual(
1658+
output.dtype,
1659+
input_tensor.dtype,
1660+
f"Output dtype should match input dtype in {name}",
1661+
)
1662+
self.assertEqual(
1663+
output.shape,
1664+
expected_output.shape,
1665+
f"Output shape should match expected shape in {name}",
1666+
)
1667+
1668+
# Verify output matches expected values
1669+
self.assertTrue(
1670+
torch.equal(output, expected_output),
1671+
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
1672+
)
1673+
15371674
@expand(
15381675
[
15391676
# Basic non-quantized average pooling

0 commit comments

Comments
 (0)