You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Tutorial] Fix 06-fused-attention.py of FP8 dtype (#7043)
When the provider is `fp8`, `v` is permuted like below, and the new stride is `(H*N_CTX*HEAD_DIM, N_CTX*HEAD_DIM, 1, N_CTX)`.
```
if mode == "fwd" and "fp8" in provider:
v = v.permute(0, 1, 3, 2).contiguous()
v = v.permute(0, 1, 3, 2)
```
This PR fixes the FP8 dtype handling in the fused-attention kernel by separating `k` and `v` offset calculations and updating related configuration details. Key changes include:
- Renaming and separating offset variables for `k` and `v` computations.
- Adjusting offset calculation for FP8 dtype and updating the tensor descriptor creation.
- Expanding configuration options for BLOCK_N and refining device-specific configuration conditions.
Signed-off-by: Whitney Tsang <[email protected]>
0 commit comments