Skip to content

Commit 92e048c

Browse files
committed
example with keras
1 parent 2e440fc commit 92e048c

File tree

6 files changed

+112
-33
lines changed

6 files changed

+112
-33
lines changed

examples/example_keras_tf2onnx.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import numpy as np
2+
import onnx
3+
import tensorflow as tf
4+
import tf2onnx
5+
import tf2onnx.keras2onnx_api
6+
7+
8+
class LeNet(tf.keras.Model):
9+
def __init__(self):
10+
super(LeNet, self).__init__()
11+
self.conv2d_1 = tf.keras.layers.Conv2D(filters=6,
12+
kernel_size=(3, 3), activation='relu',
13+
input_shape=(32, 32, 1))
14+
self.average_pool = tf.keras.layers.AveragePooling2D((3, 3))
15+
self.conv2d_2 = tf.keras.layers.Conv2D(filters=16,
16+
kernel_size=(3, 3), activation='relu')
17+
self.flatten = tf.keras.layers.Flatten()
18+
self.fc_1 = tf.keras.layers.Dense(120, activation='relu')
19+
self.fc_2 = tf.keras.layers.Dense(84, activation='relu')
20+
self.out = tf.keras.layers.Dense(10, activation='softmax')
21+
22+
def call(self, inputs, **kwargs):
23+
x = self.conv2d_1(inputs)
24+
x = self.average_pool(x)
25+
x = self.conv2d_2(x)
26+
x = self.average_pool(x)
27+
x = self.flatten(x)
28+
x = self.fc_2(self.fc_1(x))
29+
return self.out(x)
30+
31+
32+
# Define a simple model
33+
model = LeNet()
34+
data = np.random.rand(2 * 416 * 416 * 3).astype(np.float32).reshape(2, 416, 416, 3)
35+
expected = model(data)
36+
37+
# Get ConcreteFunction
38+
# concrete_func = tf.function(model).get_concrete_function(tf.TensorSpec([None, None, None, None], tf.float32))
39+
oxml = tf2onnx.keras2onnx_api.convert_keras(model, input_signature=[tf.TensorSpec([None, None, None, None], tf.float32)])
40+
onnx.save(oxml, "model.onnx")

tests/keras2onnx_unit_tests/test_subclassing.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import tensorflow as tf
77
from test_utils import convert_keras_for_test as convert_keras
88
from mock_keras2onnx.proto import is_tensorflow_older_than
9+
import tf2onnx
910

1011
if (not mock_keras2onnx.proto.is_tf_keras) or (not mock_keras2onnx.proto.tfcompat.is_tf2):
1112
pytest.skip("Tensorflow 2.0 only tests.", allow_module_level=True)
@@ -17,7 +18,7 @@ def __init__(self):
1718
self.conv2d_1 = tf.keras.layers.Conv2D(filters=6,
1819
kernel_size=(3, 3), activation='relu',
1920
input_shape=(32, 32, 1))
20-
self.average_pool = tf.keras.layers.AveragePooling2D()
21+
self.average_pool = tf.keras.layers.AveragePooling2D((3, 3))
2122
self.conv2d_2 = tf.keras.layers.Conv2D(filters=16,
2223
kernel_size=(3, 3), activation='relu')
2324
self.flatten = tf.keras.layers.Flatten()
@@ -91,8 +92,9 @@ def test_lenet(runner):
9192
lenet = LeNet()
9293
data = np.random.rand(2 * 416 * 416 * 3).astype(np.float32).reshape(2, 416, 416, 3)
9394
expected = lenet(data)
94-
lenet._set_inputs(data)
95-
oxml = convert_keras(lenet)
95+
if hasattr(lenet, "_set_inputs"):
96+
lenet._set_inputs(data)
97+
oxml = convert_keras(lenet, input_signature=[tf.TensorSpec([None, None, None, None], tf.float32)])
9698
assert runner('lenet', oxml, data, expected)
9799

98100

@@ -234,15 +236,28 @@ def call(self, inputs, **kwargs):
234236
swm = Model()
235237
const_in = [tf.Variable([2, 4, 6, 8, 10], dtype=tf.int32, name="input")]
236238
expected = swm(const_in)
237-
if hasattr(swm, "_set_input"):
238-
swm._set_inputs(const_in)
239-
else:
240-
swm.inputs_spec = const_in
241-
if hasattr(swm, "_set_output"):
242-
swm._set_output(expected)
243-
else:
244-
swm.outputs_spec = expected
245-
oxml = convert_keras(swm)
239+
240+
"""
241+
for op in concrete_func.graph.get_operations():
242+
print("--", op.name)
243+
print(op)
244+
245+
print("***", concrete_func.inputs)
246+
print("***", concrete_func.outputs)
247+
"""
248+
run_model = tf.function(swm)
249+
concrete_func = run_model.get_concrete_function(tf.TensorSpec([None], tf.int32))
250+
model_proto, external_tensor_storage = tf2onnx.convert._convert_common(
251+
concrete_func.graph.as_graph_def(),
252+
input_names=[i.name for i in concrete_func.inputs],
253+
output_names=[i.name for i in concrete_func.outputs],
254+
large_model=False,
255+
output_path="where_test.onnx",
256+
)
257+
assert model_proto
258+
assert not external_tensor_storage
259+
260+
oxml = convert_keras(swm, input_signature=[tf.TensorSpec([None], tf.int32)])
246261
assert runner('where_test', oxml, const_in, expected)
247262

248263

