Skip to content

Commit 90c8da6

Browse files
Fix _can_use_flash_attention. (#21512)
1 parent e704b46 commit 90c8da6

File tree

1 file changed

+9
-9
lines changed
  • keras/src/backend/jax

1 file changed

+9
-9
lines changed

keras/src/backend/jax/nn.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import builtins
2+
import inspect
23
import math
34

45
import jax
@@ -1054,16 +1055,15 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False):
10541055
if not check_compute_capability("8.0"):
10551056
raise RuntimeError("Require at least Ampere arch to run")
10561057
# Check inputs layout
1058+
check_layout_params = list(
1059+
inspect.signature(check_layout).parameters.keys()
1060+
)
1061+
for known_param in ("query", "key", "value", "bias", "layout"):
1062+
check_layout_params.remove(known_param)
1063+
# Defaults to `None` when not specified.
1064+
kwargs = {key: None for key in check_layout_params}
10571065
check_layout(
1058-
query,
1059-
key,
1060-
value,
1061-
bias,
1062-
q_seqlen=None,
1063-
kv_seqlen=None,
1064-
layout=_normalize_layout("BTNH"),
1065-
q_offsets=None,
1066-
kv_offsets=None,
1066+
query, key, value, bias, layout=_normalize_layout("BTNH"), **kwargs
10671067
)
10681068
check_is_flash_attention(
10691069
query,

0 commit comments

Comments
 (0)