Skip to content

Commit de939f3

Browse files
f-salvettifatcat-z
andauthored
support cumprod op (#2133)
* support cumprod op --------- Signed-off-by: Francesco Salvetti <[email protected]> Co-authored-by: Jay Zhang <[email protected]>
1 parent fa6db66 commit de939f3

File tree

3 files changed

+251
-1
lines changed

3 files changed

+251
-1
lines changed

tests/test_backend.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4718,6 +4718,53 @@ def func(x):
47184718
return tf.identity(x_, name=_TFOUTPUT)
47194719
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
47204720

4721+
@check_opset_min_version(10, "Slice")
4722+
def test_cumprod(self):
4723+
x_val = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32).reshape((2, 2))
4724+
def func(x):
4725+
x_ = tf.math.cumprod(x, axis=0)
4726+
return tf.identity(x_, name=_TFOUTPUT)
4727+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
4728+
4729+
@check_opset_min_version(10, "Slice")
4730+
def test_cumprod_axis1(self):
4731+
x_val = np.array([1., 2., 3., 4.,
4732+
5., 6., 7., 8.,
4733+
9., 10., 11., 12.,
4734+
13., 14., 15., 16.,
4735+
17., 18., 19., 20.,
4736+
21., 22., 23., 24.], dtype=np.float32).reshape((2, 3, 4))
4737+
def func(x):
4738+
x_ = tf.math.cumprod(x, axis=1)
4739+
return tf.identity(x_, name=_TFOUTPUT)
4740+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
4741+
4742+
@check_opset_min_version(10, "Slice")
4743+
def test_cumprod_axis1_reverse(self):
4744+
x_val = np.array([1., 2., 3., 4.,
4745+
5., 6., 7., 8.,
4746+
9., 10., 11., 12.,
4747+
13., 14., 15., 16.,
4748+
17., 18., 19., 20.,
4749+
21., 22., 23., 24.], dtype=np.float32).reshape((2, 3, 4))
4750+
def func(x):
4751+
x_ = tf.math.cumprod(x, axis=1, reverse=True)
4752+
return tf.identity(x_, name=_TFOUTPUT)
4753+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
4754+
4755+
@check_opset_min_version(10, "Slice")
4756+
def test_cumprod_axis1_reverse_exclusive(self):
4757+
x_val = np.array([1., 2., 3., 4.,
4758+
5., 6., 7., 8.,
4759+
9., 10., 11., 12.,
4760+
13., 14., 15., 16.,
4761+
17., 18., 19., 20.,
4762+
21., 22., 23., 24.], dtype=np.float32).reshape((2, 3, 4))
4763+
def func(x):
4764+
x_ = tf.math.cumprod(x, axis=1, reverse=True, exclusive=True)
4765+
return tf.identity(x_, name=_TFOUTPUT)
4766+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
4767+
47214768
@check_opset_min_version(11, "Round")
47224769
def test_round(self):
47234770
x_val = np.array([-0.7, -0.5, -0.0, 0.0, +0.0, 0.3, 0.5, 0.7, float('nan')], dtype=np.float32)

tf2onnx/onnx_opset/math.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from tf2onnx import constants, utils
1313
from tf2onnx.handler import tf_op
1414
from tf2onnx.onnx_opset import common
15+
from tf2onnx.graph_builder import GraphBuilder
16+
1517

1618
logger = logging.getLogger(__name__)
1719

@@ -555,6 +557,207 @@ def version_11(cls, ctx, node, **kwargs):
555557
pass
556558

557559

