Skip to content

Commit fad8020

Browse files
optimizations - XLA_DISABLE_FUNCTIONALIZATION=1
1 parent 2001ed2 commit fad8020

File tree

4 files changed

+29
-34
lines changed

4 files changed

+29
-34
lines changed

examples/research_projects/pytorch_xla/training/text_to_image/README_sdxl.md

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,15 @@ Install PyTorch and PyTorch/XLA nightly versions:
4444
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
4545
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
4646
--command='
47-
pip install torch==2.6.0+cpu.cxx11.abi \
48-
https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp310-cp310-manylinux_2_28_x86_64.whl \
49-
'torch_xla[tpu]' \
47+
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
48+
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev-cp310-cp310-linux_x86_64.whl' \
5049
-f https://storage.googleapis.com/libtpu-releases/index.html \
51-
-f https://storage.googleapis.com/libtpu-wheels/index.html \
52-
-f https://download.pytorch.org/whl/torch
53-
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
50+
-f https://storage.googleapis.com/libtpu-wheels/index.html
51+
52+
# Optional: if you're using custom kernels, install pallas dependencies
53+
pip install 'torch_xla[pallas]' \
54+
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
55+
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
5456
'
5557
```
5658
@@ -72,7 +74,6 @@ cd diffusers
7274
git checkout main
7375
cd examples/research_projects/pytorch_xla/training/text_to_image/
7476
pip3 install -r requirements_sdxl.txt
75-
pip3 install pillow --upgrade
7677
cd ../../../../../
7778
pip3 install .'
7879
```
@@ -94,14 +95,14 @@ are fixed.
9495
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
9596
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
9697
--command='
97-
export XLA_DISABLE_FUNCTIONALIZATION=0
98+
export XLA_DISABLE_FUNCTIONALIZATION=1
9899
export PROFILE_DIR=/tmp/
99100
export CACHE_DIR=/tmp/
100101
export DATASET_NAME=lambdalabs/naruto-blip-captions
101-
export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p
102+
export GLOBAL_BATCH_SIZE=32
102103
export TRAIN_STEPS=50
103104
export OUTPUT_DIR=/tmp/trained-model/
104-
python diffusers/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 --dataset_name=$DATASET_NAME --resolution=1024 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4'
105+
python examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 --dataset_name=$DATASET_NAME --resolution=1024 --center_crop --random_flip --train_batch_size=$GLOBAL_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=16 --device_prefetch_size=16'
105106
```
106107
107108
Pass `--print_loss` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer.

examples/research_projects/pytorch_xla/training/text_to_image/requirements_sdxl.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
accelerate>=0.16.0
2-
torch==2.5.1
3-
torchvision==0.20.1
42
transformers>=4.25.1
53
datasets>=2.19.1
64
ftfy

examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from diffusers.utils import is_wandb_available
3232
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
3333

34+
# torch._dynamo.config.force_parameter_static_shapes = False
3435

3536
if is_wandb_available():
3637
pass
@@ -148,16 +149,12 @@ def start_training(self):
148149
for step in range(0, self.args.max_train_steps):
149150
print("step: ", step)
150151
batch = next(self.dataloader)
151-
if step == measure_start_step and PROFILE_DIR is not None:
152-
xm.wait_device_ops()
153-
xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration)
152+
if step == measure_start_step:
153+
if PROFILE_DIR is not None:
154+
xm.wait_device_ops()
155+
xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration)
154156
last_time = time.time()
155-
loss = self.step_fn(
156-
batch["model_input"],
157-
batch["prompt_embeds"],
158-
batch["pooled_prompt_embeds"],
159-
batch["original_sizes"],
160-
batch["crop_top_lefts"])
157+
loss = self.step_fn(batch)
161158
self.global_step += 1
162159

163160
def print_loss_closure(step, loss):
@@ -182,15 +179,15 @@ def print_loss_closure(step, loss):
182179

183180
def step_fn(
184181
self,
185-
model_input,
186-
prompt_embeds,
187-
pooled_prompt_embeds,
188-
original_sizes,
189-
crop_top_lefts
182+
batch
190183
):
191184
with xp.Trace("model.forward"):
192185
self.optimizer.zero_grad()
193-
186+
model_input = batch["model_input"]
187+
prompt_embeds = batch["prompt_embeds"]
188+
pooled_prompt_embeds = batch["pooled_prompt_embeds"]
189+
original_sizes = batch["original_sizes"]
190+
crop_top_lefts = batch["crop_top_lefts"]
194191

195192
noise = torch.randn_like(model_input).to(self.device, dtype=self.weight_dtype)
196193
bsz = model_input.shape[0]
@@ -638,6 +635,7 @@ def main(args):
638635
text_encoder_2 = text_encoder_2.to(device, dtype=weight_dtype)
639636
vae = vae.to(device, dtype=weight_dtype)
640637
unet = unet.to(device, dtype=weight_dtype)
638+
#unet = torch.compile(unet, backend='openxla', dynamic=True)
641639
optimizer = setup_optimizer(unet, args)
642640
vae.requires_grad_(False)
643641
text_encoder.requires_grad_(False)

src/diffusers/models/attention_processor.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3241,7 +3241,6 @@ def __call__(
32413241
def xla_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
32423242
L, S = query.size(-2), key.size(-2)
32433243
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
3244-
attn_bias = torch.zeros(L, S, dtype=query.dtype)
32453244
if is_causal:
32463245
assert attn_mask is None
32473246
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
@@ -3254,7 +3253,6 @@ def xla_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_
32543253
else:
32553254
attn_bias += attn_mask
32563255
attn_weight = query @ key.transpose(-2, -1) * scale_factor
3257-
attn_weight += attn_bias
32583256
attn_weight = torch.softmax(attn_weight, dim=-1)
32593257
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
32603258
return attn_weight @ value
@@ -3330,7 +3328,7 @@ def __call__(
33303328

33313329
# the output of sdp = (batch, num_heads, seq_len, head_dim)
33323330
# TODO: add support for attn.scale when we move to Torch 2.1
3333-
hidden_states = self.xla_scaled_dot_product_attention(
3331+
hidden_states = F.scaled_dot_product_attention(
33343332
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
33353333
)
33363334

@@ -3428,7 +3426,7 @@ def __call__(
34283426
# the output of sdp = (batch, num_heads, seq_len, head_dim)
34293427
# TODO: add support for attn.scale when we move to Torch 2.1
34303428
if all(tensor.shape[2] >= 4096 for tensor in [query, key, value]):
3431-
logger.warning("Using flash attention")
3429+
# logger.warning("Using flash attention")
34323430
if attention_mask is not None:
34333431
attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
34343432
# Convert mask to float and replace 0s with -inf and 1s with 0
@@ -3444,9 +3442,9 @@ def __call__(
34443442
partition_spec = self.partition_spec if is_spmd() else None
34453443
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec)
34463444
else:
3447-
logger.warning(
3448-
"Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096."
3449-
)
3445+
# logger.warning(
3446+
# "Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096."
3447+
# )
34503448
hidden_states = xla_scaled_dot_product_attention(
34513449
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
34523450
)

0 commit comments

Comments
 (0)