Skip to content

Commit 5978f1d

Browse files
[JAX] Default to fused attention in JAX DPA (NVIDIA#2363)
* Default to fused attention in JAX DPA Signed-off-by: Kshitij Lakhani <[email protected]> * Consolidate documentation for DPA in JAX Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kshitij Lakhani <[email protected]> * Correctly update the documentation for defaults in JAX DPA Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kshitij Lakhani <[email protected]> --------- Signed-off-by: Kshitij Lakhani <[email protected]> Signed-off-by: Kshitij Lakhani <[email protected]> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent 26aad6b commit 5978f1d

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

transformer_engine/jax/flax/transformer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -407,10 +407,10 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
407407
Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
408408
variable:
409409

410-
* Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default).
411-
* Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention
412-
kernel is not available on the system, a warning will be issued, and the module will
413-
automatically fall back to the unfused backend.
410+
* Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention.
411+
* Set :attr:`NVTE_FUSED_ATTN=1` for fused attention (default). If the required cuDNN fused
412+
attention kernel is not available on the system, a warning will be issued, and the module
413+
will automatically fall back to the unfused backend.
414414

415415
.. note::
416416
The DotProductAttention default setting enables non-deterministic kernels for reduced
@@ -602,7 +602,8 @@ def __call__(
602602
else:
603603
assert bias is not None
604604

605-
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
605+
# Use fused attn (if kernel check below passes) by default
606+
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1"))
606607

607608
sequence_dim = 0 if self.transpose_batch_sequence else 1
608609
seqlen_q = query.shape[sequence_dim]

0 commit comments

Comments
 (0)