Skip to content

Commit 0c0ec1a

Browse files
Re-apply Fixed issue with dot_product_attention when using TPU. #21254 after addressing cuDNN/FlashAttention API updates (#21333)
* Update nn.py * Update nn.py * Update nn.py * Update nn.py * Update nn.py Corrected indentation in doc string * Update nn.py * Update random_grayscale.py Fixed issue with passing a single image without batch dimension. * Update keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py Co-authored-by: Jyotinder Singh <[email protected]> * Update random_grayscale_test.py Test case for unbatched inputs * code reformat * Update random_grayscale_test.py Testcase for checking both unbatched and batched single image inputs. * changed compute_output_spec There was a bug, and it was causing cycle in graph. * Update random_grayscale.py removed the use of tree.map_structure * Reapply "Fixed issue with dot_product_attention when using TPU. (#21254)" (#21329) This reverts commit 81821e0. * Improve error handling in _can_use_flash_attention for better debugging Enhanced the _can_use_flash_attention function to provide more detailed error messages when flash attention compatibility checks fail. Changes: - Replace generic exception catching with specific error propagation - When raise_error=True, directly re-raise original exceptions from check_layout() and check_is_flash_attention() functions - Preserve detailed error context from JAX internal validation functions - Maintain existing behavior when raise_error=False (returns False) This improves debugging experience by surfacing specific technical details about tensor layout incompatibilities, cuDNN version requirements, and other flash attention compatibility issues. Relates to keras-hub PR #2257 and addresses flash attention debugging needs. * Revert "Improve error handling in _can_use_flash_attention for better debugging" This reverts commit 7a0c547. * Fix JAX API compatibility and improve error handling in `_can_use_flash_attention` Changes: - Add missing q_offsets=None and kv_offsets=None parameters to check_layout() call to match updated JAX function signature - Replace bare `except:` with `except Exception as e:` and `raise e` to preserve detailed error messages from JAX validation functions - Maintain existing fallback behavior when raise_error=False This resolves compatibility issues with newer JAX versions and improves debugging experience by surfacing specific technical details about flash attention compatibility failures. * Updated `dot_product_attention` Simplified the check for `flasth_attention` by removing redundant checks that are already done in `_can_use_flash_attention`. * Update nn.py * Update nn.py --------- Co-authored-by: Jyotinder Singh <[email protected]>
1 parent 9286df3 commit 0c0ec1a

File tree

1 file changed

+178
-36
lines changed
  • keras/src/backend/jax

1 file changed

+178
-36
lines changed

keras/src/backend/jax/nn.py

Lines changed: 178 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,8 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False):
10621062
q_seqlen=None,
10631063
kv_seqlen=None,
10641064
layout=_normalize_layout("BTNH"),
1065+
q_offsets=None,
1066+
kv_offsets=None,
10651067
)
10661068
check_is_flash_attention(
10671069
query,
@@ -1126,16 +1128,45 @@ def wrap_flash_attention(
11261128
decoder_segment_ids,
11271129
custom_mask=None,
11281130
attn_logits_soft_cap=None,
1131+
head_shards=1,
1132+
q_seq_shards=1,
11291133
):
1134+
"""Applies a wrapped flash attention mechanism using the Splash kernel.
1135+
This function prepares the appropriate attention mask (causal or custom),
1136+
constructs a multi-head mask, and applies the Splash multi-head attention
1137+
kernel to the provided query, key, and value tensors. It supports optional
1138+
sharding and soft capping of attention logits.
1139+
Args:
1140+
query: jax.Array. The query tensor of shape
1141+
(batch, num_heads, seq_len, head_dim).
1142+
key: jax.Array. The key tensor of shape
1143+
(batch, num_heads, seq_len, head_dim).
1144+
value: jax.Array. The value tensor of shape
1145+
(batch, num_heads, seq_len, head_dim).
1146+
decoder_segment_ids: Optional. Segment IDs for the decoder, used for
1147+
sharding or masking.
1148+
custom_mask: Optional[jax.Array]. A custom attention mask to apply. If
1149+
None, a causal mask is used.
1150+
attn_logits_soft_cap: Optional[float]. If provided, applies a soft cap
1151+
to the attention logits.
1152+
head_shards: int, default=1. Number of shards for the attention heads.
1153+
q_seq_shards: int, default=1. Number of shards for the query sequence
1154+
dimension.
1155+
Returns:
1156+
jax.Array: The result of applying the Splash multi-head attention
1157+
kernel to the inputs.
1158+
Raises:
1159+
AssertionError: If sharding along the sequence dimension is attempted
1160+
with decoder_segment_ids.
1161+
"""
11301162
if decoder_segment_ids is not None:
11311163
assert query.shape[2] == decoder_segment_ids.q.shape[1], (
1132-
"Sharding along sequence dimension not allowed in tpu kernel "
1133-
"attention"
1164+
"Sharding along sequence dimension not allowed"
1165+
" in TPU kernel attention"
11341166
)
11351167

11361168
if custom_mask is not None:
11371169
mask = splash_attention_mask.NumpyMask(array=custom_mask)
1138-
11391170
else:
11401171
mask = splash_attention_mask.CausalMask(
11411172
shape=(query.shape[2], query.shape[2])
@@ -1147,8 +1178,8 @@ def wrap_flash_attention(
11471178
)
11481179
splash_kernel = splash_attention_kernel.make_splash_mha(
11491180
mask=multi_head_mask,
1150-
head_shards=1,
1151-
q_seq_shards=1,
1181+
head_shards=head_shards,
1182+
q_seq_shards=q_seq_shards,
11521183
attn_logits_soft_cap=attn_logits_soft_cap,
11531184
)
11541185

@@ -1168,6 +1199,38 @@ def dot_product_attention(
11681199
flash_attention=None,
11691200
attn_logits_soft_cap=None,
11701201
):
1202+
"""Computes dot-product attention given query, key, and value.
1203+
1204+
This is the core computation of attention that is used in transformers.
1205+
For TPU platforms, flash attention optimizations are automatically applied
1206+
when possible, and sharding parameters are inferred from the layout map
1207+
in the current distribution context.
1208+
1209+
Args:
1210+
query: Queries with shape `[batch, time, heads,
1211+
depth_k]`.
1212+
key: Keys with shape `[batch, time, heads,
1213+
depth_k]`.
1214+
value: Values with shape `[batch, time, heads,
1215+
depth_v]`.
1216+
bias: Optional bias with shape broadcastable to
1217+
`[batch, heads, dest_time, source_time]`.
1218+
mask: Optional mask with shape broadcastable to
1219+
`[batch, heads, dest_time, source_time]`.
1220+
scale: Float. Optional scale that is applied to the attention
1221+
computation.
1222+
is_causal: Boolean. Specifying whether causal masking is applied.
1223+
flash_attention: Boolean. Whether to use flash attention optimization
1224+
for increased performance. Default to None, which means it will
1225+
be auto-determined based on the platform, input shapes and
1226+
compatibility.
1227+
attn_logits_soft_cap: Float. Optional float to softly cap attention
1228+
logits to avoid numerical stability issues. Applied as:
1229+
`logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`.
1230+
1231+
Returns:
1232+
JAX Array of shape `[batch, time, heads, depth_v]`.
1233+
"""
11711234
query = convert_to_tensor(query)
11721235
key = convert_to_tensor(key)
11731236
value = convert_to_tensor(value)
@@ -1177,47 +1240,123 @@ def dot_product_attention(
11771240
f"Received: query.shape={query.shape}, key.shape={key.shape}, "
11781241
f"value.shape={value.shape}."
11791242
)
1243+
1244+
# Check platform
1245+
platform = jax.devices()[0].platform
1246+
is_tpu = platform == "tpu"
1247+
1248+
# Determine flash attention compatibility
11801249
if flash_attention is None:
11811250
flash_attention = _can_use_flash_attention(query, key, value, bias)
11821251
elif flash_attention is True:
11831252
# Use `raise_error=True` to provide more details if the inputs failed to
11841253
# use flash attention
11851254
_can_use_flash_attention(query, key, value, bias, raise_error=True)
11861255

1187-
if jax.devices()[0].platform == "tpu":
1188-
# Transpose to ('batch', 'heads', 'length', 'kv')
1189-
query = jnp.transpose(query, axes=(0, 2, 1, 3))
1190-
key = jnp.transpose(key, axes=(0, 2, 1, 3))
1191-
value = jnp.transpose(value, axes=(0, 2, 1, 3))
1192-
B, H, S, KV = query.shape
1193-
1194-
segment_ids = jnp.ones([B, S])
1195-
# {token_ids, padding_mask, segment_ids} enable packing
1196-
out = wrap_flash_attention(
1197-
query,
1198-
key,
1199-
value,
1200-
decoder_segment_ids=splash_attention_kernel.SegmentIds(
1201-
segment_ids, segment_ids
1202-
),
1203-
custom_mask=mask,
1204-
attn_logits_soft_cap=attn_logits_soft_cap,
1256+
# TPU-specific flash attention path
1257+
if is_tpu and flash_attention:
1258+
# Get sharding parameters from distribution context
1259+
try:
1260+
from keras.src.distribution.distribution_lib import ModelParallel
1261+
from keras.src.distribution.distribution_lib import (
1262+
distribution as get_dist,
1263+
)
1264+
1265+
# Get current distribution if available
1266+
dist = get_dist()
1267+
if dist and isinstance(dist, ModelParallel):
1268+
mesh = dist.device_mesh
1269+
if "model" in mesh.axis_names:
1270+
model_dim_index = mesh.axis_names.index("model")
1271+
# Set head_shards based on the model dimension of the mesh
1272+
head_shards = mesh.shape[model_dim_index]
1273+
# Typically keep q_seq_shards=1 for best performance
1274+
q_seq_shards = 1
1275+
except (ImportError, ValueError, AttributeError):
1276+
# Use default values if detection fails
1277+
head_shards = 1
1278+
q_seq_shards = 1
1279+
# Transpose to ('batch', 'heads', 'length', 'head_dim')
1280+
query_tpu_layout = jnp.transpose(query, axes=(0, 2, 1, 3))
1281+
key_tpu_layout = jnp.transpose(key, axes=(0, 2, 1, 3))
1282+
value_tpu_layout = jnp.transpose(value, axes=(0, 2, 1, 3))
1283+
1284+
bs, num_heads, q_len, head_dim = query_tpu_layout.shape
1285+
1286+
# Apply scale to query if provided
1287+
if scale is not None:
1288+
# TPU kernel applies 1/sqrt(head_dim) internally, to achieve
1289+
# overall QK^T * scale, scale query by (scale * sqrt(head_dim))
1290+
query_tpu_layout = query_tpu_layout * (scale * math.sqrt(head_dim))
1291+
1292+
# Create segment IDs for Splash Attention (for packing/batching)
1293+
segment_ids = jnp.zeros([bs, q_len], dtype=jnp.int32)
1294+
decoder_segment_ids = splash_attention_kernel.SegmentIds(
1295+
q=segment_ids, kv=segment_ids
12051296
)
1206-
out = jnp.transpose(out, axes=(0, 2, 1, 3))
1207-
return out
12081297

1209-
# `dot_product_attention` is only available in jax>=0.4.31
1298+
# Process mask for Splash Attention
1299+
custom_mask = None
1300+
if mask is not None:
1301+
mask_bool = mask.astype("bool") if mask.dtype != jnp.bool_ else mask
1302+
1303+
if mask_bool.ndim == 3 and mask_bool.shape[0] == bs:
1304+
custom_mask = mask_bool[0]
1305+
elif mask_bool.ndim == 4 and mask_bool.shape[0] == bs:
1306+
custom_mask = mask_bool[0, 0]
1307+
1308+
if is_causal and custom_mask is not None:
1309+
causal_mask = jnp.tril(
1310+
jnp.ones((q_len, q_len), dtype=jnp.bool_)
1311+
)
1312+
custom_mask = jnp.logical_and(custom_mask, causal_mask)
1313+
1314+
if custom_mask is None and is_causal:
1315+
custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_))
1316+
1317+
try:
1318+
output = wrap_flash_attention(
1319+
query_tpu_layout,
1320+
key_tpu_layout,
1321+
value_tpu_layout,
1322+
decoder_segment_ids=decoder_segment_ids,
1323+
custom_mask=custom_mask,
1324+
attn_logits_soft_cap=attn_logits_soft_cap,
1325+
head_shards=head_shards,
1326+
q_seq_shards=q_seq_shards,
1327+
)
1328+
# Transpose output back to Keras layout
1329+
return jnp.transpose(output, axes=(0, 2, 1, 3))
1330+
except Exception:
1331+
flash_attention = False
1332+
1333+
# JAX native dot_product_attention for GPU or fallback for TPU
12101334
if hasattr(jax.nn, "dot_product_attention"):
1211-
return jax.nn.dot_product_attention(
1212-
query,
1213-
key,
1214-
value,
1215-
bias=bias,
1216-
mask=mask,
1217-
scale=scale,
1218-
is_causal=is_causal,
1219-
implementation="cudnn" if flash_attention else "xla",
1220-
)
1335+
try:
1336+
return jax.nn.dot_product_attention(
1337+
query,
1338+
key,
1339+
value,
1340+
bias=bias,
1341+
mask=mask,
1342+
scale=scale,
1343+
is_causal=is_causal,
1344+
implementation="cudnn" if flash_attention else "xla",
1345+
)
1346+
except Exception:
1347+
# If flash attention fails, fall back to XLA implementation
1348+
if flash_attention:
1349+
return jax.nn.dot_product_attention(
1350+
query,
1351+
key,
1352+
value,
1353+
bias=bias,
1354+
mask=mask,
1355+
scale=scale,
1356+
is_causal=is_causal,
1357+
implementation="xla",
1358+
)
1359+
raise
12211360

12221361
if flash_attention:
12231362
raise RuntimeError(
@@ -1228,6 +1367,9 @@ def dot_product_attention(
12281367
# Ref: jax.nn.dot_product_attention
12291368
# https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886
12301369
# Not support `query_seq_lengths` and `key_value_seq_lengths` args
1370+
1371+
# Fallback to custom XLA implementation
1372+
# This is the reference implementation from jax.nn.dot_product_attention
12311373
output_shape = query.shape
12321374
_, _, K, H = key.shape
12331375
scale = (1.0 / jnp.sqrt(H)) if scale is None else scale

0 commit comments

Comments
 (0)