We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f529292 commit 3156bedCopy full SHA for 3156bed
onnxscript/function_libs/torch_lib/ops/core.py
@@ -3814,11 +3814,15 @@ def aten_gather(
3814
else:
3815
return op.Expand(self, op.Shape(index))
3816
3817
- if len(index.shape) == 0:
3818
- return op.Identity(self)
+ is_scalar_index = len(index.shape) == 0
+ if is_scalar_index:
3819
+ index = op.Unsqueeze(index, [0])
3820
3821
index = op.Cast(index, to=INT64.dtype)
3822
result = op.GatherElements(self, index, axis=dim)
3823
+
3824
3825
+ result = op.Squeeze(result, [0])
3826
return result
3827
3828
0 commit comments