Skip to content

Commit 8914427

Browse files
Patch release commits for 3.13.1 (#22005)
* Remove NumPy warning with NumPy >= 2. (#21949) Merely importing keras currently triggers this warning with NumPy 2. ``` keras/src/export/tf2onnx_lib.py:8: FutureWarning: In the future `np.object` will be defined as the corresponding NumPy scalar. ``` Only patch NumPy if and when needed. * Fix CUDNN flash attention for JAX > 0.6.2. (#21970) The signature of `check_is_flash_attention` changed with JAX 0.7.0. In addition to `query` and `key`, a positional argument of `value` was added. This was not caught as our kokoro tests use JAX 0.6.2 because it's the last version that supports Python 3.10. This change was tested here: #21957 * Do no always make batch size dynamic during export. (#21944) This is a follow-up of #21674 This PR changed the signature of `make_tf_tensor_spec` from `(x)` to `(x, dynamic_batch=True)`, thereby adding the ability to make the batch size dynamic. This PR also adds `_get_save_spec(self, dynamic_batch=True)` which uses `make_tf_tensor_spec` and forwards the `dynamic_batch` argument. However, the default before this change for other export (SavedModel, ONNX) was to keep the batch size untouched. In particular, when a user manually provides an `input_signature` to [`ExportArchive.add_endpoint`](https://github.com/keras-team/keras/blob/master/keras/src/export/saved_model.py#L362), we should honor. The user controls whether the batch size is dynamic or not in the `input_signature`. This PR changes the default of `make_tf_tensor_spec` back to `dynamic_batch=False` to revert SavedModel and ONNX exports to the previous behavior. Also removed call to `return super()._get_save_spec(dynamic_batch)` which can never succeed as `TFLayer` is a top level class (ignoring the auto-tracking stuff). * Cherry pick & patch release --------- Co-authored-by: hertschuh <1091026+hertschuh@users.noreply.github.com>
1 parent 986ff97 commit 8914427

File tree

5 files changed

+43
-30
lines changed

5 files changed

+43
-30
lines changed

keras/src/backend/jax/nn.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,25 +1471,42 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False):
14711471
# Only support at least Ampere
14721472
if not check_compute_capability("8.0"):
14731473
raise RuntimeError("Require at least Ampere arch to run")
1474-
# Check inputs layout
1474+
1475+
# Inspect inputs of `check_layout`
14751476
check_layout_params = list(
14761477
inspect.signature(check_layout).parameters.keys()
14771478
)
14781479
for known_param in ("query", "key", "value", "bias", "layout"):
14791480
check_layout_params.remove(known_param)
14801481
# Defaults to `None` when not specified.
1481-
kwargs = {key: None for key in check_layout_params}
1482+
check_layout_kwargs = {key: None for key in check_layout_params}
14821483
check_layout(
1483-
query, key, value, bias, layout=_normalize_layout("BTNH"), **kwargs
1484-
)
1485-
check_is_flash_attention(
14861484
query,
14871485
key,
1488-
_normalize_layout("BTNH"),
1489-
cudnn_version,
1490-
bias is not None,
1491-
is_training=False,
1486+
value,
1487+
bias,
1488+
layout=_normalize_layout("BTNH"),
1489+
**check_layout_kwargs,
14921490
)
1491+
1492+
# Inspect inputs of `check_is_flash_attention`
1493+
check_is_flash_attention_params = inspect.signature(
1494+
check_is_flash_attention
1495+
).parameters
1496+
check_is_flash_attention_kwargs = {
1497+
"query": query,
1498+
"key": key,
1499+
"value": value,
1500+
"layout": _normalize_layout("BTNH"),
1501+
"cudnn_version": cudnn_version,
1502+
"has_bias": bias is not None,
1503+
"is_training": False,
1504+
}
1505+
# Remove unsupported arguments
1506+
for param in list(check_is_flash_attention_kwargs.keys()):
1507+
if param not in check_is_flash_attention_params:
1508+
check_is_flash_attention_kwargs.pop(param)
1509+
check_is_flash_attention(**check_is_flash_attention_kwargs)
14931510
return True
14941511
except:
14951512
if raise_error:

keras/src/backend/tensorflow/layer.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -94,22 +94,18 @@ def _get_save_spec(self, dynamic_batch=True):
9494
A TensorSpec, list or dict mirroring the model inputs, or
9595
`None` when specs cannot be inferred.
9696
"""
97-
# Prefer the base implementation if available
98-
try:
99-
return super()._get_save_spec(dynamic_batch)
100-
except AttributeError:
101-
# Lazy import to avoid circular dependency
102-
from keras.src.export.export_utils import make_tf_tensor_spec
103-
104-
# Fall back to building specs from `self.inputs`
105-
inputs = getattr(self, "inputs", None)
106-
if inputs is None:
107-
return None
108-
109-
return tree.map_structure(
110-
lambda x: make_tf_tensor_spec(x, dynamic_batch=dynamic_batch),
111-
inputs,
112-
)
97+
# Lazy import to avoid circular dependency
98+
from keras.src.export.export_utils import make_tf_tensor_spec
99+
100+
# Fall back to building specs from `self.inputs`
101+
inputs = getattr(self, "inputs", None)
102+
if inputs is None:
103+
return None
104+
105+
return tree.map_structure(
106+
lambda x: make_tf_tensor_spec(x, dynamic_batch=dynamic_batch),
107+
inputs,
108+
)
113109

114110
@property
115111
def _default_save_signature(self):

keras/src/export/export_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def make_input_spec(x):
102102
return input_spec
103103

104104

105-
def make_tf_tensor_spec(x, dynamic_batch=True):
105+
def make_tf_tensor_spec(x, dynamic_batch=False):
106106
"""Create a TensorSpec from various input types.
107107
108108
Args:

keras/src/export/tf2onnx_lib.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55

66
import numpy as np
77

8-
if not hasattr(np, "object"):
9-
np.object = object
10-
118

129
@functools.lru_cache()
1310
def patch_tf2onnx():
@@ -20,6 +17,9 @@ def patch_tf2onnx():
2017

2118
logger = logging.getLogger(tf2onnx.__name__)
2219

20+
if not hasattr(np, "object"):
21+
np.object = object
22+
2323
def patched_rewrite_constant_fold(g, ops):
2424
"""
2525
We call tensorflow transform with constant folding but in some cases

keras/src/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from keras.src.api_export import keras_export
22

33
# Unique source of truth for the version number.
4-
__version__ = "3.13.0"
4+
__version__ = "3.13.1"
55

66

77
@keras_export("keras.version")

0 commit comments

Comments
 (0)