|
12 | 12 | from tf2onnx import constants, utils
|
13 | 13 | from tf2onnx.handler import tf_op
|
14 | 14 | from tf2onnx.onnx_opset import common
|
| 15 | +from tf2onnx.graph_builder import GraphBuilder |
| 16 | + |
15 | 17 |
|
16 | 18 | logger = logging.getLogger(__name__)
|
17 | 19 |
|
@@ -555,6 +557,207 @@ def version_11(cls, ctx, node, **kwargs):
|
555 | 557 | pass
|
556 | 558 |
|
557 | 559 |
|
| 560 | +@tf_op("Cumprod") |
| 561 | +class CumProd: |
| 562 | + @classmethod |
| 563 | + def version_10(cls, ctx, node, **kwargs): |
| 564 | + # opset 10 required for Slice to support starts/ends/axes/steps as inputs |
| 565 | + axis_node = node.inputs[1] |
| 566 | + is_axis_const = axis_node.is_const() |
| 567 | + if is_axis_const: # we can compute axis value right now |
| 568 | + axis = axis_node.get_tensor_value() |
| 569 | + axis_node = ctx.make_const(utils.make_name("axis"), np.array([axis], dtype=np.int64)) |
| 570 | + else: |
| 571 | + axis_node = ctx.make_node("Cast", inputs=[axis_node.output[0]], attr={"to": onnx_pb.TensorProto.INT64}, |
| 572 | + op_name_scope=node.name, outputs=[utils.make_name("axis")]) |
| 573 | + axis_node = GraphBuilder(ctx).make_unsqueeze({'data': axis_node.output[0], 'axes': [0]}, return_node=True) |
| 574 | + axis = axis_node.output[0] |
| 575 | + |
| 576 | + input_rank = len(ctx.get_shape(node.input[0])) |
| 577 | + cond_true_node = ctx.make_const(utils.make_name("cond_in"), np.ones((), dtype=bool)) |
| 578 | + input_shape_node = ctx.make_node("Shape", inputs=[node.input[0]], op_name_scope=node.name, |
| 579 | + outputs=[utils.make_name("input_shape")]) |
| 580 | + axis_length_node = ctx.make_node("Gather", inputs=[input_shape_node.output[0], node.input[1]], |
| 581 | + op_name_scope=node.name, outputs=[utils.make_name("axis_length")]) |
| 582 | + one_node = ctx.make_const(utils.make_name("one"), np.array([1], "int64")) |
| 583 | + axis_length_plus_one_node = ctx.make_node("Add", inputs=[axis_length_node.output[0], one_node.output[0]], |
| 584 | + op_name_scope=node.name, |
| 585 | + outputs=[utils.make_name("axis_length_plus_one")]) |
| 586 | + num_iter_node = ctx.make_node("Sub", inputs=[axis_length_node.output[0], one_node.output[0]], |
| 587 | + op_name_scope=node.name, outputs=[utils.make_name("num_iter")]) |
| 588 | + |
| 589 | + if node.get_attr_value("exclusive"): # one iter less, crop the input, then pad the output |
| 590 | + num_iter_node = ctx.make_node("Sub", inputs=[num_iter_node.output[0], one_node.output[0]], |
| 591 | + op_name_scope=node.name, outputs=[utils.make_name("num_iter")]) |
| 592 | + zero_node = ctx.make_const(utils.make_name("zero"), np.array([0], "int64")) |
| 593 | + if node.get_attr_value("reverse"): |
| 594 | + pad_axis = [0, 1] |
| 595 | + start_slice = one_node.output[0] |
| 596 | + end_slice = axis_length_plus_one_node.output[0] |
| 597 | + else: |
| 598 | + minus_one_node = ctx.make_const(utils.make_name("minus_one"), np.array([-1], "int64")) |
| 599 | + pad_axis = [1, 0] |
| 600 | + start_slice = zero_node.output[0] |
| 601 | + end_slice = minus_one_node.output[0] |
| 602 | + pads_node = cls.get_pads_node(ctx, pad_axis, axis, input_rank, node.name) |
| 603 | + slice_shape = [-1] * len(ctx.get_shape(node.input[0])) |
| 604 | + inputs_node = ctx.make_node("Slice", inputs=[node.input[0], start_slice, end_slice, axis_node.output[0]], |
| 605 | + op_name_scope=node.name, outputs=[utils.make_name("slice")], |
| 606 | + shapes=[slice_shape], dtypes=[ctx.get_dtype(node.input[0])]) |
| 607 | + inputs = inputs_node.output[0] |
| 608 | + else: |
| 609 | + inputs = node.input[0] |
| 610 | + |
| 611 | + loop_graph = cls.make_loop_graph(ctx, node, inputs, input_rank, axis) |
| 612 | + loop_graph.parent_graph = ctx |
| 613 | + |
| 614 | + loop_inputs = [num_iter_node.output[0], cond_true_node.output[0], inputs, |
| 615 | + axis_length_plus_one_node.output[0], inputs] |
| 616 | + loop_outputs = [utils.make_name("loop_inputs_out"), utils.make_name("loop_axis_length_plus_one_out"), |
| 617 | + utils.make_name("loop_accumulator_out")] |
| 618 | + if not is_axis_const: # axis is a tensor, we neeed to feed it to the loop graph |
| 619 | + loop_inputs.append(axis) |
| 620 | + loop_outputs.append(utils.make_name("loop_axis_out")) |
| 621 | + loop_outputs_shapes = [loop_graph.get_shape(o) for o in loop_graph.outputs[1:]] |
| 622 | + loop_outputs_dtypes = [loop_graph.get_dtype(o) for o in loop_graph.outputs[1:]] |
| 623 | + |
| 624 | + loop_node = ctx.make_node("Loop", inputs=loop_inputs, branches={"body": loop_graph}, outputs=loop_outputs, |
| 625 | + shapes=loop_outputs_shapes, dtypes=loop_outputs_dtypes, op_name_scope=node.name) |
| 626 | + |
| 627 | + if node.get_attr_value("exclusive"): # pad the output |
| 628 | + if ctx.get_dtype(loop_node.output[2]) != ctx.get_dtype(one_node.output[0]): |
| 629 | + pad_const_node = ctx.make_node("Cast", inputs=[one_node.output[0]], |
| 630 | + attr={"to": ctx.get_dtype(loop_node.output[2])}, |
| 631 | + op_name_scope=node.name, outputs=[utils.make_name("pad_const")]) |
| 632 | + else: |
| 633 | + pad_const_node = one_node |
| 634 | + output_node = ctx.make_node("Pad", op_name_scope=node.name, outputs=[utils.make_name("cumprod_out")], |
| 635 | + inputs=[loop_node.output[2], pads_node.output[0], pad_const_node.output[0]]) |
| 636 | + output = output_node.output[0] |
| 637 | + else: |
| 638 | + output = loop_node.output[2] |
| 639 | + output_node = ctx.make_node("Identity", inputs=[output], outputs=[utils.make_name("cumprod_out")], |
| 640 | + shapes=[ctx.get_shape(node.input[0])], dtypes=[ctx.get_dtype(node.input[0])]) |
| 641 | + ctx.insert_node_on_output(output_node, node.output[0]) |
| 642 | + ctx.remove_node(node.name) |
| 643 | + |
| 644 | + @classmethod |
| 645 | + def make_loop_graph(cls, ctx, node, inputs_tensor, input_rank, axis): |
| 646 | + inputs_tensor_shape = ctx.get_shape(inputs_tensor) |
| 647 | + inputs_tensor_dtype = ctx.get_dtype(inputs_tensor) |
| 648 | + |
| 649 | + graph = ctx.create_new_graph_with_same_config() |
| 650 | + graph.add_graph_input(utils.make_name("iteration_num"), onnx_pb.TensorProto.INT64, []) |
| 651 | + graph.add_graph_input(utils.make_name("condition_in"), onnx_pb.TensorProto.BOOL, []) |
| 652 | + graph.add_graph_input(utils.make_name("inputs"), inputs_tensor_dtype, inputs_tensor_shape) |
| 653 | + graph.add_graph_input(utils.make_name("axis_length_plus_one"), onnx_pb.TensorProto.INT64, [1]) |
| 654 | + graph.add_graph_input(utils.make_name("accumulator"), inputs_tensor_dtype, inputs_tensor_shape) |
| 655 | + if not isinstance(axis, int): # axis is a tensor, we need to feed it to the loop graph |
| 656 | + graph.add_graph_input(utils.make_name("axis"), onnx_pb.TensorProto.INT64, [1]) |
| 657 | + axis = graph.input_names[-1] |
| 658 | + axis_node = graph.get_node_by_output(axis) |
| 659 | + else: |
| 660 | + axis_node = graph.make_const(utils.make_name("axis"), np.array([axis], "int64")) |
| 661 | + |
| 662 | + # main loop graph |
| 663 | + loop_name = node.name + "/loop" |
| 664 | + iter_num = GraphBuilder(graph).make_unsqueeze({'data': graph.input_names[0], 'axes': [0]}) |
| 665 | + one_node = graph.make_const(utils.make_name("one"), np.array(1, "int64")) |
| 666 | + zero_node = graph.make_const(utils.make_name("zero"), np.array([0], "int64")) |
| 667 | + |
| 668 | + add_node = graph.make_node("Add", inputs=[iter_num, one_node.output[0]], |
| 669 | + outputs=[utils.make_name("add")], op_name_scope=loop_name) |
| 670 | + |
| 671 | + if node.get_attr_value("reverse"): |
| 672 | + pad_axis = [zero_node.output[0], add_node.output[0]] |
| 673 | + start_slice = add_node.output[0] |
| 674 | + end_slice = graph.input_names[3] |
| 675 | + else: |
| 676 | + neg_node = graph.make_node("Neg", inputs=[add_node.output[0]], |
| 677 | + outputs=[utils.make_name("neg")], op_name_scope=loop_name) |
| 678 | + pad_axis = [add_node.output[0], zero_node.output[0]] |
| 679 | + start_slice = zero_node.output[0] |
| 680 | + end_slice = neg_node.output[0] |
| 681 | + |
| 682 | + pads_node = cls.get_pads_node(graph, pad_axis, axis, input_rank, is_pad_axis_const=False, base_name=loop_name) |
| 683 | + slice_node = graph.make_node("Slice", op_name_scope=loop_name, outputs=[utils.make_name("slice")], |
| 684 | + inputs=[graph.input_names[2], start_slice, end_slice, axis_node.output[0]]) |
| 685 | + if graph.get_dtype(slice_node.output[0]) != graph.get_dtype(one_node.output[0]): |
| 686 | + pad_const_node = graph.make_node("Cast", inputs=[one_node.output[0]], |
| 687 | + attr={"to": graph.get_dtype(slice_node.output[0])}, |
| 688 | + op_name_scope=loop_name, outputs=[utils.make_name("pad_const")]) |
| 689 | + else: |
| 690 | + pad_const_node = one_node |
| 691 | + pad_node = graph.make_node("Pad", inputs=[slice_node.output[0], pads_node.output[0], pad_const_node.output[0]], |
| 692 | + op_name_scope=loop_name, outputs=[utils.make_name("pad")]) |
| 693 | + mul_node = graph.make_node("Mul", inputs=[graph.input_names[4], pad_node.output[0]], |
| 694 | + op_name_scope=loop_name, outputs=[utils.make_name("mul")], |
| 695 | + shapes=[inputs_tensor_shape], dtypes=[inputs_tensor_dtype]) |
| 696 | + |
| 697 | + # manage loop outputs |
| 698 | + output_cond_node = graph.make_node("Identity", inputs=[graph.input_names[1]], op_name_scope=loop_name, |
| 699 | + outputs=[utils.make_name("condition_out")]) |
| 700 | + output_inp_node = graph.make_node("Identity", inputs=[graph.input_names[2]], op_name_scope=loop_name, |
| 701 | + outputs=[utils.make_name("inputs_out")]) |
| 702 | + output_axis_length_plus_one_node = graph.make_node("Identity", inputs=[graph.input_names[3]], |
| 703 | + op_name_scope=loop_name, |
| 704 | + outputs=[utils.make_name("axis_length_plus_one_out")]) |
| 705 | + output_acc_node = graph.make_node("Identity", inputs=[mul_node.output[0]], op_name_scope=loop_name, |
| 706 | + outputs=[utils.make_name("accumulator_out")]) |
| 707 | + |
| 708 | + graph.add_graph_output(output_cond_node.output[0]) # 1 condition output |
| 709 | + graph.add_graph_output(output_inp_node.output[0]) # N loop carried dependencies outputs |
| 710 | + graph.add_graph_output(output_axis_length_plus_one_node.output[0]) # N loop carried dependencies outputs |
| 711 | + graph.add_graph_output(output_acc_node.output[0]) # N loop carried dependencies outputs |
| 712 | + |
| 713 | + if not isinstance(axis, int): # axis is a tensor, we need to feed it to the loop graph |
| 714 | + output_axis_node = graph.make_node("Identity", inputs=[axis], op_name_scope=loop_name, |
| 715 | + outputs=[utils.make_name("axis_out")]) |
| 716 | + graph.add_graph_output(output_axis_node.output[0]) # N loop carried dependencies outputs |
| 717 | + return graph |
| 718 | + |
| 719 | + @classmethod |
| 720 | + def get_pads_node(cls, graph, pad_axis, axis, rank, is_pad_axis_const=True, base_name=""): |
| 721 | + if isinstance(axis, int): # axis, is a const, we directly compute padding values |
| 722 | + pre_pad = np.zeros(axis, "int64") |
| 723 | + post_pad = np.zeros(rank - axis - 1, "int64") |
| 724 | + if is_pad_axis_const: # pylint: disable=R1705 |
| 725 | + pads = np.concatenate([pre_pad, pad_axis[0:1], post_pad, |
| 726 | + pre_pad, pad_axis[1:2], post_pad]) |
| 727 | + pads_node = graph.make_const(utils.make_name("pads"), pads) |
| 728 | + return pads_node |
| 729 | + else: |
| 730 | + pre_pad_node = graph.make_const(utils.make_name("pre_pad"), pre_pad) |
| 731 | + post_pad_node = graph.make_const(utils.make_name("post_pad"), post_pad) |
| 732 | + |
| 733 | + else: # axis is a tensor, we need to compute padding values at runtime |
| 734 | + if is_pad_axis_const: |
| 735 | + pad_axis = [graph.make_const(utils.make_name("pad"), |
| 736 | + np.array([pad], "int64")).output[0] for pad in pad_axis] |
| 737 | + |
| 738 | + rank_tensor = graph.make_const(utils.make_name("rank"), np.array([rank], "int64")).output[0] |
| 739 | + zero_node = graph.make_const(utils.make_name("zero"), np.array([0], "int64")) |
| 740 | + one_node = graph.make_const(utils.make_name("zero"), np.array([1], "int64")) |
| 741 | + |
| 742 | + post_repeat_node = graph.make_node("Sub", inputs=[rank_tensor, axis], |
| 743 | + outputs=[utils.make_name("post_repeat")], op_name_scope=base_name) |
| 744 | + post_repeat_node = graph.make_node("Sub", inputs=[post_repeat_node.output[0], one_node.output[0]], |
| 745 | + outputs=[utils.make_name("post_repeat")], op_name_scope=base_name) |
| 746 | + |
| 747 | + pre_pad_node = graph.make_node("Tile", inputs=[zero_node.output[0], axis], op_name_scope=base_name, |
| 748 | + attr={"axis": 0}, outputs=[utils.make_name("pre_pad")]) |
| 749 | + |
| 750 | + post_pad_node = graph.make_node("Tile", inputs=[zero_node.output[0], post_repeat_node.output[0]], |
| 751 | + attr={"axis": 0}, outputs=[utils.make_name("post_pad")], |
| 752 | + op_name_scope=base_name) |
| 753 | + |
| 754 | + pads_node = graph.make_node("Concat", attr={"axis": 0}, outputs=[utils.make_name("pads")], |
| 755 | + op_name_scope=base_name, |
| 756 | + inputs=[pre_pad_node.output[0], pad_axis[0], post_pad_node.output[0], |
| 757 | + pre_pad_node.output[0], pad_axis[1], post_pad_node.output[0]]) |
| 758 | + return pads_node |
| 759 | + |
| 760 | + |
558 | 761 | @tf_op("Round")
|
559 | 762 | class Round:
|
560 | 763 | @classmethod
|
|
0 commit comments