@@ -63,6 +63,7 @@ def dot_product_attention_weights(
6363 module : Module | None = None ,
6464 promote_dtype : PromoteDtypeFn = dtypes .promote_dtype ,
6565 is_causal : bool = False ,
66+ preferred_element_type : Dtype | None = None ,
6667):
6768 """Computes dot-product attention weights given query and key.
6869
@@ -100,6 +101,9 @@ def dot_product_attention_weights(
100101 the logits to mask out the non-causal parts of the attention matrix,
101102 but other implementations like cudnn will avoid computing the
102103 non-causal regions, providing speedups.
104+ preferred_element_type: Optional parameter controls the data type output by
105+ the dot product. This argument is passed to ``dot_general`` function.
106+ See ``jax.lax.dot`` for details.
103107
104108 Returns:
105109 Output of shape `[batch..., num_heads, q_length, kv_length]`.
@@ -117,7 +121,11 @@ def dot_product_attention_weights(
117121 query = query / jnp .sqrt (depth ).astype (dtype )
118122 # attn weight shape is (batch..., num_heads, q_length, kv_length)
119123 attn_weights = jnp .einsum (
120- '...qhd,...khd->...hqk' , query , key , precision = precision
124+ '...qhd,...khd->...hqk' ,
125+ query ,
126+ key ,
127+ precision = precision ,
128+ preferred_element_type = preferred_element_type ,
121129 )
122130
123131 # apply attention bias: masking, dropout, proximity bias, etc.
@@ -172,6 +180,7 @@ def dot_product_attention(
172180 module : Module | None = None ,
173181 promote_dtype : PromoteDtypeFn = dtypes .promote_dtype ,
174182 is_causal : bool = False ,
183+ preferred_element_type : Dtype | None = None ,
175184):
176185 """Computes dot-product attention given query, key, and value.
177186
@@ -218,6 +227,9 @@ def dot_product_attention(
218227 the logits to mask out the non-causal parts of the attention matrix,
219228 but other implementations like cudnn will avoid computing the
220229 non-causal regions, providing speedups.
230+ preferred_element_type: Optional parameter controls the data type output by
231+ the dot product. This argument is passed to ``dot_general`` function.
232+ See ``jax.lax.dot`` for details.
221233
222234 Returns:
223235 Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
@@ -244,10 +256,17 @@ def reshape_4d(x):
244256 reshape_4d , (query , key , value , bias , mask ))
245257 if mask is not None :
246258 mask = mask .astype (jnp .bool )
247- out = jax .nn .dot_product_attention (query , key , value , bias , mask , is_causal = is_causal )
259+ out = jax .nn .dot_product_attention (
260+ query ,
261+ key ,
262+ value ,
263+ bias ,
264+ mask ,
265+ is_causal = is_causal ,
266+ )
248267 if len (query_shape ) > 4 :
249268 out = jnp .reshape (out , query_shape )
250- return out
269+ return out . astype ( preferred_element_type )
251270
252271 # compute attention weights
253272 attn_weights = dot_product_attention_weights (
@@ -264,11 +283,16 @@ def reshape_4d(x):
264283 module ,
265284 promote_dtype ,
266285 is_causal ,
286+ preferred_element_type = preferred_element_type ,
267287 )
268288
269289 # return weighted sum over values for each query position
270290 return jnp .einsum (
271- '...hqk,...khd->...qhd' , attn_weights , value , precision = precision
291+ '...hqk,...khd->...qhd' ,
292+ attn_weights ,
293+ value ,
294+ precision = precision ,
295+ preferred_element_type = preferred_element_type ,
272296 )
273297
274298
@@ -351,6 +375,8 @@ class MultiHeadAttention(Module):
351375 the scale of the query layer norm layer.
352376 key_ln_scale_metadata: Optional metadata dictionary to set when initializing
353377 the scale of the key layer norm layer.
378+ preferred_element_type: numerical precision of the computation, see
379+ `jax.lax.dot_general` for details.
354380 """
355381
356382 def __init__ (
@@ -378,6 +404,7 @@ def __init__(
378404 qkv_promote_dtype : PromoteDtypeFn = dtypes .promote_dtype ,
379405 out_promote_dtype : PromoteDtypeFn = dtypes .promote_dtype ,
380406 ln_promote_dtype : PromoteDtypeFn = dtypes .promote_dtype ,
407+ preferred_element_type : Dtype | None = None ,
381408 # Deprecated, will be removed.
382409 qkv_dot_general : DotGeneralT | None = None ,
383410 out_dot_general : DotGeneralT | None = None ,
@@ -412,6 +439,7 @@ def __init__(
412439 self .use_bias = use_bias
413440 self .attention_fn = attention_fn
414441 self .decode = decode
442+ self .preferred_element_type = preferred_element_type
415443 self .normalize_qk = normalize_qk
416444 self .qkv_promote_dtype = qkv_promote_dtype
417445 self .out_promote_dtype = out_promote_dtype
@@ -441,6 +469,7 @@ def __init__(
441469 promote_dtype = self .qkv_promote_dtype ,
442470 dot_general = self .qkv_dot_general ,
443471 dot_general_cls = self .qkv_dot_general_cls ,
472+ preferred_element_type = self .preferred_element_type ,
444473 kernel_metadata = kernel_metadata ,
445474 bias_metadata = bias_metadata ,
446475 )
@@ -490,6 +519,7 @@ def __init__(
490519 promote_dtype = self .out_promote_dtype ,
491520 dot_general = self .out_dot_general ,
492521 dot_general_cls = self .out_dot_general_cls ,
522+ preferred_element_type = self .preferred_element_type ,
493523 rngs = rngs ,
494524 kernel_metadata = out_kernel_metadata or kernel_metadata ,
495525 bias_metadata = out_bias_metadata or bias_metadata ,
@@ -665,6 +695,7 @@ def __call__(
665695 deterministic = deterministic ,
666696 dtype = self .dtype ,
667697 precision = self .precision ,
698+ preferred_element_type = self .preferred_element_type ,
668699 module = self if sow_weights else None ,
669700 )
670701 # back to the original inputs dimensions
0 commit comments