Skip to content

Commit 6234a37

Browse files
add print loss cli argument. Run make style and quality.
1 parent 96af06e commit 6234a37

File tree

3 files changed

+70
-70
lines changed

3 files changed

+70
-70
lines changed

examples/research_projects/pytorch_xla/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,10 @@ export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to
9898
export TRAIN_STEPS=50
9999
export OUTPUT_DIR=/tmp/trained-model/
100100
python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4'
101-
102101
```
103102

103+
Pass `--print_loss` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer.
104+
104105
### Environment Envs Explained
105106

106107
* `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer.

examples/research_projects/pytorch_xla/train_text_to_image_xla.py

Lines changed: 60 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import argparse
22
import os
33
import random
4-
54
import time
65
from pathlib import Path
76

@@ -29,11 +28,12 @@
2928
from diffusers.utils import is_wandb_available
3029
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
3130

31+
3232
if is_wandb_available():
3333
pass
3434

35-
PROFILE_DIR = os.environ.get('PROFILE_DIR', None)
36-
CACHE_DIR = os.environ.get('CACHE_DIR', None)
35+
PROFILE_DIR = os.environ.get("PROFILE_DIR", None)
36+
CACHE_DIR = os.environ.get("CACHE_DIR", None)
3737
if CACHE_DIR:
3838
xr.initialize_cache(CACHE_DIR, readonly=False)
3939
xr.use_spmd()
@@ -151,12 +151,24 @@ def start_training(self):
151151
dataloader_exception = True
152152
print(e)
153153
break
154-
if step == measure_start_step and PROFILE_DIR is not None:
154+
if step == measure_start_step and PROFILE_DIR is not None:
155155
xm.wait_device_ops()
156-
xp.trace_detached('localhost:9012', PROFILE_DIR, duration_ms=args.profile_duration)
157-
last_time = time.time()
156+
xp.trace_detached("localhost:9012", PROFILE_DIR, duration_ms=args.profile_duration)
157+
last_time = time.time()
158158
loss = self.step_fn(batch["pixel_values"], batch["input_ids"])
159159
self.global_step += 1
160+
161+
def print_loss_closure(step, loss):
162+
print(f"Step: {step}, Loss: {loss}")
163+
164+
if args.print_loss:
165+
xm.add_step_closure(
166+
print_loss_closure,
167+
args=(
168+
self.global_step,
169+
loss,
170+
),
171+
)
160172
xm.mark_step()
161173
if not dataloader_exception:
162174
xm.wait_device_ops()
@@ -170,7 +182,7 @@ def step_fn(
170182
self,
171183
pixel_values,
172184
input_ids,
173-
):
185+
):
174186
with xp.Trace("model.forward"):
175187
self.optimizer.zero_grad()
176188
latents = self.vae.encode(pixel_values).latent_dist.sample()
@@ -196,12 +208,8 @@ def step_fn(
196208
elif self.noise_scheduler.config.prediction_type == "v_prediction":
197209
target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
198210
else:
199-
raise ValueError(
200-
f"Unknown prediction type {self.noise_scheduler.config.prediction_type}"
201-
)
202-
model_pred = self.unet(
203-
noisy_latents, timesteps, encoder_hidden_states, return_dict=False
204-
)[0]
211+
raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
212+
model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
205213
with xp.Trace("model.backward"):
206214
if self.args.snr_gamma is None:
207215
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
@@ -210,9 +218,9 @@ def step_fn(
210218
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
211219
# This is discussed in Section 4.2 of the same paper.
212220
snr = compute_snr(self.noise_scheduler, timesteps)
213-
mse_loss_weights = torch.stack(
214-
[snr, self.args.snr_gamma * torch.ones_like(timesteps)], dim=1
215-
).min(dim=1)[0]
221+
mse_loss_weights = torch.stack([snr, self.args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
222+
dim=1
223+
)[0]
216224
if self.noise_scheduler.config.prediction_type == "epsilon":
217225
mse_loss_weights = mse_loss_weights / snr
218226
elif self.noise_scheduler.config.prediction_type == "v_prediction":
@@ -226,11 +234,10 @@ def step_fn(
226234
self.run_optimizer()
227235
return loss
228236

237+
229238
def parse_args():
230239
parser = argparse.ArgumentParser(description="Simple example of a training script.")
231-
parser.add_argument(
232-
"--profile_duration", type=int, default=10000, help="Profile duration in ms"
233-
)
240+
parser.add_argument("--profile_duration", type=int, default=10000, help="Profile duration in ms")
234241
parser.add_argument(
235242
"--pretrained_model_name_or_path",
236243
type=str,
@@ -359,25 +366,19 @@ def parse_args():
359366
"--loader_prefetch_size",
360367
type=int,
361368
default=1,
362-
help=(
363-
"Number of subprocesses to use for data loading to cpu."
364-
),
369+
help=("Number of subprocesses to use for data loading to cpu."),
365370
)
366371
parser.add_argument(
367372
"--loader_prefetch_factor",
368373
type=int,
369374
default=2,
370-
help=(
371-
"Number of batches loaded in advance by each worker."
372-
),
375+
help=("Number of batches loaded in advance by each worker."),
373376
)
374377
parser.add_argument(
375378
"--device_prefetch_size",
376379
type=int,
377380
default=1,
378-
help=(
379-
"Number of subprocesses to use for data loading to tpu from cpu. "
380-
),
381+
help=("Number of subprocesses to use for data loading to tpu from cpu. "),
381382
)
382383
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
383384
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
@@ -394,10 +395,7 @@ def parse_args():
394395
type=str,
395396
default=None,
396397
choices=["no", "bf16"],
397-
help=(
398-
"Whether to use mixed precision. Bf16 requires PyTorch >= 1.10"
399-
),
400-
398+
help=("Whether to use mixed precision. Bf16 requires PyTorch >= 1.10"),
401399
)
402400
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
403401
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
@@ -407,6 +405,12 @@ def parse_args():
407405
default=None,
408406
help="The name of the repository to keep in sync with the local `output_dir`.",
409407
)
408+
parser.add_argument(
409+
"--print_loss",
410+
default=False,
411+
action="store_true",
412+
help=("Print loss at every step."),
413+
)
410414

411415
args = parser.parse_args()
412416

@@ -416,6 +420,7 @@ def parse_args():
416420

417421
return args
418422

423+
419424
def setup_optimizer(unet, args):
420425
optimizer_cls = torch.optim.AdamW
421426
return optimizer_cls(
@@ -427,6 +432,7 @@ def setup_optimizer(unet, args):
427432
foreach=True,
428433
)
429434

435+
430436
def load_dataset(args):
431437
if args.dataset_name is not None:
432438
# Downloading and loading a dataset from the hub.
@@ -446,6 +452,7 @@ def load_dataset(args):
446452
)
447453
return dataset
448454

455+
449456
def get_column_names(dataset, args):
450457
column_names = dataset["train"].column_names
451458

@@ -470,13 +477,12 @@ def get_column_names(dataset, args):
470477

471478

472479
def main(args):
473-
474480
args = parse_args()
475481

476-
server = xp.start_server(9012)
482+
_ = xp.start_server(9012)
477483

478484
num_devices = xr.global_runtime_device_count()
479-
mesh = xs.get_1d_mesh('data')
485+
mesh = xs.get_1d_mesh("data")
480486
xs.set_global_mesh(mesh)
481487

482488
text_encoder = CLIPTextModel.from_pretrained(
@@ -511,6 +517,7 @@ def main(args):
511517
)
512518

513519
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
520+
514521
unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward)
515522

