Skip to content

Commit d78acde

Browse files
bghirabghirasayakpaul
authored
apple mps: training support for SDXL (ControlNet, LoRA, Dreambooth, T2I) (#7447)
* apple mps: training support for SDXL LoRA * sdxl: support training lora, dreambooth, t2i, pix2pix, and controlnet on apple mps --------- Co-authored-by: bghira <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 6df103d commit d78acde

File tree

5 files changed

+94
-20
lines changed

5 files changed

+94
-20
lines changed

examples/controlnet/train_controlnet_sdxl.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,11 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
125125
)
126126

127127
image_logs = []
128-
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
128+
inference_ctx = (
129+
contextlib.nullcontext()
130+
if (is_final_validation or torch.backends.mps.is_available())
131+
else torch.autocast("cuda")
132+
)
129133

130134
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
131135
validation_image = Image.open(validation_image).convert("RGB")
@@ -792,6 +796,12 @@ def main(args):
792796

793797
logging_dir = Path(args.output_dir, args.logging_dir)
794798

799+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
800+
# due to pytorch#99272, MPS does not yet support bfloat16.
801+
raise ValueError(
802+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
803+
)
804+
795805
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
796806

797807
accelerator = Accelerator(

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515

1616
import argparse
17-
import contextlib
1817
import gc
1918
import itertools
2019
import json
@@ -208,11 +207,18 @@ def log_validation(
208207
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
209208
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
210209
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
211-
inference_ctx = (
212-
contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
213-
)
210+
enable_autocast = True
211+
if torch.backends.mps.is_available() or (
212+
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
213+
):
214+
enable_autocast = False
215+
if "playground" in args.pretrained_model_name_or_path:
216+
enable_autocast = False
214217

215-
with inference_ctx:
218+
with torch.autocast(
219+
accelerator.device.type,
220+
enabled=enable_autocast,
221+
):
216222
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
217223

218224
for tracker in accelerator.trackers:
@@ -230,7 +236,8 @@ def log_validation(
230236
)
231237

232238
del pipeline
233-
torch.cuda.empty_cache()
239+
if torch.cuda.is_available():
240+
torch.cuda.empty_cache()
234241

235242
return images
236243

@@ -967,6 +974,12 @@ def main(args):
967974
if args.do_edm_style_training and args.snr_gamma is not None:
968975
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
969976

977+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
978+
# due to pytorch#99272, MPS does not yet support bfloat16.
979+
raise ValueError(
980+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
981+
)
982+
970983
logging_dir = Path(args.output_dir, args.logging_dir)
971984

972985
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
@@ -1009,7 +1022,8 @@ def main(args):
10091022
cur_class_images = len(list(class_images_dir.iterdir()))
10101023

10111024
if cur_class_images < args.num_class_images:
1012-
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
1025+
has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
1026+
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
10131027
if args.prior_generation_precision == "fp32":
10141028
torch_dtype = torch.float32
10151029
elif args.prior_generation_precision == "fp16":
@@ -1134,6 +1148,12 @@ def main(args):
11341148
elif accelerator.mixed_precision == "bf16":
11351149
weight_dtype = torch.bfloat16
11361150

1151+
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
1152+
# due to pytorch#99272, MPS does not yet support bfloat16.
1153+
raise ValueError(
1154+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
1155+
)
1156+
11371157
# Move unet, vae and text_encoder to device and cast to weight_dtype
11381158
unet.to(accelerator.device, dtype=weight_dtype)
11391159

@@ -1278,7 +1298,7 @@ def load_model_hook(models, input_dir):
12781298

12791299
# Enable TF32 for faster training on Ampere GPUs,
12801300
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1281-
if args.allow_tf32:
1301+
if args.allow_tf32 and torch.cuda.is_available():
12821302
torch.backends.cuda.matmul.allow_tf32 = True
12831303

12841304
if args.scale_lr:
@@ -1455,7 +1475,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14551475
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
14561476
del tokenizers, text_encoders
14571477
gc.collect()
1458-
torch.cuda.empty_cache()
1478+
if torch.cuda.is_available():
1479+
torch.cuda.empty_cache()
14591480

14601481
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
14611482
# pack the statically computed variables appropriately here. This is so that we don't

examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,7 @@
7171

7272

7373
def log_validation(
74-
pipeline,
75-
args,
76-
accelerator,
77-
generator,
78-
global_step,
79-
is_final_validation=False,
74+
pipeline, args, accelerator, generator, global_step, is_final_validation=False, enable_autocast=True
8075
):
8176
logger.info(
8277
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
@@ -96,7 +91,7 @@ def log_validation(
9691
else Image.open(image_url_or_path).convert("RGB")
9792
)(args.val_image_url_or_path)
9893

99-
with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"):
94+
with torch.autocast(accelerator.device.type, enabled=enable_autocast):
10095
edited_images = []
10196
# Run inference
10297
for val_img_idx in range(args.num_validation_images):
@@ -497,6 +492,13 @@ def main():
497492
),
498493
)
499494
logging_dir = os.path.join(args.output_dir, args.logging_dir)
495+
496+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
497+
# due to pytorch#99272, MPS does not yet support bfloat16.
498+
raise ValueError(
499+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
500+
)
501+
500502
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
501503
accelerator = Accelerator(
502504
gradient_accumulation_steps=args.gradient_accumulation_steps,
@@ -981,6 +983,13 @@ def collate_fn(examples):
981983
if accelerator.is_main_process:
982984
accelerator.init_trackers("instruct-pix2pix-xl", config=vars(args))
983985

986+
# Some configurations require autocast to be disabled.
987+
enable_autocast = True
988+
if torch.backends.mps.is_available() or (
989+
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
990+
):
991+
enable_autocast = False
992+
984993
# Train!
985994
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
986995

@@ -1193,6 +1202,7 @@ def collate_fn(examples):
11931202
generator,
11941203
global_step,
11951204
is_final_validation=False,
1205+
enable_autocast=enable_autocast,
11961206
)
11971207

11981208
if args.use_ema:
@@ -1242,6 +1252,7 @@ def collate_fn(examples):
12421252
generator,
12431253
global_step,
12441254
is_final_validation=True,
1255+
enable_autocast=enable_autocast,
12451256
)
12461257

