|  | 
| 59 | 59 |     compute_density_for_timestep_sampling, | 
| 60 | 60 |     compute_loss_weighting_for_sd3, | 
| 61 | 61 |     free_memory, | 
|  | 62 | +    offload_models, | 
| 62 | 63 | ) | 
| 63 | 64 | from diffusers.utils import ( | 
| 64 | 65 |     check_min_version, | 
| @@ -1375,43 +1376,34 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): | 
| 1375 | 1376 |     # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid | 
| 1376 | 1377 |     # the redundant encoding. | 
| 1377 | 1378 |     if not train_dataset.custom_instance_prompts: | 
| 1378 |  | -        if args.offload: | 
| 1379 |  | -            text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) | 
| 1380 |  | -        ( | 
| 1381 |  | -            instance_prompt_hidden_states_t5, | 
| 1382 |  | -            instance_prompt_hidden_states_llama3, | 
| 1383 |  | -            instance_pooled_prompt_embeds, | 
| 1384 |  | -            _, | 
| 1385 |  | -            _, | 
| 1386 |  | -            _, | 
| 1387 |  | -        ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline) | 
| 1388 |  | -        if args.offload: | 
| 1389 |  | -            text_encoding_pipeline = text_encoding_pipeline.to("cpu") | 
|  | 1379 | +        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): | 
|  | 1380 | +            ( | 
|  | 1381 | +                instance_prompt_hidden_states_t5, | 
|  | 1382 | +                instance_prompt_hidden_states_llama3, | 
|  | 1383 | +                instance_pooled_prompt_embeds, | 
|  | 1384 | +                _, | 
|  | 1385 | +                _, | 
|  | 1386 | +                _, | 
|  | 1387 | +            ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline) | 
| 1390 | 1388 | 
 | 
| 1391 | 1389 |     # Handle class prompt for prior-preservation. | 
| 1392 | 1390 |     if args.with_prior_preservation: | 
| 1393 |  | -        if args.offload: | 
| 1394 |  | -            text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) | 
| 1395 |  | -        (class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = ( | 
| 1396 |  | -            compute_text_embeddings(args.class_prompt, text_encoding_pipeline) | 
| 1397 |  | -        ) | 
| 1398 |  | -        if args.offload: | 
| 1399 |  | -            text_encoding_pipeline = text_encoding_pipeline.to("cpu") | 
|  | 1391 | +        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): | 
|  | 1392 | +            (class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = ( | 
|  | 1393 | +                compute_text_embeddings(args.class_prompt, text_encoding_pipeline) | 
|  | 1394 | +            ) | 
| 1400 | 1395 | 
 | 
| 1401 | 1396 |     validation_embeddings = {} | 
| 1402 | 1397 |     if args.validation_prompt is not None: | 
| 1403 |  | -        if args.offload: | 
| 1404 |  | -            text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) | 
| 1405 |  | -        ( | 
| 1406 |  | -            validation_embeddings["prompt_embeds_t5"], | 
| 1407 |  | -            validation_embeddings["prompt_embeds_llama3"], | 
| 1408 |  | -            validation_embeddings["pooled_prompt_embeds"], | 
| 1409 |  | -            validation_embeddings["negative_prompt_embeds_t5"], | 
| 1410 |  | -            validation_embeddings["negative_prompt_embeds_llama3"], | 
| 1411 |  | -            validation_embeddings["negative_pooled_prompt_embeds"], | 
| 1412 |  | -        ) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline) | 
| 1413 |  | -        if args.offload: | 
| 1414 |  | -            text_encoding_pipeline = text_encoding_pipeline.to("cpu") | 
|  | 1398 | +        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): | 
|  | 1399 | +            ( | 
|  | 1400 | +                validation_embeddings["prompt_embeds_t5"], | 
|  | 1401 | +                validation_embeddings["prompt_embeds_llama3"], | 
|  | 1402 | +                validation_embeddings["pooled_prompt_embeds"], | 
|  | 1403 | +                validation_embeddings["negative_prompt_embeds_t5"], | 
|  | 1404 | +                validation_embeddings["negative_prompt_embeds_llama3"], | 
|  | 1405 | +                validation_embeddings["negative_pooled_prompt_embeds"], | 
|  | 1406 | +            ) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline) | 
| 1415 | 1407 | 
 | 
| 1416 | 1408 |     # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), | 
| 1417 | 1409 |     # pack the statically computed variables appropriately here. This is so that we don't | 
| @@ -1593,12 +1585,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): | 
| 1593 | 1585 |                 if args.cache_latents: | 
| 1594 | 1586 |                     model_input = latents_cache[step].sample() | 
| 1595 | 1587 |                 else: | 
| 1596 |  | -                    if args.offload: | 
| 1597 |  | -                        vae = vae.to(accelerator.device) | 
| 1598 |  | -                    pixel_values = batch["pixel_values"].to(dtype=vae.dtype) | 
|  | 1588 | +                    with offload_models(vae, device=accelerator.device, offload=args.offload): | 
|  | 1589 | +                        pixel_values = batch["pixel_values"].to(dtype=vae.dtype) | 
| 1599 | 1590 |                     model_input = vae.encode(pixel_values).latent_dist.sample() | 
| 1600 |  | -                    if args.offload: | 
| 1601 |  | -                        vae = vae.to("cpu") | 
|  | 1591 | + | 
| 1602 | 1592 |                 model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor | 
| 1603 | 1593 |                 model_input = model_input.to(dtype=weight_dtype) | 
| 1604 | 1594 | 
 | 
|  | 
0 commit comments