516523
vae.requires_grad_(False)
@@ -562,19 +569,9 @@ def tokenize_captions(examples, is_train=True):
562569

563570
train_transforms = transforms.Compose(
564571
[
565-
transforms.Resize(
566-
args.resolution, interpolation=transforms.InterpolationMode.BILINEAR
567-
),
568-
(
569-
transforms.CenterCrop(args.resolution)
570-
if args.center_crop
571-
else transforms.RandomCrop(args.resolution)
572-
),
573-
(
574-
transforms.RandomHorizontalFlip()
575-
if args.random_flip
576-
else transforms.Lambda(lambda x: x)
577-
),
572+
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
573+
(transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)),
574+
(transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x)),
578575
transforms.ToTensor(),
579576
transforms.Normalize([0.5], [0.5]),
580577
]
@@ -592,17 +589,13 @@ def preprocess_train(examples):
592589

593590
def collate_fn(examples):
594591
pixel_values = torch.stack([example["pixel_values"] for example in examples])
595-
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).to(
596-
weight_dtype
597-
)
592+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).to(weight_dtype)
598593
input_ids = torch.stack([example["input_ids"] for example in examples])
599594
return {"pixel_values": pixel_values, "input_ids": input_ids}
600595

601596
g = torch.Generator()
602597
g.manual_seed(xr.host_index())
603-
sampler = torch.utils.data.RandomSampler(
604-
train_dataset, replacement=True, num_samples=int(1e10), generator=g
605-
)
598+
sampler = torch.utils.data.RandomSampler(train_dataset, replacement=True, num_samples=int(1e10), generator=g)
606599
train_dataloader = torch.utils.data.DataLoader(
607600
train_dataset,
608601
sampler=sampler,
@@ -616,9 +609,7 @@ def collate_fn(examples):
616609
train_dataloader,
617610
device,
618611
input_sharding={
619-
"pixel_values": xs.ShardingSpec(
620-
mesh, ("data", None, None, None), minibatch=True
621-
),
612+
"pixel_values": xs.ShardingSpec(mesh, ("data", None, None, None), minibatch=True),
622613
"input_ids": xs.ShardingSpec(mesh, ("data", None), minibatch=True),
623614
},
624615
loader_prefetch_size=args.loader_prefetch_size,
@@ -635,15 +626,17 @@ def collate_fn(examples):
635626
)
636627
print(f" Total optimization steps = {args.max_train_steps}")
637628

