Skip to content

Commit e7ea9b2

Browse files
Merge branch 'master' into my-fix-branch
2 parents 5ac523e + dfc36cd commit e7ea9b2

File tree

180 files changed

+3854
-1142
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

180 files changed

+3854
-1142
lines changed

api_gen.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,13 @@ def create_legacy_directory(package_dir):
8484
for fname in fnames:
8585
if fname.endswith(".py"):
8686
legacy_fpath = os.path.join(root, fname)
87-
tf_keras_root = root.replace("/_legacy", "/_tf_keras/keras")
87+
tf_keras_root = root.replace(
88+
os.path.join(os.path.sep, "_legacy"),
89+
os.path.join(os.path.sep, "_tf_keras", "keras"),
90+
)
8891
core_api_fpath = os.path.join(
89-
root.replace("/_legacy", ""), fname
92+
root.replace(os.path.join(os.path.sep, "_legacy"), ""),
93+
fname,
9094
)
9195
if not os.path.exists(tf_keras_root):
9296
os.makedirs(tf_keras_root)
@@ -125,7 +129,7 @@ def create_legacy_directory(package_dir):
125129
r"\n",
126130
core_api_contents,
127131
)
128-
legacy_contents = core_api_contents + "\n" + legacy_contents
132+
legacy_contents = f"{core_api_contents}\n{legacy_contents}"
129133
with open(tf_keras_fpath, "w") as f:
130134
f.write(legacy_contents)
131135

conftest.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
import os
2-
3-
# When using jax.experimental.enable_x64 in unit test, we want to keep the
4-
# default dtype with 32 bits, aligning it with Keras's default.
5-
os.environ["JAX_DEFAULT_DTYPE_BITS"] = "32"
6-
71
try:
82
# When using torch and tensorflow, torch needs to be imported first,
93
# otherwise it will segfault upon import. This should force the torch

guides/distributed_training_with_tensorflow.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ def make_or_restore_model():
194194
# Either restore the latest model, or create a fresh one
195195
# if there is no checkpoint available.
196196
checkpoints = [
197-
checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)
197+
os.path.join(checkpoint_dir, name)
198+
for name in os.listdir(checkpoint_dir)
198199
]
199200
if checkpoints:
200201
latest_checkpoint = max(checkpoints, key=os.path.getctime)
@@ -216,7 +217,7 @@ def run_training(epochs=1):
216217
# This callback saves a SavedModel every epoch
217218
# We include the current epoch in the folder name.
218219
keras.callbacks.ModelCheckpoint(
219-
filepath=checkpoint_dir + "/ckpt-{epoch}.keras",
220+
filepath=os.path.join(checkpoint_dir, "ckpt-{epoch}.keras"),
220221
save_freq="epoch",
221222
)
222223
]

guides/training_with_built_in_methods.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,8 @@ def make_or_restore_model():
11331133
# Either restore the latest model, or create a fresh one
11341134
# if there is no checkpoint available.
11351135
checkpoints = [
1136-
checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)
1136+
os.path.join(checkpoint_dir, name)
1137+
for name in os.listdir(checkpoint_dir)
11371138
]
11381139
if checkpoints:
11391140
latest_checkpoint = max(checkpoints, key=os.path.getctime)
@@ -1148,7 +1149,8 @@ def make_or_restore_model():
11481149
# This callback saves the model every 100 batches.
11491150
# We include the training loss in the saved model name.
11501151
keras.callbacks.ModelCheckpoint(
1151-
filepath=checkpoint_dir + "/model-loss={loss:.2f}.keras", save_freq=100
1152+
filepath=os.path.join(checkpoint_dir, "model-loss={loss:.2f}.keras"),
1153+
save_freq=100,
11521154
)
11531155
]
11541156
model.fit(x_train, y_train, epochs=1, callbacks=callbacks)

