@@ -1835,94 +1835,97 @@ class MatrixDiagPartV2V3:
1835
1835
@classmethod
1836
1836
def version_11 (cls , ctx , node , ** kwargs ):
1837
1837
1838
- def mkconst (npval , desc ):
1839
- name = utils .make_name (node .name ) + f'_{ desc } '
1840
- return ctx .make_const (name , npval ).output [0 ]
1838
+ def mkconsts (values , dtype = np .int64 ):
1839
+ ret = []
1840
+ for value in values :
1841
+ name = utils .make_name (node .name + '_const' )
1842
+ ret .append (ctx .make_const (name , np .array (value , dtype = dtype )).output [0 ])
1843
+ return ret
1841
1844
1842
1845
# assemble MatrixDiagPart V2&V3
1843
1846
m = node .input [0 ]
1844
1847
m_shape = ctx .get_shape (m )
1845
- utils .make_sure (- 1 not in m_shape , 'At least one dim is unknown %s' , str (m_shape ))
1846
-
1847
- xlen = m_shape [- 1 ]
1848
- ylen = m_shape [- 2 ]
1849
- xlenp = xlen + 1
1850
- pads = np .zeros (2 * len (m_shape ), dtype = np .int64 )
1848
+ m_rank = len (m_shape )
1849
+ pads = np .zeros (2 * m_rank , dtype = np .int64 )
1851
1850
pads [- 2 :] = [1 , 1 ]
1851
+ utils .make_sure (m_rank > 1 , 'Input data should be at least 2D %s' , str (m_shape ))
1852
1852
1853
1853
align = 'LEFT_LEFT'
1854
1854
if node .op .op_type == 'MatrixDiagPartV3' :
1855
1855
align = node .get_attr_str ('align' ) if 'align' in node .attr else 'LEFT_RIGHT'
1856
1856
xalign , yalign = align .split ('_' )
1857
1857
1858
1858
# consts
1859
- const_zero = mkconst (np .array ([0 ], np .int64 ), 'const_zero_dtype' )
1860
- const_zero_float = mkconst (np .array ([0 ], np .float32 ), 'const_zero_dtype_f' )
1861
- const_one = mkconst (np .array ([1 ], np .int64 ), 'const_one_dtype' )
1862
- const_neg_one = mkconst (np .array ([- 1 ]).astype (np .int64 ), 'const_neg_one' )
1863
- const_neg_one_float = mkconst (np .array ([- 1 ]).astype (np .float32 ), 'const_neg_one_f' )
1864
- const_pad_vals = mkconst (pads , 'pads' )
1865
- const_t = mkconst (np .array ([- 1 , 1 ], np .int64 ), 'const_t' )
1866
- const_xlen = mkconst (np .array ([xlen ], np .int64 ), 'const_xlen' )
1867
- const_ylen = mkconst (np .array ([ylen ], np .int64 ), 'const_ylen' )
1868
- const_xlenp = mkconst (np .array ([xlenp ], np .int64 ), 'const_xlenp' )
1869
- const_stride = mkconst (np .array ([xlenp + 1 ], np .int64 ), 'const_stride' )
1870
- const_minxy_float = mkconst (np .array ([min (xlen , ylen )], np .float32 ), 'const_minxy_f' )
1871
- const_xmax = mkconst (np .array ([xlen * xlenp + xlenp - 1 ], np .int64 ), 'const_xmax' )
1872
- const_ymax = mkconst (np .array ([xlenp * ylen - 1 ], np .int64 ), 'const_ymax' )
1873
- const_ymax_float = mkconst (np .array ([xlenp * ylen - 1 ], np .float32 ), 'const_ymax_f' )
1874
- const_partial_shape = mkconst (np .asarray (m_shape [:- 2 ], np .int64 ), 'partial_shape' )
1875
- const_m2_shape = mkconst (np .asarray (m_shape [:- 2 ] + [- 1 ], np .int64 ), 'm2_shape' )
1876
- const_gather_shape = mkconst (np .asarray (m_shape [:- 2 ] + [1 ], np .int64 ), 'gather_shape' )
1859
+ const_zero_float , const_neg_one_float = mkconsts ([[0 ], [- 1 ]], np .float32 )
1860
+ const_zero , const_one , const_neg_one , const_neg_two , const_pad_vals , const_t = \
1861
+ mkconsts ([[0 ], [1 ], [- 1 ], [- 2 ], pads , [- 1 , 1 ]])
1862
+
1863
+ m_shape = ctx .make_node ('Shape' , [node .input [0 ]]).output [0 ]
1864
+ xlen = ctx .make_node ('Gather' , [m_shape , const_neg_one ]).output [0 ]
1865
+ ylen = ctx .make_node ('Gather' , [m_shape , const_neg_two ]).output [0 ]
1866
+ xlenp = ctx .make_node ('Add' , [xlen , const_one ]).output [0 ]
1867
+ stride = ctx .make_node ('Add' , [xlenp , const_one ]).output [0 ]
1868
+ minxy_0 = ctx .make_node ('Concat' , [xlen , ylen ], attr = {'axis' : 0 }).output [0 ]
1869
+ minxy = ctx .make_node ('ReduceMin' , [minxy_0 ]).output [0 ]
1870
+ minxy_float = ctx .make_node ('Cast' , [minxy ], attr = {'to' : TensorProto .FLOAT }).output [0 ]
1871
+ xmax_0 = ctx .make_node ('Mul' , [xlen , xlenp ]).output [0 ]
1872
+ xmax_1 = ctx .make_node ('Add' , [xmax_0 , xlenp ]).output [0 ]
1873
+ xmax = ctx .make_node ('Add' , [xmax_1 , const_neg_one ]).output [0 ]
1874
+ ymax_0 = ctx .make_node ('Mul' , [xlenp , ylen ]).output [0 ]
1875
+ ymax = ctx .make_node ('Add' , [ymax_0 , const_neg_one ]).output [0 ]
1876
+ ymax_float = ctx .make_node ('Cast' , [ymax ], attr = {'to' : TensorProto .FLOAT }).output [0 ]
1877
+ partial_shape = ctx .make_node ('Slice' , [m_shape , const_zero , const_neg_two ]).output [0 ]
1878
+ m2_shape = ctx .make_node ('Concat' , [partial_shape , const_neg_one ], attr = {'axis' : 0 }).output [0 ]
1879
+ gather_shape = ctx .make_node ('Concat' , [partial_shape , const_one ], attr = {'axis' : 0 }).output [0 ]
1877
1880
1878
1881
# get k0, k1 values. diags to be extracted
1879
1882
input1 = ctx .make_node ('Cast' , [node .input [1 ]], attr = {'to' : TensorProto .INT64 })
1880
- k0 = ctx .make_node ('ReduceMin' , [input1 .output [0 ]])
1881
- k1 = ctx .make_node ('ReduceMax' , [input1 .output [0 ]])
1882
- k1_scalar = ctx .make_node ('Squeeze' , [k1 .output [0 ]])
1883
+ k0 = ctx .make_node ('ReduceMin' , [input1 .output [0 ]]). output [ 0 ]
1884
+ k1 = ctx .make_node ('ReduceMax' , [input1 .output [0 ]]). output [ 0 ]
1885
+ k1_scalar = ctx .make_node ('Squeeze' , [k1 ]) .output [0 ]
1883
1886
m_padded = ctx .make_node ('Pad' , [m , const_pad_vals , node .input [2 ]])
1884
1887
1885
1888
# starting indexes for super diagonals
1886
- xstart_0 = ctx .make_node ('Cast' , [k0 . output [ 0 ] ], attr = {'to' : TensorProto .FLOAT })
1889
+ xstart_0 = ctx .make_node ('Cast' , [k0 ], attr = {'to' : TensorProto .FLOAT })
1887
1890
xstart_1 = ctx .make_node ('Max' , [const_zero_float , xstart_0 .output [0 ]])
1888
1891
xstart_2 = ctx .make_node ('Cast' , [xstart_1 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1889
1892
xstart_3 = ctx .make_node ('Add' , [xstart_2 .output [0 ], const_neg_one ])
1890
- xstart_4 = ctx .make_node ('Range' , [k1_scalar . output [ 0 ] , xstart_3 .output [0 ], const_neg_one ])
1893
+ xstart_4 = ctx .make_node ('Range' , [k1_scalar , xstart_3 .output [0 ], const_neg_one ])
1891
1894
xstart = ctx .make_node ('Reshape' , [xstart_4 .output [0 ], const_t ])
1892
1895
1893
1896
# starting indexes for sub diagonals
1894
- ystart_0 = ctx .make_node ('Cast' , [k1 . output [ 0 ] ], attr = {'to' : TensorProto .FLOAT })
1897
+ ystart_0 = ctx .make_node ('Cast' , [k1 ], attr = {'to' : TensorProto .FLOAT })
1895
1898
ystart_1 = ctx .make_node ('Min' , [const_neg_one_float , ystart_0 .output [0 ]])
1896
1899
ystart_2 = ctx .make_node ('Cast' , [ystart_1 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1897
1900
ystart_2_scalar = ctx .make_node ('Squeeze' , [ystart_2 .output [0 ]])
1898
- ystart_3 = ctx .make_node ('Add' , [k0 . output [ 0 ] , const_neg_one ])
1901
+ ystart_3 = ctx .make_node ('Add' , [k0 , const_neg_one ])
1899
1902
ystart_4 = ctx .make_node ('Range' , [ystart_2_scalar .output [0 ], ystart_3 .output [0 ], const_neg_one ])
1900
1903
ystart = ctx .make_node ('Reshape' , [ystart_4 .output [0 ], const_t ])
1901
1904
1902
- xmax_0 = ctx .make_node ('Mul' , [xstart .output [0 ], const_xlenp ])
1903
- xmax = ctx .make_node ('Sub' , [const_xmax , xmax_0 .output [0 ]])
1905
+ xmax_0 = ctx .make_node ('Mul' , [xstart .output [0 ], xlenp ])
1906
+ xmax = ctx .make_node ('Sub' , [xmax , xmax_0 .output [0 ]])
1904
1907
xmax_float = ctx .make_node ('Cast' , [xmax .output [0 ]], attr = {'to' : TensorProto .FLOAT })
1905
1908
1906
1909
# lengths of super/sub diags to extract
1907
- xsize_0 = ctx .make_node ('Sub' , [const_xlen , xstart .output [0 ]])
1910
+ xsize_0 = ctx .make_node ('Sub' , [xlen , xstart .output [0 ]])
1908
1911
xsize_1 = ctx .make_node ('Cast' , [xsize_0 .output [0 ]], attr = {'to' : TensorProto .FLOAT })
1909
- xsize_2 = ctx .make_node ('Min' , [xsize_1 .output [0 ], const_minxy_float ])
1912
+ xsize_2 = ctx .make_node ('Min' , [xsize_1 .output [0 ], minxy_float ])
1910
1913
xsize = ctx .make_node ('Cast' , [xsize_2 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1911
- ysize_0 = ctx .make_node ('Add' , [const_ylen , ystart .output [0 ]])
1914
+ ysize_0 = ctx .make_node ('Add' , [ylen , ystart .output [0 ]])
1912
1915
ysize_1 = ctx .make_node ('Cast' , [ysize_0 .output [0 ]], attr = {'to' : TensorProto .FLOAT })
1913
- ysize_2 = ctx .make_node ('Min' , [ysize_1 .output [0 ], const_minxy_float ])
1916
+ ysize_2 = ctx .make_node ('Min' , [ysize_1 .output [0 ], minxy_float ])
1914
1917
ysize = ctx .make_node ('Cast' , [ysize_2 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1915
1918
diagsize = ctx .make_node ('Concat' , [xsize .output [0 ], ysize .output [0 ]], attr = {'axis' : 0 })
1916
1919
maxsize = ctx .make_node ('ReduceMax' , [diagsize .output [0 ]], attr = {'keep_dims' : 0 })
1917
1920
maxsize_0 = ctx .make_node ('Reshape' , [maxsize .output [0 ], const_neg_one ])
1918
1921
maxsize_scalar = ctx .make_node ('Squeeze' , [maxsize .output [0 ]])
1919
1922
1920
1923
diagdistances_0 = ctx .make_node ('Range' , [const_zero , maxsize_scalar .output [0 ], const_one ])
1921
- diagdistances = ctx .make_node ('Mul' , [diagdistances_0 .output [0 ], const_stride ])
1924
+ diagdistances = ctx .make_node ('Mul' , [diagdistances_0 .output [0 ], stride ])
1922
1925
1923
1926
def right_align (sizes , indices , starts , maxval ):
1924
1927
op1 = ctx .make_node ('Sub' , [maxsize .output [0 ], sizes .output [0 ]])
1925
- op2 = ctx .make_node ('Mul' , [op1 .output [0 ], const_stride ])
1928
+ op2 = ctx .make_node ('Mul' , [op1 .output [0 ], stride ])
1926
1929
op3 = ctx .make_node ('Sub' , [indices .output [0 ], op2 .output [0 ]])
1927
1930
op4 = ctx .make_node ('Less' , [op3 .output [0 ], starts .output [0 ]])
1928
1931
op5 = ctx .make_node ('Where' , [op4 .output [0 ], maxval , op3 .output [0 ]])
@@ -1932,48 +1935,48 @@ def right_align(sizes, indices, starts, maxval):
1932
1935
xdiags_0 = ctx .make_node ('Add' , [xstart .output [0 ], diagdistances .output [0 ]])
1933
1936
xdiags_1 = ctx .make_node ('Cast' , [xdiags_0 .output [0 ]], attr = {'to' : TensorProto .FLOAT })
1934
1937
if xalign == 'RIGHT' :
1935
- xdiags = right_align (xsize , xdiags_0 , xstart , const_ymax )
1938
+ xdiags = right_align (xsize , xdiags_0 , xstart , ymax )
1936
1939
else :
1937
1940
xdiags_2 = ctx .make_node ('Min' , [xdiags_1 .output [0 ], xmax_float .output [0 ]])
1938
1941
xdiags = ctx .make_node ('Cast' , [xdiags_2 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1939
1942
1940
1943
ydiags_0_ = ctx .make_node ('Abs' , [ystart .output [0 ]])
1941
- ydiags_1 = ctx .make_node ('Mul' , [ydiags_0_ .output [0 ], const_xlenp ])
1944
+ ydiags_1 = ctx .make_node ('Mul' , [ydiags_0_ .output [0 ], xlenp ])
1942
1945
ydiags_2 = ctx .make_node ('Add' , [ydiags_1 .output [0 ], diagdistances .output [0 ]])
1943
1946
ydiags_3 = ctx .make_node ('Cast' , [ydiags_2 .output [0 ]], attr = {'to' : TensorProto .FLOAT })
1944
1947
if yalign == 'RIGHT' :
1945
- ydiags = right_align (ysize , ydiags_2 , ydiags_1 , const_ymax )
1948
+ ydiags = right_align (ysize , ydiags_2 , ydiags_1 , ymax )
1946
1949
else :
1947
- ydiags_4 = ctx .make_node ('Min' , [ydiags_3 .output [0 ], const_ymax_float ])
1950
+ ydiags_4 = ctx .make_node ('Min' , [ydiags_3 .output [0 ], ymax_float ])
1948
1951
ydiags = ctx .make_node ('Cast' , [ydiags_4 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1949
1952
1950
1953
# flatten last dimension of matrix
1951
- m2 = ctx .make_node ('Reshape' , [m_padded .output [0 ], const_m2_shape ])
1954
+ m2 = ctx .make_node ('Reshape' , [m_padded .output [0 ], m2_shape ])
1952
1955
1953
1956
diags_0 = ctx .make_node ('Concat' , [xdiags .output [0 ], ydiags .output [0 ]], attr = {'axis' : 0 })
1954
1957
diags_1 = ctx .make_node ('Reshape' , [diags_0 .output [0 ], const_neg_one ])
1955
- diags_2 = ctx .make_node ('Expand' , [diags_1 .output [0 ], const_gather_shape ])
1958
+ diags_2 = ctx .make_node ('Expand' , [diags_1 .output [0 ], gather_shape ])
1956
1959
diags = ctx .make_node ('GatherElements' , [m2 .output [0 ], diags_2 .output [0 ]], attr = {'axis' : - 1 })
1957
1960
1958
1961
def compute_out_shape (k0_k1_same = False ):
1959
1962
g = ctx .create_new_graph_with_same_config ()
1960
1963
g .parent_graph = ctx
1961
1964
if k0_k1_same :
1962
- dims = [const_partial_shape , maxsize_0 .output [0 ]]
1965
+ dims = [partial_shape , maxsize_0 .output [0 ]]
1963
1966
else :
1964
- dims = [const_partial_shape , const_neg_one , maxsize_0 .output [0 ]]
1967
+ dims = [partial_shape , const_neg_one , maxsize_0 .output [0 ]]
1965
1968
outshape = g .make_node ('Concat' , dims , attr = {'axis' : 0 })
1966
1969
g .add_graph_output (outshape .output [0 ], TensorProto .INT64 , [- 1 ])
1967
1970
return g
1968
1971
1969
1972
# if k0=k1, rank of output matrix is 1 less than usual
1970
1973
# hence, need 'If' to compute right output matrix shape
1971
- k0_k1_same = ctx .make_node ('Equal' , [k1 . output [ 0 ] , k0 . output [ 0 ] ])
1974
+ k0_k1_same = ctx .make_node ('Equal' , [k1 , k0 ])
1972
1975
if_node = ctx .make_node ('If' , [k0_k1_same .output [0 ]])
1973
1976
if_node .set_body_graph_as_attr ('then_branch' , compute_out_shape (True ))
1974
1977
if_node .set_body_graph_as_attr ('else_branch' , compute_out_shape (False ))
1975
1978
1976
- shapes = [- 1 ] * len ( m_shape )
1979
+ shapes = [- 1 ] * m_rank
1977
1980
dtypes = node .output_dtypes
1978
1981
ctx .remove_node (node .name )
1979
1982
ctx .make_node ('Reshape' , [diags .output [0 ], if_node .output [0 ]], name = node .name , outputs = node .output ,
0 commit comments