diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index 22e8a30427b9..b7cd0e20f481 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -263,8 +263,8 @@ def main(): world_size = dist.get_world_size() pipeline = DiffusionPipeline.from_pretrained( - "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device - ) + "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 + ).to(device) pipeline.transformer.set_attention_backend("_native_cudnn") cp_config = ContextParallelConfig(ring_degree=world_size)