Skip to content

Commit 914d7a0

Browse files
singliswschin
authored andcommitted
Handling of model version information (#9)
* - Moving model version, domain, producer, and producer version into functions. This allows for other packages that are using onnmltools to override these variables as needed. * - Updating based on feedback.
1 parent b028b08 commit 914d7a0

File tree

3 files changed

+39
-11
lines changed

3 files changed

+39
-11
lines changed

onnxmltools/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
__producer_version__ = __version__
1313
__domain__ = "onnxml"
1414
__model_version__ = 0
15-
__operator_set_version__ = 0
15+
1616

1717
from .convert import convert_coreml
1818
from .convert import convert_sklearn
1919

2020
from .utils import load_model
2121
from .utils import save_model
22-
from .utils import save_text
22+
from .utils import save_text
23+

onnxmltools/convert/common/ModelBuilder.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,6 @@
88
from uuid import uuid4
99
from ..common import model_util
1010
from ...proto import onnx_proto
11-
from ... import __domain__
12-
from ... import __producer__
13-
from ... import __producer_version__
14-
from ... import __model_version__
15-
1611

1712
class ModelBuilder:
1813
def __init__(self, name=None, doc_string='', metadata_props=[]):
@@ -47,10 +42,10 @@ def add_domain_version_pair(self, pair):
4742
def make_model(self):
4843
return model_util.make_model(self._name,
4944
onnx_proto.IR_VERSION,
50-
__producer__,
51-
__producer_version__,
52-
__domain__,
53-
__model_version__,
45+
model_util.get_producer(),
46+
model_util.get_producer_version(),
47+
model_util.get_domain(),
48+
model_util.get_model_version(),
5449
self._doc_string,
5550
self._metadata_props,
5651
self._operator_domain_version_pairs,

onnxmltools/convert/common/model_util.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,38 @@
1515
onnx_proto.TensorProto.INT16, onnx_proto.TensorProto.INT32, onnx_proto.TensorProto.INT64]
1616

1717

18+
def get_producer():
19+
"""
20+
Internal helper function to return the producer
21+
"""
22+
from ... import __producer__
23+
return __producer__
24+
25+
26+
def get_producer_version():
27+
"""
28+
Internal helper function to return the producer version
29+
"""
30+
from ... import __producer_version__
31+
return __producer_version__
32+
33+
34+
def get_domain():
35+
"""
36+
Internal helper function to return the model domain
37+
"""
38+
from ... import __domain__
39+
return __domain__
40+
41+
42+
def get_model_version():
43+
"""
44+
Internal helper function to return the model version
45+
"""
46+
from ... import __model_version__
47+
return __model_version__
48+
49+
1850
def make_tensor_value_info(name, elem_type=None, shape=None, doc_string=''):
1951
"""
2052
Makes a TypeProto based on the data type and shape.

0 commit comments

Comments
 (0)