File tree Expand file tree Collapse file tree 1 file changed +5
-5
lines changed
onnxscript/function_libs/torch_lib/ops Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -4765,11 +4765,8 @@ def aten_index(
47654765 None in `indices` are like fillers for dimensions that cannot be removed in the process.
47664766 """
47674767 # Handle Boolean indexing first
4768- for index in indices :
4769- if index is None :
4770- continue
4771- if index .dtype == BOOL .dtype :
4772- return _aten_index_bool (self , indices )
4768+ if any (index is not None and index .dtype == ir .DataType .BOOL for index in indices ):
4769+ return _aten_index_bool (self , indices )
47734770
47744771 index_ranks = [len (index .shape ) for index in indices if index is not None ]
47754772
@@ -4844,6 +4841,9 @@ def aten_index_put(
48444841 See implementation of `torch.onnx.symbolic_opset11.index_put
48454842 <https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
48464843 """
4844+ if any (index is not None and index .dtype == BOOL .dtype for index in indices ):
4845+ return _aten_index_put_bool (self , indices , values , accumulate )
4846+
48474847 # Ensure the number of indices matches the tensor rank by appending trailing Nones.
48484848 self_rank = len (self .shape )
48494849 if len (indices ) < self_rank :
You can’t perform that action at this time.
0 commit comments