Skip to content

Commit f2ea6b4

Browse files
committed
Fix Context Parallelism doc
1 parent 04f9d2b commit f2ea6b4

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

docs/source/en/training/distributed_inference.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ try:
253253
device = torch.device("cuda", rank % torch.cuda.device_count())
254254
torch.cuda.set_device(device)
255255

256-
transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2))
256+
transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2), device_map="cuda")
257257
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda")
258258
pipeline.transformer.set_attention_backend("flash")
259259

@@ -289,4 +289,4 @@ Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
289289

290290
```py
291291
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
292-
```
292+
```

0 commit comments

Comments
 (0)