@@ -1447,7 +1447,7 @@ def add_output_qk_to_mha(model: OnnxModel, dtype: int = 0, skip_node_idxs: list[
14471447 return model
14481448
14491449
1450- def fix_past_sequence_length (model : ModelProto ):
1450+ def fix_past_sequence_length (model : OnnxModel ):
14511451 # Modify total_sequence_length = past_sequence_length + curr_sequence_length subgraph to calculate
14521452 # past_sequence_length from the new `past_sequence_length` input of size 1D and type int32 instead of
14531453 # from `past_key_self_0` since DecoderMaskedMultiHeadAttention (DMMHA) uses buffer sharing and
@@ -1480,56 +1480,119 @@ def fix_past_sequence_length(model: ModelProto):
14801480 # |
14811481 # Add
14821482
1483+ # Constant names to be used
1484+ past_seq_len_name = "past_sequence_length"
1485+ past_seq_len_int32 = "past_seq_len_int32"
1486+ past_seq_len_int64 = "past_seq_len_int64"
1487+
14831488 node = list (filter (lambda n : n .op_type == "LayerNormalization" , model .model .graph .node ))[0 ] # noqa: RUF015
14841489
1485- base_path = model .match_parent_path (
1490+ base_path_hf = model .match_parent_path (
1491+ node ,
1492+ ["Add" , "Gather" , "Tile" , "Expand" , "Unsqueeze" , "Range" ],
1493+ [0 , 1 , 1 , 0 , 0 , 0 ],
1494+ )
1495+ base_path_oai = model .match_parent_path (
14861496 node ,
14871497 ["Add" , "Slice" ],
14881498 [0 , 1 ],
14891499 )
1490- if base_path is None :
1500+ if base_path_hf is not None :
1501+ base_path = base_path_hf
1502+ elif base_path_oai is not None :
1503+ base_path = base_path_oai
1504+ else :
1505+ logger .info ("Cannot identify base path for fixing past_sequence_length subgraph" )
14911506 return
1507+ base_node = base_path [- 1 ]
14921508
1493- left_path = model .match_parent_path (
1494- base_path [- 1 ],
1495- ["Unsqueeze" , "Add" , "Gather" , "Shape" ],
1496- [2 , 0 , 0 , 0 ],
1497- )
1498- right_path = model .match_parent_path (
1499- base_path [- 1 ],
1500- ["Unsqueeze" , "Gather" , "Shape" ],
1501- [1 , 0 , 0 ],
1502- )
1503- long_right_path = model .match_parent_path (
1504- base_path [- 1 ],
1505- ["Unsqueeze" , "Gather" , "Shape" , "Reshape" , "Transpose" ],
1506- [1 , 0 , 0 , 0 , 0 ],
1507- )
1508- if left_path is None or right_path is None or left_path [- 2 :] != right_path [- 2 :]:
1509- return
1509+ if base_node .op_type == "Range" :
1510+ # Hugging Face implementation
1511+ range_node = base_path [- 1 ]
1512+
1513+ gather_path = model .match_parent_path (
1514+ range_node ,
1515+ ["Gather" , "Shape" ],
1516+ [0 , 0 ],
1517+ )
1518+ if gather_path is None :
1519+ logger .info ("Cannot identify gather path for fixing past_sequence_length subgraph" )
1520+ return
1521+
1522+ add_path = model .match_parent_path (
1523+ range_node ,
1524+ ["Add" , "Gather" , "Shape" ],
1525+ [1 , 0 , 0 ],
1526+ )
1527+ if add_path is None :
1528+ logger .info ("Cannot identify add path for fixing past_sequence_length subgraph" )
1529+ return
1530+ add_node = add_path [0 ]
1531+
1532+ if gather_path != add_path [1 :]:
1533+ logger .info ("Gather path and add path do not share the same nodes for calculating the past_sequence_length" )
1534+ return
1535+
1536+ # Remove `past_key_self_0 --> Shape --> Gather` connection
1537+ constant_in_gather = list (filter (lambda n : n .output [0 ] == gather_path [0 ].input [1 ], model .model .graph .node ))[0 ] # noqa: RUF015
1538+ model .model .graph .node .remove (constant_in_gather )
1539+ model .model .graph .node .remove (gather_path [0 ])
1540+ model .model .graph .node .remove (gather_path [1 ])
1541+
1542+ # Add `past_seq_len_int64` as an input name to existing nodes
1543+ range_node .input [0 ] = past_seq_len_int64
1544+ add_node .input [0 ] = past_seq_len_int64
15101545
1511- # Remove `past_key_self_0 --> [Transpose --> Reshape] --> Shape --> Gather` connection
1512- # where `Transpose --> Reshape` part may or may not exist. The OpenAI implementation of
1513- # Whisper has an extra `Transpose --> Reshape` connection to remove.
1514- constant_node = list (filter (lambda n : n .output [0 ] == left_path [- 2 ].input [1 ], model .model .graph .node ))[0 ] # noqa: RUF015
1515- model .model .graph .node .remove (left_path [- 2 ])
1516- model .model .graph .node .remove (left_path [- 1 ])
1517- model .model .graph .node .remove (constant_node )
1518- if long_right_path is not None :
1519- # Remove `Transpose --> Reshape` part
1520- model .model .graph .node .remove (long_right_path [- 2 ])
1521- model .model .graph .node .remove (long_right_path [- 1 ])
1546+ else :
1547+ # OpenAI implementation
1548+ input_ids_path = model .match_parent_path (
1549+ base_node ,
1550+ ["Unsqueeze" , "Add" , "Gather" , "Shape" , "Reshape" , "Transpose" ],
1551+ [2 , 0 , 0 , 0 , 0 , 0 ],
1552+ )
1553+ if input_ids_path is None :
1554+ logger .info ("Cannot identify input_ids path for fixing past_sequence_length subgraph" )
1555+ return
1556+ add_node = input_ids_path [1 ]
1557+
1558+ past_key_path = model .match_parent_path (
1559+ base_node ,
1560+ ["Unsqueeze" , "Gather" , "Shape" , "Reshape" , "Transpose" ],
1561+ [1 , 0 , 0 , 0 , 0 ],
1562+ )
1563+ if past_key_path is None :
1564+ logger .info ("Cannot identify past_key path for fixing past_sequence_length subgraph" )
1565+ return
1566+ unsqueeze_node = past_key_path [0 ]
1567+
1568+ if input_ids_path [2 :] != past_key_path [1 :]:
1569+ logger .info (
1570+ "The input_ids path and past_key path do not share the same nodes for calculating the past_sequence_length"
1571+ )
1572+ return
1573+
1574+ # Remove `past_key_self_0 --> Transpose --> Reshape --> Shape --> Gather` connection
1575+ constant_in_gather = list (filter (lambda n : n .output [0 ] == past_key_path [1 ].input [1 ], model .model .graph .node ))[0 ] # noqa: RUF015
1576+ model .model .graph .node .remove (constant_in_gather )
1577+ constant_in_reshape = list (filter (lambda n : n .output [0 ] == past_key_path [- 2 ].input [1 ], model .model .graph .node ))[ # noqa: RUF015
1578+ 0
1579+ ]
1580+ model .model .graph .node .remove (constant_in_reshape )
1581+ model .model .graph .node .remove (past_key_path [1 ])
1582+ model .model .graph .node .remove (past_key_path [2 ])
1583+ model .model .graph .node .remove (past_key_path [3 ])
1584+ model .model .graph .node .remove (past_key_path [4 ])
1585+
1586+ # Add `past_seq_len_int64` as an input name to existing nodes
1587+ unsqueeze_node .input [0 ] = past_seq_len_int64
1588+ add_node .input [0 ] = past_seq_len_int64
15221589
15231590 # Add `past_sequence_length` as model input
1524- past_seq_len_name = "past_sequence_length"
15251591 model .model .graph .input .append (
15261592 onnx .helper .make_tensor_value_info (past_seq_len_name , TensorProto .INT32 , shape = [1 ]),
15271593 )
15281594
15291595 # Add `past_sequence_length --> Squeeze --> Cast` connection
1530- past_seq_len_int32 = "past_seq_len_int32"
1531- past_seq_len_int64 = "past_seq_len_int64"
1532-
15331596 squeeze_node = onnx .helper .make_node (
15341597 "Squeeze" ,
15351598 inputs = [past_seq_len_name ],
@@ -1546,14 +1609,9 @@ def fix_past_sequence_length(model: ModelProto):
15461609 )
15471610 cast_output = onnx .helper .make_tensor_value_info (past_seq_len_int64 , TensorProto .INT64 , shape = [])
15481611
1549- model .model .graph .value_info .extend ([squeeze_output , cast_output ])
1550-
1551- # Add `past_seq_len_int64` as an input name to existing nodes
1552- left_path [1 ].input [0 ] = past_seq_len_int64
1553- right_path [0 ].input [0 ] = past_seq_len_int64
1554-
15551612 # Add new nodes to graph
15561613 model .model .graph .node .extend ([squeeze_node , cast_node ])
1614+ model .model .graph .value_info .extend ([squeeze_output , cast_output ])
15571615 model .topological_sort ()
15581616 return model , past_seq_len_name
15591617
0 commit comments