Skip to content

Commit f8471f8

Browse files
authored
[Relax][PyTorch] Add decomposed operator support for Pad (#18449)
## Related Issue - #18401 ## Why - When run_ep_decomposition=True is enabled, PyTorch decomposes pad operators into lower-level operations: - Constant mode → `constant_pad_nd.default` - Reflect/Replicate modes → `index.Tensor` with None indices - Circular mode → `copy.default` and `slice` operations - Some of the decomposed operators were not supported, causing failures ## How - Added support for `constant_pad_nd.default` and `copy.default` operator - Fixed `_index_tensor` to handle None indices by: - Using `take` operation when only one dimension is indexed (optimization) - Converting `None` to explicit `arange` for general cases - Updated test_pad to use run_ep_decomposition=True
1 parent 6c7ed24 commit f8471f8

File tree

3 files changed

+228
-23
lines changed

3 files changed

+228
-23
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1379,6 +1379,23 @@ def _pad(self, node: fx.Node) -> relax.Var:
13791379

13801380
return self.block_builder.emit(relax.op.nn.pad(x, pad_width, mode, value))
13811381

1382+
def _constant_pad_nd(self, node: fx.Node) -> relax.Var:
1383+
x = self.env[node.args[0]]
1384+
pad = node.args[1]
1385+
value = node.args[2] if len(node.args) > 2 else node.kwargs.get("value", 0.0)
1386+
value = 0.0 if value is None else value
1387+
1388+
# Calculate symmetric padding width for each dimension
1389+
# and applying them in reverse order to match the input dimensions.
1390+
input_ndim = x.struct_info.ndim
1391+
pad_width = [0] * (input_ndim * 2)
1392+
pad_pairs = [pad[i : i + 2] for i in range(0, len(pad), 2)]
1393+
reversed_pairs = list(reversed(pad_pairs))
1394+
flattened = [v for pair in reversed_pairs for v in pair]
1395+
pad_width[-len(flattened) :] = flattened
1396+
1397+
return self.block_builder.emit(relax.op.nn.pad(x, pad_width, "constant", value))
1398+
13821399
def _pixel_shuffle(self, node: fx.Node) -> relax.Var:
13831400
data = self.env[node.args[0]]
13841401
upscale_factor = node.args[1]
@@ -1665,8 +1682,37 @@ def _index_put(self, node: fx.Node) -> relax.Var:
16651682

16661683
def _index_tensor(self, node: fx.Node) -> relax.Var:
16671684
args = self.retrieve_args(node)
1685+
data = args[0]
16681686
indices = args[1]
1669-
return self.block_builder.emit(relax.op.index_tensor(args[0], indices))
1687+
1688+
# In PyTorch's aten.index.Tensor, None means "select all elements" for that dimension
1689+
non_none_indices = [(i, idx) for i, idx in enumerate(indices) if idx is not None]
1690+
1691+
# Special case: if there's only one non-None index, use take operation
1692+
if len(non_none_indices) == 1:
1693+
axis, index_tensor = non_none_indices[0]
1694+
return self.block_builder.emit(relax.op.take(data, index_tensor, axis=axis))
1695+
1696+
# General case: multiple non-None indices require advanced indexing
1697+
processed_indices = []
1698+
data_shape = self.shape_of(data)
1699+
1700+
for i, idx in enumerate(indices):
1701+
if idx is None:
1702+
dim_size = data_shape[i]
1703+
arange_idx = self.block_builder.emit(
1704+
relax.op.arange(
1705+
start=relax.PrimValue(0),
1706+
end=dim_size,
1707+
step=relax.PrimValue(1),
1708+
dtype="int64",
1709+
)
1710+
)
1711+
processed_indices.append(arange_idx)
1712+
else:
1713+
processed_indices.append(idx)
1714+
1715+
return self.block_builder.emit(relax.op.index_tensor(data, processed_indices))
16701716

16711717
def _meshgrid(self, node: fx.Node) -> relax.Var:
16721718
args = self.retrieve_args(node)

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,8 @@ def create_convert_map(
862862
"_log_softmax.default": self._log_softmax,
863863
"neg.default": self._unary_op(relax.op.negative),
864864
"pad.default": self._pad,
865+
"constant_pad_nd.default": self._constant_pad_nd,
866+
"copy.default": self._copy_,
865867
"pixel_shuffle.default": self._pixel_shuffle,
866868
"prelu.default": self._prelu,
867869
"reciprocal.default": self._reciprocal,

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 179 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2715,13 +2715,25 @@ def main(
27152715
x: R.Tensor((1, 3, 10, 10), dtype="float32")
27162716
) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")):
27172717
with R.dataflow():
2718-
lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
2719-
x,
2720-
pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
2721-
pad_mode="reflect",
2722-
pad_value=0.0,
2718+
lv: R.Tensor((14,), dtype="int64") = R.arange(
2719+
R.prim_value(-2), R.prim_value(12), R.prim_value(1), dtype="int64"
27232720
)
2724-
gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,)
2721+
lv1: R.Tensor((14,), dtype="int64") = R.abs(lv)
2722+
lv2: R.Tensor((14,), dtype="int64") = R.subtract(R.const(9, "int64"), lv1)
2723+
lv3: R.Tensor((14,), dtype="int64") = R.abs(lv2)
2724+
lv4: R.Tensor((14,), dtype="int64") = R.subtract(R.const(9, "int64"), lv3)
2725+
lv5: R.Tensor((1, 3, 14, 10), dtype="float32") = R.take(x, lv4, axis=2, mode="fast")
2726+
lv6: R.Tensor((12,), dtype="int64") = R.arange(
2727+
R.prim_value(-1), R.prim_value(11), R.prim_value(1), dtype="int64"
2728+
)
2729+
lv7: R.Tensor((12,), dtype="int64") = R.abs(lv6)
2730+
lv8: R.Tensor((12,), dtype="int64") = R.subtract(R.const(9, "int64"), lv7)
2731+
lv9: R.Tensor((12,), dtype="int64") = R.abs(lv8)
2732+
lv10: R.Tensor((12,), dtype="int64") = R.subtract(R.const(9, "int64"), lv9)
2733+
lv11: R.Tensor((1, 3, 14, 12), dtype="float32") = R.take(
2734+
lv5, lv10, axis=3, mode="fast"
2735+
)
2736+
gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv11,)
27252737
R.output(gv)
27262738
return gv
27272739