638-
trainer = TrainSD(vae=vae,
639-
weight_dtype=weight_dtype,
640-
device=device,
641-
noise_scheduler=noise_scheduler,
642-
unet=unet,
643-
optimizer=optimizer,
644-
text_encoder=text_encoder,
645-
dataloader=train_dataloader,
646-
args=args)
629+
trainer = TrainSD(
630+
vae=vae,
631+
weight_dtype=weight_dtype,
632+
device=device,
633+
noise_scheduler=noise_scheduler,
634+
unet=unet,
635+
optimizer=optimizer,
636+
text_encoder=text_encoder,
637+
dataloader=train_dataloader,
638+
args=args,
639+
)
647640

648641
trainer.start_training()
649642
unet = trainer.unet.to("cpu")
@@ -672,4 +665,4 @@ def collate_fn(examples):
672665

673666
if __name__ == "__main__":
674667
args = parse_args()
675-
main(args)
668+
main(args)

src/diffusers/models/attention_processor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,12 @@
3838

3939
if is_torch_xla_available():
4040
from torch_xla.experimental.custom_kernel import flash_attention
41+
4142
XLA_AVAILABLE = True
4243
else:
4344
XLA_AVAILABLE = False
4445

46+
4547
@maybe_allow_in_graph
4648
class Attention(nn.Module):
4749
r"""
@@ -2483,12 +2485,16 @@ def __call__(
24832485
if attention_mask is not None:
24842486
attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
24852487
# Convert mask to float and replace 0s with -inf and 1s with 0
2486-
attention_mask = attention_mask.float().masked_fill(attention_mask == 0, float('-inf')).masked_fill(attention_mask == 1, float(0.0))
2488+
attention_mask = (
2489+
attention_mask.float()
2490+
.masked_fill(attention_mask == 0, float("-inf"))
2491+
.masked_fill(attention_mask == 1, float(0.0))
2492+
)
24872493

24882494
# Apply attention mask to key
24892495
key = key + attention_mask
24902496
query /= math.sqrt(query.shape[3])
2491-
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=('data', None, None, None))
2497+
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=("data", None, None, None))
24922498
else:
24932499
hidden_states = F.scaled_dot_product_attention(
24942500
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False

0 commit comments

Comments
 (0)