Skip to content

Commit 4660e38

Browse files
yushangdipytorchmergebot
authored andcommitted
write conv1d decomposition (pytorch#163080)
In Unified Runtime, we cannot have any fallback ops (for now). Not all conv1d ops can avoid fallbacks now, so we write a decomposition for it. it's not registered to the default decomposition table as currently only executorch/unified runtime needs it. But it might benefit inductor as well because conv2d can generate triton kernels while there's no triton codegen for conv1d. I don't know if the conv2d triton kernel will have better perf compared to aten::conv1d, so it's not registered by default yet. To register it, one just needs to do `import torch._decomp as decomp;decomp.register_decomposition(torch.ops.aten.conv1d.default, conv1d_to_conv2d)` Pull Request resolved: pytorch#163080 Approved by: https://github.com/angelayi
1 parent 5236007 commit 4660e38

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

test/test_decomp.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,6 +1343,55 @@ def test_aten_core_operators(self):
13431343
core_aten_ops = useful_decomps - core_decomps
13441344
self.assertExpected("".join(sorted(op.name() + "\n" for op in core_aten_ops)))
13451345

1346+
def test_conv1d_decomposition(self):
1347+
from torch._inductor.decomposition import conv1d_to_conv2d
1348+
1349+
def check_case(
1350+
N=2,
1351+
C_in=3,
1352+
C_out=5,
1353+
L=37,
1354+
K=5,
1355+
stride=2,
1356+
padding=3,
1357+
dilation=1,
1358+
groups=1,
1359+
dtype=torch.float32,
1360+
device="cpu",
1361+
):
1362+
torch.manual_seed(0)
1363+
x = torch.randn(N, C_in, L, dtype=dtype, device=device)
1364+
w = torch.randn(C_out, C_in // groups, K, dtype=dtype, device=device)
1365+
b = torch.randn(C_out, dtype=dtype, device=device)
1366+
1367+
ref = torch.ops.aten.conv1d.default(
1368+
x,
1369+
w,
1370+
b,
1371+
stride=[stride],
1372+
padding=[padding],
1373+
dilation=[dilation],
1374+
groups=groups,
1375+
)
1376+
got = conv1d_to_conv2d(
1377+
x,
1378+
w,
1379+
b,
1380+
stride=[stride],
1381+
padding=[padding],
1382+
dilation=[dilation],
1383+
groups=groups,
1384+
)
1385+
self.assertTrue(torch.allclose(ref, got, atol=1e-5, rtol=1e-5))
1386+
1387+
# A few cases
1388+
check_case() # default
1389+
check_case(stride=1, padding=0, K=3)
1390+
check_case(stride=3, padding=4, K=7)
1391+
check_case(dilation=2, padding=6, K=5) # dilation
1392+
check_case(groups=1, C_in=8, C_out=12) # groups=1 bigger
1393+
check_case(groups=2, C_in=8, C_out=12) # grouped conv
1394+
13461395

13471396
if __name__ == "__main__":
13481397
run_tests()

torch/_inductor/decomposition.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,3 +1172,45 @@ def repeat_interleave_Tensor(
11721172
return torch.searchsorted(
11731173
cumsum, pos, out_int32=(repeat.dtype == torch.int32), right=True
11741174
)
1175+
1176+
1177+
# intentionally not regiestered
1178+
def conv1d_to_conv2d(
1179+
input: torch.Tensor,
1180+
weight: torch.Tensor,
1181+
bias: Optional[torch.Tensor] = None,
1182+
stride: tuple[int] = (1,),
1183+
padding: tuple[int] = (0,),
1184+
dilation: tuple[int] = (1,),
1185+
groups: int = 1,
1186+
) -> torch.Tensor:
1187+
# Shapes:
1188+
# input: (N, C_in, L_in)
1189+
# weight: (C_out, C_in // groups, K)
1190+
# bias: (C_out,)
1191+
assert input.dim() == 3 and weight.dim() == 3, (
1192+
"Expect (N,C_in,L) and (C_out,C_in//groups,K)"
1193+
)
1194+
1195+
stride = stride[0]
1196+
padding = padding[0]
1197+
dilation = dilation[0]
1198+
1199+
# Unsqueeze to make input 2D: (N,C,L) -> (N,C,L,1)
1200+
input_2d = input.unsqueeze(-1)
1201+
# Unsqueeze kernel: (C_out,C_in/groups,K) -> (C_out,C_in/groups,K,1)
1202+
weight_2d = weight.unsqueeze(-1)
1203+
1204+
# Call conv2d with adjusted args
1205+
out_2d = aten.conv2d.default(
1206+
input_2d,
1207+
weight_2d,
1208+
bias,
1209+
stride=(stride, 1),
1210+
padding=(padding, 0),
1211+
dilation=(dilation, 1),
1212+
groups=groups,
1213+
)
1214+
1215+
# Squeeze dummy dimension back out: (N,C_out,L_out,1) -> (N,C_out,L_out)
1216+
return out_2d.squeeze(-1)

0 commit comments

Comments
 (0)