@@ -1859,6 +1859,7 @@ def mkconsts(values, dtype=np.int64):
1859
1859
const_zero_float , const_neg_one_float = mkconsts ([[0 ], [- 1 ]], np .float32 )
1860
1860
const_zero , const_one , const_neg_one , const_neg_two , const_pad_vals , const_t = \
1861
1861
mkconsts ([[0 ], [1 ], [- 1 ], [- 2 ], pads , [- 1 , 1 ]])
1862
+ const_zero_scalar , const_one_scalar , const_neg_one_scalar = mkconsts ([0 , 1 , - 1 ])
1862
1863
1863
1864
m_shape = ctx .make_node ('Shape' , [node .input [0 ]]).output [0 ]
1864
1865
xlen = ctx .make_node ('Gather' , [m_shape , const_neg_one ]).output [0 ]
@@ -1882,24 +1883,24 @@ def mkconsts(values, dtype=np.int64):
1882
1883
input1 = ctx .make_node ('Cast' , [node .input [1 ]], attr = {'to' : TensorProto .INT64 })
1883
1884
k0 = ctx .make_node ('ReduceMin' , [input1 .output [0 ]]).output [0 ]
1884
1885
k1 = ctx .make_node ('ReduceMax' , [input1 .output [0 ]]).output [0 ]
1886
+ k0_scalar = ctx .make_node ('Squeeze' , [k0 ]).output [0 ]
1885
1887
k1_scalar = ctx .make_node ('Squeeze' , [k1 ]).output [0 ]
1886
1888
m_padded = ctx .make_node ('Pad' , [m , const_pad_vals , node .input [2 ]])
1887
1889
1888
1890
# starting indexes for super diagonals
1889
- xstart_0 = ctx .make_node ('Cast' , [k0 ], attr = {'to' : TensorProto .FLOAT })
1891
+ xstart_0 = ctx .make_node ('Cast' , [k0_scalar ], attr = {'to' : TensorProto .FLOAT })
1890
1892
xstart_1 = ctx .make_node ('Max' , [const_zero_float , xstart_0 .output [0 ]])
1891
1893
xstart_2 = ctx .make_node ('Cast' , [xstart_1 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1892
- xstart_3 = ctx .make_node ('Add' , [xstart_2 .output [0 ], const_neg_one ])
1893
- xstart_4 = ctx .make_node ('Range' , [k1_scalar , xstart_3 .output [0 ], const_neg_one ])
1894
+ xstart_3 = ctx .make_node ('Add' , [xstart_2 .output [0 ], const_neg_one_scalar ])
1895
+ xstart_4 = ctx .make_node ('Range' , [k1_scalar , xstart_3 .output [0 ], const_neg_one_scalar ])
1894
1896
xstart = ctx .make_node ('Reshape' , [xstart_4 .output [0 ], const_t ])
1895
1897
1896
1898
# starting indexes for sub diagonals
1897
- ystart_0 = ctx .make_node ('Cast' , [k1 ], attr = {'to' : TensorProto .FLOAT })
1899
+ ystart_0 = ctx .make_node ('Cast' , [k1_scalar ], attr = {'to' : TensorProto .FLOAT })
1898
1900
ystart_1 = ctx .make_node ('Min' , [const_neg_one_float , ystart_0 .output [0 ]])
1899
1901
ystart_2 = ctx .make_node ('Cast' , [ystart_1 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1900
- ystart_2_scalar = ctx .make_node ('Squeeze' , [ystart_2 .output [0 ]])
1901
- ystart_3 = ctx .make_node ('Add' , [k0 , const_neg_one ])
1902
- ystart_4 = ctx .make_node ('Range' , [ystart_2_scalar .output [0 ], ystart_3 .output [0 ], const_neg_one ])
1902
+ ystart_3 = ctx .make_node ('Add' , [k0_scalar , const_neg_one_scalar ])
1903
+ ystart_4 = ctx .make_node ('Range' , [ystart_2 .output [0 ], ystart_3 .output [0 ], const_neg_one_scalar ])
1903
1904
ystart = ctx .make_node ('Reshape' , [ystart_4 .output [0 ], const_t ])
1904
1905
1905
1906
xmax_0 = ctx .make_node ('Mul' , [xstart .output [0 ], xlenp ])
@@ -1920,7 +1921,7 @@ def mkconsts(values, dtype=np.int64):
1920
1921
maxsize_0 = ctx .make_node ('Reshape' , [maxsize .output [0 ], const_neg_one ])
1921
1922
maxsize_scalar = ctx .make_node ('Squeeze' , [maxsize .output [0 ]])
1922
1923
1923
- diagdistances_0 = ctx .make_node ('Range' , [const_zero , maxsize_scalar .output [0 ], const_one ])
1924
+ diagdistances_0 = ctx .make_node ('Range' , [const_zero_scalar , maxsize_scalar .output [0 ], const_one_scalar ])
1924
1925
diagdistances = ctx .make_node ('Mul' , [diagdistances_0 .output [0 ], stride ])
1925
1926
1926
1927
def right_align (sizes , indices , starts , maxval ):
@@ -1976,7 +1977,7 @@ def compute_out_shape(k0_k1_same=False):
1976
1977
if_node .set_body_graph_as_attr ('then_branch' , compute_out_shape (True ))
1977
1978
if_node .set_body_graph_as_attr ('else_branch' , compute_out_shape (False ))
1978
1979
1979
- shapes = [ - 1 ] * m_rank
1980
+ shapes = ctx . get_shape ( node . output [ 0 ])
1980
1981
dtypes = node .output_dtypes
1981
1982
ctx .remove_node (node .name )
1982
1983
ctx .make_node ('Reshape' , [diags .output [0 ], if_node .output [0 ]], name = node .name , outputs = node .output ,
0 commit comments