Skip to content

Commit 6dc71f3

Browse files
authored
[Relax][PyTorch] Add decomposed operator support for interpolate (#18462)
## Related Issue - #18401 ## How - Refactored `_index_tensor` to handle broadcast
1 parent ebd169b commit 6dc71f3

File tree

2 files changed

+455
-25
lines changed

2 files changed

+455
-25
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,20 +1693,42 @@ def _index_tensor(self, node: fx.Node) -> relax.Var:
16931693
axis, index_tensor = non_none_indices[0]
16941694
return self.block_builder.emit(relax.op.take(data, index_tensor, axis=axis))
16951695

1696-
# General case: multiple non-None indices require advanced indexing
1696+
# Check if all indices can be squeezed to 1D for sequential take
1697+
def is_squeezable(idx):
1698+
if idx.struct_info.ndim == 1:
1699+
return True
1700+
if idx.struct_info.ndim == 2:
1701+
shape = idx.struct_info.shape
1702+
for d in shape:
1703+
if isinstance(d, int) and d == 1:
1704+
return True
1705+
# Check for tir.IntImm
1706+
if hasattr(d, "value") and d.value == 1:
1707+
return True
1708+
return False
1709+
1710+
all_squeezable = all(is_squeezable(idx) for _, idx in non_none_indices)
1711+
if all_squeezable:
1712+
result = data
1713+
for axis, idx in reversed(non_none_indices):
1714+
if idx.struct_info.ndim > 1:
1715+
idx = self.block_builder.emit(relax.op.squeeze(idx))
1716+
result = self.block_builder.emit(relax.op.take(result, idx, axis=axis))
1717+
return result
1718+
1719+
# General case: replace None with arange, reshaped for broadcasting
1720+
max_ndim = max((idx.struct_info.ndim for _, idx in non_none_indices), default=1)
16971721
processed_indices = []
16981722
data_shape = self.shape_of(data)
16991723

17001724
for i, idx in enumerate(indices):
17011725
if idx is None:
1702-
dim_size = data_shape[i]
17031726
arange_idx = self.block_builder.emit(
1704-
relax.op.arange(
1705-
start=relax.PrimValue(0),
1706-
end=dim_size,
1707-
step=relax.PrimValue(1),
1708-
dtype="int64",
1709-
)
1727+
relax.op.arange(relax.PrimValue(0), data_shape[i], relax.PrimValue(1), "int64")
1728+
)
1729+
# Reshape to [dim_size, 1, 1, ...] for broadcasting
1730+
arange_idx = self.block_builder.emit(
1731+
relax.op.reshape(arange_idx, [data_shape[i]] + [1] * (max_ndim - 1))
17101732
)
17111733
processed_indices.append(arange_idx)
17121734
else:

0 commit comments

Comments
 (0)