Skip to content

Commit 196dea1

Browse files
Update Whisper attention fusions (microsoft#24857)
### Description This PR updates the attention fusions for Whisper to work with the latest `transformers` package (`4.52.3`). ### Motivation and Context Previously, the attention fusions were maintained for many older `transformers` versions. The existing fusions do not work with the latest `transformers` versions.
1 parent 70de20b commit 196dea1

File tree

36 files changed

+1185
-2620
lines changed

36 files changed

+1185
-2620
lines changed

onnxruntime/python/tools/transformers/convert_generation.py

Lines changed: 99 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,7 +1447,7 @@ def add_output_qk_to_mha(model: OnnxModel, dtype: int = 0, skip_node_idxs: list[
14471447
return model
14481448

14491449

1450-
def fix_past_sequence_length(model: ModelProto):
1450+
def fix_past_sequence_length(model: OnnxModel):
14511451
# Modify total_sequence_length = past_sequence_length + curr_sequence_length subgraph to calculate
14521452
# past_sequence_length from the new `past_sequence_length` input of size 1D and type int32 instead of
14531453
# from `past_key_self_0` since DecoderMaskedMultiHeadAttention (DMMHA) uses buffer sharing and
@@ -1480,56 +1480,119 @@ def fix_past_sequence_length(model: ModelProto):
14801480
# |
14811481
# Add
14821482

1483+
# Constant names to be used
1484+
past_seq_len_name = "past_sequence_length"
1485+
past_seq_len_int32 = "past_seq_len_int32"
1486+
past_seq_len_int64 = "past_seq_len_int64"
1487+
14831488
node = list(filter(lambda n: n.op_type == "LayerNormalization", model.model.graph.node))[0] # noqa: RUF015
14841489

1485-
base_path = model.match_parent_path(
1490+
base_path_hf = model.match_parent_path(
1491+
node,
1492+
["Add", "Gather", "Tile", "Expand", "Unsqueeze", "Range"],
1493+
[0, 1, 1, 0, 0, 0],
1494+
)
1495+
base_path_oai = model.match_parent_path(
14861496
node,
14871497
["Add", "Slice"],
14881498
[0, 1],
14891499
)
1490-
if base_path is None:
1500+
if base_path_hf is not None:
1501+
base_path = base_path_hf
1502+
elif base_path_oai is not None:
1503+
base_path = base_path_oai
1504+
else:
1505+
logger.info("Cannot identify base path for fixing past_sequence_length subgraph")
14911506
return
1507+
base_node = base_path[-1]
14921508

1493-
left_path = model.match_parent_path(
1494-
base_path[-1],
1495-
["Unsqueeze", "Add", "Gather", "Shape"],
1496-
[2, 0, 0, 0],
1497-
)
1498-
right_path = model.match_parent_path(
1499-
base_path[-1],
1500-
["Unsqueeze", "Gather", "Shape"],
1501-
[1, 0, 0],
1502-
)
1503-
long_right_path = model.match_parent_path(
1504-
base_path[-1],
1505-
["Unsqueeze", "Gather", "Shape", "Reshape", "Transpose"],
1506-
[1, 0, 0, 0, 0],
1507-
)
1508-
if left_path is None or right_path is None or left_path[-2:] != right_path[-2:]:
1509-
return
1509+
if base_node.op_type == "Range":
1510+
# Hugging Face implementation
1511+
range_node = base_path[-1]
1512+
1513+
gather_path = model.match_parent_path(
1514+
range_node,
1515+
["Gather", "Shape"],
1516+
[0, 0],
1517+
)
1518+
if gather_path is None:
1519+
logger.info("Cannot identify gather path for fixing past_sequence_length subgraph")
1520+
return
1521+
1522+
add_path = model.match_parent_path(
1523+
range_node,
1524+
["Add", "Gather", "Shape"],
1525+
[1, 0, 0],
1526+
)
1527+
if add_path is None:
1528+
logger.info("Cannot identify add path for fixing past_sequence_length subgraph")
1529+
return
1530+
add_node = add_path[0]
1531+
1532+
if gather_path != add_path[1:]:
1533+
logger.info("Gather path and add path do not share the same nodes for calculating the past_sequence_length")
1534+
return
1535+
1536+
# Remove `past_key_self_0 --> Shape --> Gather` connection
1537+
constant_in_gather = list(filter(lambda n: n.output[0] == gather_path[0].input[1], model.model.graph.node))[0] # noqa: RUF015
1538+
model.model.graph.node.remove(constant_in_gather)
1539+
model.model.graph.node.remove(gather_path[0])
1540+
model.model.graph.node.remove(gather_path[1])
1541+
1542+
# Add `past_seq_len_int64` as an input name to existing nodes
1543+
range_node.input[0] = past_seq_len_int64
1544+
add_node.input[0] = past_seq_len_int64
15101545

1511-
# Remove `past_key_self_0 --> [Transpose --> Reshape] --> Shape --> Gather` connection
1512-
# where `Transpose --> Reshape` part may or may not exist. The OpenAI implementation of
1513-
# Whisper has an extra `Transpose --> Reshape` connection to remove.
1514-
constant_node = list(filter(lambda n: n.output[0] == left_path[-2].input[1], model.model.graph.node))[0] # noqa: RUF015
1515-
model.model.graph.node.remove(left_path[-2])
1516-
model.model.graph.node.remove(left_path[-1])
1517-
model.model.graph.node.remove(constant_node)
1518-
if long_right_path is not None:
1519-
# Remove `Transpose --> Reshape` part
1520-
model.model.graph.node.remove(long_right_path[-2])
1521-
model.model.graph.node.remove(long_right_path[-1])
1546+
else:
1547+
# OpenAI implementation
1548+
input_ids_path = model.match_parent_path(
1549+
base_node,
1550+
["Unsqueeze", "Add", "Gather", "Shape", "Reshape", "Transpose"],
1551+
[2, 0, 0, 0, 0, 0],
1552+
)
1553+
if input_ids_path is None:
1554+
logger.info("Cannot identify input_ids path for fixing past_sequence_length subgraph")
1555+
return
1556+
add_node = input_ids_path[1]
1557+
1558+
past_key_path = model.match_parent_path(
1559+
base_node,
1560+
["Unsqueeze", "Gather", "Shape", "Reshape", "Transpose"],
1561+
[1, 0, 0, 0, 0],
1562+
)
1563+
if past_key_path is None:
1564+
logger.info("Cannot identify past_key path for fixing past_sequence_length subgraph")
1565+
return
1566+
unsqueeze_node = past_key_path[0]
1567+
1568+
if input_ids_path[2:] != past_key_path[1:]:
1569+
logger.info(
1570+
"The input_ids path and past_key path do not share the same nodes for calculating the past_sequence_length"
1571+
)
1572+
return
1573+
1574+
# Remove `past_key_self_0 --> Transpose --> Reshape --> Shape --> Gather` connection
1575+
constant_in_gather = list(filter(lambda n: n.output[0] == past_key_path[1].input[1], model.model.graph.node))[0] # noqa: RUF015
1576+
model.model.graph.node.remove(constant_in_gather)
1577+
constant_in_reshape = list(filter(lambda n: n.output[0] == past_key_path[-2].input[1], model.model.graph.node))[ # noqa: RUF015
1578+
0
1579+
]
1580+
model.model.graph.node.remove(constant_in_reshape)
1581+
model.model.graph.node.remove(past_key_path[1])
1582+
model.model.graph.node.remove(past_key_path[2])
1583+
model.model.graph.node.remove(past_key_path[3])
1584+
model.model.graph.node.remove(past_key_path[4])
1585+
1586+
# Add `past_seq_len_int64` as an input name to existing nodes
1587+
unsqueeze_node.input[0] = past_seq_len_int64
1588+
add_node.input[0] = past_seq_len_int64
15221589

15231590
# Add `past_sequence_length` as model input
1524-
past_seq_len_name = "past_sequence_length"
15251591
model.model.graph.input.append(
15261592
onnx.helper.make_tensor_value_info(past_seq_len_name, TensorProto.INT32, shape=[1]),
15271593
)
15281594

15291595
# Add `past_sequence_length --> Squeeze --> Cast` connection
1530-
past_seq_len_int32 = "past_seq_len_int32"
1531-
past_seq_len_int64 = "past_seq_len_int64"
1532-
15331596
squeeze_node = onnx.helper.make_node(
15341597
"Squeeze",
15351598
inputs=[past_seq_len_name],
@@ -1546,14 +1609,9 @@ def fix_past_sequence_length(model: ModelProto):
15461609
)
15471610
cast_output = onnx.helper.make_tensor_value_info(past_seq_len_int64, TensorProto.INT64, shape=[])
15481611

1549-
model.model.graph.value_info.extend([squeeze_output, cast_output])
1550-
1551-
# Add `past_seq_len_int64` as an input name to existing nodes
1552-
left_path[1].input[0] = past_seq_len_int64
1553-
right_path[0].input[0] = past_seq_len_int64
1554-
15551612
# Add new nodes to graph
15561613
model.model.graph.node.extend([squeeze_node, cast_node])
1614+
model.model.graph.value_info.extend([squeeze_output, cast_output])
15571615
model.topological_sort()
15581616
return model, past_seq_len_name
15591617

onnxruntime/python/tools/transformers/fusion_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -663,12 +663,12 @@ def create_attention_node(
663663
first_input: str,
664664
output: str,
665665
add_qk_str: str = "",
666+
causal: bool = False,
666667
past_k: str = "",
667668
past_v: str = "",
668669
present_k: str = "",
669670
present_v: str = "",
670671
scale: float | None = None,
671-
causal: bool = False,
672672
) -> NodeProto | None:
673673
"""Create an Attention node.
674674
@@ -685,12 +685,12 @@ def create_attention_node(
685685
first_input (str): first input name
686686
output (str): output name
687687
add_qk_str (str): name of Add node after Q x K'
688+
causal: whether it is uni-directional mask.
688689
past_k (str): name of input for past K value
689690
past_v (str): name of input for past V value
690691
present_k (str): name of output to store present K value
691692
present_v (str): name of output to store present V value
692693
scale: scale before softmax
693-
causal: whether it is uni-directional mask.
694694
695695
Returns:
696696
Union[NodeProto, None]: the node created or None if failed.

0 commit comments

Comments
 (0)