Skip to content

Commit 2809859

Browse files
save checkpoint. Inference example.
1 parent f8a48d6 commit 2809859

File tree

2 files changed

+54
-32
lines changed

2 files changed

+54
-32
lines changed

examples/research_projects/pytorch_xla/training/text_to_image/README_sdxl.md

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ export TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1
100100
export PROFILE_DIR=/tmp/
101101
export CACHE_DIR=/tmp/
102102
export DATASET_NAME=lambdalabs/naruto-blip-captions
103-
export GLOBAL_BATCH_SIZE=32
103+
export PER_HOST_BATCH_SIZE=32
104104
export TRAIN_STEPS=50
105105
export OUTPUT_DIR=/tmp/trained-model/
106-
python examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 --dataset_name=$DATASET_NAME --resolution=1024 --center_crop --random_flip --train_batch_size=$GLOBAL_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=16 --device_prefetch_size=16'
106+
python examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 --dataset_name=$DATASET_NAME --resolution=1024 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=16 --device_prefetch_size=16'
107107
```
108108
109109
Pass `--print_loss` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer.
@@ -135,21 +135,23 @@ import numpy as np
135135
136136
import torch_xla.core.xla_model as xm
137137
from time import time
138-
from diffusers import StableDiffusionPipeline
138+
from diffusers import StableDiffusionXLPipeline
139139
import torch_xla.runtime as xr
140140
141+
MODEL_PATH = os.environ.get("OUTPUT_DIR", None)
142+
141143
CACHE_DIR = os.environ.get("CACHE_DIR", None)
142144
if CACHE_DIR:
143145
xr.initialize_cache(CACHE_DIR, readonly=False)
144146
145147
def main():
146148
device = xm.xla_device()
147-
model_path = "jffacevedo/pxla_trained_model"
148-
pipe = StableDiffusionPipeline.from_pretrained(
149-
model_path,
149+
pipe = StableDiffusionXLPipeline.from_pretrained(
150+
MODEL_PATH,
150151
torch_dtype=torch.bfloat16
151152
)
152153
pipe.to(device)
154+
pipe.unet.enable_xla_flash_attention(partition_spec=("data", None, None, None))
153155
prompt = ["A naruto with green eyes and red legs."]
154156
start = time()
155157
print("compiling...")
@@ -168,9 +170,12 @@ if __name__ == '__main__':
168170
Expected Results:
169171
170172
```bash
173+
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
174+
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 7.93it/s]
171175
compiling...
172-
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [10:03<00:00, 20.10s/it]
173-
compile time: 720.656970500946
176+
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [01:35<00:00, 3.19s/it]
177+
compile time: 241.23492813110352
174178
generate...
175-
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 17.65it/s]
176-
generation time (after compile) : 1.8461642265319824
179+
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:04<00:00, 6.72it/s]
180+
generation time (after compile) : 5.266263246536255
181+
```

examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -774,28 +774,45 @@ def collate_fn(examples):
774774
args=args,
775775
)
776776
trainer.start_training()
777-
# unet = trainer.unet.to("cpu")
778-
# vae = trainer.vae.to("cpu")
779-
# text_encoder = trainer.text_encoder.to("cpu")
780-
781-
# pipeline = StableDiffusionXLPipeline.from_pretrained(
782-
# args.pretrained_model_name_or_path,
783-
# text_encoder=text_encoder,
784-
# vae=vae,
785-
# unet=unet,
786-
# revision=args.revision,
787-
# variant=args.variant,
788-
# )
789-
# pipeline.save_pretrained(args.output_dir)
790-
791-
# if xm.is_master_ordinal() and args.push_to_hub:
792-
# save_model_card(args, repo_id, repo_folder=args.output_dir)
793-
# upload_folder(
794-
# repo_id=repo_id,
795-
# folder_path=args.output_dir,
796-
# commit_message="End of training",
797-
# ignore_patterns=["step_*", "epoch_*"],
798-
# )
777+
unet = trainer.unet.to("cpu")
778+
779+
text_encoder = CLIPTextModel.from_pretrained(
780+
args.pretrained_model_name_or_path,
781+
subfolder="text_encoder",
782+
revision=args.revision,
783+
variant=args.variant,
784+
)
785+
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
786+
args.pretrained_model_name_or_path,
787+
subfolder="text_encoder_2",
788+
revision=args.revision,
789+
variant=args.variant,
790+
)
791+
vae = AutoencoderKL.from_pretrained(
792+
args.pretrained_model_name_or_path,
793+
subfolder="vae",
794+
revision=args.revision,
795+
variant=args.variant,
796+
)
797+
798+
pipeline = StableDiffusionXLPipeline.from_pretrained(
799+
args.pretrained_model_name_or_path,
800+
text_encoder=text_encoder,
801+
vae=vae,
802+
unet=unet,
803+
revision=args.revision,
804+
variant=args.variant,
805+
)
806+
pipeline.save_pretrained(args.output_dir)
807+
808+
if xm.is_master_ordinal() and args.push_to_hub:
809+
save_model_card(args, repo_id, repo_folder=args.output_dir)
810+
upload_folder(
811+
repo_id=repo_id,
812+
folder_path=args.output_dir,
813+
commit_message="End of training",
814+
ignore_patterns=["step_*", "epoch_*"],
815+
)
799816

800817

801818
if __name__ == "__main__":

0 commit comments

Comments
 (0)