Skip to content

Commit 7631c1a

Browse files
ibadrwenbingl
andauthored
Extend CoreML: ReshapeStatic/LoadConstantND (#430)
Signed-off-by: Islam <[email protected]> Co-authored-by: Wenbing Li <[email protected]>
1 parent a849358 commit 7631c1a

File tree

6 files changed

+123
-0
lines changed

6 files changed

+123
-0
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
7+
from .....proto import helper
8+
from .....proto import onnx_proto
9+
from ....common._registration import register_converter
10+
from ....common._apply_operation import apply_constant
11+
12+
def convert_load_constant_nd(scope, operator, container):
13+
params = operator.raw_operator.loadConstantND
14+
constant_name = scope.get_unique_variable_name('constant')
15+
constant = helper.make_tensor(constant_name, onnx_proto.TensorProto.FLOAT,
16+
params.shape, params.data.floatValue)
17+
18+
apply_constant(scope, operator.output_full_names, container,
19+
operator_name=operator.full_name, value=constant)
20+
21+
register_converter('loadConstantND', convert_load_constant_nd)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
7+
from ....common._apply_operation import apply_reshape
8+
from ....common._registration import register_converter
9+
10+
11+
def convert_reshape_static(scope, operator, container):
12+
from coremltools.proto.NeuralNetwork_pb2 import ReshapeLayerParams as Params
13+
14+
params = operator.raw_operator.reshapeStatic
15+
16+
# print(params)
17+
intra_variable_name = operator.inputs[0].full_name
18+
19+
N = operator.inputs[0].type.shape[0]
20+
if N == 'None':
21+
N = -1
22+
if len(params.targetShape) == 4:
23+
output_shape = [int(d) for d in params.targetShape]
24+
output_shape[0] = N # Overwrite bad default CoreML setting
25+
elif len(params.targetShape) == 3:
26+
output_shape = [N] + [int(d) for d in params.targetShape]
27+
elif len(params.targetShape) == 2:
28+
output_shape = [N] + [int(d) for d in params.targetShape]
29+
else:
30+
raise ValueError('The targeted shape of Reshape (name: %s) must be 3-element or 4-element array but got %s'\
31+
% (operator.full_name, params.targetShape))
32+
33+
apply_reshape(scope=scope, input_name=intra_variable_name, output_name=operator.outputs[0].full_name,
34+
container=container, operator_name=operator.full_name, desired_shape=output_shape)
35+
36+
37+
register_converter('reshapeStatic', convert_reshape_static)

onnxmltools/convert/coreml/operator_converters/neural_network/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from . import InnerProduct
2323
from . import L2Normalize
2424
from . import LoadConstant
25+
from . import LoadConstantND
2526
from . import LRN
2627
from . import LSTM
2728
from . import Max
@@ -35,6 +36,7 @@
3536
from . import Reduce
3637
from . import ReorganizeData
3738
from . import Reshape
39+
from . import ReshapeStatic
3840
from . import Scale
3941
from . import SequenceRepeat
4042
from . import SimpleRNN
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#-------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
#--------------------------------------------------------------------------
6+
7+
from ....common._registration import register_shape_calculator
8+
from ....common.data_types import TensorType, FloatTensorType
9+
from ....common.utils import check_input_and_output_numbers
10+
11+
def calculate_load_constant_nd_output_shapes(operator):
12+
check_input_and_output_numbers(operator, input_count_range=None, output_count_range=1)
13+
14+
output = operator.outputs[0]
15+
16+
# CoreML's constant is always 3-D tensor, so we assume its shape is [C, H, W].
17+
const_shape = operator.raw_operator.loadConstantND.shape
18+
# We convert [C, H, W] to [1, C, H, W] because our parsing code use [N, C, H, W]
19+
const_shape = [1] + [int(d) for d in const_shape]
20+
if output.type is None:
21+
# Use default type
22+
output.type = FloatTensorType(const_shape, doc_string=output.type.doc_string)
23+
else:
24+
if not isinstance(output.type, TensorType):
25+
raise RuntimeError('Type conflict detected. Output must be a tensor.')
26+
# If output type exists, we just modify its shape.
27+
output.type.shape = const_shape
28+
29+
30+
register_shape_calculator('loadConstantND', calculate_load_constant_nd_output_shapes)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
7+
from ....common._registration import register_shape_calculator
8+
from ....common.data_types import FloatTensorType
9+
from ....common.utils import check_input_and_output_numbers, check_input_and_output_types
10+
11+
def calculate_reshape_static_output_shapes(operator):
12+
'''
13+
Allowed input/output patterns are
14+
1. [N, C, H, W] ---> [N, C', H', W']
15+
16+
Note that C*H*W should equal to C'*H'*W'.
17+
'''
18+
check_input_and_output_numbers(operator, input_count_range=1, output_count_range=1)
19+
check_input_and_output_types(operator, good_input_types=[FloatTensorType])
20+
21+
params = operator.raw_operator.reshapeStatic
22+
23+
output_shape = list(int(i) for i in params.targetShape)
24+
25+
if len(output_shape) == 3:
26+
output_shape = [operator.inputs[0].type.shape[0]] + output_shape
27+
28+
operator.outputs[0].type.shape = output_shape
29+
30+
31+
register_shape_calculator('reshapeStatic', calculate_reshape_static_output_shapes)

onnxmltools/convert/coreml/shape_calculators/neural_network/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from . import IdentityFloat
1818
from . import InnerProduct
1919
from . import LoadConstant
20+
from . import LoadConstantND
2021
from . import LSTM
2122
from . import Merge
2223
from . import Pad
@@ -25,6 +26,7 @@
2526
from . import Reduce
2627
from . import ReorganizeData
2728
from . import Reshape
29+
from . import ReshapeStatic
2830
from . import SequenceRepeat
2931
from . import Slice
3032
from . import Split

0 commit comments

Comments
 (0)