Skip to content

Commit c97b5b0

Browse files
authored
Support none reduction for scatterND (#1013)
Signed-off-by: Kevin Chen <[email protected]>
1 parent 118ed0a commit c97b5b0

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

docs/operators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ TensorRT supports the following ONNX data types: DOUBLE, FLOAT32, FLOAT16, BFLOA
168168
| Scan | Y | FP32, FP16, BF16|
169169
| Scatter | Y | FP32, FP16, BF16, INT32, INT64 |
170170
| ScatterElements | Y | FP32, FP16, BF16, INT32, INT64 |
171-
| ScatterND | Y | FP32, FP16, BF16, INT32, INT64 | `reduction` is not supported
171+
| ScatterND | Y | FP32, FP16, BF16, INT32, INT64 | `reduction` other than `none` is not supported
172172
| Selu | Y | FP32, FP16, BF16, |
173173
| SequenceAt | N |
174174
| SequenceConstruct | N |

onnxOpImporters.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5334,8 +5334,9 @@ DEFINE_BUILTIN_OP_IMPORTER(GridSample)
53345334
DEFINE_BUILTIN_OP_IMPORTER(ScatterND)
53355335
{
53365336
OnnxAttrs attrs(node, ctx);
5337-
ONNXTRT_CHECK_NODE(!attrs.count("reduction"), "Attribute reduction is not supported.", node, nodeIdx,
5338-
ErrorCode::kUNSUPPORTED_NODE_ATTR);
5337+
auto mode = attrs.get<std::string>("reduction", "none");
5338+
ONNXTRT_CHECK_NODE(
5339+
mode == "none", "ScatterND with reduction is not supported.", node, nodeIdx, ErrorCode::kUNSUPPORTED_NODE_ATTR);
53395340
return addScatterLayer(ctx, node, nodeIdx, inputs, nvinfer1::ScatterMode::kND);
53405341
}
53415342

0 commit comments

Comments
 (0)