Skip to content

Commit a0e1d3b

Browse files
committed
first try
1 parent 03c615f commit a0e1d3b

File tree

5 files changed

+49
-5
lines changed

5 files changed

+49
-5
lines changed

Troubleshooting.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,7 @@ An example of this is the [ONNX Slice operator before opset-10](https://github.c
3636
You can pass the options ```--fold_const```(removed after tf2onnx-1.9.3) in the tf2onnx command line that allows tf2onnx to apply more aggressive constant folding which will increase chances to find a constant.
3737

3838
If this doesn't work the model is most likely not to be able to convert to ONNX. We used to see this a lot of issue with the ONNX Slice op and in opset-10 was updated for exactly this reason.
39+
40+
## cudaSetDevice() on GPU:0 failed. Status: CUDA-capable device(s) is/are busy or unavailable
41+
42+
See [Regression: TF 2.18 crashes with cudaSetDevice failing due to GPU being busy](https://github.com/tensorflow/tensorflow/issues/78784).

examples/end2end_tfkeras.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@
88
*onnxruntime*, *tensorflow* and *tensorflow.lite*.
99
"""
1010
from onnxruntime import InferenceSession
11-
import os
1211
import subprocess
1312
import timeit
1413
import numpy as np
15-
import tensorflow as tf
1614
from tensorflow import keras
1715
from tensorflow.keras import layers, Input
1816

examples/tf_custom_op/double_and_add_one_custom_op.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import os
88
from tf2onnx import utils
99
from tf2onnx.handler import tf_op
10-
from tf2onnx.tf_loader import tf_placeholder
1110

1211

1312
DIR_PATH = os.path.realpath(os.path.dirname(__file__))

tests/keras2onnx_unit_tests/test_subclassing.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,43 @@ def call(self, inputs, **kwargs):
4949
return output
5050

5151

52+
def get_save_spec(model, dynamic_batch=False):
53+
"""Returns the save spec of the subclassing keras model."""
54+
from tensorflow.python.framework import tensor_spec
55+
shapes_dict = getattr(model, '_build_shapes_dict', None)
56+
# TODO: restore dynamic_batch
57+
# assert not dynamic_batch, f"get_save_spec: dynamic_batch={dynamic_batch}, shapes_dict={shapes_dict}"
58+
if not shapes_dict:
59+
return None
60+
61+
if 'input_shape' not in shapes_dict:
62+
raise ValueError(
63+
'Model {} cannot be saved because the input shapes have not been set.'
64+
)
65+
66+
input_shape = shapes_dict['input_shape']
67+
if isinstance(input_shape, tuple):
68+
shape = input_shape
69+
shape = (None,) + shape[1:]
70+
return tensor_spec.TensorSpec(
71+
shape=shape, dtype=model.input_dtype
72+
)
73+
elif isinstance(input_shape, dict):
74+
specs = {}
75+
for key, shape in input_shape.items():
76+
shape = (None,) + shape[1:]
77+
specs[key] = tensor_spec.TensorSpec(
78+
shape=shape, dtype=model.input_dtype, name=key
79+
)
80+
return specs
81+
elif isinstance(input_shape, list):
82+
specs = []
83+
for shape in input_shape:
84+
shape = (None,) + shape[1:]
85+
specs.append(tensor_spec.TensorSpec(shape=shape, dtype=model.input_dtype))
86+
return specs
87+
88+
5289
class SimpleWrapperModel(tf.keras.Model):
5390
def __init__(self, func):
5491
super(SimpleWrapperModel, self).__init__()
@@ -57,6 +94,9 @@ def __init__(self, func):
5794
def call(self, inputs, **kwargs):
5895
return self.func(inputs)
5996

97+
def _get_save_spec(self, dynamic_batch=False):
98+
return get_save_spec(self, dynamic_batch=dynamic_batch)
99+
60100

61101
def test_lenet(runner):
62102
tf.keras.backend.clear_session()
@@ -198,7 +238,10 @@ def _tf_where(input_0):
198238
swm = SimpleWrapperModel(_tf_where)
199239
const_in = [np.array([2, 4, 6, 8, 10]).astype(np.int32)]
200240
expected = swm(const_in)
201-
swm._set_inputs(const_in)
241+
if hasattr(swm, "_set_input"):
242+
swm._set_inputs(const_in)
243+
else:
244+
swm.inputs_spec = const_in
202245
oxml = convert_keras(swm)
203246
assert runner('where_test', oxml, const_in, expected)
204247

tf2onnx/convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_
447447
function = _saving_utils.trace_model_call(model, input_signature)
448448
try:
449449
concrete_func = function.get_concrete_function()
450-
except TypeError as e:
450+
except (TypeError, AttributeError) as e:
451451
# Legacy keras models don't accept the training arg tf provides so we hack around it
452452
if "got an unexpected keyword argument 'training'" not in str(e):
453453
raise e

0 commit comments

Comments
 (0)