Skip to content

Commit 1e3085f

Browse files
authored
Merge pull request #12 from eki-project/feature/attention-streamline
Streamlining of Scaled Dot-Product Attention
2 parents a0b9007 + 95ed158 commit 1e3085f

File tree

4 files changed

+194
-52
lines changed

4 files changed

+194
-52
lines changed

src/finn/transformation/streamline/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def apply(self, model):
7676
BatchNormToAffine(),
7777
ConvertSignToThres(),
7878
MoveMulPastMaxPool(),
79-
MoveScalarLinearPastInvariants(),
8079
AbsorbSignBiasIntoMultiThreshold(),
80+
MoveScalarLinearPastInvariants(),
8181
MoveAddPastMul(),
8282
MoveScalarAddPastMatMul(),
8383
MoveAddPastConv(),

src/finn/transformation/streamline/absorb.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
import qonnx.core.data_layout as DataLayout
3131
import warnings
3232
from onnx import helper as oh
33+
# Protobuf onnx graph node type
34+
from onnx import NodeProto # noqa
35+
# QONNX wrapper of ONNX model graphs
36+
from qonnx.core.modelwrapper import ModelWrapper
3337
from qonnx.core.datatype import DataType
3438

3539
# QONNX wrapper of ONNX model graphs
@@ -261,7 +265,7 @@ def apply(self, model):
261265

262266

