|  | 
| 58 | 58 |     compute_density_for_timestep_sampling, | 
| 59 | 59 |     compute_loss_weighting_for_sd3, | 
| 60 | 60 |     free_memory, | 
|  | 61 | +    offload_models, | 
| 61 | 62 | ) | 
| 62 | 63 | from diffusers.utils import ( | 
| 63 | 64 |     check_min_version, | 
| @@ -1364,43 +1365,34 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): | 
| 1364 | 1365 |     # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid | 
| 1365 | 1366 |     # the redundant encoding. | 
| 1366 | 1367 |     if not train_dataset.custom_instance_prompts: | 
| 1367 |  | -        if args.offload: | 
| 1368 |  | -            text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) | 
| 1369 |  | -        ( | 
| 1370 |  | -            instance_prompt_hidden_states_t5, | 
| 1371 |  | -            instance_prompt_hidden_states_llama3, | 
| 1372 |  | -            instance_pooled_prompt_embeds, | 
| 1373 |  | -            _, | 
| 1374 |  | -            _, | 
| 1375 |  | -            _, | 
| 1376 |  | -        ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline) | 
| 1377 |  | -        if args.offload: | 
| 1378 |  | -            text_encoding_pipeline = text_encoding_pipeline.to("cpu") | 
|  | 1368 | +        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): | 
|  | 1369 | +            ( | 
|  | 1370 | +                instance_prompt_hidden_states_t5, | 
|  | 1371 | +                instance_prompt_hidden_states_llama3, | 
|  | 1372 | +                instance_pooled_prompt_embeds, | 
|  | 1373 | +                _, | 
|  | 1374 | +                _, | 
|  | 1375 | +                _, | 
|  | 1376 | +            ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline) | 
| 1379 | 1377 | 
 | 
| 1380 | 1378 |     # Handle class prompt for prior-preservation. | 
| 1381 | 1379 |     if args.with_prior_preservation: | 
| 1382 |  | -        if args.offload: | 
| 1383 |  | -            text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) | 
| 1384 |  | -        (class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = ( | 
| 1385 |  | -            compute_text_embeddings(args.class_prompt, text_encoding_pipeline) | 
| 1386 |  | -        ) | 
| 1387 |  | -        if args.offload: | 
| 1388 |  | -            text_encoding_pipeline = text_encoding_pipeline.to("cpu") | 
|  | 1380 | +        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): | 
|  | 1381 | +            (class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = ( | 
|  | 1382 | +                compute_text_embeddings(args.class_prompt, text_encoding_pipeline) | 
|  | 1383 | +            ) | 
| 1389 | 1384 | 
 | 
| 1390 | 1385 |     validation_embeddings = {} | 
| 1391 | 1386 |     if args.validation_prompt is not None: | 
| 1392 |  | -        if args.offload: | 
| 1393 |  | -            text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) | 
| 1394 |  | -        ( | 
| 1395 |  | -            validation_embeddings["prompt_embeds_t5"], | 
| 1396 |  | -            validation_embeddings["prompt_embeds_llama3"], | 
| 1397 |  | -            validation_embeddings["pooled_prompt_embeds"], | 
| 1398 |  | -            validation_embeddings["negative_prompt_embeds_t5"], | 
| 1399 |  | -            validation_embeddings["negative_prompt_embeds_llama3"], | 
| 1400 |  | -            validation_embeddings["negative_pooled_prompt_embeds"], | 
| 1401 |  | -        ) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline) | 
| 1402 |  | -        if args.offload: | 
| 1403 |  | -            text_encoding_pipeline = text_encoding_pipeline.to("cpu") | 
|  | 1387 | +        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): | 
|  | 1388 | +            ( | 
|  | 1389 | +                validation_embeddings["prompt_embeds_t5"], | 
|  | 1390 | +                validation_embeddings["prompt_embeds_llama3"], | 
|  | 1391 | +                validation_embeddings["pooled_prompt_embeds"], | 
|  | 1392 | +                validation_embeddings["negative_prompt_embeds_t5"], | 
|  | 1393 | +                validation_embeddings["negative_prompt_embeds_llama3"], | 
|  | 1394 | +                validation_embeddings["negative_pooled_prompt_embeds"], | 
|  | 1395 | +            ) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline) | 
| 1404 | 1396 | 
 | 
| 1405 | 1397 |     # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), | 
| 1406 | 1398 |     # pack the statically computed variables appropriately here. This is so that we don't | 
| @@ -1581,12 +1573,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): | 
| 1581 | 1573 |                 if args.cache_latents: | 
| 1582 | 1574 |                     model_input = latents_cache[step].sample() | 
| 1583 | 1575 |                 else: | 
| 1584 |  | -                    if args.offload: | 
| 1585 |  | -                        vae = vae.to(accelerator.device) | 
| 1586 |  | -                    pixel_values = batch["pixel_values"].to(dtype=vae.dtype) | 
|  | 1576 | +                    with offload_models(vae, device=accelerator.device, offload=args.offload): | 
|  | 1577 | +                        pixel_values = batch["pixel_values"].to(dtype=vae.dtype) | 
| 1587 | 1578 |                     model_input = vae.encode(pixel_values).latent_dist.sample() | 
| 1588 |  | -                    if args.offload: | 
| 1589 |  | -                        vae = vae.to("cpu") | 
|  | 1579 | + | 
| 1590 | 1580 |                 model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor | 
| 1591 | 1581 |                 model_input = model_input.to(dtype=weight_dtype) | 
| 1592 | 1582 | 
 | 
|  | 
0 commit comments