Skip to content

Commit 4c6bfc5

Browse files
authored
Handle enable CUDA graph param in SD example (deepspeedai#246)
This PR updates how the enable_cuda_graph param is set depending on the world_size i.e. CUDA graphs should only be enabled when world_size==1.
1 parent 127c7a1 commit 4c6bfc5

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

inference/huggingface/stable-diffusion/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ pip install -r requirements.txt
1111
Examples can be run as follows:
1212
<pre>deepspeed --num_gpus [number of GPUs] test-[model].py</pre>
1313

14+
NOTE: Local CUDA graphs for replaced SD modules will only be enabled when `mp_size==1`.
15+
1416
# Example Output
1517
Command:
1618
<pre>

inference/huggingface/stable-diffusion/test-stable-diffusion.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
model = "prompthero/midjourney-v4-diffusion"
1010
local_rank = int(os.getenv("LOCAL_RANK", "0"))
1111
device = torch.device(f"cuda:{local_rank}")
12-
world_size = int(os.getenv('WORLD_SIZE', '4'))
12+
world_size = int(os.getenv('WORLD_SIZE', '1'))
1313
generator = torch.Generator(device=torch.cuda.current_device())
1414

1515
pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.half)
@@ -19,12 +19,14 @@
1919
baseline_image = pipe(prompt, guidance_scale=7.5, generator=generator).images[0]
2020
baseline_image.save(f"baseline.png")
2121

22-
# NOTE: DeepSpeed inference supports local CUDA graphs for replaced SD modules
22+
# NOTE: DeepSpeed inference supports local CUDA graphs for replaced SD modules.
23+
# Local CUDA graphs for replaced SD modules will only be enabled when `mp_size==1`
2324
pipe = deepspeed.init_inference(
2425
pipe,
26+
mp_size=world_size,
2527
dtype=torch.half,
2628
replace_with_kernel_inject=True,
27-
enable_cuda_graph=True,
29+
enable_cuda_graph=True if world_size==1 else False,
2830
)
2931

3032
generator.manual_seed(0xABEDABE7)

0 commit comments

Comments
 (0)