@@ -49,7 +49,7 @@ def _make_gathernd_inner_loop(ctx, params, index, dtype):
49
49
return inner_loop
50
50
51
51
52
- def make_gathernd (ctx , params , indices , output , scope_name , t_params ):
52
+ def make_gathernd (ctx , params , indices , output , scope_name , t_params , shapes , dtypes ):
53
53
"""make GatherNd op."""
54
54
# Tparams output = GatherNd(Tparams params, Tidx indices)
55
55
scope_name = utils .make_name (scope_name )
@@ -131,7 +131,11 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params):
131
131
[output_shape_ .output [0 ]],
132
132
attr = {"axes" : [0 ], "ends" : [- 1 ], "starts" : [0 ]},
133
133
dtypes = [TensorProto .INT64 ])
134
- ctx .make_node ("Reshape" , [gathernd_loop .output [1 ], output_shape .output [0 ]], outputs = [output ])
134
+ ctx .make_node ("Reshape" ,
135
+ [gathernd_loop .output [1 ], output_shape .output [0 ]],
136
+ outputs = [output ],
137
+ shapes = shapes ,
138
+ dtypes = dtypes )
135
139
136
140
137
141
def gathernd_op (ctx , node , name , args ):
@@ -143,4 +147,7 @@ def gathernd_op(ctx, node, name, args):
143
147
# same as the attr Tparams
144
148
t_params = ctx .get_dtype (params )
145
149
utils .make_sure (t_params , "Dtype of {} is None" .format (indices ))
146
- make_gathernd (ctx , params , indices , output , name , t_params )
150
+ shapes = node .output_shapes
151
+ dtypes = node .output_dtypes
152
+ ctx .remove_node (node .name )
153
+ make_gathernd (ctx , params , indices , output , name , t_params , shapes , dtypes )
0 commit comments