Skip to content

Commit 2001ed2

Browse files
update dependencies, use custome attention.
1 parent a9513c1 commit 2001ed2

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

examples/research_projects/pytorch_xla/training/text_to_image/README_sdxl.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,12 @@ Install PyTorch and PyTorch/XLA nightly versions:
4444
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
4545
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
4646
--command='
47-
pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
48-
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
47+
pip install torch==2.6.0+cpu.cxx11.abi \
48+
https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp310-cp310-manylinux_2_28_x86_64.whl \
49+
'torch_xla[tpu]' \
50+
-f https://storage.googleapis.com/libtpu-releases/index.html \
51+
-f https://storage.googleapis.com/libtpu-wheels/index.html \
52+
-f https://download.pytorch.org/whl/torch
4953
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
5054
'
5155
```
@@ -97,7 +101,7 @@ export DATASET_NAME=lambdalabs/naruto-blip-captions
97101
export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p
98102
export TRAIN_STEPS=50
99103
export OUTPUT_DIR=/tmp/trained-model/
100-
python diffusers/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4'
104+
python diffusers/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 --dataset_name=$DATASET_NAME --resolution=1024 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4'
101105
```
102106

103107
Pass `--print_loss` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer.

examples/research_projects/pytorch_xla/training/text_to_image/requirements_sdxl.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
accelerate>=0.16.0
22
torch==2.5.1
3-
torchvision==0.21.0
3+
torchvision==0.20.1
44
transformers>=4.25.1
55
datasets>=2.19.1
66
ftfy

src/diffusers/models/attention_processor.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3238,6 +3238,26 @@ def __call__(
32383238

32393239
return hidden_states
32403240

3241+
def xla_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
3242+
L, S = query.size(-2), key.size(-2)
3243+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
3244+
attn_bias = torch.zeros(L, S, dtype=query.dtype)
3245+
if is_causal:
3246+
assert attn_mask is None
3247+
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
3248+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
3249+
attn_bias.to(query.dtype)
3250+
3251+
if attn_mask is not None:
3252+
if attn_mask.dtype == torch.bool:
3253+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
3254+
else:
3255+
attn_bias += attn_mask
3256+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
3257+
attn_weight += attn_bias
3258+
attn_weight = torch.softmax(attn_weight, dim=-1)
3259+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
3260+
return attn_weight @ value
32413261

32423262
class AttnProcessor2_0:
32433263
r"""
@@ -3310,7 +3330,7 @@ def __call__(
33103330

33113331
# the output of sdp = (batch, num_heads, seq_len, head_dim)
33123332
# TODO: add support for attn.scale when we move to Torch 2.1
3313-
hidden_states = F.scaled_dot_product_attention(
3333+
hidden_states = self.xla_scaled_dot_product_attention(
33143334
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
33153335
)
33163336

@@ -3408,6 +3428,7 @@ def __call__(
34083428
# the output of sdp = (batch, num_heads, seq_len, head_dim)
34093429
# TODO: add support for attn.scale when we move to Torch 2.1
34103430
if all(tensor.shape[2] >= 4096 for tensor in [query, key, value]):
3431+
logger.warning("Using flash attention")
34113432
if attention_mask is not None:
34123433
attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
34133434
# Convert mask to float and replace 0s with -inf and 1s with 0
@@ -3426,7 +3447,7 @@ def __call__(
34263447
logger.warning(
34273448
"Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096."
34283449
)
3429-
hidden_states = F.scaled_dot_product_attention(
3450+
hidden_states = xla_scaled_dot_product_attention(
34303451
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
34313452
)
34323453

0 commit comments

Comments
 (0)