Skip to content

Conversation

@pcmoritz
Copy link
Collaborator

Since https://docs.jax.dev/en/latest/_autosummary/jax.nn.dot_product_attention.html supports the cudnn flash attention implementation, I was curious how the performance compares and whether we leave performance on the table by not using it. I ran it with

uv run --extra aws --extra gpu --extra tinker -m tx.tinker.api --base-model Qwen/Qwen3-4B --max-lora-adapters 3 --max-lora-rank 1 --tensor-parallel-size 4 --train-micro-batch-size 8 --sample-max-num-sequences 256 > out.log

and

uv run --with Pillow --with wandb --with tinker rl_loop.py base_url=http://localhost:8000/ model_name="Qwen/Qwen3-4B" lora_rank=1 max_length=1024 max_tokens=512 save_every=100

The timing with cudnn is slightly worse than the XLA timing, so there is no point in merging this PR, but I wanted to open it for documentation purposes.

Screenshot 2025-12-12 at 2 01 27 PM

@pcmoritz pcmoritz added the tx label Dec 12, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request tests the cudnn backend for jax.nn.dot_product_attention. The changes correctly prepare the attention mask for the cuDNN implementation by explicitly broadcasting it. However, there is one critical issue: the deterministic=True parameter, which is required for the cuDNN backend, is missing from the dot_product_attention call. I've added a suggestion to fix this.

Comment on lines 132 to 140
attn_output = jax.nn.dot_product_attention(
q,
k,
v,
scale=1.0 / self.head_dim**0.5,
mask=attention_mask[:, None, None, :].astype(bool),
mask=mask,
is_causal=kv_cache is None,
implementation="cudnn",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

According to the JAX documentation for jax.nn.dot_product_attention with implementation="cudnn", the deterministic parameter must be set to True. Without this, the call might fail or not use the cuDNN kernel as intended.

Suggested change
attn_output = jax.nn.dot_product_attention(
q,
k,
v,
scale=1.0 / self.head_dim**0.5,
mask=attention_mask[:, None, None, :].astype(bool),
mask=mask,
is_causal=kv_cache is None,
implementation="cudnn",
)
attn_output = jax.nn.dot_product_attention(
q,
k,
v,
scale=1.0 / self.head_dim**0.5,
mask=mask,
is_causal=kv_cache is None,
implementation="cudnn",
deterministic=True,
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant