@@ -1835,94 +1835,98 @@ class MatrixDiagPartV2V3:
1835
1835
@classmethod
1836
1836
def version_11 (cls , ctx , node , ** kwargs ):
1837
1837
1838
- def mkconst (npval , desc ):
1839
- name = utils .make_name (node .name ) + f'_{ desc } '
1840
- return ctx .make_const (name , npval ).output [0 ]
1838
+ def mkconsts (values , dtype = np .int64 ):
1839
+ ret = []
1840
+ for value in values :
1841
+ name = utils .make_name (node .name + '_const' )
1842
+ ret .append (ctx .make_const (name , np .array (value , dtype = dtype )).output [0 ])
1843
+ return ret
1841
1844
1842
1845
# assemble MatrixDiagPart V2&V3
1843
1846
m = node .input [0 ]
1844
1847
m_shape = ctx .get_shape (m )
1845
- utils .make_sure (- 1 not in m_shape , 'At least one dim is unknown %s' , str (m_shape ))
1846
-
1847
- xlen = m_shape [- 1 ]
1848
- ylen = m_shape [- 2 ]
1849
- xlenp = xlen + 1
1850
- pads = np .zeros (2 * len (m_shape ), dtype = np .int64 )
1848
+ m_rank = len (m_shape )
1849
+ pads = np .zeros (2 * m_rank , dtype = np .int64 )
1851
1850
pads [- 2 :] = [1 , 1 ]
1851
+ utils .make_sure (m_rank > 1 , 'Input data should be at least 2D %s' , str (m_shape ))
1852
1852
1853
1853
align = 'LEFT_LEFT'
1854
1854
if node .op .op_type == 'MatrixDiagPartV3' :
1855
1855
align = node .get_attr_str ('align' ) if 'align' in node .attr else 'LEFT_RIGHT'
1856
1856
xalign , yalign = align .split ('_' )
1857
1857
1858
1858
# consts
1859
- const_zero = mkconst (np .array ([0 ], np .int64 ), 'const_zero_dtype' )
1860
- const_zero_float = mkconst (np .array ([0 ], np .float32 ), 'const_zero_dtype_f' )
1861
- const_one = mkconst (np .array ([1 ], np .int64 ), 'const_one_dtype' )
1862
- const_neg_one = mkconst (np .array ([- 1 ]).astype (np .int64 ), 'const_neg_one' )
1863
- const_neg_one_float = mkconst (np .array ([- 1 ]).astype (np .float32 ), 'const_neg_one_f' )
1864
- const_pad_vals = mkconst (pads , 'pads' )
1865
- const_t = mkconst (np .array ([- 1 , 1 ], np .int64 ), 'const_t' )
1866
- const_xlen = mkconst (np .array ([xlen ], np .int64 ), 'const_xlen' )
1867
- const_ylen = mkconst (np .array ([ylen ], np .int64 ), 'const_ylen' )
1868
- const_xlenp = mkconst (np .array ([xlenp ], np .int64 ), 'const_xlenp' )
1869
- const_stride = mkconst (np .array ([xlenp + 1 ], np .int64 ), 'const_stride' )
1870
- const_minxy_float = mkconst (np .array ([min (xlen , ylen )], np .float32 ), 'const_minxy_f' )
1871
- const_xmax = mkconst (np .array ([xlen * xlenp + xlenp - 1 ], np .int64 ), 'const_xmax' )
1872
- const_ymax = mkconst (np .array ([xlenp * ylen - 1 ], np .int64 ), 'const_ymax' )
1873
- const_ymax_float = mkconst (np .array ([xlenp * ylen - 1 ], np .float32 ), 'const_ymax_f' )
1874
- const_partial_shape = mkconst (np .asarray (m_shape [:- 2 ], np .int64 ), 'partial_shape' )
1875
- const_m2_shape = mkconst (np .asarray (m_shape [:- 2 ] + [- 1 ], np .int64 ), 'm2_shape' )
1876
- const_gather_shape = mkconst (np .asarray (m_shape [:- 2 ] + [1 ], np .int64 ), 'gather_shape' )
1859
+ const_zero_float , const_neg_one_float = mkconsts ([[0 ], [- 1 ]], np .float32 )
1860
+ const_zero , const_one , const_neg_one , const_neg_two , const_pad_vals , const_t = \
1861
+ mkconsts ([[0 ], [1 ], [- 1 ], [- 2 ], pads , [- 1 , 1 ]])
1862
+ const_zero_scalar , const_one_scalar , const_neg_one_scalar = mkconsts ([0 , 1 , - 1 ])
1863
+
1864
+ m_shape = ctx .make_node ('Shape' , [node .input [0 ]]).output [0 ]
1865
+ xlen = ctx .make_node ('Gather' , [m_shape , const_neg_one ]).output [0 ]
1866
+ ylen = ctx .make_node ('Gather' , [m_shape , const_neg_two ]).output [0 ]
1867
+ xlenp = ctx .make_node ('Add' , [xlen , const_one ]).output [0 ]
1868
+ stride = ctx .make_node ('Add' , [xlenp , const_one ]).output [0 ]
1869
+ minxy_0 = ctx .make_node ('Concat' , [xlen , ylen ], attr = {'axis' : 0 }).output [0 ]
1870
+ minxy = ctx .make_node ('ReduceMin' , [minxy_0 ]).output [0 ]
1871
+ minxy_float = ctx .make_node ('Cast' , [minxy ], attr = {'to' : TensorProto .FLOAT }).output [0 ]
1872
+ xmax_0 = ctx .make_node ('Mul' , [xlen , xlenp ]).output [0 ]
1873
+ xmax_1 = ctx .make_node ('Add' , [xmax_0 , xlenp ]).output [0 ]
1874
+ xmax = ctx .make_node ('Add' , [xmax_1 , const_neg_one ]).output [0 ]
1875
+ ymax_0 = ctx .make_node ('Mul' , [xlenp , ylen ]).output [0 ]
1876
+ ymax = ctx .make_node ('Add' , [ymax_0 , const_neg_one ]).output [0 ]
1877
+ ymax_float = ctx .make_node ('Cast' , [ymax ], attr = {'to' : TensorProto .FLOAT }).output [0 ]
1878
+ partial_shape = ctx .make_node ('Slice' , [m_shape , const_zero , const_neg_two ]).output [0 ]
1879
+ m2_shape = ctx .make_node ('Concat' , [partial_shape , const_neg_one ], attr = {'axis' : 0 }).output [0 ]
1880
+ gather_shape = ctx .make_node ('Concat' , [partial_shape , const_one ], attr = {'axis' : 0 }).output [0 ]
1877
1881
1878
1882
# get k0, k1 values. diags to be extracted
1879
1883
input1 = ctx .make_node ('Cast' , [node .input [1 ]], attr = {'to' : TensorProto .INT64 })
1880
- k0 = ctx .make_node ('ReduceMin' , [input1 .output [0 ]])
1881
- k1 = ctx .make_node ('ReduceMax' , [input1 .output [0 ]])
1882
- k1_scalar = ctx .make_node ('Squeeze' , [k1 .output [0 ]])
1884
+ k0 = ctx .make_node ('ReduceMin' , [input1 .output [0 ]]).output [0 ]
1885
+ k1 = ctx .make_node ('ReduceMax' , [input1 .output [0 ]]).output [0 ]
1886
+ k0_scalar = ctx .make_node ('Squeeze' , [k0 ]).output [0 ]
1887
+ k1_scalar = ctx .make_node ('Squeeze' , [k1 ]).output [0 ]
1883
1888
m_padded = ctx .make_node ('Pad' , [m , const_pad_vals , node .input [2 ]])
1884
1889
1885
1890
# starting indexes for super diagonals
1886
- xstart_0 = ctx .make_node ('Cast' , [k0 . output [ 0 ] ], attr = {'to' : TensorProto .FLOAT })
1891
+ xstart_0 = ctx .make_node ('Cast' , [k0_scalar ], attr = {'to' : TensorProto .FLOAT })
1887
1892
xstart_1 = ctx .make_node ('Max' , [const_zero_float , xstart_0 .output [0 ]])
1888
1893
xstart_2 = ctx .make_node ('Cast' , [xstart_1 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1889
- xstart_3 = ctx .make_node ('Add' , [xstart_2 .output [0 ], const_neg_one ])
1890
- xstart_4 = ctx .make_node ('Range' , [k1_scalar . output [ 0 ] , xstart_3 .output [0 ], const_neg_one ])
1894
+ xstart_3 = ctx .make_node ('Add' , [xstart_2 .output [0 ], const_neg_one_scalar ])
1895
+ xstart_4 = ctx .make_node ('Range' , [k1_scalar , xstart_3 .output [0 ], const_neg_one_scalar ])
1891
1896
xstart = ctx .make_node ('Reshape' , [xstart_4 .output [0 ], const_t ])
1892
1897
1893
1898
# starting indexes for sub diagonals
1894
- ystart_0 = ctx .make_node ('Cast' , [k1 . output [ 0 ] ], attr = {'to' : TensorProto .FLOAT })
1899
+ ystart_0 = ctx .make_node ('Cast' , [k1_scalar ], attr = {'to' : TensorProto .FLOAT })
1895
1900
ystart_1 = ctx .make_node ('Min' , [const_neg_one_float , ystart_0 .output [0 ]])
1896
1901
ystart_2 = ctx .make_node ('Cast' , [ystart_1 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1897
- ystart_2_scalar = ctx .make_node ('Squeeze' , [ystart_2 .output [0 ]])
1898
- ystart_3 = ctx .make_node ('Add' , [k0 .output [0 ], const_neg_one ])
1899
- ystart_4 = ctx .make_node ('Range' , [ystart_2_scalar .output [0 ], ystart_3 .output [0 ], const_neg_one ])
1902
+ ystart_3 = ctx .make_node ('Add' , [k0_scalar , const_neg_one_scalar ])
1903
+ ystart_4 = ctx .make_node ('Range' , [ystart_2 .output [0 ], ystart_3 .output [0 ], const_neg_one_scalar ])
1900
1904
ystart = ctx .make_node ('Reshape' , [ystart_4 .output [0 ], const_t ])
1901
1905
1902
- xmax_0 = ctx .make_node ('Mul' , [xstart .output [0 ], const_xlenp ])
1903
- xmax = ctx .make_node ('Sub' , [const_xmax , xmax_0 .output [0 ]])
1906
+ xmax_0 = ctx .make_node ('Mul' , [xstart .output [0 ], xlenp ])
1907
+ xmax = ctx .make_node ('Sub' , [xmax , xmax_0 .output [0 ]])
1904
1908
xmax_float = ctx .make_node ('Cast' , [xmax .output [0 ]], attr = {'to' : TensorProto .FLOAT })
1905
1909
1906
1910
# lengths of super/sub diags to extract
1907
- xsize_0 = ctx .make_node ('Sub' , [const_xlen , xstart .output [0 ]])
1911
+ xsize_0 = ctx .make_node ('Sub' , [xlen , xstart .output [0 ]])
1908
1912
xsize_1 = ctx .make_node ('Cast' , [xsize_0 .output [0 ]], attr = {'to' : TensorProto .FLOAT })
1909
- xsize_2 = ctx .make_node ('Min' , [xsize_1 .output [0 ], const_minxy_float ])
1913
+ xsize_2 = ctx .make_node ('Min' , [xsize_1 .output [0 ], minxy_float ])
1910
1914
xsize = ctx .make_node ('Cast' , [xsize_2 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1911
- ysize_0 = ctx .make_node ('Add' , [const_ylen , ystart .output [0 ]])
1915
+ ysize_0 = ctx .make_node ('Add' , [ylen , ystart .output [0 ]])
1912
1916
ysize_1 = ctx .make_node ('Cast' , [ysize_0 .output [0 ]], attr = {'to' : TensorProto .FLOAT })
1913
- ysize_2 = ctx .make_node ('Min' , [ysize_1 .output [0 ], const_minxy_float ])
1917
+ ysize_2 = ctx .make_node ('Min' , [ysize_1 .output [0 ], minxy_float ])
1914
1918
ysize = ctx .make_node ('Cast' , [ysize_2 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1915
1919
diagsize = ctx .make_node ('Concat' , [xsize .output [0 ], ysize .output [0 ]], attr = {'axis' : 0 })
1916
1920
maxsize = ctx .make_node ('ReduceMax' , [diagsize .output [0 ]], attr = {'keep_dims' : 0 })
1917
1921
maxsize_0 = ctx .make_node ('Reshape' , [maxsize .output [0 ], const_neg_one ])
1918
1922
maxsize_scalar = ctx .make_node ('Squeeze' , [maxsize .output [0 ]])
1919
1923
1920
- diagdistances_0 = ctx .make_node ('Range' , [const_zero , maxsize_scalar .output [0 ], const_one ])
1921
- diagdistances = ctx .make_node ('Mul' , [diagdistances_0 .output [0 ], const_stride ])
1924
+ diagdistances_0 = ctx .make_node ('Range' , [const_zero_scalar , maxsize_scalar .output [0 ], const_one_scalar ])
1925
+ diagdistances = ctx .make_node ('Mul' , [diagdistances_0 .output [0 ], stride ])
1922
1926
1923
1927
def right_align (sizes , indices , starts , maxval ):
1924
1928
op1 = ctx .make_node ('Sub' , [maxsize .output [0 ], sizes .output [0 ]])
1925
- op2 = ctx .make_node ('Mul' , [op1 .output [0 ], const_stride ])
1929
+ op2 = ctx .make_node ('Mul' , [op1 .output [0 ], stride ])
1926
1930
op3 = ctx .make_node ('Sub' , [indices .output [0 ], op2 .output [0 ]])
1927
1931
op4 = ctx .make_node ('Less' , [op3 .output [0 ], starts .output [0 ]])
1928
1932
op5 = ctx .make_node ('Where' , [op4 .output [0 ], maxval , op3 .output [0 ]])
@@ -1932,48 +1936,48 @@ def right_align(sizes, indices, starts, maxval):
1932
1936
xdiags_0 = ctx .make_node ('Add' , [xstart .output [0 ], diagdistances .output [0 ]])
1933
1937
xdiags_1 = ctx .make_node ('Cast' , [xdiags_0 .output [0 ]], attr = {'to' : TensorProto .FLOAT })
1934
1938
if xalign == 'RIGHT' :
1935
- xdiags = right_align (xsize , xdiags_0 , xstart , const_ymax )
1939
+ xdiags = right_align (xsize , xdiags_0 , xstart , ymax )
1936
1940
else :
1937
1941
xdiags_2 = ctx .make_node ('Min' , [xdiags_1 .output [0 ], xmax_float .output [0 ]])
1938
1942
xdiags = ctx .make_node ('Cast' , [xdiags_2 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1939
1943
1940
1944
ydiags_0_ = ctx .make_node ('Abs' , [ystart .output [0 ]])
1941
- ydiags_1 = ctx .make_node ('Mul' , [ydiags_0_ .output [0 ], const_xlenp ])
1945
+ ydiags_1 = ctx .make_node ('Mul' , [ydiags_0_ .output [0 ], xlenp ])
1942
1946
ydiags_2 = ctx .make_node ('Add' , [ydiags_1 .output [0 ], diagdistances .output [0 ]])
1943
1947
ydiags_3 = ctx .make_node ('Cast' , [ydiags_2 .output [0 ]], attr = {'to' : TensorProto .FLOAT })
1944
1948
if yalign == 'RIGHT' :
1945
- ydiags = right_align (ysize , ydiags_2 , ydiags_1 , const_ymax )
1949
+ ydiags = right_align (ysize , ydiags_2 , ydiags_1 , ymax )
1946
1950
else :
1947
- ydiags_4 = ctx .make_node ('Min' , [ydiags_3 .output [0 ], const_ymax_float ])
1951
+ ydiags_4 = ctx .make_node ('Min' , [ydiags_3 .output [0 ], ymax_float ])
1948
1952
ydiags = ctx .make_node ('Cast' , [ydiags_4 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1949
1953
1950
1954
# flatten last dimension of matrix
1951
- m2 = ctx .make_node ('Reshape' , [m_padded .output [0 ], const_m2_shape ])
1955
+ m2 = ctx .make_node ('Reshape' , [m_padded .output [0 ], m2_shape ])
1952
1956
1953
1957
diags_0 = ctx .make_node ('Concat' , [xdiags .output [0 ], ydiags .output [0 ]], attr = {'axis' : 0 })
1954
1958
diags_1 = ctx .make_node ('Reshape' , [diags_0 .output [0 ], const_neg_one ])
1955
- diags_2 = ctx .make_node ('Expand' , [diags_1 .output [0 ], const_gather_shape ])
1959
+ diags_2 = ctx .make_node ('Expand' , [diags_1 .output [0 ], gather_shape ])
1956
1960
diags = ctx .make_node ('GatherElements' , [m2 .output [0 ], diags_2 .output [0 ]], attr = {'axis' : - 1 })
1957
1961
1958
1962
def compute_out_shape (k0_k1_same = False ):
1959
1963
g = ctx .create_new_graph_with_same_config ()
1960
1964
g .parent_graph = ctx
1961
1965
if k0_k1_same :
1962
- dims = [const_partial_shape , maxsize_0 .output [0 ]]
1966
+ dims = [partial_shape , maxsize_0 .output [0 ]]
1963
1967
else :
1964
- dims = [const_partial_shape , const_neg_one , maxsize_0 .output [0 ]]
1968
+ dims = [partial_shape , const_neg_one , maxsize_0 .output [0 ]]
1965
1969
outshape = g .make_node ('Concat' , dims , attr = {'axis' : 0 })
1966
1970
g .add_graph_output (outshape .output [0 ], TensorProto .INT64 , [- 1 ])
1967
1971
return g
1968
1972
1969
1973
# if k0=k1, rank of output matrix is 1 less than usual
1970
1974
# hence, need 'If' to compute right output matrix shape
1971
- k0_k1_same = ctx .make_node ('Equal' , [k1 . output [ 0 ] , k0 . output [ 0 ] ])
1975
+ k0_k1_same = ctx .make_node ('Equal' , [k1 , k0 ])
1972
1976
if_node = ctx .make_node ('If' , [k0_k1_same .output [0 ]])
1973
1977
if_node .set_body_graph_as_attr ('then_branch' , compute_out_shape (True ))
1974
1978
if_node .set_body_graph_as_attr ('else_branch' , compute_out_shape (False ))
1975
1979
1976
- shapes = [ - 1 ] * len ( m_shape )
1980
+ shapes = ctx . get_shape ( node . output [ 0 ] )
1977
1981
dtypes = node .output_dtypes
1978
1982
ctx .remove_node (node .name )
1979
1983
ctx .make_node ('Reshape' , [diags .output [0 ], if_node .output [0 ]], name = node .name , outputs = node .output ,
0 commit comments