Skip to content

Commit 96af06e

Browse files
update ptxla example based on Pei's comments.
1 parent f04ee1d commit 96af06e

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

examples/research_projects/pytorch_xla/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ export DATASET_NAME=lambdalabs/naruto-blip-captions
9797
export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p
9898
export TRAIN_STEPS=50
9999
export OUTPUT_DIR=/tmp/trained-model/
100-
python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=4 --loader_prefetch_size=4 --device_prefetch_size=4'
100+
python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4'
101101

102102
```
103103

examples/research_projects/pytorch_xla/train_text_to_image_xla.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
if is_wandb_available():
3333
pass
3434

35-
PROFILE_DIR=os.environ.get('PROFILE_DIR', None)
35+
PROFILE_DIR = os.environ.get('PROFILE_DIR', None)
3636
CACHE_DIR = os.environ.get('CACHE_DIR', None)
3737
if CACHE_DIR:
3838
xr.initialize_cache(CACHE_DIR, readonly=False)
@@ -363,6 +363,14 @@ def parse_args():
363363
"Number of subprocesses to use for data loading to cpu."
364364
),
365365
)
366+
parser.add_argument(
367+
"--loader_prefetch_factor",
368+
type=int,
369+
default=2,
370+
help=(
371+
"Number of batches loaded in advance by each worker."
372+
),
373+
)
366374
parser.add_argument(
367375
"--device_prefetch_size",
368376
type=int,
@@ -579,7 +587,7 @@ def preprocess_train(examples):
579587
return examples
580588

581589
train_dataset = dataset["train"]
582-
train_dataset.set_format('torch')
590+
train_dataset.set_format("torch")
583591
train_dataset.set_transform(preprocess_train)
584592

585593
def collate_fn(examples):
@@ -601,6 +609,7 @@ def collate_fn(examples):
601609
collate_fn=collate_fn,
602610
num_workers=args.dataloader_num_workers,
603611
batch_size=args.train_batch_size,
612+
prefetch_factor=args.loader_prefetch_factor,
604613
)
605614

606615
train_dataloader = pl.MpDeviceLoader(

src/diffusers/models/attention_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torch import nn
2121

2222
from ..image_processor import IPAdapterMaskProcessor
23-
from ..utils import deprecate, logging, is_torch_xla_available
23+
from ..utils import deprecate, is_torch_xla_available, logging
2424
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
2525
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
2626

@@ -2484,7 +2484,7 @@ def __call__(
24842484
attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
24852485
# Convert mask to float and replace 0s with -inf and 1s with 0
24862486
attention_mask = attention_mask.float().masked_fill(attention_mask == 0, float('-inf')).masked_fill(attention_mask == 1, float(0.0))
2487-
2487+
24882488
# Apply attention mask to key
24892489
key = key + attention_mask
24902490
query /= math.sqrt(query.shape[3])

0 commit comments

Comments
 (0)