@@ -126,7 +126,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
126126 if node_before_layer_norm is None :
127127 continue
128128 child = self .model .find_first_child_by_type (
129- node_before_layer_norm , "LayerNormalization" , input_name_to_nodes , False
129+ node_before_layer_norm ,
130+ "LayerNormalization" ,
131+ input_name_to_nodes ,
132+ False ,
130133 )
131134 if child is None :
132135 continue
@@ -146,19 +149,26 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
146149 qkv_nodes = self .model .match_parent_path (
147150 normalize_node ,
148151 ["Add" , "MatMul" , "Reshape" , "Transpose" , "MatMul" ],
149- [1 , 1 , 0 , 0 , 0 ],
152+ [1 , None , 0 , 0 , 0 ],
150153 )
151154 if qkv_nodes is None :
152155 logger .debug ("fuse_attention: failed to match qkv path" )
153156 return
154-
155- reshape_qkv , transpose_qkv , matmul_qkv = qkv_nodes [2 ], qkv_nodes [3 ], qkv_nodes [- 1 ]
157+ reshape_qkv , transpose_qkv , matmul_qkv = (
158+ qkv_nodes [2 ],
159+ qkv_nodes [3 ],
160+ qkv_nodes [- 1 ],
161+ )
156162
157163 v_nodes = self .model .match_parent_path (
158- matmul_qkv , ["Reshape" , "Transpose" , "Reshape" , "Add" , "MatMul" ], [1 , 0 , 0 , 0 , None ]
164+ matmul_qkv ,
165+ ["Reshape" , "Transpose" , "Reshape" , "Add" , "MatMul" ],
166+ [1 , 0 , 0 , 0 , None ],
159167 )
160168 if v_nodes is None :
161- v_nodes = self .model .match_parent_path (matmul_qkv , ["Transpose" , "Reshape" , "Add" , "MatMul" ], [1 , 0 , 0 , 1 ])
169+ v_nodes = self .model .match_parent_path (
170+ matmul_qkv , ["Transpose" , "Reshape" , "Add" , "MatMul" ], [1 , 0 , 0 , None ]
171+ )
162172 if v_nodes is None :
163173 logger .debug ("fuse_attention: failed to match v path" )
164174 return
@@ -182,17 +192,30 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
182192 )
183193 if qk_nodes is None :
184194 qk_nodes = self .model .match_parent_path (matmul_qkv , ["Softmax" , "Add" , "Mul" , "MatMul" ], [0 , 0 , 0 , 0 ])
185- if qk_nodes is None :
186- qk_nodes = self .model .match_parent_path (
187- matmul_qkv , ["Cast" , "Cast" , "Softmax" , "Add" , "Mul" , "MatMul" ], [0 , 0 , 0 , 0 , 0 , 0 ]
188- )
189- if qk_nodes is None :
190- logger .debug ("fuse_attention: failed to match qk path" )
191- return
192- else :
193- add_mask = qk_nodes [3 ]
194- else :
195+ if qk_nodes is not None :
195196 add_mask = qk_nodes [1 ]
197+ else :
198+ # If attention mask is not used, we can still match the qk path.
199+ qk_nodes = self .model .match_parent_path (matmul_qkv , ["Softmax" , "Mul" , "MatMul" ], [0 , 0 , 0 ])
200+ if qk_nodes is None :
201+ # Cast nodes are added in the model for fp16.
202+ qk_nodes = self .model .match_parent_path (
203+ matmul_qkv ,
204+ ["Cast" , "Cast" , "Softmax" , "Add" , "Mul" , "MatMul" ],
205+ [0 , 0 , 0 , 0 , 0 , 0 ],
206+ )
207+ if qk_nodes is not None :
208+ add_mask = qk_nodes [3 ]
209+ else :
210+ # If attention mask is not used, we can still match the qk path.
211+ qk_nodes = self .model .match_parent_path (
212+ matmul_qkv ,
213+ ["Cast" , "Cast" , "Softmax" , "Mul" , "MatMul" ],
214+ [0 , 0 , 0 , 0 , 0 ],
215+ )
216+ if qk_nodes is None :
217+ logger .debug ("fuse_attention: failed to match qk path" )
218+ return
196219 else :
197220 assert len (add_mask_indices ) == 1
198221 causal_mask_input_index = 1 - add_mask_indices [0 ]
@@ -201,10 +224,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
201224 matmul_qk = qk_nodes [- 1 ]
202225
203226 q_nodes = self .model .match_parent_path (
204- matmul_qk , ["Reshape" , "Transpose" , "Reshape" , "Mul" , "Add" , "MatMul" ], [0 , 0 , 0 , 0 , None , None ]
227+ matmul_qk ,
228+ ["Reshape" , "Transpose" , "Reshape" , "Mul" , "Add" , "MatMul" ],
229+ [0 , 0 , 0 , 0 , None , None ],
205230 )
206231 if q_nodes is None :
207- q_nodes = self .model .match_parent_path (matmul_qk , ["Transpose" , "Reshape" , "Add" , "MatMul" ], [0 , 0 , 0 , 1 ])
232+ q_nodes = self .model .match_parent_path (
233+ matmul_qk , ["Transpose" , "Reshape" , "Add" , "MatMul" ], [0 , 0 , 0 , None ]
234+ )
208235 if q_nodes is None :
209236 logger .debug ("fuse_attention: failed to match q path" )
210237 return
@@ -216,10 +243,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
216243 add_q , matmul_q = q_nodes [- 2 ], q_nodes [- 1 ]
217244
218245 k_nodes = self .model .match_parent_path (
219- matmul_qk , ["Transpose" , "Reshape" , "Transpose" , "Reshape" , "Add" , "MatMul" ], [1 , 0 , 0 , 0 , 0 , None ]
246+ matmul_qk ,
247+ ["Transpose" , "Reshape" , "Transpose" , "Reshape" , "Add" , "MatMul" ],
248+ [1 , 0 , 0 , 0 , 0 , None ],
220249 )
221250 if k_nodes is None :
222- k_nodes = self .model .match_parent_path (matmul_qk , ["Transpose" , "Reshape" , "Add" , "MatMul" ], [1 , 0 , 0 , 1 ])
251+ k_nodes = self .model .match_parent_path (
252+ matmul_qk , ["Transpose" , "Reshape" , "Add" , "MatMul" ], [1 , 0 , 0 , None ]
253+ )
223254 if k_nodes is None :
224255 logger .debug ("fuse_attention: failed to match k path" )
225256 return
@@ -242,7 +273,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
242273 # 4D Add after Q x K'
243274 add_qk_nodes = self .model .match_parent_path (
244275 add_mask ,
245- ["Where" , "Sub" , "Cast" , "Expand" , "Unsqueeze" , "Unsqueeze" , "Reshape" , "Reshape" , "Cast" ],
276+ [
277+ "Where" ,
278+ "Sub" ,
279+ "Cast" ,
280+ "Expand" ,
281+ "Unsqueeze" ,
282+ "Unsqueeze" ,
283+ "Reshape" ,
284+ "Reshape" ,
285+ "Cast" ,
286+ ],
246287 [1 , 2 , 1 , 0 , 0 , 0 , 0 , 0 , 0 ],
247288 )
248289 if add_qk_nodes is not None :
0 commit comments