Skip to content

Commit 310bcde

Browse files
Fix mixed dtypes issue in nn.dot_product_attention. (#21558)
1 parent e8771ad commit 310bcde

File tree

7 files changed

+55
-7
lines changed

7 files changed

+55
-7
lines changed

keras/src/backend/jax/nn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,6 +1241,12 @@ def dot_product_attention(
12411241
f"Received: query.shape={query.shape}, key.shape={key.shape}, "
12421242
f"value.shape={value.shape}."
12431243
)
1244+
compute_dtype = backend.result_type(query.dtype, key.dtype, value.dtype)
1245+
query = cast(query, compute_dtype)
1246+
key = cast(key, compute_dtype)
1247+
value = cast(value, compute_dtype)
1248+
if bias is not None:
1249+
bias = convert_to_tensor(bias, dtype=compute_dtype)
12441250

12451251
# Check platform
12461252
platform = jax.devices()[0].platform

keras/src/backend/numpy/nn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,6 +1163,13 @@ def dot_product_attention(
11631163
f"Received: query.shape={query.shape}, key.shape={key.shape}, "
11641164
f"value.shape={value.shape}."
11651165
)
1166+
compute_dtype = backend.result_type(query.dtype, key.dtype, value.dtype)
1167+
query = cast(query, compute_dtype)
1168+
key = cast(key, compute_dtype)
1169+
value = cast(value, compute_dtype)
1170+
if bias is not None:
1171+
bias = convert_to_tensor(bias, dtype=compute_dtype)
1172+
11661173
_, _, _, H = key.shape
11671174
scale = (1.0 / np.sqrt(H)) if scale is None else scale
11681175
return _dot_product_attention_xla(

keras/src/backend/openvino/nn.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,3 +486,19 @@ def ctc_decode(
486486

487487
def psnr(x1, x2, max_val):
488488
raise NotImplementedError("`psnr` is not supported with openvino backend")
489+
490+
491+
def dot_product_attention(
492+
query,
493+
key,
494+
value,
495+
bias=None,
496+
mask=None,
497+
scale=None,
498+
is_causal=False,
499+
flash_attention=None,
500+
attn_logits_soft_cap=None,
501+
):
502+
raise NotImplementedError(
503+
"`dot_product_attention` is not supported with openvino backend"
504+
)

keras/src/backend/tensorflow/nn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,13 @@ def dot_product_attention(
10651065
f"Received: query.shape={query.shape}, key.shape={key.shape}, "
10661066
f"value.shape={value.shape}."
10671067
)
1068+
compute_dtype = backend.result_type(query.dtype, key.dtype, value.dtype)
1069+
query = cast(query, compute_dtype)
1070+
key = cast(key, compute_dtype)
1071+
value = cast(value, compute_dtype)
1072+
if bias is not None:
1073+
bias = convert_to_tensor(bias, dtype=compute_dtype)
1074+
10681075
H = tf.shape(key)[-1]
10691076
scale = (1.0 / tf.sqrt(tf.cast(H, "float32"))) if scale is None else scale
10701077
return _dot_product_attention_xla(

keras/src/backend/torch/nn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,11 @@ def dot_product_attention(
10431043
f"Received: query.shape={query.shape}, key.shape={key.shape}, "
10441044
f"value.shape={value.shape}."
10451045
)
1046+
compute_dtype = backend.result_type(query.dtype, key.dtype, value.dtype)
1047+
query = cast(query, compute_dtype)
1048+
key = cast(key, compute_dtype)
1049+
value = cast(value, compute_dtype)
1050+
10461051
mask = mask if mask is None else convert_to_tensor(mask, dtype="bool")
10471052
if mask is not None:
10481053
# Explicit set `is_causal` to `False` when `mask` is not `None`.

keras/src/ops/nn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2641,7 +2641,8 @@ def compute_output_spec(
26412641
mask=None,
26422642
scale=None,
26432643
):
2644-
return KerasTensor(query.shape, dtype=query.dtype)
2644+
dtype = backend.result_type(query.dtype, key.dtype, value.dtype)
2645+
return KerasTensor(query.shape, dtype=dtype)
26452646

26462647

26472648
@keras_export(

keras/src/ops/nn_test.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3114,14 +3114,20 @@ def test_ctc_decode(self, dtype):
31143114
self.assertEqual(standardize_dtype(decoded.dtype), "int32")
31153115
self.assertEqual(standardize_dtype(scores.dtype), expected_dtype)
31163116

3117-
@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
3118-
def test_dot_product_attention(self, dtype):
3117+
@parameterized.named_parameters(
3118+
named_product(
3119+
dtypes=list(combinations(FLOAT_DTYPES, 2))
3120+
+ [(dtype, dtype) for dtype in FLOAT_DTYPES]
3121+
)
3122+
)
3123+
def test_dot_product_attention(self, dtypes):
31193124
# TODO: Get expected output from jax if `jax.nn.dot_product_attention`
31203125
# is available.
3121-
query = knp.ones((2, 3, 3, 8), dtype=dtype)
3122-
key = knp.ones((2, 3, 3, 8), dtype=dtype)
3123-
value = knp.ones((2, 3, 3, 8), dtype=dtype)
3124-
expected_dtype = dtype
3126+
query_dtype, key_value_dtype = dtypes
3127+
query = knp.ones((2, 3, 3, 8), dtype=query_dtype)
3128+
key = knp.ones((2, 3, 3, 8), dtype=key_value_dtype)
3129+
value = knp.ones((2, 3, 3, 8), dtype=key_value_dtype)
3130+
expected_dtype = backend.result_type(*dtypes)
31253131

31263132
self.assertDType(
31273133
knn.dot_product_attention(query, key, value), expected_dtype

0 commit comments

Comments
 (0)