Skip to content

Commit f2602d1

Browse files
authored
Merge pull request #215 from fastmachinelearning/feature/opset_utils
Opset version utilities
2 parents e30f89c + a097b41 commit f2602d1

File tree

4 files changed

+96
-13
lines changed

4 files changed

+96
-13
lines changed

src/qonnx/core/modelwrapper.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
import qonnx.util.basic as util
4040
import qonnx.util.onnx as onnxutil
4141
from qonnx.core.datatype import DataType
42-
from qonnx.custom_op.registry import is_custom_op
42+
from qonnx.custom_op.registry import getCustomOp, is_custom_op
4343
from qonnx.transformation.double_to_single_float import DoubleToSingleFloat
4444
from qonnx.transformation.general import (
4545
RemoveStaticGraphInputs,
@@ -183,7 +183,7 @@ def transform(
183183
if self.fix_float64:
184184
(transformed_model, model_was_changed) = DoubleToSingleFloat().apply(transformed_model)
185185

186-
if apply_to_subgraphs and not use_preorder_traversal:
186+
if apply_to_subgraphs and (use_preorder_traversal is False):
187187
transformed_model.transform_subgraphs(
188188
transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal
189189
)
@@ -738,3 +738,44 @@ def set_tensor_sparsity(self, tensor_name, sparsity_dict):
738738
qa.tensor_name = tensor_name
739739
qa.quant_parameter_tensor_names.append(dt)
740740
qnt_annotations.append(qa)
741+
742+
def get_opset_imports(self):
743+
"""Returns a list of imported opsets as a {domain, version} dictionary."""
744+
return {opset.domain: opset.version for opset in self._model_proto.opset_import}
745+
746+
def get_customop_wrapper(self, node, fallback_customop_version=util.get_preferred_qonnx_opset()):
747+
"""Return CustomOp instance for given node, respecting the
748+
imported opset version in the model protobuf. If the node's domain
749+
is not found in the model's opset imports, fallback_customop_version
750+
will be used."""
751+
opset_imports = self.get_opset_imports()
752+
try:
753+
opset_import = opset_imports[node.domain]
754+
return getCustomOp(node, onnx_opset_version=opset_import)
755+
except KeyError:
756+
# domain not found in imports, use fallback version
757+
warnings.warn(
758+
f"Domain {node.domain} not found in model opset imports, "
759+
f"using fallback_customop_version={fallback_customop_version}"
760+
)
761+
return getCustomOp(node, onnx_opset_version=fallback_customop_version)
762+
763+
def set_opset_import(self, domain, version):
764+
"""Sets the opset version for a given domain in the model's opset imports.
765+
If the domain already exists, its version will be updated. If not, a new
766+
opset import will be added.
767+
768+
Args:
769+
domain (str): The domain name (e.g. "qonnx.custom_op.general")
770+
version (int): The opset version number
771+
"""
772+
# find if domain already exists in opset imports
773+
for opset in self._model_proto.opset_import:
774+
if opset.domain == domain:
775+
opset.version = version
776+
return
777+
# domain not found, add new opset import
778+
new_opset = onnx.OperatorSetIdProto()
779+
new_opset.domain = domain
780+
new_opset.version = version
781+
self._model_proto.opset_import.append(new_opset)

src/qonnx/transformation/fixedpt_quantize.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,19 @@ def default_op_filter(op):
4141

4242
class FixedPointQuantizeParamsFromDict(Transformation):
4343
"""
44-
Quantize model parameters to a given fixed-point representation.
45-
The self.max_err dictionary stores the maximum error for each quantized input after calling.
46-
Parameters:
47-
fixedpt_dict: Dictionary containing tensor names and their corresponding target fixed-point
48-
data type or its canonical name
49-
rounding_mode: Rounding mode used for conversion into fixed point.
50-
Default is "ROUND",
51-
possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN", "HALF_UP", "HALF_DOWN"]
44+
Quantize model parameters to a given fixed-point representation.
45+
The self.max_err dictionary stores the maximum error for each quantized input after calling.
46+
Parameters:
47+
fixedpt_dict: Dictionary containing tensor names and their corresponding target fixed-point
48+
<<<<<<< HEAD
49+
data type or its canonical name
50+
=======
51+
data type or its canonical name
52+
>>>>>>> 7dfc4b8 ([Lint] rerun linter, fix errors)
53+
rounding_mode: Rounding mode used for conversion into fixed point.
54+
Default is "ROUND",
55+
possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN",
56+
"HALF_UP", "HALF_DOWN"]
5257
"""
5358

5459
def __init__(self, fixedpt_dict, rounding_mode="ROUND"):
@@ -66,7 +71,7 @@ def apply(self, model: ModelWrapper):
6671
if current_dtype.is_fixed_point():
6772
warn(
6873
f"Tensor {tname} is already a {current_dtype.get_canonical_name()} type. "
69-
f"Recasting to {tdtype.get_canonical_name()}"
74+
"Recasting to {tdtype.get_canonical_name()}"
7075
)
7176

7277
in1_t_new = self.round_func(in1_t.astype(np.float32) / tdtype.scale_factor()) * tdtype.scale_factor()

src/qonnx/util/basic.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,19 @@ def get_preferred_onnx_opset():
5151
return 11
5252

5353

54+
def get_preferred_qonnx_opset():
55+
"Return preferred ONNX opset version for QONNX"
56+
return 1
57+
58+
5459
def qonnx_make_model(graph_proto, **kwargs):
5560
"Wrapper around ONNX make_model with preferred qonnx opset version"
5661
opset_imports = kwargs.pop("opset_imports", None)
5762
if opset_imports is None:
58-
opset_imports = [make_opsetid("", get_preferred_onnx_opset())]
63+
opset_imports = [
64+
make_opsetid("", get_preferred_onnx_opset()),
65+
make_opsetid("qonnx.custom_op.general", get_preferred_qonnx_opset()),
66+
]
5967
kwargs["opset_imports"] = opset_imports
6068
else:
6169
kwargs["opset_imports"] = opset_imports

tests/core/test_modelwrapper.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import qonnx.core.data_layout as DataLayout
3434
from qonnx.core.datatype import DataType
3535
from qonnx.core.modelwrapper import ModelWrapper
36-
from qonnx.util.basic import qonnx_make_model
36+
from qonnx.util.basic import get_preferred_onnx_opset, qonnx_make_model
3737

3838

3939
def test_modelwrapper():
@@ -68,6 +68,7 @@ def test_modelwrapper():
6868
inp_sparsity = {"dw": {"kernel_shape": [3, 3]}}
6969
model.set_tensor_sparsity(first_conv_iname, inp_sparsity)
7070
assert model.get_tensor_sparsity(first_conv_iname) == inp_sparsity
71+
assert model.get_opset_imports() == {"": 8}
7172

7273

7374
def test_modelwrapper_set_get_rm_initializer():
@@ -230,3 +231,31 @@ def test_modelwrapper_set_tensor_shape_multiple_inputs():
230231
# check that order of inputs is preserved
231232
assert model.graph.input[0].name == "in1"
232233
assert model.graph.input[1].name == "in2"
234+
235+
236+
def test_modelwrapper_set_opset_import():
237+
# Create a simple model
238+
in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, [4, 4])
239+
out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [4, 4])
240+
node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["out1"])
241+
graph = onnx.helper.make_graph(
242+
nodes=[node],
243+
name="single_node_graph",
244+
inputs=[in1],
245+
outputs=[out1],
246+
)
247+
onnx_model = qonnx_make_model(graph, producer_name="opset-test-model")
248+
model = ModelWrapper(onnx_model)
249+
250+
# Test setting new domain
251+
model.set_opset_import("qonnx.custom_op.general", 1)
252+
preferred_onnx_opset = get_preferred_onnx_opset()
253+
assert model.get_opset_imports() == {"": preferred_onnx_opset, "qonnx.custom_op.general": 1}
254+
255+
# Test updating existing domain
256+
model.set_opset_import("qonnx.custom_op.general", 2)
257+
assert model.get_opset_imports() == {"": preferred_onnx_opset, "qonnx.custom_op.general": 2}
258+
259+
# Test setting ONNX main domain
260+
model.set_opset_import("", 13)
261+
assert model.get_opset_imports() == {"": 13, "qonnx.custom_op.general": 2}

0 commit comments

Comments
 (0)