Skip to content

Commit b4f9507

Browse files
author
Elisa Tsai
committed
Reapply cross self attention switch and attention bug fixes
This commit reapplies both: - Cross self attention switch (#251) (#288) - Attention bug fixes, tokamax splash defaulting logic (#282) (#287)
1 parent f1ff3cc commit b4f9507

28 files changed

+616
-381
lines changed

.github/workflows/UnitTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ jobs:
5858
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
5959
- name: PyTest
6060
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
61-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
61+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=65472" python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
6262
# add_pull_ready:
6363
# if: github.ref != 'refs/heads/main'
6464
# permissions:

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
__pycache__/
55
*.py[cod]
66
*$py.class
7-
87
# C extensions
98
*.so
109

@@ -98,6 +97,7 @@ celerybeat-schedule
9897

9998
# Environments
10099
.env
100+
.history
101101
.venv
102102
env/
103103
venv/

docs/attention_blocks_flowchart.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Attention block sizes
2+
3+
## Description
4+
- "block_q": Block sizes (HBM TO VMEM and VREG) to tile along Q sequence in forward pass
5+
- "block_kv_compute" : Sub Block size (VMEM to VREG) of "block_kv" where compute is performed in forward pass. It must be factor or same as "block_kv"
6+
- "block_kv" : Block sizes (HBM TO VMEM) to tile along KV sequence in forward pass
7+
- "block_q_dkv" : Block sizes along Q sequence in backward pass with fused kernel to compute gradient of q, k , v. It must be factor or same as block_q
8+
- "block_kv_dkv" : Block sizes along KV sequence in backward pass. It must be factor or same as block_kv
9+
- "block_kv_dkv_compute" : Sub Block Sizes of block_kv_dkv, must be factor or same as "block_kv_dkv"
10+
- "block_q_dq" : Block sizes along Q sequence in backward pass with unfused kernel to compute gradient of just q. it must be factor or same as "block_q"
11+
- "block_kv_dq" : Block sizes along KV to tiline on KV sequence in backward pass with unfused kernel to compute gradient of just q. it must be factor or same as "block_kv"
12+
- "use_fused_bwd_kernel" : This means fused bwd kernel is used where DQ, DK, DV are computed in single kernel. It usually more perfomant but comes with slight HBM memory overhead.
13+
14+
## Flowchart
15+
16+
Maxdiffusion automatically adheres to this flowchart to ensure working, and there is a log that will inform you on the modifications that maxdiffusion makes to the specified block sizes.
17+
18+
![alt text](attention_blocks_flowchart.png)
19+
20+
> "tokamax_flash" uses the splash attention implementation in [tokamax-repo](https://github.com/openxla/tokamax/blob/main/tokamax/_src/ops/experimental/tpu/splash_attention/splash_attention_kernel.py) This kernel only supports fused backward pass where gradients for q,k,v are computed in a single kernel so "block_q_dq" and "block_kv_dq" are not used
21+
22+
## How block sizes matter for perfomance and accuracy
23+
24+
Block sizes key to saturating HBM bandwidth and ensuring maximum possible overlap of computation on cores with HBM use and VMEM to VREG. It is highly recommended to tune them.
25+
26+
Block sizes also have an effect on the sequence length. Sequence length is multiple of resolution and number of frames (video), along with VAE scale down factors and patchifying ratios. This sequence length or shard of this sequence length needs to be multiple of the block sizes specified. Therefore maxdiffusion pads the sequence lengths to the nearest multiple of the block sizes. It is advisable to choose block sizes which are factor of sequence length, atleast for the Q block sizes.
27+
28+
> In cross attention Image or Video tokens are attending to text tokens sequence length of text tokens is really small and potentially smaller than specified block size so KV block sizes are overwritten to safe values.
29+
30+
> KV block sizes must be multiple of 128 since the size of register is 8x128 and in attention KV sequence dim lies on 128 for the multiplications as K is transposed.
229 KB
Loading

preview-xpk.sh

Lines changed: 0 additions & 93 deletions
This file was deleted.

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ ftfy
1313
tensorboard>=2.17.0
1414
tensorboardx>=2.6.2.2
1515
tensorboard-plugin-profile>=2.15.2
16+
tokamax
1617
Jinja2
1718
scikit-image
1819
parameterized

src/maxdiffusion/common_types.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@
3333
BlockSizes = splash_attention_kernel.BlockSizes
3434

3535
AxisNames = tuple[str, ...]
36-
36+
# Physical axis names for device meshes.
37+
DATA = "data"
38+
FSDP = "fsdp"
39+
TENSOR = "tensor"
40+
# Logical axis names for model parameters and activations.
3741
BATCH = "activation_batch"
3842
LENGTH = "activation_length"
3943
KV_LENGTH = "activation_kv_length"
@@ -44,4 +48,32 @@
4448
KEEP_2 = "activation_keep_2"
4549
CONV_OUT = "activation_conv_out_channels"
4650

51+
# For setting self/cross attention independently in splash kernel
52+
SELF_ATTN_HEAD = "activation_self_attn_heads"
53+
SELF_ATTN_Q_LENGTH = "activation_self_attn_q_length"
54+
SELF_ATTN_KV_LENGTH = "activation_self_attn_kv_length"
55+
CROSS_ATTN_HEAD = "activation_cross_attn_heads"
56+
CROSS_ATTN_Q_LENGTH = "activation_cross_attn_q_length"
57+
CROSS_ATTN_KV_LENGTH = "activation_cross_attn_kv_length"
58+
59+
4760
WAN_MODEL = "Wan2.1"
61+
62+
### Common axis rules for ring attention ###
63+
RING_ATTENTION_AXIS_RULES = [
64+
[SELF_ATTN_HEAD, None],
65+
[SELF_ATTN_Q_LENGTH, FSDP],
66+
[SELF_ATTN_KV_LENGTH, FSDP],
67+
[CROSS_ATTN_HEAD, None],
68+
[CROSS_ATTN_Q_LENGTH, FSDP],
69+
[CROSS_ATTN_KV_LENGTH, FSDP],
70+
]
71+
72+
SEQUENCE_PARALLEL_AXIS_RULES = [
73+
[SELF_ATTN_HEAD, None],
74+
[SELF_ATTN_Q_LENGTH, FSDP],
75+
[SELF_ATTN_KV_LENGTH, None],
76+
[CROSS_ATTN_HEAD, None],
77+
[CROSS_ATTN_Q_LENGTH, FSDP],
78+
[CROSS_ATTN_KV_LENGTH, None],
79+
]

src/maxdiffusion/configs/base14.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@ jit_initializers: True
5050
from_pt: False
5151
split_head_dim: True
5252
attention: 'dot_product' # Supported attention: dot_product, flash
53+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
54+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
55+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
56+
mask_padding_tokens: True
57+
# Maxdiffusion has 2 types of attention sharding strategies:
58+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
59+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
60+
# in cross attention q.
61+
attention_sharding_uniform: True
5362
flash_block_sizes: {}
5463
# GroupNorm groups
5564
norm_num_groups: 32

src/maxdiffusion/configs/base21.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,16 @@ jit_initializers: True
4949
from_pt: False
5050
split_head_dim: True
5151
attention: 'dot_product' # Supported attention: dot_product, flash
52+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
53+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
54+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
55+
mask_padding_tokens: True
56+
# Maxdiffusion has 2 types of attention sharding strategies:
57+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
58+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
59+
# in cross attention q.
60+
attention_sharding_uniform: True
61+
5262
flash_block_sizes: {}
5363
# GroupNorm groups
5464
norm_num_groups: 32

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,16 @@ jit_initializers: True
5050
from_pt: True
5151
split_head_dim: True
5252
attention: 'flash' # Supported attention: dot_product, flash
53+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
54+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
55+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
56+
mask_padding_tokens: True
57+
# Maxdiffusion has 2 types of attention sharding strategies:
58+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
59+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
60+
# in cross attention q.
61+
attention_sharding_uniform: True
62+
5363
flash_block_sizes: {}
5464
# to override default block sizes for flash attention
5565
# flash_block_sizes:

0 commit comments

Comments
 (0)