Skip to content

Commit 7c6d314

Browse files
authored
fix the use of device_map in CP docs (#12902)
up
1 parent 3138e37 commit 7c6d314

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

docs/source/en/training/distributed_inference.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,8 @@ def main():
263263
world_size = dist.get_world_size()
264264

265265
pipeline = DiffusionPipeline.from_pretrained(
266-
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
267-
)
266+
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
267+
).to(device)
268268
pipeline.transformer.set_attention_backend("_native_cudnn")
269269

270270
cp_config = ContextParallelConfig(ring_degree=world_size)

0 commit comments

Comments
 (0)