Skip to content

Commit be9ea48

Browse files
feat(nnx): add preferred_element_type to attention.py and recurrent.py, out_sharding to recurrent.py#
1 parent 707a569 commit be9ea48

File tree

4 files changed

+194
-17
lines changed

4 files changed

+194
-17
lines changed

flax/nnx/nn/attention.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def dot_product_attention_weights(
6363
module: Module | None = None,
6464
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
6565
is_causal: bool = False,
66+
preferred_element_type: Dtype | None = None,
6667
):
6768
"""Computes dot-product attention weights given query and key.
6869
@@ -100,6 +101,9 @@ def dot_product_attention_weights(
100101
the logits to mask out the non-causal parts of the attention matrix,
101102
but other implementations like cudnn will avoid computing the
102103
non-causal regions, providing speedups.
104+
preferred_element_type: Optional parameter controls the data type output by
105+
the dot product. This argument is passed to ``dot_general`` function.
106+
See ``jax.lax.dot`` for details.
103107
104108
Returns:
105109
Output of shape `[batch..., num_heads, q_length, kv_length]`.
@@ -117,7 +121,11 @@ def dot_product_attention_weights(
117121
query = query / jnp.sqrt(depth).astype(dtype)
118122
# attn weight shape is (batch..., num_heads, q_length, kv_length)
119123
attn_weights = jnp.einsum(
120-
'...qhd,...khd->...hqk', query, key, precision=precision
124+
'...qhd,...khd->...hqk',
125+
query,
126+
key,
127+
precision=precision,
128+
preferred_element_type=preferred_element_type,
121129
)
122130

123131
# apply attention bias: masking, dropout, proximity bias, etc.
@@ -172,6 +180,7 @@ def dot_product_attention(
172180
module: Module | None = None,
173181
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
174182
is_causal: bool = False,
183+
preferred_element_type: Dtype | None = None,
175184
):
176185
"""Computes dot-product attention given query, key, and value.
177186
@@ -218,6 +227,9 @@ def dot_product_attention(
218227
the logits to mask out the non-causal parts of the attention matrix,
219228
but other implementations like cudnn will avoid computing the
220229
non-causal regions, providing speedups.
230+
preferred_element_type: Optional parameter controls the data type output by
231+
the dot product. This argument is passed to ``dot_general`` function.
232+
See ``jax.lax.dot`` for details.
221233
222234
Returns:
223235
Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
@@ -244,10 +256,17 @@ def reshape_4d(x):
244256
reshape_4d, (query, key, value, bias, mask))
245257
if mask is not None:
246258
mask = mask.astype(jnp.bool)
247-
out = jax.nn.dot_product_attention(query, key, value, bias, mask, is_causal=is_causal)
259+
out = jax.nn.dot_product_attention(
260+
query,
261+
key,
262+
value,
263+
bias,
264+
mask,
265+
is_causal=is_causal,
266+
)
248267
if len(query_shape) > 4:
249268
out = jnp.reshape(out, query_shape)
250-
return out
269+
return out.astype(preferred_element_type)
251270

252271
# compute attention weights
253272
attn_weights = dot_product_attention_weights(
@@ -264,11 +283,16 @@ def reshape_4d(x):
264283
module,
265284
promote_dtype,
266285
is_causal,
286+
preferred_element_type=preferred_element_type,
267287
)
268288

269289
# return weighted sum over values for each query position
270290
return jnp.einsum(
271-
'...hqk,...khd->...qhd', attn_weights, value, precision=precision
291+
'...hqk,...khd->...qhd',
292+
attn_weights,
293+
value,
294+
precision=precision,
295+
preferred_element_type=preferred_element_type,
272296
)
273297

274298

@@ -351,6 +375,8 @@ class MultiHeadAttention(Module):
351375
the scale of the query layer norm layer.
352376
key_ln_scale_metadata: Optional metadata dictionary to set when initializing
353377
the scale of the key layer norm layer.
378+
preferred_element_type: numerical precision of the computation, see
379+
`jax.lax.dot_general` for details.
354380
"""
355381

356382
def __init__(
@@ -378,6 +404,7 @@ def __init__(
378404
qkv_promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
379405
out_promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
380406
ln_promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
407+
preferred_element_type: Dtype | None = None,
381408
# Deprecated, will be removed.
382409
qkv_dot_general: DotGeneralT | None = None,
383410
out_dot_general: DotGeneralT | None = None,
@@ -412,6 +439,7 @@ def __init__(
412439
self.use_bias = use_bias
413440
self.attention_fn = attention_fn
414441
self.decode = decode
442+
self.preferred_element_type = preferred_element_type
415443
self.normalize_qk = normalize_qk
416444
self.qkv_promote_dtype = qkv_promote_dtype
417445
self.out_promote_dtype = out_promote_dtype
@@ -441,6 +469,7 @@ def __init__(
441469
promote_dtype=self.qkv_promote_dtype,
442470
dot_general=self.qkv_dot_general,
443471
dot_general_cls=self.qkv_dot_general_cls,
472+
preferred_element_type=self.preferred_element_type,
444473
kernel_metadata=kernel_metadata,
445474
bias_metadata=bias_metadata,
446475
)
@@ -490,6 +519,7 @@ def __init__(
490519
promote_dtype=self.out_promote_dtype,
491520
dot_general=self.out_dot_general,
492521
dot_general_cls=self.out_dot_general_cls,
522+
preferred_element_type=self.preferred_element_type,
493523
rngs=rngs,
494524
kernel_metadata=out_kernel_metadata or kernel_metadata,
495525
bias_metadata=out_bias_metadata or bias_metadata,
@@ -665,6 +695,7 @@ def __call__(
665695
deterministic=deterministic,
666696
dtype=self.dtype,
667697
precision=self.precision,
698+
preferred_element_type=self.preferred_element_type,
668699
module=self if sow_weights else None,
669700
)
670701
# back to the original inputs dimensions

0 commit comments

Comments
 (0)