Skip to content

Commit 60ebdf6

Browse files
authored
Fix import of Keras (#2420)
* Fix import of Keras * Fix import of Keras * Fix import of Keras * Fix Keras2 import
1 parent c60112e commit 60ebdf6

File tree

3 files changed

+67
-66
lines changed

3 files changed

+67
-66
lines changed

keras_cv/backend/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,14 @@
2626
- `random`: `keras.random` for Keras 3 or `keras_core.ops` for Keras 2.
2727
"""
2828
from keras_cv.backend import config # noqa: E402
29-
from keras_cv.backend import keras # noqa: E402
29+
30+
if config.keras_3():
31+
import keras # noqa: E402
32+
33+
keras.backend.name_scope = keras.name_scope
34+
else:
35+
import keras_cv.backend.keras2 as keras # noqa: E402
36+
3037
from keras_cv.backend import ops # noqa: E402
3138
from keras_cv.backend import random # noqa: E402
3239
from keras_cv.backend import tf_ops # noqa: E402

keras_cv/backend/keras.py

Lines changed: 0 additions & 65 deletions
This file was deleted.

keras_cv/backend/keras2.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2023 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import types
16+
17+
from tensorflow import keras # noqa: F403, F401
18+
from tensorflow.keras import * # noqa: F403, F401
19+
20+
from keras_cv.backend import config # noqa: F403, F401
21+
22+
_KERAS_CORE_ALIASES = {
23+
"utils->saving": [
24+
"register_keras_serializable",
25+
"deserialize_keras_object",
26+
"serialize_keras_object",
27+
"get_registered_object",
28+
],
29+
"models->saving": ["load_model"],
30+
}
31+
32+
if not hasattr(keras, "saving"):
33+
keras.saving = types.SimpleNamespace()
34+
35+
# add aliases
36+
for key, value in _KERAS_CORE_ALIASES.items():
37+
src, _, dst = key.partition("->")
38+
src = src.split(".")
39+
dst = dst.split(".")
40+
41+
src_mod, dst_mod = keras, keras
42+
43+
# navigate to where we want to alias the attributes
44+
for mod in src:
45+
src_mod = getattr(src_mod, mod)
46+
for mod in dst:
47+
dst_mod = getattr(dst_mod, mod)
48+
49+
# add an alias for each attribute
50+
for attr in value:
51+
if isinstance(attr, tuple):
52+
src_attr, dst_attr = attr
53+
else:
54+
src_attr, dst_attr = attr, attr
55+
attr_val = getattr(src_mod, src_attr)
56+
setattr(dst_mod, dst_attr, attr_val)
57+
58+
# TF Keras doesn't have this rename.
59+
keras.activations.silu = keras.activations.swish

0 commit comments

Comments
 (0)