diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index d23ce8ac..c7b2c676 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -39,7 +39,7 @@ import qonnx.util.basic as util import qonnx.util.onnx as onnxutil from qonnx.core.datatype import DataType -from qonnx.custom_op.registry import is_custom_op +from qonnx.custom_op.registry import getCustomOp, is_custom_op from qonnx.transformation.double_to_single_float import DoubleToSingleFloat from qonnx.transformation.general import ( RemoveStaticGraphInputs, @@ -183,7 +183,7 @@ def transform( if self.fix_float64: (transformed_model, model_was_changed) = DoubleToSingleFloat().apply(transformed_model) - if apply_to_subgraphs and not use_preorder_traversal: + if apply_to_subgraphs and (use_preorder_traversal is False): transformed_model.transform_subgraphs( transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal ) @@ -738,3 +738,44 @@ def set_tensor_sparsity(self, tensor_name, sparsity_dict): qa.tensor_name = tensor_name qa.quant_parameter_tensor_names.append(dt) qnt_annotations.append(qa) + + def get_opset_imports(self): + """Returns a list of imported opsets as a {domain, version} dictionary.""" + return {opset.domain: opset.version for opset in self._model_proto.opset_import} + + def get_customop_wrapper(self, node, fallback_customop_version=util.get_preferred_qonnx_opset()): + """Return CustomOp instance for given node, respecting the + imported opset version in the model protobuf. If the node's domain + is not found in the model's opset imports, fallback_customop_version + will be used.""" + opset_imports = self.get_opset_imports() + try: + opset_import = opset_imports[node.domain] + return getCustomOp(node, onnx_opset_version=opset_import) + except KeyError: + # domain not found in imports, use fallback version + warnings.warn( + f"Domain {node.domain} not found in model opset imports, " + f"using fallback_customop_version={fallback_customop_version}" + ) + return getCustomOp(node, onnx_opset_version=fallback_customop_version) + + def set_opset_import(self, domain, version): + """Sets the opset version for a given domain in the model's opset imports. + If the domain already exists, its version will be updated. If not, a new + opset import will be added. + + Args: + domain (str): The domain name (e.g. "qonnx.custom_op.general") + version (int): The opset version number + """ + # find if domain already exists in opset imports + for opset in self._model_proto.opset_import: + if opset.domain == domain: + opset.version = version + return + # domain not found, add new opset import + new_opset = onnx.OperatorSetIdProto() + new_opset.domain = domain + new_opset.version = version + self._model_proto.opset_import.append(new_opset) diff --git a/src/qonnx/transformation/fixedpt_quantize.py b/src/qonnx/transformation/fixedpt_quantize.py index 894d7ea6..127fa4b1 100644 --- a/src/qonnx/transformation/fixedpt_quantize.py +++ b/src/qonnx/transformation/fixedpt_quantize.py @@ -41,14 +41,19 @@ def default_op_filter(op): class FixedPointQuantizeParamsFromDict(Transformation): """ - Quantize model parameters to a given fixed-point representation. - The self.max_err dictionary stores the maximum error for each quantized input after calling. - Parameters: - fixedpt_dict: Dictionary containing tensor names and their corresponding target fixed-point - data type or its canonical name - rounding_mode: Rounding mode used for conversion into fixed point. - Default is "ROUND", - possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN", "HALF_UP", "HALF_DOWN"] + Quantize model parameters to a given fixed-point representation. + The self.max_err dictionary stores the maximum error for each quantized input after calling. + Parameters: + fixedpt_dict: Dictionary containing tensor names and their corresponding target fixed-point + <<<<<<< HEAD + data type or its canonical name + ======= + data type or its canonical name + >>>>>>> 7dfc4b8 ([Lint] rerun linter, fix errors) + rounding_mode: Rounding mode used for conversion into fixed point. + Default is "ROUND", + possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN", + "HALF_UP", "HALF_DOWN"] """ def __init__(self, fixedpt_dict, rounding_mode="ROUND"): @@ -66,7 +71,7 @@ def apply(self, model: ModelWrapper): if current_dtype.is_fixed_point(): warn( f"Tensor {tname} is already a {current_dtype.get_canonical_name()} type. " - f"Recasting to {tdtype.get_canonical_name()}" + "Recasting to {tdtype.get_canonical_name()}" ) in1_t_new = self.round_func(in1_t.astype(np.float32) / tdtype.scale_factor()) * tdtype.scale_factor() diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 4e300dd1..17957d12 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -51,11 +51,19 @@ def get_preferred_onnx_opset(): return 11 +def get_preferred_qonnx_opset(): + "Return preferred ONNX opset version for QONNX" + return 1 + + def qonnx_make_model(graph_proto, **kwargs): "Wrapper around ONNX make_model with preferred qonnx opset version" opset_imports = kwargs.pop("opset_imports", None) if opset_imports is None: - opset_imports = [make_opsetid("", get_preferred_onnx_opset())] + opset_imports = [ + make_opsetid("", get_preferred_onnx_opset()), + make_opsetid("qonnx.custom_op.general", get_preferred_qonnx_opset()), + ] kwargs["opset_imports"] = opset_imports else: kwargs["opset_imports"] = opset_imports diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py index 722f0fb1..fb26e420 100644 --- a/tests/core/test_modelwrapper.py +++ b/tests/core/test_modelwrapper.py @@ -33,7 +33,7 @@ import qonnx.core.data_layout as DataLayout from qonnx.core.datatype import DataType from qonnx.core.modelwrapper import ModelWrapper -from qonnx.util.basic import qonnx_make_model +from qonnx.util.basic import get_preferred_onnx_opset, qonnx_make_model def test_modelwrapper(): @@ -68,6 +68,7 @@ def test_modelwrapper(): inp_sparsity = {"dw": {"kernel_shape": [3, 3]}} model.set_tensor_sparsity(first_conv_iname, inp_sparsity) assert model.get_tensor_sparsity(first_conv_iname) == inp_sparsity + assert model.get_opset_imports() == {"": 8} def test_modelwrapper_set_get_rm_initializer(): @@ -230,3 +231,31 @@ def test_modelwrapper_set_tensor_shape_multiple_inputs(): # check that order of inputs is preserved assert model.graph.input[0].name == "in1" assert model.graph.input[1].name == "in2" + + +def test_modelwrapper_set_opset_import(): + # Create a simple model + in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, [4, 4]) + out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [4, 4]) + node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["out1"]) + graph = onnx.helper.make_graph( + nodes=[node], + name="single_node_graph", + inputs=[in1], + outputs=[out1], + ) + onnx_model = qonnx_make_model(graph, producer_name="opset-test-model") + model = ModelWrapper(onnx_model) + + # Test setting new domain + model.set_opset_import("qonnx.custom_op.general", 1) + preferred_onnx_opset = get_preferred_onnx_opset() + assert model.get_opset_imports() == {"": preferred_onnx_opset, "qonnx.custom_op.general": 1} + + # Test updating existing domain + model.set_opset_import("qonnx.custom_op.general", 2) + assert model.get_opset_imports() == {"": preferred_onnx_opset, "qonnx.custom_op.general": 2} + + # Test setting ONNX main domain + model.set_opset_import("", 13) + assert model.get_opset_imports() == {"": 13, "qonnx.custom_op.general": 2}