@@ -1819,6 +1819,227 @@ def version_11(cls, ctx, node, **kwargs):
1819
1819
name = node .name , outputs = node .output , shapes = shapes , dtypes = dtypes )
1820
1820
1821
1821
1822
+ @tf_op (["MatrixDiagPartV2" , "MatrixDiagPartV3" ])
1823
+ class MatrixDiagPartV2V3 :
1824
+ @classmethod
1825
+ def version_11 (cls , ctx , node , ** kwargs ):
1826
+ # assemble MatrixDiagPart V2&V3 by looping k diagonals with proper pads
1827
+ input_tensor = node .input [0 ]
1828
+ k = ctx .make_node ('Cast' , [node .input [1 ]], attr = {'to' : TensorProto .INT64 }).output [0 ]
1829
+ padding = node .input [2 ]
1830
+ align = 'LEFT_LEFT'
1831
+ if node .op .op_type == 'MatrixDiagPartV3' :
1832
+ align = node .get_attr_str ('align' ) if 'align' in node .attr else 'LEFT_RIGHT'
1833
+ input_rank = len (ctx .get_shape (input_tensor ))
1834
+ raw_input_shape = [- 1 ] * input_rank
1835
+ per_loop_shape = raw_input_shape [:- 1 ]
1836
+ raw_output_shape = raw_input_shape [:- 2 ] + [- 1 ]
1837
+ loop_output_shape = raw_output_shape + [- 1 ]
1838
+ ctx .set_shape (node .output [0 ], raw_output_shape )
1839
+ for out in ctx .find_output_consumers (node .output [0 ]):
1840
+ if out .op .op_type == 'Identity' :
1841
+ ctx .set_shape (out .output [0 ], raw_output_shape )
1842
+ # define constants
1843
+ const_zero = ctx .make_const (utils .make_name (node .name ) + 'const_zero' , np .array ([0 ]).astype (np .int64 ))
1844
+ const_one = ctx .make_const (utils .make_name (node .name ) + 'const_one' , np .array ([1 ]).astype (np .int64 ))
1845
+ const_two = ctx .make_const (utils .make_name (node .name ) + 'const_two' , np .array ([2 ]).astype (np .int64 ))
1846
+ const_neg_one = ctx .make_const (utils .make_name (node .name ) + 'const_neg_one' , np .array ([- 1 ]).astype (np .int64 ))
1847
+ const_neg_two = ctx .make_const (utils .make_name (node .name ) + 'const_neg_two' , np .array ([- 2 ]).astype (np .int64 ))
1848
+ # prepare new_shape of input
1849
+ input_shape = ctx .make_node ('Shape' , [input_tensor ])
1850
+ shape_input_shape = ctx .make_node ('Shape' , [input_shape .output [0 ]])
1851
+ matrix_shape = ctx .make_node ('Slice' ,
1852
+ [input_shape .output [0 ], const_neg_two .output [0 ], shape_input_shape .output [0 ]])
1853
+ min_dim = ctx .make_node ('ReduceMin' , [matrix_shape .output [0 ]])
1854
+ input_depth = ctx .make_node ('Slice' , [matrix_shape .output [0 ], const_neg_two .output [0 ], const_neg_one .output [0 ]])
1855
+ input_width = ctx .make_node ('Slice' , [matrix_shape .output [0 ], const_neg_one .output [0 ], const_two .output [0 ]])
1856
+ temp_shape = ctx .make_node ('Concat' , [const_neg_one .output [0 ], matrix_shape .output [0 ]], attr = {'axis' : 0 })
1857
+ temp_input = ctx .make_node ('Reshape' , [input_tensor , temp_shape .output [0 ]])
1858
+ temp_transposed = ctx .make_node ('Transpose' , [temp_input .output [0 ]], attr = {'perm' : [0 , 2 , 1 ]})
1859
+ half_shape = ctx .make_node ('Slice' , [input_shape .output [0 ], const_zero .output [0 ], const_neg_two .output [0 ]])
1860
+ new_shape = ctx .make_node ('Concat' , [half_shape .output [0 ], input_width .output [0 ], input_depth .output [0 ]],
1861
+ attr = {'axis' : 0 })
1862
+ # define body graph for main loop
1863
+ k_shape = ctx .make_node ('Shape' , [k ])
1864
+ k_start = ctx .make_node ('Slice' , [k , const_zero .output [0 ], const_one .output [0 ]])
1865
+ k_end = ctx .make_node ('Slice' , [k , const_neg_one .output [0 ], k_shape .output [0 ]])
1866
+ raw_total_k = ctx .make_node ('Sub' , [k_end .output [0 ], k_start .output [0 ]])
1867
+ total_k = ctx .make_node ('Add' , [raw_total_k .output [0 ], const_one .output [0 ]])
1868
+ trip_name = utils .make_name (node .name + "_i" )
1869
+ cond_name = utils .make_name (node .name + "_cond" )
1870
+ body_graph = ctx .create_new_graph_with_same_config ()
1871
+ body_graph .add_graph_input (trip_name , TensorProto .INT64 , [1 ])
1872
+ body_graph .add_graph_input (cond_name , TensorProto .BOOL , [])
1873
+ body_graph .parent_graph = ctx
1874
+ # identity of input
1875
+ identity_input_graph = body_graph .create_new_graph_with_same_config ()
1876
+ identity_input_graph .parent_graph = body_graph
1877
+ identity_input = identity_input_graph .make_node ('Identity' , [input_tensor ])
1878
+ identity_input_graph .add_graph_output (identity_input .output [0 ], ctx .get_dtype (node .input [0 ]), raw_input_shape )
1879
+ # transposed input
1880
+ transposed_input_graph = body_graph .create_new_graph_with_same_config ()
1881
+ transposed_input_graph .parent_graph = body_graph
1882
+ next_shape = transposed_input_graph .make_node ('Concat' , [half_shape .output [0 ], input_width .output [0 ],
1883
+ input_depth .output [0 ]], attr = {'axis' : 0 })
1884
+ transposed_input = transposed_input_graph .make_node ('Reshape' ,
1885
+ [temp_transposed .output [0 ], next_shape .output [0 ]])
1886
+ transposed_input_graph .add_graph_output (transposed_input .output [0 ], ctx .get_dtype (node .input [0 ]),
1887
+ raw_input_shape )
1888
+ # compute current k of the loop
1889
+ current_k = body_graph .make_node ('Sub' , [k_end .output [0 ], trip_name ])
1890
+ is_k_noneg = body_graph .make_node ('Greater' , [current_k .output [0 ], const_neg_one .output [0 ]])
1891
+ processed_input = body_graph .make_node ('If' , [is_k_noneg .output [0 ]])
1892
+ processed_input .set_body_graph_as_attr ('then_branch' , identity_input_graph )
1893
+ processed_input .set_body_graph_as_attr ('else_branch' , transposed_input_graph )
1894
+ processed_shape = body_graph .make_node ('Shape' , [processed_input .output [0 ]])
1895
+ shape_processed_shape = body_graph .make_node ('Shape' , [processed_shape .output [0 ]])
1896
+ new_depth = body_graph .make_node ('Slice' ,
1897
+ [processed_shape .output [0 ], const_neg_two .output [0 ], const_neg_one .output [0 ]])
1898
+ new_width = body_graph .make_node ('Slice' , [processed_shape .output [0 ], const_neg_one .output [0 ],
1899
+ shape_processed_shape .output [0 ]])
1900
+ abs_k = body_graph .make_node ('Abs' , [current_k .output [0 ]])
1901
+ range_k = body_graph .make_node ('Range' , [abs_k .output [0 ], new_width .output [0 ], const_one .output [0 ]],
1902
+ domain = "com.microsoft" )
1903
+ sliced_range = body_graph .make_node ('Slice' , [range_k .output [0 ], const_zero .output [0 ], new_depth .output [0 ]])
1904
+ sliced_shape = body_graph .make_node ('Shape' , [sliced_range .output [0 ]])
1905
+ pad_length = body_graph .make_node ('Sub' , [new_depth .output [0 ], sliced_shape .output [0 ]])
1906
+ pad_length_2 = body_graph .make_node ('Concat' , [const_zero .output [0 ], pad_length .output [0 ]], attr = {'axis' : 0 })
1907
+ padded_range = body_graph .make_node ('Pad' , [sliced_range .output [0 ], pad_length_2 .output [0 ]])
1908
+ unsqueezed_range = body_graph .make_node ('Unsqueeze' , [padded_range .output [0 ]], attr = {'axes' : [1 ]})
1909
+ half_shape_x = body_graph .make_node ('Slice' ,
1910
+ [new_shape .output [0 ], const_zero .output [0 ], const_neg_two .output [0 ]])
1911
+ shape_range = body_graph .make_node ('Shape' , [unsqueezed_range .output [0 ]])
1912
+ full_shape = body_graph .make_node ('Concat' , [half_shape_x .output [0 ], shape_range .output [0 ]], attr = {'axis' : 0 })
1913
+ expanded_range = body_graph .make_node ('Expand' , [unsqueezed_range .output [0 ], full_shape .output [0 ]])
1914
+ gathered_input = body_graph .make_node ('GatherElements' , [processed_input .output [0 ], expanded_range .output [0 ]],
1915
+ attr = {'axis' : - 1 })
1916
+ squeezed_input = body_graph .make_node ('Squeeze' , [gathered_input .output [0 ]], attr = {'axes' : [- 1 ]})
1917
+ left_width = body_graph .make_node ('Sub' , [new_width .output [0 ], abs_k .output [0 ]])
1918
+ dims = body_graph .make_node ('Concat' , [left_width .output [0 ], new_depth .output [0 ]], attr = {'axis' : 0 })
1919
+ valid_dim = body_graph .make_node ('ReduceMin' , [dims .output [0 ]])
1920
+ raw_output = body_graph .make_node ('Slice' , [squeezed_input .output [0 ], const_zero .output [0 ], valid_dim .output [0 ],
1921
+ const_neg_one .output [0 ]])
1922
+ gap_output = body_graph .make_node ('Sub' , [min_dim .output [0 ], valid_dim .output [0 ]])
1923
+ gaps = body_graph .make_node ('Concat' , [const_zero .output [0 ], gap_output .output [0 ]], attr = {'axis' : 0 })
1924
+ processed_gap = body_graph .make_node ('ReduceMax' , [gaps .output [0 ]])
1925
+ pad_zero = body_graph .make_node ('Mul' , [new_shape .output [0 ], const_zero .output [0 ]])
1926
+ sliced_zero = body_graph .make_node ('Slice' , [pad_zero .output [0 ], const_zero .output [0 ], const_neg_two .output [0 ]])
1927
+ # gap_pos_k_graph
1928
+ gap_pos_k_graph = body_graph .create_new_graph_with_same_config ()
1929
+ gap_pos_k_graph .parent_graph = body_graph
1930
+ gap_pos_k = gap_pos_k_graph .make_node ('Concat' , [const_zero .output [0 ],
1931
+ processed_gap .output [0 ]],
1932
+ attr = {'axis' : 0 }) \
1933
+ if align .startswith ('LEFT' ) \
1934
+ else gap_pos_k_graph .make_node ('Concat' , [processed_gap .output [0 ],
1935
+ const_zero .output [0 ]],
1936
+ attr = {'axis' : 0 })
1937
+ gap_pos_k_graph .add_graph_output (gap_pos_k .output [0 ], TensorProto .INT64 , [- 1 ])
1938
+ # gap_neg_k_graph
1939
+ gap_neg_k_graph = body_graph .create_new_graph_with_same_config ()
1940
+ gap_neg_k_graph .parent_graph = body_graph
1941
+ gap_neg_k = gap_neg_k_graph .make_node ('Concat' , [const_zero .output [0 ],
1942
+ processed_gap .output [0 ]],
1943
+ attr = {'axis' : 0 }) \
1944
+ if align .endswith ('LEFT' ) \
1945
+ else gap_neg_k_graph .make_node ('Concat' , [processed_gap .output [0 ],
1946
+ const_zero .output [0 ]],
1947
+ attr = {'axis' : 0 })
1948
+ gap_neg_k_graph .add_graph_output (gap_neg_k .output [0 ], TensorProto .INT64 , [- 1 ])
1949
+ # pad output with gap
1950
+ gap_k = body_graph .make_node ('If' , [is_k_noneg .output [0 ]])
1951
+ gap_k .set_body_graph_as_attr ("then_branch" , gap_pos_k_graph )
1952
+ gap_k .set_body_graph_as_attr ("else_branch" , gap_neg_k_graph )
1953
+ gap_left = body_graph .make_node ('Slice' , [gap_k .output [0 ], const_zero .output [0 ], const_one .output [0 ]])
1954
+ gap_right = body_graph .make_node ('Slice' , [gap_k .output [0 ], const_one .output [0 ], const_two .output [0 ]])
1955
+ gap_all = body_graph .make_node ('Concat' , [sliced_zero .output [0 ], gap_left .output [0 ], sliced_zero .output [0 ],
1956
+ gap_right .output [0 ]], attr = {'axis' : 0 })
1957
+ padded_output = body_graph .make_node ('Pad' , [raw_output .output [0 ], gap_all .output [0 ], padding ])
1958
+ cond_output = body_graph .make_node ('Identity' , [cond_name ])
1959
+ body_graph .add_graph_output (cond_output .output [0 ], TensorProto .BOOL , [])
1960
+ body_graph .add_graph_output (padded_output .output [0 ], ctx .get_dtype (node .input [0 ]), per_loop_shape )
1961
+ body_graph .add_graph_output (gap_k .output [0 ], TensorProto .INT64 , [- 1 ])
1962
+ # make loop
1963
+ cond_const = ctx .make_const (utils .make_name ("cond" ), np .ones ((), dtype = np .bool ))
1964
+ main_loop = ctx .make_node ('Loop' , [total_k .output [0 ], cond_const .output [0 ]], output_count = 2 )
1965
+ main_loop .set_body_graph_as_attr ("body" , body_graph )
1966
+ # reshape output
1967
+ next_padded_shape = ctx .make_node ('Concat' , [total_k .output [0 ], const_neg_one .output [0 ], min_dim .output [0 ]],
1968
+ attr = {'axis' : 0 })
1969
+ reshaped_padded = ctx .make_node ('Reshape' , [main_loop .output [0 ], next_padded_shape .output [0 ]])
1970
+ transposed_padded = ctx .make_node ('Transpose' , [reshaped_padded .output [0 ]], attr = {'perm' : [1 , 0 , 2 ]})
1971
+ output_shape = ctx .make_node ('Concat' , [half_shape .output [0 ], total_k .output [0 ], const_neg_one .output [0 ]],
1972
+ attr = {'axis' : 0 })
1973
+ reshaped_output = ctx .make_node ('Reshape' , [transposed_padded .output [0 ], output_shape .output [0 ]])
1974
+ # compute pads
1975
+ left_pads = ctx .make_node ('Slice' , [main_loop .output [1 ], const_neg_two .output [0 ], const_neg_one .output [0 ],
1976
+ const_neg_one .output [0 ]])
1977
+ flattened_left_pads = ctx .make_node ('Reshape' , [left_pads .output [0 ], const_neg_one .output [0 ]])
1978
+ min_left_pads = ctx .make_node ('ReduceMin' , [flattened_left_pads .output [0 ]])
1979
+ right_pads = ctx .make_node ('Slice' , [main_loop .output [1 ], const_neg_one .output [0 ], const_two .output [0 ],
1980
+ const_neg_one .output [0 ]])
1981
+ flattened_right_pads = ctx .make_node ('Reshape' , [right_pads .output [0 ], const_neg_one .output [0 ]])
1982
+ min_right_pads = ctx .make_node ('ReduceMin' , [flattened_right_pads .output [0 ]])
1983
+ # trim left pads
1984
+ identity_left_sliced_graph = ctx .create_new_graph_with_same_config ()
1985
+ identity_left_sliced_graph .parent_graph = ctx
1986
+ identity_left_sliced = identity_left_sliced_graph .make_node ('Identity' , [reshaped_output .output [0 ]])
1987
+ identity_left_sliced_graph .add_graph_output (identity_left_sliced .output [0 ], ctx .get_dtype (node .input [0 ]),
1988
+ loop_output_shape )
1989
+ output_left_sliced_graph = ctx .create_new_graph_with_same_config ()
1990
+ output_left_sliced_graph .parent_graph = ctx
1991
+ output_left_sliced = output_left_sliced_graph .make_node ('Slice' ,
1992
+ [reshaped_output .output [0 ], min_left_pads .output [0 ],
1993
+ min_dim .output [0 ], const_neg_one .output [0 ]])
1994
+ output_left_sliced_graph .add_graph_output (output_left_sliced .output [0 ], ctx .get_dtype (node .input [0 ]),
1995
+ loop_output_shape )
1996
+ left_pads_greater_than_zero = ctx .make_node ('Greater' , [min_left_pads .output [0 ], const_zero .output [0 ]])
1997
+ final_output_left_sliced = ctx .make_node ('If' , [left_pads_greater_than_zero .output [0 ]])
1998
+ final_output_left_sliced .set_body_graph_as_attr ("then_branch" , output_left_sliced_graph )
1999
+ final_output_left_sliced .set_body_graph_as_attr ("else_branch" , identity_left_sliced_graph )
2000
+ # trim right pads
2001
+ valid_right_dim = ctx .make_node ('Sub' , [min_dim .output [0 ], min_right_pads .output [0 ]])
2002
+ identity_right_sliced_graph = ctx .create_new_graph_with_same_config ()
2003
+ identity_right_sliced_graph .parent_graph = ctx
2004
+ identity_right_sliced = identity_right_sliced_graph .make_node ('Identity' , [final_output_left_sliced .output [0 ]])
2005
+ identity_right_sliced_graph .add_graph_output (identity_right_sliced .output [0 ], ctx .get_dtype (node .input [0 ]),
2006
+ loop_output_shape )
2007
+ output_right_sliced_graph = ctx .create_new_graph_with_same_config ()
2008
+ output_right_sliced_graph .parent_graph = ctx
2009
+ output_right_sliced = output_right_sliced_graph .make_node ('Slice' , [final_output_left_sliced .output [0 ],
2010
+ const_zero .output [0 ],
2011
+ valid_right_dim .output [0 ],
2012
+ const_neg_one .output [0 ]])
2013
+ output_right_sliced_graph .add_graph_output (output_right_sliced .output [0 ], ctx .get_dtype (node .input [0 ]),
2014
+ loop_output_shape )
2015
+ right_dim_greater_than_valid = ctx .make_node ('Greater' , [min_dim .output [0 ], valid_right_dim .output [0 ]])
2016
+ final_output_right_sliced = ctx .make_node ('If' , [right_dim_greater_than_valid .output [0 ]])
2017
+ final_output_right_sliced .set_body_graph_as_attr ("then_branch" , output_right_sliced_graph )
2018
+ final_output_right_sliced .set_body_graph_as_attr ("else_branch" , identity_right_sliced_graph )
2019
+ # squeeze output
2020
+ latest_shape = ctx .make_node ('Shape' , [final_output_right_sliced .output [0 ]])
2021
+ latest_depth = ctx .make_node ('Slice' ,
2022
+ [latest_shape .output [0 ], const_neg_two .output [0 ], const_neg_one .output [0 ]])
2023
+ need_squeeze = ctx .make_node ('Equal' , [latest_depth .output [0 ], const_one .output [0 ]])
2024
+ identity_sliced_graph = ctx .create_new_graph_with_same_config ()
2025
+ identity_sliced_graph .parent_graph = ctx
2026
+ identity_sliced = identity_sliced_graph .make_node ('Identity' , [final_output_right_sliced .output [0 ]])
2027
+ identity_sliced_graph .add_graph_output (identity_sliced .output [0 ], ctx .get_dtype (node .input [0 ]),
2028
+ raw_output_shape + [- 1 ])
2029
+ squeeze_sliced_graph = ctx .create_new_graph_with_same_config ()
2030
+ squeeze_sliced_graph .parent_graph = ctx
2031
+ squeeze_sliced = squeeze_sliced_graph .make_node ('Squeeze' , [final_output_right_sliced .output [0 ]],
2032
+ attr = {'axes' : [- 2 ]})
2033
+ squeeze_sliced_graph .add_graph_output (squeeze_sliced .output [0 ], ctx .get_dtype (node .input [0 ]), raw_output_shape )
2034
+ shapes = node .output_shapes
2035
+ dtypes = node .output_dtypes
2036
+ ctx .remove_node (node .name )
2037
+ squeeze_if = ctx .make_node ('If' , [need_squeeze .output [0 ]], name = node .name , outputs = node .output , shapes = shapes ,
2038
+ dtypes = dtypes )
2039
+ squeeze_if .set_body_graph_as_attr ("then_branch" , squeeze_sliced_graph )
2040
+ squeeze_if .set_body_graph_as_attr ("else_branch" , identity_sliced_graph )
2041
+
2042
+
1822
2043
@tf_op ("BroadcastTo" )
1823
2044
class BroadcastTo :
1824
2045
@classmethod
0 commit comments