Skip to content

Commit c090ba6

Browse files
tiagoshibatawschin
authored andcommitted
Manipulation of metadata props (#115)
* Add image metadata to CoreML conversion * Remove extra properties in data_types * Add set_denotation * Remove properties with same case insensitive name * Add denotation automatically on images * Check ONNX version before adding metadata * Use StrictVersion for comparison * Fix code style * Move metadata and denotation to Topology/TensorType * Change Topology default metadata_props value * Match color_space with ONNX definitions * Refactor case insensitive dictionaries * Use warnings module * Put valid metadata props in separate constant * Avoid casefold() for compatibility with Python 2 * Don't guess the color space and nominal pixel range from CoreML * Add tests to image metadata * Don't set denotation in ONNX < 1.2.2 * Use unittest.skipIf instead of running empty test * Add docstrings * Add usage examples * Forward tensor denotation when removing redundant variables * Only check denotation in tensors
1 parent be66467 commit c090ba6

File tree

9 files changed

+220
-21
lines changed

9 files changed

+220
-21
lines changed

onnxmltools/convert/common/_topology.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from distutils.version import StrictVersion
1010
from ...proto import onnx
1111
from ...proto import helper
12+
from ...utils.metadata_props import add_metadata_props
1213
from .data_types import *
1314
from ._container import ModelComponentContainer
1415
from . import _registration
@@ -234,7 +235,7 @@ class Topology:
234235

235236
def __init__(self, model, default_batch_size=1, initial_types=None,
236237
reserved_variable_names=None, reserved_operator_names=None, targeted_onnx=None,
237-
custom_conversion_functions=None, custom_shape_calculators=None):
238+
custom_conversion_functions=None, custom_shape_calculators=None, metadata_props=None):
238239
'''
239240
Initialize a Topology object, which is an intermediate representation of a computational graph.
240241
@@ -253,6 +254,7 @@ def __init__(self, model, default_batch_size=1, initial_types=None,
253254
self.variable_name_set = reserved_variable_names if reserved_variable_names is not None else set()
254255
self.operator_name_set = reserved_operator_names if reserved_operator_names is not None else set()
255256
self.initial_types = initial_types if initial_types else list()
257+
self.metadata_props = metadata_props if metadata_props else dict()
256258
self.default_batch_size = default_batch_size
257259
self.targeted_onnx_version = StrictVersion(targeted_onnx)
258260
self.custom_conversion_functions = custom_conversion_functions if custom_conversion_functions else {}
@@ -520,19 +522,29 @@ def _resolve_duplicates(self):
520522
continue
521523
another_operator.inputs[i] = original
522524

523-
# When original variable's document string is empty but duplicate's document string is not, we
524-
# copy that non-empty string to the original variable to avoid information loss.
525+
# When original variable's documentation string or denotation is empty but duplicate's is not, we
526+
# copy that field to the original variable to avoid information loss.
525527
if not original.type.doc_string and duplicate.type.doc_string:
526528
original.type.doc_string = duplicate.type.doc_string
527529

528-
# Sometime, shapes of duplicates are different. We try to replace the original variable's unknown dimensions
529-
# as many as possible because we will get rid of the duplicate.
530-
if isinstance(original.type, TensorType) and isinstance(duplicate.type, TensorType) and \
531-
len(original.type.shape) == len(duplicate.type.shape):
532-
for i in range(len(original.type.shape)):
533-
if original.type.shape[i] != 'None':
534-
continue
535-
original.type.shape[i] = duplicate.type.shape[i]
530+
if isinstance(original.type, TensorType) and isinstance(duplicate.type, TensorType):
531+
if not original.type.denotation and duplicate.type.denotation:
532+
original.type.denotation = duplicate.type.denotation
533+
if not original.type.channel_denotations:
534+
original.type.channel_denotations = duplicate.type.channel_denotations
535+
elif duplicate.type.channel_denotations:
536+
# Merge the channel denotations if available in both the original and the duplicate
537+
for i in range(len(original.type.channel_denotations)):
538+
if original.type.channel_denotations[i]:
539+
continue
540+
original.type.channel_denotations[i] = duplicate.type.channel_denotations[i]
541+
# Sometime, shapes of duplicates are different. We try to replace the original variable's unknown dimensions
542+
# as many as possible because we will get rid of the duplicate.
543+
if len(original.type.shape) == len(duplicate.type.shape):
544+
for i in range(len(original.type.shape)):
545+
if original.type.shape[i] != 'None':
546+
continue
547+
original.type.shape[i] = duplicate.type.shape[i]
536548

537549
# Because we're iterating through the topology, we cannot delete any operator or variable. Otherwise,
538550
# the traversing function may be broken. We will delete those abandoned ones later.
@@ -735,6 +747,7 @@ def convert_topology(topology, model_name, doc_string, targeted_onnx):
735747
i += 1
736748

737749
# Add extra information
750+
add_metadata_props(onnx_model, topology.metadata_props)
738751
onnx_model.ir_version = onnx_proto.IR_VERSION
739752
onnx_model.producer_name = utils.get_producer()
740753
onnx_model.producer_version = utils.get_producer_version()
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
try:
2+
from collections.abc import Mapping, MutableMapping
3+
except ImportError:
4+
from collections import Mapping, MutableMapping
5+
from collections import OrderedDict
6+
7+
8+
class CaseInsensitiveDict(MutableMapping):
9+
def __init__(self, data=None, **kwargs):
10+
self._dict = OrderedDict()
11+
if data:
12+
self.update(data, **kwargs)
13+
14+
def __setitem__(self, key, value):
15+
self._dict[key.lower()] = (key, value)
16+
17+
def __getitem__(self, key):
18+
return self._dict[key.lower()][1]
19+
20+
def __delitem__(self, key):
21+
del self._dict[key.lower()]
22+
23+
def __iter__(self):
24+
return (key for key, _ in self._dict.values())
25+
26+
def __len__(self):
27+
return len(self._dict)
28+
29+
def lower_key_iteritems(self):
30+
"""Like iteritems(), but with lowercase keys."""
31+
return (
32+
(lower_key, keyval[1])
33+
for lower_key, keyval
34+
in self._dict.items()
35+
)
36+
37+
def __eq__(self, other):
38+
if isinstance(other, Mapping):
39+
other = CaseInsensitiveDict(other)
40+
else:
41+
return NotImplemented
42+
return dict(self.lower_key_iteritems()) == dict(other.lower_key_iteritems())
43+
44+
def copy(self):
45+
return CaseInsensitiveDict(self._dict.values())
46+
47+
def __repr__(self):
48+
return str(dict(self.items()))

onnxmltools/convert/common/data_types.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@ def to_onnx_type(self):
5454

5555

5656
class TensorType(DataType):
57-
def __init__(self, shape=None, doc_string=''):
57+
def __init__(self, shape=None, doc_string='', denotation=None, channel_denotations=None):
5858
super(TensorType, self).__init__([] if not shape else shape, doc_string)
59+
self.denotation = denotation
60+
self.channel_denotations = channel_denotations
5961

6062
def _get_element_onnx_type(self):
6163
raise NotImplementedError()
@@ -71,6 +73,13 @@ def to_onnx_type(self):
7173
s.dim_param = 'None'
7274
else:
7375
raise ValueError('Unsupported dimension type: %s' % type(d))
76+
if getattr(onnx_type, 'denotation', None) is not None:
77+
if self.denotation:
78+
onnx_type.denotation = self.denotation
79+
if self.channel_denotations:
80+
for d, denotation in zip(onnx_type.tensor_type.shape.dim, self.channel_denotations):
81+
if denotation:
82+
d.denotation = denotation
7483
return onnx_type
7584

7685

@@ -83,8 +92,8 @@ def _get_element_onnx_type(self):
8392

8493

8594
class FloatTensorType(TensorType):
86-
def __init__(self, shape=None, color_space=None, doc_string=''):
87-
super(FloatTensorType, self).__init__(shape, doc_string)
95+
def __init__(self, shape=None, color_space=None, doc_string='', denotation=None, channel_denotations=None):
96+
super(FloatTensorType, self).__init__(shape, doc_string, denotation, channel_denotations)
8897
self.color_space = color_space
8998

9099
def _get_element_onnx_type(self):

onnxmltools/convert/coreml/_parse.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# --------------------------------------------------------------------------
66

77
from distutils.version import StrictVersion
8+
import warnings
89
from ...proto import onnx
910
from ..common._container import CoremlModelContainer
1011
from ..common._topology import Topology
@@ -59,8 +60,9 @@ def _parse_coreml_feature(feature_info, targeted_onnx_version, batch_size=1):
5960
raise ValueError('Unknown image format. Only gray-level, RGB, and BGR are supported')
6061
shape.append(raw_type.imageType.height)
6162
shape.append(raw_type.imageType.width)
62-
color_space_map = {10: 'GRAY', 20: 'RGB', 30: 'BGR'}
63-
return FloatTensorType(shape, color_space_map[color_space], doc_string=doc_string)
63+
color_space_map = {10: 'Gray8', 20: 'Rgb8', 30: 'Bgr8'}
64+
return FloatTensorType(shape, color_space_map[color_space], doc_string=doc_string,
65+
denotation='IMAGE', channel_denotations=['DATA_BATCH', 'DATA_CHANNEL', 'DATA_FEATURE', 'DATA_FEATURE'])
6466
elif type_name == 'multiArrayType':
6567
element_type_id = raw_type.multiArrayType.dataType
6668
shape = [d for d in raw_type.multiArrayType.shape]
@@ -463,8 +465,12 @@ def parse_coreml(model, initial_types=None, targeted_onnx=onnx.__version__, cust
463465
_parse_model(topology, scope, model)
464466
topology.compile()
465467

466-
# Use original CoreML names for model-level input(s)/output(s)
467468
for variable in topology.find_root_and_sink_variables():
469+
color_space = getattr(variable.type, 'color_space', None)
470+
if color_space:
471+
if topology.metadata_props.setdefault('Image.BitmapPixelFormat', color_space) != color_space:
472+
warnings.warn('Conflicting pixel formats found. In ONNX, all input/output images must use the same pixel format.')
473+
# Use original CoreML names for model-level input(s)/output(s)
468474
if variable.raw_name not in reserved_variable_names:
469475
continue
470476
topology.rename_variable(variable.onnx_name, variable.raw_name)

onnxmltools/convert/coreml/operator_converters/neural_network/ImageScaler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ def convert_preprocessing_scaler(scope, operator, container):
1414

1515
attrs = {'name': operator.full_name, 'scale': params.channelScale}
1616
color_space = operator.inputs[0].type.color_space
17-
if color_space == 'GRAY':
17+
if color_space == 'Gray8':
1818
attrs['bias'] = [params.grayBias]
19-
elif color_space == 'RGB':
19+
elif color_space == 'Rgb8':
2020
attrs['bias'] = [params.redBias, params.greenBias, params.blueBias]
21-
elif color_space == 'BGR':
21+
elif color_space == 'Bgr8':
2222
attrs['bias'] = [params.blueBias, params.greenBias, params.redBias]
2323
else:
2424
raise ValueError('Unknown color space for tensor {}'.format(operator.inputs[0].full_name))

onnxmltools/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@
1010
from .main import set_model_version
1111
from .main import set_model_domain
1212
from .main import set_model_doc_string
13+
from .metadata_props import add_metadata_props, set_denotation
1314
from .visualize import visualize_model
1415
from .float16_converter import convert_float_to_float16
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import warnings
2+
from ..convert.common.case_insensitive_dict import CaseInsensitiveDict
3+
from ..proto import onnx, onnx_proto
4+
from distutils.version import StrictVersion
5+
6+
7+
KNOWN_METADATA_PROPS = CaseInsensitiveDict({
8+
'Image.BitmapPixelFormat': ['gray8', 'rgb8', 'bgr8', 'rgba8', 'bgra8'],
9+
'Image.ColorSpaceGamma': ['linear', 'srgb'],
10+
'Image.NominalPixelRange': ['nominalrange_0_255', 'normalized_0_1', 'normalized_1_1', 'nominalrange_16_235'],
11+
})
12+
13+
14+
def _validate_metadata(metadata_props):
15+
'''
16+
Validate metadata properties and possibly show warnings or throw exceptions.
17+
18+
:param metadata_props: A dictionary of metadata properties, with property names and values (see :func:`~onnxmltools.utils.metadata_props.add_metadata_props` for examples)
19+
'''
20+
if len(CaseInsensitiveDict(metadata_props)) != len(metadata_props):
21+
raise RuntimeError('Duplicate metadata props found')
22+
23+
for key, value in metadata_props.items():
24+
valid_values = KNOWN_METADATA_PROPS.get(key)
25+
if valid_values and value.lower() not in valid_values:
26+
warnings.warn('Key {} has invalid value {}. Valid values are {}'.format(key, value, valid_values))
27+
28+
29+
def add_metadata_props(onnx_model, metadata_props, targeted_onnx=onnx.__version__):
30+
'''
31+
Add metadata properties to the model. See recommended key names at:
32+
`Extensibility - Metadata <https://github.com/onnx/onnx/blob/296953db87b79c0137c5d9c1a8f26dfaa2495afc/docs/IR.md#metadata>`_ and
33+
`Optional Metadata <https://github.com/onnx/onnx/blob/master/docs/IR.md#optional-metadata>`_
34+
35+
36+
:param onnx_model: ONNX model object
37+
:param metadata_props: A dictionary of metadata properties, with property names and values (example: `{ 'model_author': 'Alice', 'model_license': 'MIT' }`)
38+
:param targeted_onnx: Target ONNX version
39+
'''
40+
if StrictVersion(targeted_onnx) < StrictVersion('1.2.1'):
41+
warnings.warn('Metadata properties are not supported in targeted ONNX-%s' % targeted_onnx)
42+
return
43+
_validate_metadata(metadata_props)
44+
new_metadata = CaseInsensitiveDict({x.key: x.value for x in onnx_model.metadata_props})
45+
new_metadata.update(metadata_props)
46+
del onnx_model.metadata_props[:]
47+
onnx_model.metadata_props.extend(
48+
onnx_proto.StringStringEntryProto(key=key, value=value)
49+
for key, value in metadata_props.items()
50+
)
51+
52+
53+
def set_denotation(onnx_model, input_name, denotation, dimension_denotation=None, targeted_onnx=onnx.__version__):
54+
'''
55+
Set input type denotation and dimension denotation.
56+
57+
Type denotation is a feature in ONNX 1.2.1 that let's the model specify the content of a tensor (e.g. IMAGE or AUDIO).
58+
This information can be used by the backend. One example where it is useful is in images: Whenever data is bound to
59+
a tensor with type denotation IMAGE, the backend can process the data (such as transforming the color space and
60+
pixel format) based on model metadata properties.
61+
62+
:param onnx_model: ONNX model object
63+
:param input_name: Name of input tensor to edit (example: `'data0'`)
64+
:param denotation: Input type denotation (`documentation <https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition>`_)
65+
(example: `'IMAGE'`)
66+
:param dimension_denotation: List of dimension type denotations. The length of the list must be the same of the number of dimensions in the tensor
67+
(`documentation https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition>`_)
68+
(example: `['DATA_BATCH', 'DATA_CHANNEL', 'DATA_FEATURE', 'DATA_FEATURE']`)
69+
:param targeted_onnx: Target ONNX version
70+
'''
71+
if StrictVersion(targeted_onnx) < StrictVersion('1.2.1'):
72+
warnings.warn('Denotation is not supported in targeted ONNX-%s' % targeted_onnx)
73+
return
74+
for graph_input in onnx_model.graph.input:
75+
if graph_input.name == input_name:
76+
graph_input.type.denotation = denotation
77+
if dimension_denotation:
78+
dimensions = graph_input.type.tensor_type.shape.dim
79+
if len(dimension_denotation) != len(dimensions):
80+
raise RuntimeError('Wrong number of dimensions: input "{}" has {} dimensions'.format(input_name, len(dimensions)))
81+
for dimension, channel_denotation in zip(dimensions, dimension_denotation):
82+
dimension.denotation = channel_denotation
83+
return onnx_model
84+
raise RuntimeError('Input "{}" not found'.format(input_name))

tests/coreml/test_AllNeuralNetworkConverters.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
import numpy
33
from coremltools.models.neural_network import NeuralNetworkBuilder
44
from coremltools.models import datatypes
5+
from coremltools.proto.FeatureTypes_pb2 import ImageFeatureType
6+
from distutils.version import StrictVersion
57
from onnxmltools import convert_coreml
8+
from onnxmltools.proto import onnx
69

710
class TestNeuralNetworkLayerConverter(unittest.TestCase):
811

@@ -444,3 +447,26 @@ def test_bidirectional_lstm_converter(self):
444447
model_onnx = convert_coreml(builder.spec)
445448
self.assertTrue(model_onnx is not None)
446449

450+
def test_image_input_type_converter(self):
451+
dim = (3, 15, 25)
452+
inputs = [('input', datatypes.Array(*dim))]
453+
outputs = [('output', datatypes.Array(*dim))]
454+
builder = NeuralNetworkBuilder(inputs, outputs)
455+
builder.add_elementwise(name='Identity', input_names=['input'],
456+
output_name='output', mode='ADD', alpha=0.0)
457+
spec = builder.spec
458+
input = spec.description.input[0]
459+
input.type.imageType.height = dim[1]
460+
input.type.imageType.width = dim[2]
461+
for coreml_colorspace, onnx_colorspace in (('RGB', 'Rgb8'), ('BGR', 'Bgr8'), ('GRAYSCALE', 'Gray8')):
462+
input.type.imageType.colorSpace = ImageFeatureType.ColorSpace.Value(coreml_colorspace)
463+
model_onnx = convert_coreml(spec)
464+
dims = [(d.dim_param or d.dim_value) for d in model_onnx.graph.input[0].type.tensor_type.shape.dim]
465+
self.assertEqual(dims, ['None', 1 if onnx_colorspace == 'Gray8' else 3, 15, 25])
466+
467+
if StrictVersion(onnx.__version__) >= StrictVersion('1.2.1'):
468+
metadata = {prop.key: prop.value for prop in model_onnx.metadata_props}
469+
self.assertEqual(metadata, { 'Image.BitmapPixelFormat': onnx_colorspace })
470+
self.assertEqual(model_onnx.graph.input[0].type.denotation, 'IMAGE')
471+
channel_denotations = [d.denotation for d in model_onnx.graph.input[0].type.tensor_type.shape.dim]
472+
self.assertEqual(channel_denotations, ['DATA_BATCH', 'DATA_CHANNEL', 'DATA_FEATURE', 'DATA_FEATURE'])

tests/utils/test_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
"""
22
Tests utilities.
33
"""
4+
from distutils.version import StrictVersion
45
import filecmp
6+
from onnxmltools.proto import onnx
57
import os
68
import unittest
79

810
from onnxmltools.utils import load_model, save_model, save_text
9-
from onnxmltools.utils import set_model_version, set_model_domain, set_model_doc_string
11+
from onnxmltools.utils import set_denotation, set_model_version, set_model_domain, set_model_doc_string
1012

1113

1214
class TestUtils(unittest.TestCase):
@@ -69,3 +71,13 @@ def test_set_docstring_blank(self):
6971
self.assertRaises(ValueError, set_model_doc_string, onnx_model.doc_string, "sample")
7072
set_model_doc_string(onnx_model, "", True)
7173
self.assertEqual(onnx_model.doc_string, "")
74+
75+
@unittest.skipIf(StrictVersion(onnx.__version__) < StrictVersion('1.2.1'),
76+
"not supported in this ONNX version")
77+
def test_set_denotation(self):
78+
this = os.path.dirname(__file__)
79+
onnx_file = os.path.join(this, "models", "coreml_OneHotEncoder_BikeSharing.onnx")
80+
onnx_model = load_model(onnx_file)
81+
set_denotation(onnx_model, "1", "IMAGE", dimension_denotation=["DATA_FEATURE"])
82+
self.assertEqual(onnx_model.graph.input[0].type.denotation, "IMAGE")
83+
self.assertEqual(onnx_model.graph.input[0].type.tensor_type.shape.dim[0].denotation, "DATA_FEATURE")

0 commit comments

Comments
 (0)