Skip to content

Commit b6ebed8

Browse files
committed
fix reduction
1 parent f3731ed commit b6ebed8

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4345,7 +4345,7 @@ def _1dint(i: int):
43454345
assert len(n_none) == 1, f"Unable to handle that case: n_none={n_none}"
43464346
unsq = op.Unsqueeze(indices[n_none[0]], _1dint(1))
43474347
if n_none[0] == 0:
4348-
return op.ScatterND(x, unsq, values)
4348+
return op.ScatterND(x, unsq, values, reduction="add" if accumulate else "none")
43494349

43504350
perm = list(range(len(x.shape)))
43514351
perm[n_none[0]], perm[0] = perm[0], perm[n_none[0]]

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,24 @@ def forward(self, x, index, update):
287287
)
288288
_testing.assert_onnx_program(onnx_program)
289289

290+
def test_index_put_55_2_25(self):
291+
class Model(torch.nn.Module):
292+
def forward(self, x, index, update):
293+
return torch.ops.aten.index_put(x, [index], update, accumulate=True)
294+
295+
x = torch.ones((6, 5), dtype=torch.float32)
296+
index = torch.tensor([4, 3], dtype=torch.int64)
297+
update = (torch.arange(10) + 10).reshape((2, -1)).to(torch.float32)
298+
onnx_program = torch.onnx.export(
299+
Model(),
300+
(x, index, update),
301+
input_names=["x", "index", "update"],
302+
output_names=["output"],
303+
opset_version=18,
304+
dynamo=True,
305+
)
306+
_testing.assert_onnx_program(onnx_program)
307+
290308
def test_index_put_scatter_nd(self):
291309
class Model(torch.nn.Module):
292310
def forward(self, x, index, update):

0 commit comments

Comments
 (0)