@@ -2732,13 +2744,19 @@ def main(
27322744
x: R.Tensor((1, 3, 10, 10), dtype="float32")
27332745
) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")):
27342746
with R.dataflow():
2735-
lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
2736-
x,
2737-
pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
2738-
pad_mode="replicate",
2739-
pad_value=0.0,
2747+
lv: R.Tensor((14,), dtype="int64") = R.arange(
2748+
R.prim_value(-2), R.prim_value(12), R.prim_value(1), dtype="int64"
27402749
)
2741-
gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,)
2750+
lv1: R.Tensor((14,), dtype="int64") = R.clip(lv, R.prim_value(0), R.prim_value(9))
2751+
lv2: R.Tensor((1, 3, 14, 10), dtype="float32") = R.take(x, lv1, axis=2, mode="fast")
2752+
lv3: R.Tensor((12,), dtype="int64") = R.arange(
2753+
R.prim_value(-1), R.prim_value(11), R.prim_value(1), dtype="int64"
2754+
)
2755+
lv4: R.Tensor((12,), dtype="int64") = R.clip(lv3, R.prim_value(0), R.prim_value(9))
2756+
lv5: R.Tensor((1, 3, 14, 12), dtype="float32") = R.take(
2757+
lv2, lv4, axis=3, mode="fast"
2758+
)
2759+
gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv5,)
27422760
R.output(gv)
27432761
return gv
27442762

