@@ -1772,9 +1772,12 @@ class MatrixDiagPart:
1772
1772
def version_11 (cls , ctx , node , ** kwargs ):
1773
1773
# MatrixDiagPart by slice and gather
1774
1774
const_zero = ctx .make_const (utils .make_name (node .name ) + 'const_zero' , np .array ([0 ]).astype (np .int64 ))
1775
+ const_zero_ = ctx .make_const (utils .make_name (node .name ) + 'const_zero_' , np .array (0 ).astype (np .int64 ))
1776
+
1775
1777
const_zero_zero = ctx .make_const (utils .make_name (node .name ) + 'const_zero_zero' ,
1776
1778
np .array ([0 , 0 ]).astype (np .int64 ))
1777
1779
const_one = ctx .make_const (utils .make_name (node .name ) + 'const_one' , np .array ([1 ]).astype (np .int64 ))
1780
+ const_one_ = ctx .make_const (utils .make_name (node .name ) + 'const_one_' , np .array (1 ).astype (np .int64 ))
1778
1781
const_two = ctx .make_const (utils .make_name (node .name ) + 'const_two' , np .array ([2 ]).astype (np .int64 ))
1779
1782
const_negative_one = ctx .make_const (utils .make_name (node .name ) + 'const_negative_one' ,
1780
1783
np .array ([- 1 ]).astype (np .int64 ))
@@ -1802,7 +1805,9 @@ def version_11(cls, ctx, node, **kwargs):
1802
1805
const_negative_one .output [0 ]])
1803
1806
sliced_input_shape_new = ctx .make_node ('Concat' , [sliced_input_shape_half .output [0 ], const_one .output [0 ]],
1804
1807
attr = {'axis' : - 1 })
1805
- matrice_range = ctx .make_node ('Range' , [const_zero .output [0 ], min_matrice_dim .output [0 ], const_one .output [0 ]])
1808
+ min_matrice_dim_ = ctx .make_node ('Squeeze' , [min_matrice_dim .output [0 ]], {'axes' : [0 ]})
1809
+ matrice_range = ctx .make_node ('Range' , [const_zero_ .output [0 ], min_matrice_dim_ .output [0 ],
1810
+ const_one_ .output [0 ]])
1806
1811
unsqueezed_matrice_range = ctx .make_node ('Unsqueeze' , [matrice_range .output [0 ]], attr = {"axes" : [- 1 ]})
1807
1812
expanded_range = ctx .make_node ('Expand' , [unsqueezed_matrice_range .output [0 ], sliced_input_shape_new .output [0 ]])
1808
1813
gathered_result = ctx .make_node ('GatherElements' , [sliced_input .output [0 ], expanded_range .output [0 ]],
@@ -1893,6 +1898,8 @@ def version_11(cls, ctx, node, **kwargs):
1893
1898
new_width = body_graph .make_node ('Slice' , [processed_shape .output [0 ], const_neg_one .output [0 ],
1894
1899
shape_processed_shape .output [0 ]])
1895
1900
abs_k = body_graph .make_node ('Abs' , [current_k .output [0 ]])
1901
+
1902
+
1896
1903
range_k = body_graph .make_node ('Range' , [abs_k .output [0 ], new_width .output [0 ], const_one .output [0 ]],
1897
1904
domain = "com.microsoft" )
1898
1905
sliced_range = body_graph .make_node ('Slice' , [range_k .output [0 ], const_zero .output [0 ], new_depth .output [0 ]])
0 commit comments