Skip to content

Commit 0d6c98e

Browse files
committed
Change the way to call legacy functions.
Signed-off-by: Jay Zhang <[email protected]>
1 parent c68331d commit 0d6c98e

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

tests/keras2onnx_unit_tests/conftest.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515
def is_keras_3():
1616
return tf.__version__.startswith("2.18") or tf.__version__.startswith("2.17") or tf.__version__.startswith("2.16")
1717

18-
if is_keras_3():
19-
import tf_keras
20-
K = tf_keras.backend
21-
2218
@pytest.fixture(scope='function')
2319
def runner():
2420
np.random.seed(42)
@@ -31,10 +27,15 @@ def runner():
3127
def runner_func(*args, **kwargs):
3228
return run_onnx_runtime(*args, model_files, **kwargs)
3329

34-
# Ensure Keras layer naming is reset for each function
35-
K.reset_uids()
36-
# Reset the TensorFlow session to avoid resource leaking between tests
37-
K.clear_session()
30+
if is_keras_3():
31+
import tf_keras
32+
tf_keras.backend.reset_uids()
33+
tf_keras.backend.clear_session()
34+
else:
35+
# Ensure Keras layer naming is reset for each function
36+
K.reset_uids()
37+
# Reset the TensorFlow session to avoid resource leaking between tests
38+
K.clear_session()
3839

3940
# Provide wrapped run_onnx_runtime function
4041
yield runner_func

0 commit comments

Comments
 (0)