12471258
accelerator.end_training()

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,12 @@ def main(args):
501501

502502
logging_dir = Path(args.output_dir, args.logging_dir)
503503

504+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
505+
# due to pytorch#99272, MPS does not yet support bfloat16.
506+
raise ValueError(
507+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
508+
)
509+
504510
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
505511
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
506512
accelerator = Accelerator(
@@ -973,6 +979,13 @@ def collate_fn(examples):
973979
if accelerator.is_main_process:
974980
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
975981

982+
# Some configurations require autocast to be disabled.
983+
enable_autocast = True
984+
if torch.backends.mps.is_available() or (
985+
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
986+
):
987+
enable_autocast = False
988+
976989
# Train!
977990
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
978991

@@ -1199,7 +1212,10 @@ def compute_time_ids(original_size, crops_coords_top_left):
11991212
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
12001213
pipeline_args = {"prompt": args.validation_prompt}
12011214

1202-
with torch.cuda.amp.autocast():
1215+
with torch.autocast(
1216+
accelerator.device.type,
1217+
enabled=enable_autocast,
1218+
):
12031219
images = [
12041220
pipeline(**pipeline_args, generator=generator).images[0]
12051221
for _ in range(args.num_validation_images)

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,12 @@ def main(args):
590590

591591
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
592592

593+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
594+
# due to pytorch#99272, MPS does not yet support bfloat16.
595+
raise ValueError(
596+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
597+
)
598+
593599
accelerator = Accelerator(
594600
gradient_accumulation_steps=args.gradient_accumulation_steps,
595601
mixed_precision=args.mixed_precision,
@@ -980,6 +986,13 @@ def unwrap_model(model):
980986
model = model._orig_mod if is_compiled_module(model) else model
981987
return model
982988

989+
# Some configurations require autocast to be disabled.
990+
enable_autocast = True
991+
if torch.backends.mps.is_available() or (
992+
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
993+
):
994+
enable_autocast = False
995+
983996
# Train!
984997
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
985998

@@ -1213,7 +1226,10 @@ def compute_time_ids(original_size, crops_coords_top_left):
12131226
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
12141227
pipeline_args = {"prompt": args.validation_prompt}
12151228

1216-
with torch.cuda.amp.autocast():
1229+
with torch.autocast(
1230+
accelerator.device.type,
1231+
enabled=enable_autocast,
1232+
):
12171233
images = [
12181234
pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
12191235
for _ in range(args.num_validation_images)
@@ -1268,7 +1284,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
12681284
if args.validation_prompt and args.num_validation_images > 0:
12691285
pipeline = pipeline.to(accelerator.device)
12701286
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1271-
with torch.cuda.amp.autocast():
1287+
with torch.autocast(accelerator.device.type, enabled=enable_autocast):
12721288
images = [
12731289
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
12741290
for _ in range(args.num_validation_images)

0 commit comments

Comments
 (0)