263267
class Absorb1BitMulIntoMatMul(Transformation):
264-
"""Absorb bipolar or binary multiplications into the preciding matrix
268+
"""Absorb bipolar or binary multiplications into the preceding matrix
265269
multiply."""
266270

267271
def apply(self, model):
@@ -270,16 +274,28 @@ def apply(self, model):
270274
graph_modified = False
271275
for n in graph.node:
272276
node_ind += 1
273-
if n.op_type == "MatMul":
277+
# Note: Join-node test is implicitly covered by testing for the
278+
# initializer below
279+
# Note: This cannot handle fork-nodes, as only the first consumer is
280+
# considered below.
281+
# TODO: Fork-nodes could be handled if the muls are the same in all
282+
# branches, but this is not checked nor rewired at all right now.
283+
if n.op_type == "MatMul" and not model.is_fork_node(n):
274284
matmul_weight_name = n.input[1]
275285
W = model.get_initializer(matmul_weight_name)
276286
Wdt = model.get_tensor_datatype(matmul_weight_name)
277-
assert W is not None, "Initializer for matmul weights is not set."
287+
# Just skip matmuls with non-existing weight initializers
288+
if W is None:
289+
continue
278290
consumer = model.find_consumer(n.output[0])
291+
# Note: Join-node test is implicitly covered by testing for the
292+
# initializer below
279293
if consumer is not None and consumer.op_type == "Mul":
280294
mul_weight_name = consumer.input[1]
281295
A = model.get_initializer(mul_weight_name)
282-
assert A is not None, "Initializer for mul weights is not set."
296+
# Just skip muls with non-existing scale initializers
297+
if A is None:
298+
continue
283299
is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1
284300
if is_1bit:
285301
Wnew = A * W
@@ -298,24 +314,36 @@ def apply(self, model):
298314

299315

300316
class Absorb1BitMulIntoConv(Transformation):
301-
"""Absorb bipolar or binary multiplications into the preciding convolution."""
317+
"""Absorb bipolar or binary multiplications into the preceding convolution."""
302318

303319
def apply(self, model):
304320
graph = model.graph
305321
node_ind = 0
306322
graph_modified = False
307323
for n in graph.node:
308324
node_ind += 1
309-
if n.op_type == "Conv":
325+
# Note: Join-node test is implicitly covered by testing for the
326+
# initializer below
327+
# Note: This cannot handle fork-nodes, as only the first consumer is
328+
# considered below.
329+
# TODO: Fork-nodes could be handled if the muls are the same in all
330+
# branches, but this is not checked nor rewired at all right now.
331+
if n.op_type == "Conv" and not model.is_fork_node(n):
310332
conv_weight_name = n.input[1]
311333
W = model.get_initializer(conv_weight_name)
312334
Wdt = model.get_tensor_datatype(conv_weight_name)
313-
assert W is not None, "Initializer for conv weights is not set."
335+
# Just skip convs with non-existing weight initializers
336+
if W is None:
337+
continue
314338
consumer = model.find_consumer(n.output[0])
339+
# Note: Join-node test is implicitly covered by testing for the
340+
# initializer below
315341
if consumer is not None and consumer.op_type == "Mul":
316342
mul_weight_name = consumer.input[1]
317343
A = model.get_initializer(mul_weight_name)
318-
assert A is not None, "Initializer for mul weights is not set."
344+
# Just skip muls with non-existing scale initializers
345+
if A is None:
346+
continue
319347
is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1
320348
is_scalar = np.prod(A.shape) == 1
321349
actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape)))

src/finn/transformation/streamline/reorder.py

Lines changed: 120 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -116,58 +116,133 @@ def apply(self, model):
116116
return model, graph_modified
117117

118118

119+
# Tests whether a tensor is a scalar, i.e., whether all dimensions are 1
120+
def is_scalar(tensor):
121+
return tensor is not None and all(x == 1 for x in tensor.shape)
122+
123+
124+
# Tests whether a node is a scalar multiplication with a constant scale factor
125+
def is_const_scalar_mul(node, model):
126+
# Only handle existing Mul type nodes
127+
if node is not None and node.op_type == "Mul":
128+
# The constant must be an initializer
129+
# Note: Assumes the constant parameter to always be the second input
130+
scale = model.get_initializer(node.input[1])
131+
# Test for existence of a constant scale factor
132+
return scale is not None and is_scalar(scale)
133+
# Did not match the operator type
134+
return False
135+
136+
137+
# Refactored version of the MoveScalarMulPastMatMul transform capable of
138+
# transforming two-input MatMul, like those being part of the attention operator
119139
class MoveScalarMulPastMatMul(Transformation):
120140
"""Move scalar mul operations past matmul operations. We want to have muls
121141
next to each other such that they can be collapsed into a single mul."""
122142

143+
# Applies the transform to a whole model graph
123144
def apply(self, model):
145+
# Get the model graph out of the model wrapper object
124146
graph = model.graph
125-
node_ind = 0
147+
# Keep track of whether the graph has been modified
126148
graph_modified = False
127-
for n in graph.node:
128-
node_ind += 1
129-
if n.op_type == "Mul" and not model.is_fork_node(n) and not model.is_join_node(n):
130-
consumer = model.find_consumer(n.output[0])
131-
if (
132-
consumer is not None
133-
and consumer.op_type == "MatMul"
134-
and not model.is_join_node(consumer)
135-
):
136-
mul_weight_name = n.input[1]
137-
matmul_weight_name = consumer.input[1]
138-
A = model.get_initializer(mul_weight_name)
139-
W = model.get_initializer(matmul_weight_name)
140-
if (A is None) or (W is None):
141-
warnings.warn("MatMul or Mul params are not constant, skipping")
149+
150+
# Iterate all nodes in the graph keeping track of the index
151+
for index, node in enumerate(graph.node):
152+
# First pattern matching condition: For the transform to be
153+
# applicable, the node has to be a MatMul operator
154+
if node.op_type == "MatMul":
155+
# Note: When touching the following code, remember to treat both
156+
# branches equivalently!
157+
# TODO: Can this be enforced or at least be made easier by
158+
# extracting common code patterns to a function?
159+
160+
# Get the left hand side and right hand side inputs
161+
# Note: Assumes the ordering of left to right inputs to match
162+
# indices 0 to 1. However, it does not "hurt" if it is
163+
# reversed as both sides are treated equivalently.
164+
lhs = model.find_producer(node.input[0])
165+
rhs = model.find_producer(node.input[1])
166+
167+
# Give precedence to the left hand side input testing for the
168+
# presence of a scalar multiplication
169+
if is_const_scalar_mul(lhs, model):
170+
# Cannot handle fork nodes: We would have to distribute the
171+
# Mul into all branches
172+
# TODO: Maybe reconsider this at some point, there is
173+
# probably nothing preventing this in general, it is just
174+
# more difficult and apparently not necessary right now.
175+
if model.is_fork_node(lhs):
176+
# Softly skip this node
142177
continue
143-
start_name = n.input[0]
144-
middle_name = n.output[0]
145-
end_name = consumer.output[0]
146-
mm_out_shape = model.get_tensor_shape(end_name)
147-
if all(x == 1 for x in A.shape):
148-
# if the mul is scalar, we can simply swap the order of ops
149-
# make and insert new nodes
150-
new_matmul = oh.make_node(
151-
"MatMul",
152-
[start_name, matmul_weight_name],
153-
[middle_name],
154-
name=consumer.name,
155-
)
156-
new_mul = oh.make_node(
157-
"Mul",
158-
[middle_name, mul_weight_name],
159-
[end_name],
160-
name=n.name,
161-
)
162-
graph.node.insert(node_ind, new_matmul)
163-
graph.node.insert(node_ind + 1, new_mul)
164-
model.set_tensor_shape(middle_name, mm_out_shape)
165-
# remove old nodes
166-
graph.node.remove(n)
167-
graph.node.remove(consumer)
168-
graph_modified = True
178+
# Unpack the connection pattern of a scalar mul feeding the
179+
# lhs input of the matmul
180+
# Names of the three input tensors to the mul-matmul complex
181+
a, b, c = lhs.input[0], lhs.input[1], node.input[1]
182+
# Names of the intermediate and the global output
183+
m, o = lhs.output[0], node.output[0] # noqa: Duplicate code
184+
# Rewire the operator connections locally, swapping mul and
185+
# matmul operator order
186+
matmul = oh.make_node("MatMul", [a, c], [m], node.name)
187+
mul = oh.make_node("Mul", [m, b], [o], lhs.name)
188+
# Insert the rewired nodes into the graph
189+
graph.node.insert(index, matmul)
190+
graph.node.insert(index + 1, mul)
191+
# Adapt the shape of the intermediate tensor as it changed
192+
# according to the output shape of the matmul
193+
model.set_tensor_shape(m, model.get_tensor_shape(o))
194+
# Remove the old nodes from the graph
195+
graph.node.remove(lhs)
196+
graph.node.remove(node)
197+
# The graph has been modified, this needs to be reported
198+
# back to the caller
199+
graph_modified = True
200+
# Cannot further modify the node (i.e., the rhs) as the
201+
# index and state of the nodes changed and need to be
202+
# queried again from the graph.node at the start of the next
203+
# iteration.
204+
continue
205+
206+
# Next try whether the right hand side matches the pattern of a
207+
# scalar multiplication
208+
if is_const_scalar_mul(rhs, model):
209+
# Cannot handle fork nodes: We would have to distribute the
210+
# Mul into all branches
211+
# TODO: Maybe reconsider this at some point, there is
212+
# probably nothing preventing this in general, it is just
213+
# more difficult and apparently not necessary right now.
214+
if model.is_fork_node(rhs):
215+
# Softly skip this node
216+
continue
217+
# Unpack the connection pattern of a scalar mul feeding the
218+
# rhs input of the matmul
219+
# Names of the three input tensors to the mul-matmul complex
220+
a, b, c = node.input[0], rhs.input[0], rhs.input[1]
221+
# Names of the intermediate and the global output
222+
m, o = rhs.output[0], node.output[0] # noqa: Duplicate code
223+
# Rewire the operator connections locally, swapping mul and
224+
# matmul operator order
225+
matmul = oh.make_node("MatMul", [a, b], [m], node.name)
226+
mul = oh.make_node("Mul", [m, c], [o], rhs.name)
227+
# Insert the rewired nodes into the graph
228+
graph.node.insert(index, matmul)
229+
graph.node.insert(index + 1, mul)
230+
# Adapt the shape of the intermediate tensor as it changed
231+
# according to the output shape of the matmul
232+
model.set_tensor_shape(m, model.get_tensor_shape(o))
233+
# Remove the old nodes from the graph
234+
graph.node.remove(rhs)
235+
graph.node.remove(node)
236+
# The graph has been modified, this needs to be reported
237+
# back to the caller
238+
graph_modified = True
239+
240+
# Finalize the transformation by inferring shapes again (as these might
241+
# have changed)
169242
model = model.transform(InferShapes())
170-
return (model, graph_modified)
243+
# Return the transformed model and indicate whether the graph actually
244+
# has been transformed
245+
return model, graph_modified
171246

172247

173248
class MoveScalarAddPastMatMul(Transformation):
@@ -617,6 +692,7 @@ def apply(self, model):
617692
graph_modified = True
618693
else:
619694
continue
695+
620696
# Note: Running shape inference is necessary as shape annotations have
621697
# been deleted above
622698
model = model.transform(InferShapes())
@@ -634,6 +710,7 @@ class MoveScalarLinearPastInvariants(Transformation):
634710
GlobalAveragePool
635711
"""
636712

