Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions src/qonnx/core/modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import qonnx.util.basic as util
import qonnx.util.onnx as onnxutil
from qonnx.core.datatype import DataType
from qonnx.custom_op.registry import is_custom_op
from qonnx.custom_op.registry import getCustomOp, is_custom_op
from qonnx.transformation.double_to_single_float import DoubleToSingleFloat
from qonnx.transformation.general import (
RemoveStaticGraphInputs,
Expand Down Expand Up @@ -183,7 +183,7 @@ def transform(
if self.fix_float64:
(transformed_model, model_was_changed) = DoubleToSingleFloat().apply(transformed_model)

if apply_to_subgraphs and not use_preorder_traversal:
if apply_to_subgraphs and (use_preorder_traversal is False):
transformed_model.transform_subgraphs(
transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal
)
Expand Down Expand Up @@ -738,3 +738,44 @@ def set_tensor_sparsity(self, tensor_name, sparsity_dict):
qa.tensor_name = tensor_name
qa.quant_parameter_tensor_names.append(dt)
qnt_annotations.append(qa)

def get_opset_imports(self):
"""Returns a list of imported opsets as a {domain, version} dictionary."""
return {opset.domain: opset.version for opset in self._model_proto.opset_import}

def get_customop_wrapper(self, node, fallback_customop_version=util.get_preferred_qonnx_opset()):
"""Return CustomOp instance for given node, respecting the
imported opset version in the model protobuf. If the node's domain
is not found in the model's opset imports, fallback_customop_version
will be used."""
opset_imports = self.get_opset_imports()
try:
opset_import = opset_imports[node.domain]
return getCustomOp(node, onnx_opset_version=opset_import)
except KeyError:
# domain not found in imports, use fallback version
warnings.warn(
f"Domain {node.domain} not found in model opset imports, "
f"using fallback_customop_version={fallback_customop_version}"
)
return getCustomOp(node, onnx_opset_version=fallback_customop_version)

def set_opset_import(self, domain, version):
"""Sets the opset version for a given domain in the model's opset imports.
If the domain already exists, its version will be updated. If not, a new
opset import will be added.

Args:
domain (str): The domain name (e.g. "qonnx.custom_op.general")
version (int): The opset version number
"""
# find if domain already exists in opset imports
for opset in self._model_proto.opset_import:
if opset.domain == domain:
opset.version = version
return
# domain not found, add new opset import
new_opset = onnx.OperatorSetIdProto()
new_opset.domain = domain
new_opset.version = version
self._model_proto.opset_import.append(new_opset)
23 changes: 14 additions & 9 deletions src/qonnx/transformation/fixedpt_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,19 @@ def default_op_filter(op):

class FixedPointQuantizeParamsFromDict(Transformation):
"""
Quantize model parameters to a given fixed-point representation.
The self.max_err dictionary stores the maximum error for each quantized input after calling.
Parameters:
fixedpt_dict: Dictionary containing tensor names and their corresponding target fixed-point
data type or its canonical name
rounding_mode: Rounding mode used for conversion into fixed point.
Default is "ROUND",
possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN", "HALF_UP", "HALF_DOWN"]
Quantize model parameters to a given fixed-point representation.
The self.max_err dictionary stores the maximum error for each quantized input after calling.
Parameters:
fixedpt_dict: Dictionary containing tensor names and their corresponding target fixed-point
<<<<<<< HEAD
data type or its canonical name
=======
data type or its canonical name
>>>>>>> 7dfc4b8 ([Lint] rerun linter, fix errors)
rounding_mode: Rounding mode used for conversion into fixed point.
Default is "ROUND",
possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN",
"HALF_UP", "HALF_DOWN"]
"""

def __init__(self, fixedpt_dict, rounding_mode="ROUND"):
Expand All @@ -66,7 +71,7 @@ def apply(self, model: ModelWrapper):
if current_dtype.is_fixed_point():
warn(
f"Tensor {tname} is already a {current_dtype.get_canonical_name()} type. "
f"Recasting to {tdtype.get_canonical_name()}"
"Recasting to {tdtype.get_canonical_name()}"
)

in1_t_new = self.round_func(in1_t.astype(np.float32) / tdtype.scale_factor()) * tdtype.scale_factor()
Expand Down
10 changes: 9 additions & 1 deletion src/qonnx/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,19 @@ def get_preferred_onnx_opset():
return 11


def get_preferred_qonnx_opset():
"Return preferred ONNX opset version for QONNX"
return 1


def qonnx_make_model(graph_proto, **kwargs):
"Wrapper around ONNX make_model with preferred qonnx opset version"
opset_imports = kwargs.pop("opset_imports", None)
if opset_imports is None:
opset_imports = [make_opsetid("", get_preferred_onnx_opset())]
opset_imports = [
make_opsetid("", get_preferred_onnx_opset()),
make_opsetid("qonnx.custom_op.general", get_preferred_qonnx_opset()),
]
kwargs["opset_imports"] = opset_imports
else:
kwargs["opset_imports"] = opset_imports
Expand Down
31 changes: 30 additions & 1 deletion tests/core/test_modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import qonnx.core.data_layout as DataLayout
from qonnx.core.datatype import DataType
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.util.basic import qonnx_make_model
from qonnx.util.basic import get_preferred_onnx_opset, qonnx_make_model


def test_modelwrapper():
Expand Down Expand Up @@ -68,6 +68,7 @@ def test_modelwrapper():
inp_sparsity = {"dw": {"kernel_shape": [3, 3]}}
model.set_tensor_sparsity(first_conv_iname, inp_sparsity)
assert model.get_tensor_sparsity(first_conv_iname) == inp_sparsity
assert model.get_opset_imports() == {"": 8}


def test_modelwrapper_set_get_rm_initializer():
Expand Down Expand Up @@ -230,3 +231,31 @@ def test_modelwrapper_set_tensor_shape_multiple_inputs():
# check that order of inputs is preserved
assert model.graph.input[0].name == "in1"
assert model.graph.input[1].name == "in2"


def test_modelwrapper_set_opset_import():
# Create a simple model
in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, [4, 4])
out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [4, 4])
node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["out1"])
graph = onnx.helper.make_graph(
nodes=[node],
name="single_node_graph",
inputs=[in1],
outputs=[out1],
)
onnx_model = qonnx_make_model(graph, producer_name="opset-test-model")
model = ModelWrapper(onnx_model)

# Test setting new domain
model.set_opset_import("qonnx.custom_op.general", 1)
preferred_onnx_opset = get_preferred_onnx_opset()
assert model.get_opset_imports() == {"": preferred_onnx_opset, "qonnx.custom_op.general": 1}

# Test updating existing domain
model.set_opset_import("qonnx.custom_op.general", 2)
assert model.get_opset_imports() == {"": preferred_onnx_opset, "qonnx.custom_op.general": 2}

# Test setting ONNX main domain
model.set_opset_import("", 13)
assert model.get_opset_imports() == {"": 13, "qonnx.custom_op.general": 2}
Loading