560+
@tf_op("Cumprod")
561+
class CumProd:
562+
@classmethod
563+
def version_10(cls, ctx, node, **kwargs):
564+
# opset 10 required for Slice to support starts/ends/axes/steps as inputs
565+
axis_node = node.inputs[1]
566+
is_axis_const = axis_node.is_const()
567+
if is_axis_const: # we can compute axis value right now
568+
axis = axis_node.get_tensor_value()
569+
axis_node = ctx.make_const(utils.make_name("axis"), np.array([axis], dtype=np.int64))
570+
else:
571+
axis_node = ctx.make_node("Cast", inputs=[axis_node.output[0]], attr={"to": onnx_pb.TensorProto.INT64},
572+
op_name_scope=node.name, outputs=[utils.make_name("axis")])
573+
axis_node = GraphBuilder(ctx).make_unsqueeze({'data': axis_node.output[0], 'axes': [0]}, return_node=True)
574+
axis = axis_node.output[0]
575+
576+
input_rank = len(ctx.get_shape(node.input[0]))
577+
cond_true_node = ctx.make_const(utils.make_name("cond_in"), np.ones((), dtype=bool))
578+
input_shape_node = ctx.make_node("Shape", inputs=[node.input[0]], op_name_scope=node.name,
579+
outputs=[utils.make_name("input_shape")])
580+
axis_length_node = ctx.make_node("Gather", inputs=[input_shape_node.output[0], node.input[1]],
581+
op_name_scope=node.name, outputs=[utils.make_name("axis_length")])
582+
one_node = ctx.make_const(utils.make_name("one"), np.array([1], "int64"))
583+
axis_length_plus_one_node = ctx.make_node("Add", inputs=[axis_length_node.output[0], one_node.output[0]],
584+
op_name_scope=node.name,
585+
outputs=[utils.make_name("axis_length_plus_one")])
586+
num_iter_node = ctx.make_node("Sub", inputs=[axis_length_node.output[0], one_node.output[0]],
587+
op_name_scope=node.name, outputs=[utils.make_name("num_iter")])
588+
589+
if node.get_attr_value("exclusive"): # one iter less, crop the input, then pad the output
590+
num_iter_node = ctx.make_node("Sub", inputs=[num_iter_node.output[0], one_node.output[0]],
591+
op_name_scope=node.name, outputs=[utils.make_name("num_iter")])
592+
zero_node = ctx.make_const(utils.make_name("zero"), np.array([0], "int64"))
593+
if node.get_attr_value("reverse"):
594+
pad_axis = [0, 1]
595+
start_slice = one_node.output[0]
596+
end_slice = axis_length_plus_one_node.output[0]
597+
else:
598+
minus_one_node = ctx.make_const(utils.make_name("minus_one"), np.array([-1], "int64"))
599+
pad_axis = [1, 0]
600+
start_slice = zero_node.output[0]
601+
end_slice = minus_one_node.output[0]
602+
pads_node = cls.get_pads_node(ctx, pad_axis, axis, input_rank, node.name)
603+
slice_shape = [-1] * len(ctx.get_shape(node.input[0]))
604+
inputs_node = ctx.make_node("Slice", inputs=[node.input[0], start_slice, end_slice, axis_node.output[0]],
605+
op_name_scope=node.name, outputs=[utils.make_name("slice")],
606+
shapes=[slice_shape], dtypes=[ctx.get_dtype(node.input[0])])
607+
inputs = inputs_node.output[0]
608+
else:
609+
inputs = node.input[0]
610+
611+
loop_graph = cls.make_loop_graph(ctx, node, inputs, input_rank, axis)
612+
loop_graph.parent_graph = ctx
613+
614+
loop_inputs = [num_iter_node.output[0], cond_true_node.output[0], inputs,
615+
axis_length_plus_one_node.output[0], inputs]
616+
loop_outputs = [utils.make_name("loop_inputs_out"), utils.make_name("loop_axis_length_plus_one_out"),
617+
utils.make_name("loop_accumulator_out")]
618+
if not is_axis_const: # axis is a tensor, we neeed to feed it to the loop graph
619+
loop_inputs.append(axis)
620+
loop_outputs.append(utils.make_name("loop_axis_out"))
621+
loop_outputs_shapes = [loop_graph.get_shape(o) for o in loop_graph.outputs[1:]]
622+
loop_outputs_dtypes = [loop_graph.get_dtype(o) for o in loop_graph.outputs[1:]]
623+
624+
loop_node = ctx.make_node("Loop", inputs=loop_inputs, branches={"body": loop_graph}, outputs=loop_outputs,
625+
shapes=loop_outputs_shapes, dtypes=loop_outputs_dtypes, op_name_scope=node.name)
626+
627+
if node.get_attr_value("exclusive"): # pad the output
628+
if ctx.get_dtype(loop_node.output[2]) != ctx.get_dtype(one_node.output[0]):
629+
pad_const_node = ctx.make_node("Cast", inputs=[one_node.output[0]],
630+
attr={"to": ctx.get_dtype(loop_node.output[2])},
631+
op_name_scope=node.name, outputs=[utils.make_name("pad_const")])
632+
else:
633+
pad_const_node = one_node
634+
output_node = ctx.make_node("Pad", op_name_scope=node.name, outputs=[utils.make_name("cumprod_out")],
635+
inputs=[loop_node.output[2], pads_node.output[0], pad_const_node.output[0]])
636+
output = output_node.output[0]
637+
else:
638+
output = loop_node.output[2]
639+
output_node = ctx.make_node("Identity", inputs=[output], outputs=[utils.make_name("cumprod_out")],
640+
shapes=[ctx.get_shape(node.input[0])], dtypes=[ctx.get_dtype(node.input[0])])
641+
ctx.insert_node_on_output(output_node, node.output[0])
642+
ctx.remove_node(node.name)
643+
644+
@classmethod
645+
def make_loop_graph(cls, ctx, node, inputs_tensor, input_rank, axis):
646+
inputs_tensor_shape = ctx.get_shape(inputs_tensor)
647+
inputs_tensor_dtype = ctx.get_dtype(inputs_tensor)
648+
649+
graph = ctx.create_new_graph_with_same_config()
650+
graph.add_graph_input(utils.make_name("iteration_num"), onnx_pb.TensorProto.INT64, [])
651+
graph.add_graph_input(utils.make_name("condition_in"), onnx_pb.TensorProto.BOOL, [])
652+
graph.add_graph_input(utils.make_name("inputs"), inputs_tensor_dtype, inputs_tensor_shape)
653+
graph.add_graph_input(utils.make_name("axis_length_plus_one"), onnx_pb.TensorProto.INT64, [1])
654+
graph.add_graph_input(utils.make_name("accumulator"), inputs_tensor_dtype, inputs_tensor_shape)
655+
if not isinstance(axis, int): # axis is a tensor, we need to feed it to the loop graph
656+
graph.add_graph_input(utils.make_name("axis"), onnx_pb.TensorProto.INT64, [1])
657+
axis = graph.input_names[-1]
658+
axis_node = graph.get_node_by_output(axis)
659+
else:
660+
axis_node = graph.make_const(utils.make_name("axis"), np.array([axis], "int64"))
661+
662+
# main loop graph
663+
loop_name = node.name + "/loop"
664+
iter_num = GraphBuilder(graph).make_unsqueeze({'data': graph.input_names[0], 'axes': [0]})
665+
one_node = graph.make_const(utils.make_name("one"), np.array(1, "int64"))
666+
zero_node = graph.make_const(utils.make_name("zero"), np.array([0], "int64"))
667+
668+
add_node = graph.make_node("Add", inputs=[iter_num, one_node.output[0]],
669+
outputs=[utils.make_name("add")], op_name_scope=loop_name)
670+
671+
if node.get_attr_value("reverse"):
672+
pad_axis = [zero_node.output[0], add_node.output[0]]
673+
start_slice = add_node.output[0]
674+
end_slice = graph.input_names[3]
675+
else:
676+
neg_node = graph.make_node("Neg", inputs=[add_node.output[0]],
677+
outputs=[utils.make_name("neg")], op_name_scope=loop_name)
678+
pad_axis = [add_node.output[0], zero_node.output[0]]
679+
start_slice = zero_node.output[0]
680+
end_slice = neg_node.output[0]
681+
682+
pads_node = cls.get_pads_node(graph, pad_axis, axis, input_rank, is_pad_axis_const=False, base_name=loop_name)
683+
slice_node = graph.make_node("Slice", op_name_scope=loop_name, outputs=[utils.make_name("slice")],
684+
inputs=[graph.input_names[2], start_slice, end_slice, axis_node.output[0]])
685+
if graph.get_dtype(slice_node.output[0]) != graph.get_dtype(one_node.output[0]):
686+
pad_const_node = graph.make_node("Cast", inputs=[one_node.output[0]],
687+
attr={"to": graph.get_dtype(slice_node.output[0])},
688+
op_name_scope=loop_name, outputs=[utils.make_name("pad_const")])
689+
else:
690+
pad_const_node = one_node
691+
pad_node = graph.make_node("Pad", inputs=[slice_node.output[0], pads_node.output[0], pad_const_node.output[0]],
692+
op_name_scope=loop_name, outputs=[utils.make_name("pad")])
693+
mul_node = graph.make_node("Mul", inputs=[graph.input_names[4], pad_node.output[0]],
694+
op_name_scope=loop_name, outputs=[utils.make_name("mul")],
695+
shapes=[inputs_tensor_shape], dtypes=[inputs_tensor_dtype])
696+
697+
# manage loop outputs
698+
output_cond_node = graph.make_node("Identity", inputs=[graph.input_names[1]], op_name_scope=loop_name,
699+
outputs=[utils.make_name("condition_out")])
700+
output_inp_node = graph.make_node("Identity", inputs=[graph.input_names[2]], op_name_scope=loop_name,
701+
outputs=[utils.make_name("inputs_out")])
702+
output_axis_length_plus_one_node = graph.make_node("Identity", inputs=[graph.input_names[3]],
703+
op_name_scope=loop_name,
704+
outputs=[utils.make_name("axis_length_plus_one_out")])
705+
output_acc_node = graph.make_node("Identity", inputs=[mul_node.output[0]], op_name_scope=loop_name,
706+
outputs=[utils.make_name("accumulator_out")])
707+
708+
graph.add_graph_output(output_cond_node.output[0]) # 1 condition output
709+
graph.add_graph_output(output_inp_node.output[0]) # N loop carried dependencies outputs
710+
graph.add_graph_output(output_axis_length_plus_one_node.output[0]) # N loop carried dependencies outputs
711+
graph.add_graph_output(output_acc_node.output[0]) # N loop carried dependencies outputs
712+
713+
if not isinstance(axis, int): # axis is a tensor, we need to feed it to the loop graph
714+
output_axis_node = graph.make_node("Identity", inputs=[axis], op_name_scope=loop_name,
715+
outputs=[utils.make_name("axis_out")])
716+
graph.add_graph_output(output_axis_node.output[0]) # N loop carried dependencies outputs
717+
return graph
718+
719+
@classmethod
720+
def get_pads_node(cls, graph, pad_axis, axis, rank, is_pad_axis_const=True, base_name=""):
721+
if isinstance(axis, int): # axis, is a const, we directly compute padding values
722+
pre_pad = np.zeros(axis, "int64")
723+
post_pad = np.zeros(rank - axis - 1, "int64")
724+
if is_pad_axis_const: # pylint: disable=R1705
725+
pads = np.concatenate([pre_pad, pad_axis[0:1], post_pad,
726+
pre_pad, pad_axis[1:2], post_pad])
727+
pads_node = graph.make_const(utils.make_name("pads"), pads)
728+
return pads_node
729+
else:
730+
pre_pad_node = graph.make_const(utils.make_name("pre_pad"), pre_pad)
731+
post_pad_node = graph.make_const(utils.make_name("post_pad"), post_pad)
732+
733+
else: # axis is a tensor, we need to compute padding values at runtime
734+
if is_pad_axis_const:
735+
pad_axis = [graph.make_const(utils.make_name("pad"),
736+
np.array([pad], "int64")).output[0] for pad in pad_axis]
737+
738+
rank_tensor = graph.make_const(utils.make_name("rank"), np.array([rank], "int64")).output[0]
739+
zero_node = graph.make_const(utils.make_name("zero"), np.array([0], "int64"))
740+
one_node = graph.make_const(utils.make_name("zero"), np.array([1], "int64"))
741+
742+
post_repeat_node = graph.make_node("Sub", inputs=[rank_tensor, axis],
743+
outputs=[utils.make_name("post_repeat")], op_name_scope=base_name)
744+
post_repeat_node = graph.make_node("Sub", inputs=[post_repeat_node.output[0], one_node.output[0]],
745+
outputs=[utils.make_name("post_repeat")], op_name_scope=base_name)
746+
747+
pre_pad_node = graph.make_node("Tile", inputs=[zero_node.output[0], axis], op_name_scope=base_name,
748+
attr={"axis": 0}, outputs=[utils.make_name("pre_pad")])
749+
750+
post_pad_node = graph.make_node("Tile", inputs=[zero_node.output[0], post_repeat_node.output[0]],
751+
attr={"axis": 0}, outputs=[utils.make_name("post_pad")],
752+
op_name_scope=base_name)
753+
754+
pads_node = graph.make_node("Concat", attr={"axis": 0}, outputs=[utils.make_name("pads")],
755+
op_name_scope=base_name,
756+
inputs=[pre_pad_node.output[0], pad_axis[0], post_pad_node.output[0],
757+
pre_pad_node.output[0], pad_axis[1], post_pad_node.output[0]])
758+
return pads_node
759+
760+
558761
@tf_op("Round")
559762
class Round:
560763
@classmethod

tf2onnx/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33

44
version = '1.13.0'
5-
git_version = 'ddca3a5eb2d912f20fe7e0568dd1a3013aee9fa3'
5+
git_version = '82ae7c4659ab9ece121a6414ee037d68b6b2d907'

0 commit comments

Comments
 (0)