Skip to content

Commit 779b51c

Browse files
committed
Fix MHSA requant fields needing to be ints when single item and tuples when multiitem
1 parent 05a2a4a commit 779b51c

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

Deeploy/OperatorDescriptor.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from enum import Enum, IntEnum
6-
from typing import Any, Dict, Tuple
6+
from typing import Any, Dict, Tuple, Union
77

88
import numpy as np
99
import onnx_graphsurgeon as gs
@@ -51,6 +51,13 @@ def FloatTupleUnpack(value: Any) -> Tuple[float, ...]:
5151
return (FloatUnpack(value),)
5252

5353

54+
def IntTupleIfNotSingleItemUnpack(value: Any) -> Union[int, Tuple[int, ...]]:
55+
try:
56+
return IntUnpack(value)
57+
except:
58+
return IntTupleUnpack(value)
59+
60+
5461
def attrToTensor(node: gs.Node, attr: str) -> None:
5562
values = node.attrs[attr]
5663
if isinstance(values, (int, float)):
@@ -609,18 +616,18 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool:
609616
"wo_bias"]),
610617
outputDescriptor = IoDesc("data_out"),
611618
attrDescriptors = [
612-
AttrDesc("preattn_requant_mul", IntUnpack),
613-
AttrDesc("preattn_requant_div", IntUnpack),
614-
AttrDesc("postattn_requant_mul", IntUnpack),
615-
AttrDesc("postattn_requant_div", IntUnpack),
616-
AttrDesc("wo_requant_mul", IntUnpack),
617-
AttrDesc("wo_requant_div", IntUnpack),
618-
AttrDesc("wq_requant_mul", IntUnpack),
619-
AttrDesc("wq_requant_div", IntUnpack),
620-
AttrDesc("wk_requant_mul", IntUnpack),
621-
AttrDesc("wk_requant_div", IntUnpack),
622-
AttrDesc("wv_requant_mul", IntUnpack),
623-
AttrDesc("wv_requant_div", IntUnpack),
619+
AttrDesc("preattn_requant_mul", IntTupleIfNotSingleItemUnpack),
620+
AttrDesc("preattn_requant_div", IntTupleIfNotSingleItemUnpack),
621+
AttrDesc("postattn_requant_mul", IntTupleIfNotSingleItemUnpack),
622+
AttrDesc("postattn_requant_div", IntTupleIfNotSingleItemUnpack),
623+
AttrDesc("wo_requant_mul", IntTupleIfNotSingleItemUnpack),
624+
AttrDesc("wo_requant_div", IntTupleIfNotSingleItemUnpack),
625+
AttrDesc("wq_requant_mul", IntTupleIfNotSingleItemUnpack),
626+
AttrDesc("wq_requant_div", IntTupleIfNotSingleItemUnpack),
627+
AttrDesc("wk_requant_mul", IntTupleIfNotSingleItemUnpack),
628+
AttrDesc("wk_requant_div", IntTupleIfNotSingleItemUnpack),
629+
AttrDesc("wv_requant_mul", IntTupleIfNotSingleItemUnpack),
630+
AttrDesc("wv_requant_div", IntTupleIfNotSingleItemUnpack),
624631
AttrDesc("n_levels", IntUnpack),
625632
AttrDesc("dim", IntUnpack),
626633
AttrDesc("dim_head", IntUnpack),

0 commit comments

Comments
 (0)