Skip to content

Commit 9ccc3a4

Browse files
add correct flash block sizes for flux.
1 parent 0bdefba commit 9ccc3a4

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_flux.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torchvision import transforms
2222
from torchvision.transforms.functional import crop
2323
from transformers import CLIPTextModel, CLIPTokenizer, PretrainedConfig, T5EncoderModel, T5TokenizerFast
24-
24+
from torch_xla.experimental.custom_kernel import FlashAttention
2525
from diffusers import (
2626
AutoencoderKL,
2727
FlowMatchEulerDiscreteScheduler,
@@ -731,7 +731,19 @@ def main(args):
731731

732732
#unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward)
733733
transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
734-
734+
FlashAttention.DEFAULT_BLOCK_SIZES = {
735+
"block_q": 1536,
736+
"block_k_major": 1536,
737+
"block_k": 1536,
738+
"block_b": 1536,
739+
"block_q_major_dkv": 1536,
740+
"block_k_major_dkv": 1536,
741+
"block_q_dkv": 1536,
742+
"block_k_dkv": 1536,
743+
"block_q_dq": 1536,
744+
"block_k_dq": 1536,
745+
"block_k_major_dq": 1536,
746+
}
735747
# For mixed precision training we cast all non-trainable weights (vae,
736748
# non-lora text_encoder and non-lora unet) to half-precision
737749
# as these weights are only used for inference, keeping weights in full

0 commit comments

Comments
 (0)