integration_tests/import_test.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,19 @@ def create_virtualenv():
4242
# Create virtual environment
4343
"python3 -m venv test_env",
4444
]
45-
os.environ["PATH"] = (
46-
"/test_env/bin/" + os.pathsep + os.environ.get("PATH", "")
45+
os.environ["PATH"] = os.pathsep.join(
46+
(
47+
os.path.join(os.getcwd(), "test_env", "bin"),
48+
os.environ.get("PATH", ""),
49+
)
4750
)
51+
if os.name == "nt":
52+
os.environ["PATH"] = os.pathsep.join(
53+
(
54+
os.path.join(os.getcwd(), "test_env", "Scripts"),
55+
os.environ["PATH"],
56+
)
57+
)
4858
run_commands_local(env_setup)
4959

5060

@@ -53,18 +63,17 @@ def manage_venv_installs(whl_path):
5363
backend_pkg, backend_extra_url = BACKEND_REQ[backend.backend()]
5464
install_setup = [
5565
# Installs the backend's package and common requirements
56-
"pip install " + backend_extra_url + backend_pkg,
66+
f"pip install {backend_extra_url}{backend_pkg}",
5767
"pip install -r requirements-common.txt",
5868
"pip install pytest",
5969
# Ensure other backends are uninstalled
60-
"pip uninstall -y "
61-
+ BACKEND_REQ[other_backends[0]][0]
62-
+ " "
63-
+ BACKEND_REQ[other_backends[1]][0]
64-
+ " "
65-
+ BACKEND_REQ[other_backends[2]][0],
70+
"pip uninstall -y {0} {1} {2}".format(
71+
BACKEND_REQ[other_backends[0]][0],
72+
BACKEND_REQ[other_backends[1]][0],
73+
BACKEND_REQ[other_backends[2]][0],
74+
),
6675
# Install `.whl` package
67-
"pip install " + whl_path,
76+
f"pip install {whl_path}",
6877
]
6978
# Install flax for JAX when NNX is enabled
7079
if backend.backend() == "jax" and config.is_nnx_enabled():
@@ -102,7 +111,11 @@ def run_commands_venv(commands):
102111
for command in commands:
103112
print(f"Running command: {command}")
104113
cmd_with_args = command.split(" ")
105-
cmd_with_args[0] = "test_env/bin/" + cmd_with_args[0]
114+
cmd_with_args[0] = os.path.join(
115+
"test_env",
116+
"Scripts" if os.name == "nt" else "bin",
117+
cmd_with_args[0],
118+
)
106119
p = subprocess.Popen(cmd_with_args)
107120
assert p.wait() == 0
108121

integration_tests/model_visualization_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def get_node_dict(graph, path=""):
4444

4545
for subgraph in graph.get_subgraphs():
4646
sub_nodes = get_node_dict(
47-
subgraph, path=path + subgraph.get_label() + " > "
47+
subgraph, path=f"{path}{subgraph.get_label()} > "
4848
)
4949
nodes.update(sub_nodes)
5050

@@ -85,7 +85,7 @@ def get_edges(graph):
8585
class ModelVisualizationTest(testing.TestCase):
8686
def multi_plot_model(self, model, name, expand_nested=False):
8787
if expand_nested:
88-
name = name + "-expand_nested"
88+
name = f"{name}-expand_nested"
8989

9090
TEST_CASES = [
9191
{},
@@ -130,7 +130,7 @@ def multi_plot_model(self, model, name, expand_nested=False):
130130

131131
for test_case in TEST_CASES:
132132
tags = [v if k == "rankdir" else k for k, v in test_case.items()]
133-
file_name = "-".join([name] + tags) + ".png"
133+
file_name = f"{'-'.join([name] + tags)}.png"
134134
plot_model(
135135
model, file_name, expand_nested=expand_nested, **test_case
136136
)

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@
195195
from keras.src.ops.numpy import heaviside as heaviside
196196
from keras.src.ops.numpy import histogram as histogram
197197
from keras.src.ops.numpy import hstack as hstack
198+
from keras.src.ops.numpy import hypot as hypot
198199
from keras.src.ops.numpy import identity as identity
199200
from keras.src.ops.numpy import imag as imag
200201
from keras.src.ops.numpy import inner as inner
@@ -204,6 +205,7 @@
204205
from keras.src.ops.numpy import isinf as isinf
205206
from keras.src.ops.numpy import isnan as isnan
206207
from keras.src.ops.numpy import isneginf as isneginf
208+
from keras.src.ops.numpy import isposinf as isposinf
207209
from keras.src.ops.numpy import kaiser as kaiser
208210
from keras.src.ops.numpy import left_shift as left_shift
209211
from keras.src.ops.numpy import less as less

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
from keras.src.ops.numpy import heaviside as heaviside
8484
from keras.src.ops.numpy import histogram as histogram
8585
from keras.src.ops.numpy import hstack as hstack
86+
from keras.src.ops.numpy import hypot as hypot
8687
from keras.src.ops.numpy import identity as identity
8788
from keras.src.ops.numpy import imag as imag
8889
from keras.src.ops.numpy import inner as inner
@@ -92,6 +93,7 @@
9293
from keras.src.ops.numpy import isinf as isinf
9394
from keras.src.ops.numpy import isnan as isnan
9495
from keras.src.ops.numpy import isneginf as isneginf
96+
from keras.src.ops.numpy import isposinf as isposinf
9597
from keras.src.ops.numpy import kaiser as kaiser
9698
from keras.src.ops.numpy import left_shift as left_shift
9799
from keras.src.ops.numpy import less as less

keras/api/_tf_keras/keras/quantizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from keras.src.quantizers import deserialize as deserialize
88
from keras.src.quantizers import get as get
99
from keras.src.quantizers import serialize as serialize
10+
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
1011
from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
1112
from keras.src.quantizers.quantizers import Quantizer as Quantizer
1213
from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize

keras/api/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@
195195
from keras.src.ops.numpy import heaviside as heaviside
196196
from keras.src.ops.numpy import histogram as histogram
197197
from keras.src.ops.numpy import hstack as hstack
198+
from keras.src.ops.numpy import hypot as hypot
198199
from keras.src.ops.numpy import identity as identity
199200
from keras.src.ops.numpy import imag as imag
200201
from keras.src.ops.numpy import inner as inner
@@ -204,6 +205,7 @@
204205
from keras.src.ops.numpy import isinf as isinf
205206
from keras.src.ops.numpy import isnan as isnan
206207
from keras.src.ops.numpy import isneginf as isneginf
208+
from keras.src.ops.numpy import isposinf as isposinf
207209
from keras.src.ops.numpy import kaiser as kaiser
208210
from keras.src.ops.numpy import left_shift as left_shift
209211
from keras.src.ops.numpy import less as less

0 commit comments

Comments
 (0)