|
39 | 39 | import qonnx.util.basic as util |
40 | 40 | import qonnx.util.onnx as onnxutil |
41 | 41 | from qonnx.core.datatype import DataType |
42 | | -from qonnx.custom_op.registry import is_custom_op |
| 42 | +from qonnx.custom_op.registry import getCustomOp, is_custom_op |
43 | 43 | from qonnx.transformation.double_to_single_float import DoubleToSingleFloat |
44 | 44 | from qonnx.transformation.general import ( |
45 | 45 | RemoveStaticGraphInputs, |
@@ -183,7 +183,7 @@ def transform( |
183 | 183 | if self.fix_float64: |
184 | 184 | (transformed_model, model_was_changed) = DoubleToSingleFloat().apply(transformed_model) |
185 | 185 |
|
186 | | - if apply_to_subgraphs and not use_preorder_traversal: |
| 186 | + if apply_to_subgraphs and (use_preorder_traversal is False): |
187 | 187 | transformed_model.transform_subgraphs( |
188 | 188 | transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal |
189 | 189 | ) |
@@ -738,3 +738,44 @@ def set_tensor_sparsity(self, tensor_name, sparsity_dict): |
738 | 738 | qa.tensor_name = tensor_name |
739 | 739 | qa.quant_parameter_tensor_names.append(dt) |
740 | 740 | qnt_annotations.append(qa) |
| 741 | + |
| 742 | + def get_opset_imports(self): |
| 743 | + """Returns a list of imported opsets as a {domain, version} dictionary.""" |
| 744 | + return {opset.domain: opset.version for opset in self._model_proto.opset_import} |
| 745 | + |
| 746 | + def get_customop_wrapper(self, node, fallback_customop_version=util.get_preferred_qonnx_opset()): |
| 747 | + """Return CustomOp instance for given node, respecting the |
| 748 | + imported opset version in the model protobuf. If the node's domain |
| 749 | + is not found in the model's opset imports, fallback_customop_version |
| 750 | + will be used.""" |
| 751 | + opset_imports = self.get_opset_imports() |
| 752 | + try: |
| 753 | + opset_import = opset_imports[node.domain] |
| 754 | + return getCustomOp(node, onnx_opset_version=opset_import) |
| 755 | + except KeyError: |
| 756 | + # domain not found in imports, use fallback version |
| 757 | + warnings.warn( |
| 758 | + f"Domain {node.domain} not found in model opset imports, " |
| 759 | + f"using fallback_customop_version={fallback_customop_version}" |
| 760 | + ) |
| 761 | + return getCustomOp(node, onnx_opset_version=fallback_customop_version) |
| 762 | + |
| 763 | + def set_opset_import(self, domain, version): |
| 764 | + """Sets the opset version for a given domain in the model's opset imports. |
| 765 | + If the domain already exists, its version will be updated. If not, a new |
| 766 | + opset import will be added. |
| 767 | +
|
| 768 | + Args: |
| 769 | + domain (str): The domain name (e.g. "qonnx.custom_op.general") |
| 770 | + version (int): The opset version number |
| 771 | + """ |
| 772 | + # find if domain already exists in opset imports |
| 773 | + for opset in self._model_proto.opset_import: |
| 774 | + if opset.domain == domain: |
| 775 | + opset.version = version |
| 776 | + return |
| 777 | + # domain not found, add new opset import |
| 778 | + new_opset = onnx.OperatorSetIdProto() |
| 779 | + new_opset.domain = domain |
| 780 | + new_opset.version = version |
| 781 | + self._model_proto.opset_import.append(new_opset) |
0 commit comments