713+
# Op-types of currently supported invariants
637714
# Op-types of currently supported invariants
638715
SUPPORTED_INVARIANTS = {
639716
"GlobalAveragePool",

tests/transformation/streamline/test_move_scalar_past_matmul.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,43 @@ def test_move_scalar_mul_past_matmul():
7272
assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0]
7373

7474

75+
@pytest.mark.streamline
76+
def test_move_scalar_mul_past_join_matmul():
77+
top_in1 = oh.make_tensor_value_info("top_in1", TensorProto.FLOAT, [1, 2])
78+
top_in2 = oh.make_tensor_value_info("top_in2", TensorProto.FLOAT, [2, 1])
79+
mul1_param = oh.make_tensor_value_info("mul1_param", TensorProto.FLOAT, [1, 1])
80+
mul2_param = oh.make_tensor_value_info("mul2_param", TensorProto.FLOAT, [1, 1])
81+
top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [1, 1])
82+
modelproto = qonnx_make_model(
83+
oh.make_graph(
84+
name="test",
85+
inputs=[top_in1, top_in2],
86+
outputs=[top_out],
87+
value_info=[mul1_param, mul2_param],
88+
nodes=[
89+
oh.make_node("Mul", ["top_in1", "mul1_param"], ["middle1"]),
90+
oh.make_node("Mul", ["top_in2", "mul2_param"], ["middle2"]),
91+
oh.make_node("MatMul", ["middle1", "middle2"], ["top_out"]),
92+
],
93+
)
94+
)
95+
model = ModelWrapper(modelproto)
96+
model = model.transform(InferShapes())
97+
model.set_initializer("mul1_param", np.asarray([[3]], dtype=np.float32))
98+
model.set_initializer("mul2_param", np.asarray([[3]], dtype=np.float32))
99+
new_model = model.transform(MoveScalarMulPastMatMul())
100+
inp_dict = {
101+
"top_in1": np.asarray([[-1.0, 1.0]], dtype=np.float32),
102+
"top_in2": np.asarray([[1.0], [-1.0]], dtype=np.float32),
103+
}
104+
assert ox.compare_execution(model, new_model, inp_dict)
105+
assert new_model.graph.node[0].op_type == "MatMul"
106+
assert new_model.graph.node[1].op_type == "Mul"
107+
assert new_model.graph.node[2].op_type == "Mul"
108+
assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0]
109+
assert new_model.graph.node[1].output[0] == new_model.graph.node[2].input[0]
110+
111+
75112
@pytest.mark.streamline
76113
def test_move_scalar_add_past_matmul():
77114
top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [1, 2])

0 commit comments

Comments
 (0)