@@ -2749,21 +2767,160 @@ def main(
27492767
x: R.Tensor((1, 3, 10, 10), dtype="float32")
27502768
) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")):
27512769
with R.dataflow():
2752-
lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
2770+
lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.zeros(
2771+
R.shape([1, 3, 14, 12]), dtype="float32"
2772+
)
2773+
lv1: R.Tensor((1, 3, 14, 10), dtype="float32") = R.strided_slice(
2774+
lv,
2775+
(R.prim_value(3),),
2776+
(R.prim_value(1),),
2777+
(R.prim_value(11),),
2778+
(R.prim_value(1),),
2779+
assume_inbound=False,
2780+
)
2781+
lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice(
27532782
x,
2754-
pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
2755-
pad_mode="circular",
2756-
pad_value=0.0,
2783+
(R.prim_value(3),),
2784+
(R.prim_value(0),),
2785+
(R.prim_value(10),),
2786+
(R.prim_value(1),),
2787+
assume_inbound=False,
27572788
)
2758-
gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,)
2789+
lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice(
2790+
lv1,
2791+
(R.prim_value(2),),
2792+
(R.prim_value(2),),
2793+
(R.prim_value(12),),
2794+
(R.prim_value(1),),
2795+
assume_inbound=False,
2796+
)
2797+
lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice(
2798+
lv2,
2799+
(R.prim_value(2),),
2800+
(R.prim_value(0),),
2801+
(R.prim_value(10),),
2802+
(R.prim_value(1),),
2803+
assume_inbound=False,
2804+
)
2805+
lv5: R.Tensor((1, 3, 14, 10), dtype="float32") = R.strided_slice(
2806+
lv,
2807+
(R.prim_value(3),),
2808+
(R.prim_value(1),),
2809+
(R.prim_value(11),),
2810+
(R.prim_value(1),),
2811+
assume_inbound=False,
2812+
)
2813+
lv6: R.Tensor((1, 3, 14, 10), dtype="float32") = R.slice_scatter(
2814+
lv5, lv4, R.prim_value(2), R.prim_value(12), R.prim_value(1), axis=2
2815+
)
2816+
lv7: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter(
2817+
lv, lv6, R.prim_value(1), R.prim_value(11), R.prim_value(1), axis=3
2818+
)
2819+
lv8: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice(
2820+
lv7,
2821+
(R.prim_value(3),),
2822+
(R.prim_value(0),),
2823+
(R.prim_value(1),),
2824+
(R.prim_value(1),),
2825+
assume_inbound=False,
2826+
)
2827+
lv9: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice(
2828+
lv7,
2829+
(R.prim_value(3),),
2830+
(R.prim_value(10),),
2831+
(R.prim_value(11),),
2832+
(R.prim_value(1),),
2833+
assume_inbound=False,
2834+
)
2835+
lv10: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter(
2836+
lv7, lv9, R.prim_value(0), R.prim_value(1), R.prim_value(1), axis=3
2837+
)
2838+
lv11: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice(
2839+
lv10,
2840+
(R.prim_value(3),),
2841+
(R.prim_value(11),),
2842+
(R.prim_value(12),),
2843+
(R.prim_value(1),),
2844+
assume_inbound=False,
2845+
)
2846+
lv12: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice(
2847+
lv10,
2848+
(R.prim_value(3),),
2849+
(R.prim_value(1),),
2850+
(R.prim_value(2),),
2851+
(R.prim_value(1),),
2852+
assume_inbound=False,
2853+
)
2854+
lv13: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter(
2855+
lv10, lv12, R.prim_value(11), R.prim_value(12), R.prim_value(1), axis=3
2856+
)
2857+
lv14: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice(
2858+
lv13,
2859+
(R.prim_value(2),),
2860+
(R.prim_value(0),),
2861+
(R.prim_value(2),),
2862+
(R.prim_value(1),),
2863+
assume_inbound=False,
2864+
)
2865+
lv15: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice(
2866+
lv13,
2867+
(R.prim_value(2),),
2868+
(R.prim_value(10),),
2869+
(R.prim_value(12),),
2870+
(R.prim_value(1),),
2871+
assume_inbound=False,
2872+
)
2873+
lv16: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter(
2874+
lv13, lv15, R.prim_value(0), R.prim_value(2), R.prim_value(1), axis=2
2875+
)
2876+
lv17: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice(
2877+
lv16,
2878+
(R.prim_value(2),),
2879+
(R.prim_value(12),),
2880+
(R.prim_value(14),),
2881+
(R.prim_value(1),),
2882+
assume_inbound=False,
2883+
)
2884+
lv18: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice(
2885+
lv16,
2886+
(R.prim_value(2),),
2887+
(R.prim_value(2),),
2888+
(R.prim_value(4),),
2889+
(R.prim_value(1),),
2890+
assume_inbound=False,
2891+
)
2892+
lv19: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter(
2893+
lv16, lv18, R.prim_value(12), R.prim_value(14), R.prim_value(1), axis=2
2894+
)
2895+
gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv19,)
27592896
R.output(gv)
27602897
return gv
27612898

27622899
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
2763-
verify_model(PadModel(pad=[1, 1, 2, 2]), example_args, {}, expected_constant)
2764-
verify_model(PadModel(pad=[1, 1, 2, 2], mode="reflect"), example_args, {}, expected_reflect)
2765-
verify_model(PadModel(pad=[1, 1, 2, 2], mode="replicate"), example_args, {}, expected_replicate)
2766-
verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), example_args, {}, expected_circular)
2900+
verify_model(
2901+
PadModel(pad=[1, 1, 2, 2]), example_args, {}, expected_constant, run_ep_decomposition=True
2902+
)
2903+
verify_model(
2904+
PadModel(pad=[1, 1, 2, 2], mode="reflect"),
2905+
example_args,
2906+
{},
2907+
expected_reflect,
2908+
run_ep_decomposition=True,
2909+
)
2910+
verify_model(
2911+
PadModel(pad=[1, 1, 2, 2], mode="replicate"),
2912+
example_args,
2913+
{},
2914+
expected_replicate,
2915+
run_ep_decomposition=True,
2916+
)
2917+
verify_model(
2918+
PadModel(pad=[1, 1, 2, 2], mode="circular"),
2919+
example_args,
2920+
{},
2921+
expected_circular,
2922+
run_ep_decomposition=True,
2923+
)
27672924

27682925

27692926
def test_pixel_shuffle():
@@ -5949,7 +6106,7 @@ def main(
59496106
) -> R.Tuple(R.Tensor((3,), dtype="float32")):
59506107
with R.dataflow():
59516108
lv: R.Tensor((5,), dtype="float32") = R.reshape(data, R.shape([5]))
5952-
lv1: R.Tensor((3,), dtype="float32") = R.index_tensor(lv, (indices,))
6109+
lv1: R.Tensor((3,), dtype="float32") = R.take(lv, indices, axis=0, mode="fast")
59536110
gv: R.Tuple(R.Tensor((3,), dtype="float32")) = (lv1,)
59546111
R.output(gv)
59556112
return gv

0 commit comments

Comments
 (0)