Skip to content

Commit 257e583

Browse files
committed
Consolidate index
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 98ceb0a commit 257e583

File tree

1 file changed

+5
-5
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+5
-5
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4765,11 +4765,8 @@ def aten_index(
47654765
None in `indices` are like fillers for dimensions that cannot be removed in the process.
47664766
"""
47674767
# Handle Boolean indexing first
4768-
for index in indices:
4769-
if index is None:
4770-
continue
4771-
if index.dtype == BOOL.dtype:
4772-
return _aten_index_bool(self, indices)
4768+
if any(index is not None and index.dtype == ir.DataType.BOOL for index in indices):
4769+
return _aten_index_bool(self, indices)
47734770

47744771
index_ranks = [len(index.shape) for index in indices if index is not None]
47754772

@@ -4844,6 +4841,9 @@ def aten_index_put(
48444841
See implementation of `torch.onnx.symbolic_opset11.index_put
48454842
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
48464843
"""
4844+
if any(index is not None and index.dtype == BOOL.dtype for index in indices):
4845+
return _aten_index_put_bool(self, indices, values, accumulate)
4846+
48474847
# Ensure the number of indices matches the tensor rank by appending trailing Nones.
48484848
self_rank = len(self.shape)
48494849
if len(indices) < self_rank:

0 commit comments

Comments
 (0)