diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a282ca717a9f..b331e4b13760 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -56,7 +56,7 @@ - local: using-diffusers/overview_techniques title: Overview - local: training/distributed_inference - title: Distributed inference with multiple GPUs + title: Distributed inference - local: using-diffusers/merge_loras title: Merge LoRAs - local: using-diffusers/scheduler_features diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index 5c371033dfd5..cd642d6aca07 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Distributed inference with multiple GPUs +# Distributed inference On distributed setups, you can run inference across multiple GPUs with 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) or [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html), which is useful for generating with multiple prompts in parallel. @@ -109,3 +109,131 @@ torchrun run_distributed.py --nproc_per_node=2 > [!TIP] > You can use `device_map` within a [`DiffusionPipeline`] to distribute its model-level components on multiple devices. Refer to the [Device placement](../tutorials/inference_with_big_models#device-placement) guide to learn more. + +## Model sharding + +Modern diffusion systems such as [Flux](../api/pipelines/flux) are very large and have multiple models. For example, [Flux.1-Dev](https://hf.co/black-forest-labs/FLUX.1-dev) is made up of two text encoders - [T5-XXL](https://hf.co/google/t5-v1_1-xxl) and [CLIP-L](https://hf.co/openai/clip-vit-large-patch14) - a [diffusion transformer](../api/models/flux_transformer), and a [VAE](../api/models/autoencoderkl). With a model this size, it can be challenging to run inference on consumer GPUs. + +Model sharding is a technique that distributes models across GPUs when the models don't fit on a single GPU. The example below assumes two 16GB GPUs are available for inference. + +Start by computing the text embeddings with the text encoders. Keep the text encoders on two GPUs by setting `device_map="balanced"`. The `balanced` strategy evenly distributes the model on all available GPUs. Use the `max_memory` parameter to allocate the maximum amount of memory for each text encoder on each GPU. + +> [!TIP] +> **Only** load the text encoders for this step! The diffusion transformer and VAE are loaded in a later step to preserve memory. + +```py +from diffusers import FluxPipeline +import torch + +prompt = "a photo of a dog with cat-like look" + +pipeline = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + transformer=None, + vae=None, + device_map="balanced", + max_memory={0: "16GB", 1: "16GB"}, + torch_dtype=torch.bfloat16 +) +with torch.no_grad(): + print("Encoding prompts.") + prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( + prompt=prompt, prompt_2=None, max_sequence_length=512 + ) +``` + +Once the text embeddings are computed, remove them from the GPU to make space for the diffusion transformer. + +```py +import gc + +def flush(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + +del pipeline.text_encoder +del pipeline.text_encoder_2 +del pipeline.tokenizer +del pipeline.tokenizer_2 +del pipeline + +flush() +``` + +Load the diffusion transformer next which has 12.5B parameters. This time, set `device_map="auto"` to automatically distribute the model across two 16GB GPUs. The `auto` strategy is backed by [Accelerate](https://hf.co/docs/accelerate/index) and available as a part of the [Big Model Inference](https://hf.co/docs/accelerate/concept_guides/big_model_inference) feature. It starts by distributing a model across the fastest device first (GPU) before moving to slower devices like the CPU and hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency. + +```py +from diffusers import FluxTransformer2DModel +import torch + +transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + device_map="auto", + torch_dtype=torch.bfloat16 +) +``` + +> [!TIP] +> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models. + +Add the transformer model to the pipeline for denoising, but set the other model-level components like the text encoders and VAE to `None` because you don't need them yet. + +```py +pipeline = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", , + text_encoder=None, + text_encoder_2=None, + tokenizer=None, + tokenizer_2=None, + vae=None, + transformer=transformer, + torch_dtype=torch.bfloat16 +) + +print("Running denoising.") +height, width = 768, 1360 +latents = pipeline( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_inference_steps=50, + guidance_scale=3.5, + height=height, + width=width, + output_type="latent", +).images +``` + +Remove the pipeline and transformer from memory as they're no longer needed. + +```py +del pipeline.transformer +del pipeline + +flush() +``` + +Finally, decode the latents with the VAE into an image. The VAE is typically small enough to be loaded on a single GPU. + +```py +from diffusers import AutoencoderKL +from diffusers.image_processor import VaeImageProcessor +import torch + +vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16).to("cuda") +vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) +image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + +with torch.no_grad(): + print("Running decoding.") + latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor) + latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor + + image = vae.decode(latents, return_dict=False)[0] + image = image_processor.postprocess(image, output_type="pil") + image[0].save("split_transformer.png") +``` + +By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.