tf2onnx/convert.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,7 @@ def from_keras3(model, input_signature=None, opset=None, custom_ops=None, custom
478478
reverse_lookup = {v: k for k, v in tensors_to_rename.items()}
479479

480480
valid_names = []
481+
model_output = None
481482
if hasattr(model, "outputs"):
482483
model_output = model.outputs
483484
else:
@@ -486,7 +487,7 @@ def from_keras3(model, input_signature=None, opset=None, custom_ops=None, custom
486487
elif model_input and len(model_input) == 1:
487488
# Let's try something to make unit test work. This should be replaced.
488489
model_output = [tf.Variable(model_input[0], name="output")]
489-
else:
490+
elif not output_names:
490491
raise RuntimeError(
491492
"You should set attribute 'outputs_spec' with your outputs "
492493
"so that the expected can use that information."
@@ -498,23 +499,20 @@ def _get_name(t, i):
498499
except AttributeError:
499500
return f"output:{i}"
500501

501-
for out in [_get_name(t, i) for i, t in enumerate(model_output)]:
502-
if out in reverse_lookup:
503-
valid_names.append(reverse_lookup[out])
504-
else:
505-
print(f"Warning: Output name '{out}' not found in reverse_lookup.")
506-
# Fallback: verwende TensorFlow-Ausgangsnamen direkt
507-
valid_names = [
508-
_get_name(t, i)
509-
for i, t in enumerate(concrete_func.outputs)
510-
if t.dtype != tf.dtypes.resource
511-
]
512-
break
513-
output_names = valid_names
514-
515-
516-
#if old_out_names is not None:
517-
#model.output_names = old_out_names
502+
if model_output:
503+
for out in [_get_name(t, i) for i, t in enumerate(model_output)]:
504+
if out in reverse_lookup:
505+
valid_names.append(reverse_lookup[out])
506+
else:
507+
print(f"Warning: Output name '{out}' not found in reverse_lookup.")
508+
# Fallback: verwende TensorFlow-Ausgangsnamen direkt
509+
valid_names = [
510+
_get_name(t, i)
511+
for i, t in enumerate(concrete_func.outputs)
512+
if t.dtype != tf.dtypes.resource
513+
]
514+
break
515+
output_names = valid_names
518516

519517
with tf.device("/cpu:0"):
520518
frozen_graph, initialized_tables = \

tf2onnx/keras2onnx_api.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def get_maximum_opset_supported():
4444
return min(max(OPSET_TO_IR_VERSION.keys()), defs.onnx_opset_version())
4545

4646
def convert_keras(model, name=None, doc_string='', target_opset=None, initial_types=None,
47-
channel_first_inputs=None, debug_mode=False, custom_op_conversions=None):
47+
channel_first_inputs=None, debug_mode=False, custom_op_conversions=None,
48+
input_signature=None):
4849
"""
4950
:param model: keras model
5051
:param name: the converted onnx model internal name
@@ -54,16 +55,18 @@ def convert_keras(model, name=None, doc_string='', target_opset=None, initial_ty
5455
:param channel_first_inputs: A list of channel first input
5556
:param debug_mode: ignored
5657
:param custom_op_conversions: ignored
58+
:param input_signature: takes precedence on initial_types if specified,
59+
example: ``[tf.TensorSpec([None], tf.int32)]``
5760
:return an ONNX ModelProto
5861
"""
5962
if target_opset is None:
6063
target_opset = get_maximum_opset_supported()
61-
input_signature = _process_initial_types(initial_types, unknown_dim=None)
64+
if input_signature is None:
65+
input_signature = _process_initial_types(initial_types, unknown_dim=None)
6266
name = name or model.name
6367

6468
model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=target_opset,
6569
inputs_as_nchw=channel_first_inputs)
6670
model.graph.name = name
6771
model.graph.doc_string = doc_string
68-
6972
return model

tf2onnx/onnx_opset/misc.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,13 @@ class NukeNode:
4242
@classmethod
4343
def version_1(cls, ctx, node, **kwargs):
4444
ctx.remove_node(node.name)
45+
46+
47+
@tf_op("StatefulPartitionedCall")
48+
class StatefulPartitionedCall:
49+
@classmethod
50+
def version_1(cls, ctx, node, **kwargs):
51+
raise NotImplementedError(
52+
"This node appears if the graph has local function. It should be inlined first. "
53+
"Inline did not work on that model. It seems Conv2D is no longer inlined."
54+
)

tf2onnx/tf_loader.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,19 @@ def tf_optimize_grappler(input_names, output_names, graph_def):
698698
from tensorflow.core.protobuf import meta_graph_pb2 as meta_graph_pb2, config_pb2, rewriter_config_pb2
699699
from tensorflow.python.grappler import tf_optimizer as tf_opt
700700

701+
rewriter_config = rewriter_config_pb2.RewriterConfig(
702+
function_optimization=rewriter_config_pb2.RewriterConfig.ON
703+
)
704+
705+
config = tf.compat.v1.ConfigProto()
706+
config.graph_options.rewrite_options.CopyFrom(rewriter_config)
707+
708+
with tf.compat.v1.Session(config=config) as sess:
709+
tf.import_graph_def(graph_def, name="")
710+
optimized_graph_def = sess.graph.as_graph_def(add_shapes=True)
711+
712+
return optimized_graph_def
713+
701714
config = config_pb2.ConfigProto()
702715
rewrite_options = config.graph_options.rewrite_options
703716
config.graph_options.infer_shapes = True

0 commit comments

Comments
 (0)