|
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | 4 |
|
5 | 5 | from enum import Enum, IntEnum |
6 | | -from typing import Any, Dict, Tuple |
| 6 | +from typing import Any, Dict, Tuple, Union |
7 | 7 |
|
8 | 8 | import numpy as np |
9 | 9 | import onnx_graphsurgeon as gs |
@@ -51,6 +51,13 @@ def FloatTupleUnpack(value: Any) -> Tuple[float, ...]: |
51 | 51 | return (FloatUnpack(value),) |
52 | 52 |
|
53 | 53 |
|
| 54 | +def IntTupleIfNotSingleItemUnpack(value: Any) -> Union[int, Tuple[int, ...]]: |
| 55 | + try: |
| 56 | + return IntUnpack(value) |
| 57 | + except: |
| 58 | + return IntTupleUnpack(value) |
| 59 | + |
| 60 | + |
54 | 61 | def attrToTensor(node: gs.Node, attr: str) -> None: |
55 | 62 | values = node.attrs[attr] |
56 | 63 | if isinstance(values, (int, float)): |
@@ -609,18 +616,18 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool: |
609 | 616 | "wo_bias"]), |
610 | 617 | outputDescriptor = IoDesc("data_out"), |
611 | 618 | attrDescriptors = [ |
612 | | - AttrDesc("preattn_requant_mul", IntUnpack), |
613 | | - AttrDesc("preattn_requant_div", IntUnpack), |
614 | | - AttrDesc("postattn_requant_mul", IntUnpack), |
615 | | - AttrDesc("postattn_requant_div", IntUnpack), |
616 | | - AttrDesc("wo_requant_mul", IntUnpack), |
617 | | - AttrDesc("wo_requant_div", IntUnpack), |
618 | | - AttrDesc("wq_requant_mul", IntUnpack), |
619 | | - AttrDesc("wq_requant_div", IntUnpack), |
620 | | - AttrDesc("wk_requant_mul", IntUnpack), |
621 | | - AttrDesc("wk_requant_div", IntUnpack), |
622 | | - AttrDesc("wv_requant_mul", IntUnpack), |
623 | | - AttrDesc("wv_requant_div", IntUnpack), |
| 619 | + AttrDesc("preattn_requant_mul", IntTupleIfNotSingleItemUnpack), |
| 620 | + AttrDesc("preattn_requant_div", IntTupleIfNotSingleItemUnpack), |
| 621 | + AttrDesc("postattn_requant_mul", IntTupleIfNotSingleItemUnpack), |
| 622 | + AttrDesc("postattn_requant_div", IntTupleIfNotSingleItemUnpack), |
| 623 | + AttrDesc("wo_requant_mul", IntTupleIfNotSingleItemUnpack), |
| 624 | + AttrDesc("wo_requant_div", IntTupleIfNotSingleItemUnpack), |
| 625 | + AttrDesc("wq_requant_mul", IntTupleIfNotSingleItemUnpack), |
| 626 | + AttrDesc("wq_requant_div", IntTupleIfNotSingleItemUnpack), |
| 627 | + AttrDesc("wk_requant_mul", IntTupleIfNotSingleItemUnpack), |
| 628 | + AttrDesc("wk_requant_div", IntTupleIfNotSingleItemUnpack), |
| 629 | + AttrDesc("wv_requant_mul", IntTupleIfNotSingleItemUnpack), |
| 630 | + AttrDesc("wv_requant_div", IntTupleIfNotSingleItemUnpack), |
624 | 631 | AttrDesc("n_levels", IntUnpack), |
625 | 632 | AttrDesc("dim", IntUnpack), |
626 | 633 | AttrDesc("dim_head", IntUnpack), |
|
0 commit comments