| 
58 | 58 |     compute_density_for_timestep_sampling,  | 
59 | 59 |     compute_loss_weighting_for_sd3,  | 
60 | 60 |     free_memory,  | 
 | 61 | +    offload_models,  | 
61 | 62 | )  | 
62 | 63 | from diffusers.utils import (  | 
63 | 64 |     check_min_version,  | 
@@ -1364,43 +1365,34 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):  | 
1364 | 1365 |     # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid  | 
1365 | 1366 |     # the redundant encoding.  | 
1366 | 1367 |     if not train_dataset.custom_instance_prompts:  | 
1367 |  | -        if args.offload:  | 
1368 |  | -            text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)  | 
1369 |  | -        (  | 
1370 |  | -            instance_prompt_hidden_states_t5,  | 
1371 |  | -            instance_prompt_hidden_states_llama3,  | 
1372 |  | -            instance_pooled_prompt_embeds,  | 
1373 |  | -            _,  | 
1374 |  | -            _,  | 
1375 |  | -            _,  | 
1376 |  | -        ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)  | 
1377 |  | -        if args.offload:  | 
1378 |  | -            text_encoding_pipeline = text_encoding_pipeline.to("cpu")  | 
 | 1368 | +        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):  | 
 | 1369 | +            (  | 
 | 1370 | +                instance_prompt_hidden_states_t5,  | 
 | 1371 | +                instance_prompt_hidden_states_llama3,  | 
 | 1372 | +                instance_pooled_prompt_embeds,  | 
 | 1373 | +                _,  | 
 | 1374 | +                _,  | 
 | 1375 | +                _,  | 
 | 1376 | +            ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)  | 
1379 | 1377 | 
 
  | 
1380 | 1378 |     # Handle class prompt for prior-preservation.  | 
1381 | 1379 |     if args.with_prior_preservation:  | 
1382 |  | -        if args.offload:  | 
1383 |  | -            text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)  | 
1384 |  | -        (class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (  | 
1385 |  | -            compute_text_embeddings(args.class_prompt, text_encoding_pipeline)  | 
1386 |  | -        )  | 
1387 |  | -        if args.offload:  | 
1388 |  | -            text_encoding_pipeline = text_encoding_pipeline.to("cpu")  | 
 | 1380 | +        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):  | 
 | 1381 | +            (class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (  | 
 | 1382 | +                compute_text_embeddings(args.class_prompt, text_encoding_pipeline)  | 
 | 1383 | +            )  | 
1389 | 1384 | 
 
  | 
1390 | 1385 |     validation_embeddings = {}  | 
1391 | 1386 |     if args.validation_prompt is not None:  | 
1392 |  | -        if args.offload:  | 
1393 |  | -            text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)  | 
1394 |  | -        (  | 
1395 |  | -            validation_embeddings["prompt_embeds_t5"],  | 
1396 |  | -            validation_embeddings["prompt_embeds_llama3"],  | 
1397 |  | -            validation_embeddings["pooled_prompt_embeds"],  | 
1398 |  | -            validation_embeddings["negative_prompt_embeds_t5"],  | 
1399 |  | -            validation_embeddings["negative_prompt_embeds_llama3"],  | 
1400 |  | -            validation_embeddings["negative_pooled_prompt_embeds"],  | 
1401 |  | -        ) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)  | 
1402 |  | -        if args.offload:  | 
1403 |  | -            text_encoding_pipeline = text_encoding_pipeline.to("cpu")  | 
 | 1387 | +        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):  | 
 | 1388 | +            (  | 
 | 1389 | +                validation_embeddings["prompt_embeds_t5"],  | 
 | 1390 | +                validation_embeddings["prompt_embeds_llama3"],  | 
 | 1391 | +                validation_embeddings["pooled_prompt_embeds"],  | 
 | 1392 | +                validation_embeddings["negative_prompt_embeds_t5"],  | 
 | 1393 | +                validation_embeddings["negative_prompt_embeds_llama3"],  | 
 | 1394 | +                validation_embeddings["negative_pooled_prompt_embeds"],  | 
 | 1395 | +            ) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)  | 
1404 | 1396 | 
 
  | 
1405 | 1397 |     # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),  | 
1406 | 1398 |     # pack the statically computed variables appropriately here. This is so that we don't  | 
@@ -1581,12 +1573,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):  | 
1581 | 1573 |                 if args.cache_latents:  | 
1582 | 1574 |                     model_input = latents_cache[step].sample()  | 
1583 | 1575 |                 else:  | 
1584 |  | -                    if args.offload:  | 
1585 |  | -                        vae = vae.to(accelerator.device)  | 
1586 |  | -                    pixel_values = batch["pixel_values"].to(dtype=vae.dtype)  | 
 | 1576 | +                    with offload_models(vae, device=accelerator.device, offload=args.offload):  | 
 | 1577 | +                        pixel_values = batch["pixel_values"].to(dtype=vae.dtype)  | 
1587 | 1578 |                     model_input = vae.encode(pixel_values).latent_dist.sample()  | 
1588 |  | -                    if args.offload:  | 
1589 |  | -                        vae = vae.to("cpu")  | 
 | 1579 | + | 
1590 | 1580 |                 model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor  | 
1591 | 1581 |                 model_input = model_input.to(dtype=weight_dtype)  | 
1592 | 1582 | 
 
  | 
 | 
0 commit comments