@@ -116,58 +116,133 @@ def apply(self, model):
116116 return model , graph_modified
117117
118118
119+ # Tests whether a tensor is a scalar, i.e., whether all dimensions are 1
120+ def is_scalar (tensor ):
121+ return tensor is not None and all (x == 1 for x in tensor .shape )
122+
123+
124+ # Tests whether a node is a scalar multiplication with a constant scale factor
125+ def is_const_scalar_mul (node , model ):
126+ # Only handle existing Mul type nodes
127+ if node is not None and node .op_type == "Mul" :
128+ # The constant must be an initializer
129+ # Note: Assumes the constant parameter to always be the second input
130+ scale = model .get_initializer (node .input [1 ])
131+ # Test for existence of a constant scale factor
132+ return scale is not None and is_scalar (scale )
133+ # Did not match the operator type
134+ return False
135+
136+
137+ # Refactored version of the MoveScalarMulPastMatMul transform capable of
138+ # transforming two-input MatMul, like those being part of the attention operator
119139class MoveScalarMulPastMatMul (Transformation ):
120140 """Move scalar mul operations past matmul operations. We want to have muls
121141 next to each other such that they can be collapsed into a single mul."""
122142
143+ # Applies the transform to a whole model graph
123144 def apply (self , model ):
145+ # Get the model graph out of the model wrapper object
124146 graph = model .graph
125- node_ind = 0
147+ # Keep track of whether the graph has been modified
126148 graph_modified = False
127- for n in graph .node :
128- node_ind += 1
129- if n .op_type == "Mul" and not model .is_fork_node (n ) and not model .is_join_node (n ):
130- consumer = model .find_consumer (n .output [0 ])
131- if (
132- consumer is not None
133- and consumer .op_type == "MatMul"
134- and not model .is_join_node (consumer )
135- ):
136- mul_weight_name = n .input [1 ]
137- matmul_weight_name = consumer .input [1 ]
138- A = model .get_initializer (mul_weight_name )
139- W = model .get_initializer (matmul_weight_name )
140- if (A is None ) or (W is None ):
141- warnings .warn ("MatMul or Mul params are not constant, skipping" )
149+
150+ # Iterate all nodes in the graph keeping track of the index
151+ for index , node in enumerate (graph .node ):
152+ # First pattern matching condition: For the transform to be
153+ # applicable, the node has to be a MatMul operator
154+ if node .op_type == "MatMul" :
155+ # Note: When touching the following code, remember to treat both
156+ # branches equivalently!
157+ # TODO: Can this be enforced or at least be made easier by
158+ # extracting common code patterns to a function?
159+
160+ # Get the left hand side and right hand side inputs
161+ # Note: Assumes the ordering of left to right inputs to match
162+ # indices 0 to 1. However, it does not "hurt" if it is
163+ # reversed as both sides are treated equivalently.
164+ lhs = model .find_producer (node .input [0 ])
165+ rhs = model .find_producer (node .input [1 ])
166+
167+ # Give precedence to the left hand side input testing for the
168+ # presence of a scalar multiplication
169+ if is_const_scalar_mul (lhs , model ):
170+ # Cannot handle fork nodes: We would have to distribute the
171+ # Mul into all branches
172+ # TODO: Maybe reconsider this at some point, there is
173+ # probably nothing preventing this in general, it is just
174+ # more difficult and apparently not necessary right now.
175+ if model .is_fork_node (lhs ):
176+ # Softly skip this node
142177 continue
143- start_name = n .input [0 ]
144- middle_name = n .output [0 ]
145- end_name = consumer .output [0 ]
146- mm_out_shape = model .get_tensor_shape (end_name )
147- if all (x == 1 for x in A .shape ):
148- # if the mul is scalar, we can simply swap the order of ops
149- # make and insert new nodes
150- new_matmul = oh .make_node (
151- "MatMul" ,
152- [start_name , matmul_weight_name ],
153- [middle_name ],
154- name = consumer .name ,
155- )
156- new_mul = oh .make_node (
157- "Mul" ,
158- [middle_name , mul_weight_name ],
159- [end_name ],
160- name = n .name ,
161- )
162- graph .node .insert (node_ind , new_matmul )
163- graph .node .insert (node_ind + 1 , new_mul )
164- model .set_tensor_shape (middle_name , mm_out_shape )
165- # remove old nodes
166- graph .node .remove (n )
167- graph .node .remove (consumer )
168- graph_modified = True
178+ # Unpack the connection pattern of a scalar mul feeding the
179+ # lhs input of the matmul
180+ # Names of the three input tensors to the mul-matmul complex
181+ a , b , c = lhs .input [0 ], lhs .input [1 ], node .input [1 ]
182+ # Names of the intermediate and the global output
183+ m , o = lhs .output [0 ], node .output [0 ] # noqa: Duplicate code
184+ # Rewire the operator connections locally, swapping mul and
185+ # matmul operator order
186+ matmul = oh .make_node ("MatMul" , [a , c ], [m ], node .name )
187+ mul = oh .make_node ("Mul" , [m , b ], [o ], lhs .name )
188+ # Insert the rewired nodes into the graph
189+ graph .node .insert (index , matmul )
190+ graph .node .insert (index + 1 , mul )
191+ # Adapt the shape of the intermediate tensor as it changed
192+ # according to the output shape of the matmul
193+ model .set_tensor_shape (m , model .get_tensor_shape (o ))
194+ # Remove the old nodes from the graph
195+ graph .node .remove (lhs )
196+ graph .node .remove (node )
197+ # The graph has been modified, this needs to be reported
198+ # back to the caller
199+ graph_modified = True
200+ # Cannot further modify the node (i.e., the rhs) as the
201+ # index and state of the nodes changed and need to be
202+ # queried again from the graph.node at the start of the next
203+ # iteration.
204+ continue
205+
206+ # Next try whether the right hand side matches the pattern of a
207+ # scalar multiplication
208+ if is_const_scalar_mul (rhs , model ):
209+ # Cannot handle fork nodes: We would have to distribute the
210+ # Mul into all branches
211+ # TODO: Maybe reconsider this at some point, there is
212+ # probably nothing preventing this in general, it is just
213+ # more difficult and apparently not necessary right now.
214+ if model .is_fork_node (rhs ):
215+ # Softly skip this node
216+ continue
217+ # Unpack the connection pattern of a scalar mul feeding the
218+ # rhs input of the matmul
219+ # Names of the three input tensors to the mul-matmul complex
220+ a , b , c = node .input [0 ], rhs .input [0 ], rhs .input [1 ]
221+ # Names of the intermediate and the global output
222+ m , o = rhs .output [0 ], node .output [0 ] # noqa: Duplicate code
223+ # Rewire the operator connections locally, swapping mul and
224+ # matmul operator order
225+ matmul = oh .make_node ("MatMul" , [a , b ], [m ], node .name )
226+ mul = oh .make_node ("Mul" , [m , c ], [o ], rhs .name )
227+ # Insert the rewired nodes into the graph
228+ graph .node .insert (index , matmul )
229+ graph .node .insert (index + 1 , mul )
230+ # Adapt the shape of the intermediate tensor as it changed
231+ # according to the output shape of the matmul
232+ model .set_tensor_shape (m , model .get_tensor_shape (o ))
233+ # Remove the old nodes from the graph
234+ graph .node .remove (rhs )
235+ graph .node .remove (node )
236+ # The graph has been modified, this needs to be reported
237+ # back to the caller
238+ graph_modified = True
239+
240+ # Finalize the transformation by inferring shapes again (as these might
241+ # have changed)
169242 model = model .transform (InferShapes ())
170- return (model , graph_modified )
243+ # Return the transformed model and indicate whether the graph actually
244+ # has been transformed
245+ return model , graph_modified
171246
172247
173248class MoveScalarAddPastMatMul (Transformation ):
@@ -617,6 +692,7 @@ def apply(self, model):
617692 graph_modified = True
618693 else :
619694 continue
695+
620696 # Note: Running shape inference is necessary as shape annotations have
621697 # been deleted above
622698 model = model .transform (InferShapes ())
@@ -634,6 +710,7 @@ class MoveScalarLinearPastInvariants(Transformation):
634710 GlobalAveragePool
635711 """
636712
713+ # Op-types of currently supported invariants
637714 # Op-types of currently supported invariants
638715 SUPPORTED_INVARIANTS = {
639716 "GlobalAveragePool" ,
0 commit comments