|
13 | 13 | import sys
|
14 | 14 |
|
15 | 15 | import numpy as np
|
16 |
| -from onnx import numpy_helper |
17 | 16 | from onnx import onnx_pb
|
18 | 17 | from onnx.onnx_pb import TensorProto
|
19 | 18 |
|
@@ -517,22 +516,11 @@ def version_11(cls, ctx, node, **kwargs):
|
517 | 516 | class ScatterND:
|
518 | 517 | @classmethod
|
519 | 518 | def version_11(cls, ctx, node, **kwargs):
|
520 |
| - |
521 |
| - # onnx requires pre-generated tensor for data |
522 |
| - np_val = np.array([0], dtype=np.int64) |
523 |
| - onnx_tensor = numpy_helper.from_array(np_val, node.child_name()) |
524 |
| - const_of_shape = ctx.insert_new_node_on_input(node, "ConstantOfShape", node.input[2], value=onnx_tensor) |
525 |
| - |
526 |
| - # cast edge to INT64 if not already |
527 |
| - input0 = const_of_shape.input[0] |
528 |
| - if ctx.get_dtype(input0) != TensorProto.INT64: |
529 |
| - ctx.insert_new_node_on_input(const_of_shape, "Cast", input0, to=TensorProto.INT64) |
530 |
| - |
531 |
| - # cast edge to INT64 if not already |
532 |
| - input0 = node.input[0] |
533 |
| - if ctx.get_dtype(input0) != TensorProto.INT64: |
534 |
| - ctx.insert_new_node_on_input(node, "Cast", input0, to=TensorProto.INT64) |
535 |
| - |
| 519 | + onnxdtype = ctx.get_dtype(node.input[1]) |
| 520 | + const_of_shape = ctx.insert_new_node_on_input(node, "ConstantOfShape", node.input[2]) |
| 521 | + ctx.insert_new_node_on_input(const_of_shape, "Cast", const_of_shape.input[0], to=TensorProto.INT64) |
| 522 | + ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=TensorProto.INT64) |
| 523 | + ctx.insert_new_node_on_input(node, "Cast", node.input[2], to=onnxdtype) |
536 | 524 | # reorder inputs to match onnx
|
537 | 525 | node.input = [node.input[2], node.input[0], node.input[1]]
|
538 | 526 |
|
|
0 commit comments