Skip to content

Commit 36c2876

Browse files
authored
[Core]Add GPU Diffusion Runner (vllm-project#822)
Signed-off-by: princepride <wangzhipeng628@gmail.com>
1 parent bb24e07 commit 36c2876

File tree

9 files changed

+405
-166
lines changed

9 files changed

+405
-166
lines changed

.buildkite/pipeline.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ steps:
116116
timeout_in_minutes: 20
117117
depends_on: image-build
118118
commands:
119-
- pytest -s -v tests/diffusion/test_gpu_worker.py
119+
- pytest -s -v tests/diffusion/test_gpu_diffusion_worker.py
120120
agents:
121121
queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU
122122
plugins:

.buildkite/test-amd.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ steps:
5454
commands:
5555
- export MIOPEN_DEBUG_CONV_DIRECT=0
5656
- export MIOPEN_DEBUG_CONV_GEMM=0
57-
- pytest -s -v tests/diffusion/test_gpu_worker.py
57+
- pytest -s -v tests/diffusion/test_gpu_diffusion_worker.py
5858

5959
- label: "Omni Model Test Qwen2-5-Omni"
6060
timeout_in_minutes: 15

docs/api/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,9 @@ Configuration classes.
103103

104104
Worker classes and model runners for distributed inference.
105105

106-
- [vllm_omni.diffusion.worker.gpu_worker.GPUWorker][]
107-
- [vllm_omni.diffusion.worker.gpu_worker.WorkerProc][]
106+
- [vllm_omni.diffusion.worker.gpu_diffusion_model_runner.GPUDiffusionModelRunner][]
107+
- [vllm_omni.diffusion.worker.gpu_diffusion_worker.GPUDiffusionWorker][]
108+
- [vllm_omni.diffusion.worker.gpu_diffusion_worker.WorkerProc][]
108109
- [vllm_omni.diffusion.worker.npu.npu_worker.NPUWorker][]
109110
- [vllm_omni.diffusion.worker.npu.npu_worker.NPUWorkerProc][]
110111
- [vllm_omni.worker.gpu_ar_model_runner.ExecuteModelState][]

tests/diffusion/test_gpu_worker.py renamed to tests/diffusion/test_gpu_diffusion_worker.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
"""
5-
Unit tests for GPUWorker class.
5+
Unit tests for GPUDiffusionWorker class.
66
7-
This module tests the GPUWorker implementation:
7+
This module tests the GPUDiffusionWorker implementation:
88
- load_weights: Loading model weights
99
- sleep: Putting worker into sleep mode (levels 1 and 2)
1010
- wake_up: Waking worker from sleep mode
@@ -15,7 +15,7 @@
1515
import pytest
1616
import torch
1717

18-
from vllm_omni.diffusion.worker.gpu_worker import GPUWorker
18+
from vllm_omni.diffusion.worker.gpu_diffusion_worker import GPUDiffusionWorker
1919

2020

2121
@pytest.fixture
@@ -33,51 +33,52 @@ def mock_od_config():
3333

3434
@pytest.fixture
3535
def mock_gpu_worker(mock_od_config):
36-
"""Create a GPUWorker with mocked initialization."""
37-
with patch.object(GPUWorker, "init_device_and_model"):
38-
worker = GPUWorker(local_rank=0, rank=0, od_config=mock_od_config)
39-
# Mock the pipeline
40-
worker.pipeline = Mock()
41-
worker.cache_backend = None
36+
"""Create a GPUDiffusionWorker with mocked initialization."""
37+
with patch.object(GPUDiffusionWorker, "init_device"):
38+
worker = GPUDiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config)
39+
# Mock the model_runner with pipeline
40+
worker.model_runner = Mock()
41+
worker.model_runner.pipeline = Mock()
42+
worker._sleep_saved_buffers = {}
4243
return worker
4344

4445

45-
class TestGPUWorkerLoadWeights:
46-
"""Test GPUWorker.load_weights method."""
46+
class TestGPUDiffusionWorkerLoadWeights:
47+
"""Test GPUDiffusionWorker.load_weights method."""
4748

4849
def test_load_weights_calls_pipeline(self, mock_gpu_worker):
49-
"""Test that load_weights delegates to pipeline.load_weights."""
50+
"""Test that load_weights delegates to model_runner.load_weights."""
5051
# Setup mock weights
5152
mock_weights = [
5253
("layer1.weight", torch.randn(10, 10)),
5354
("layer2.weight", torch.randn(20, 20)),
5455
]
5556
expected_loaded = {"layer1.weight", "layer2.weight"}
5657

57-
# Configure pipeline mock
58-
mock_gpu_worker.pipeline.load_weights = Mock(return_value=expected_loaded)
58+
# Configure model_runner mock
59+
mock_gpu_worker.model_runner.load_weights = Mock(return_value=expected_loaded)
5960

6061
# Call load_weights
6162
result = mock_gpu_worker.load_weights(mock_weights)
6263

63-
# Verify pipeline.load_weights was called with the weights
64-
mock_gpu_worker.pipeline.load_weights.assert_called_once_with(mock_weights)
64+
# Verify model_runner.load_weights was called with the weights
65+
mock_gpu_worker.model_runner.load_weights.assert_called_once_with(mock_weights)
6566
assert result == expected_loaded
6667

6768
def test_load_weights_empty_iterable(self, mock_gpu_worker):
6869
"""Test load_weights with empty weights iterable."""
69-
mock_gpu_worker.pipeline.load_weights = Mock(return_value=set())
70+
mock_gpu_worker.model_runner.load_weights = Mock(return_value=set())
7071

7172
result = mock_gpu_worker.load_weights([])
7273

73-
mock_gpu_worker.pipeline.load_weights.assert_called_once_with([])
74+
mock_gpu_worker.model_runner.load_weights.assert_called_once_with([])
7475
assert result == set()
7576

7677

77-
class TestGPUWorkerSleep:
78-
"""Test GPUWorker.sleep method."""
78+
class TestGPUDiffusionWorkerSleep:
79+
"""Test GPUDiffusionWorker.sleep method."""
7980

80-
@patch("vllm_omni.diffusion.worker.gpu_worker.torch.cuda.mem_get_info")
81+
@patch("vllm_omni.diffusion.worker.gpu_diffusion_worker.torch.cuda.mem_get_info")
8182
@patch("vllm.device_allocator.cumem.CuMemAllocator")
8283
def test_sleep_level_1(self, mock_allocator_class, mock_mem_info, mock_gpu_worker):
8384
"""Test sleep mode level 1 (offload weights only)."""
@@ -103,7 +104,7 @@ def test_sleep_level_1(self, mock_allocator_class, mock_mem_info, mock_gpu_worke
103104
# Verify buffers were NOT saved (level 1 doesn't save buffers)
104105
assert len(mock_gpu_worker._sleep_saved_buffers) == 0
105106

106-
@patch("vllm_omni.diffusion.worker.gpu_worker.torch.cuda.mem_get_info")
107+
@patch("vllm_omni.diffusion.worker.gpu_diffusion_worker.torch.cuda.mem_get_info")
107108
@patch("vllm.device_allocator.cumem.CuMemAllocator")
108109
def test_sleep_level_2(self, mock_allocator_class, mock_mem_info, mock_gpu_worker):
109110
"""Test sleep mode level 2 (offload all, save buffers)."""
@@ -121,7 +122,7 @@ def test_sleep_level_2(self, mock_allocator_class, mock_mem_info, mock_gpu_worke
121122
# Mock pipeline buffers
122123
mock_buffer1 = torch.randn(10, 10)
123124
mock_buffer2 = torch.randn(20, 20)
124-
mock_gpu_worker.pipeline.named_buffers = Mock(
125+
mock_gpu_worker.model_runner.pipeline.named_buffers = Mock(
125126
return_value=[
126127
("buffer1", mock_buffer1),
127128
("buffer2", mock_buffer2),
@@ -140,7 +141,7 @@ def test_sleep_level_2(self, mock_allocator_class, mock_mem_info, mock_gpu_worke
140141
assert "buffer1" in mock_gpu_worker._sleep_saved_buffers
141142
assert "buffer2" in mock_gpu_worker._sleep_saved_buffers
142143

143-
@patch("vllm_omni.diffusion.worker.gpu_worker.torch.cuda.mem_get_info")
144+
@patch("vllm_omni.diffusion.worker.gpu_diffusion_worker.torch.cuda.mem_get_info")
144145
@patch("vllm.device_allocator.cumem.CuMemAllocator")
145146
def test_sleep_memory_freed_validation(self, mock_allocator_class, mock_mem_info, mock_gpu_worker):
146147
"""Test that sleep validates memory was actually freed."""
@@ -159,8 +160,8 @@ def test_sleep_memory_freed_validation(self, mock_allocator_class, mock_mem_info
159160
mock_gpu_worker.sleep(level=1)
160161

161162

162-
class TestGPUWorkerWakeUp:
163-
"""Test GPUWorker.wake_up method."""
163+
class TestGPUDiffusionWorkerWakeUp:
164+
"""Test GPUDiffusionWorker.wake_up method."""
164165

165166
@patch("vllm.device_allocator.cumem.CuMemAllocator")
166167
def test_wake_up_without_buffers(self, mock_allocator_class, mock_gpu_worker):
@@ -202,7 +203,7 @@ def test_wake_up_with_buffers(self, mock_allocator_class, mock_gpu_worker):
202203
mock_buffer2 = Mock()
203204
mock_buffer2.data = Mock()
204205

205-
mock_gpu_worker.pipeline.named_buffers = Mock(
206+
mock_gpu_worker.model_runner.pipeline.named_buffers = Mock(
206207
return_value=[
207208
("buffer1", mock_buffer1),
208209
("buffer2", mock_buffer2),
@@ -243,7 +244,7 @@ def test_wake_up_partial_buffer_restore(self, mock_allocator_class, mock_gpu_wor
243244
mock_buffer2 = Mock()
244245
mock_buffer2.data = Mock()
245246

246-
mock_gpu_worker.pipeline.named_buffers = Mock(
247+
mock_gpu_worker.model_runner.pipeline.named_buffers = Mock(
247248
return_value=[
248249
("buffer1", mock_buffer1),
249250
("buffer2", mock_buffer2),

vllm_omni/diffusion/worker/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Worker classes for diffusion models."""
44

5-
from vllm_omni.diffusion.worker.gpu_worker import GPUWorker, WorkerProc
5+
from vllm_omni.diffusion.worker.gpu_diffusion_model_runner import GPUDiffusionModelRunner
6+
from vllm_omni.diffusion.worker.gpu_diffusion_worker import (
7+
GPUDiffusionWorker,
8+
WorkerProc,
9+
)
610

7-
__all__ = ["GPUWorker", "WorkerProc"]
11+
__all__ = [
12+
"GPUDiffusionModelRunner",
13+
"GPUDiffusionWorker",
14+
"WorkerProc",
15+
]
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Diffusion Model Runner for vLLM-Omni.
5+
6+
Handles model loading, compilation, caching, and execution of diffusion model
7+
forward passes. This follows the AR pattern where the Runner handles all
8+
model-related operations.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
import time
14+
from collections.abc import Iterable
15+
from contextlib import nullcontext
16+
17+
import torch
18+
from vllm.config import LoadConfig
19+
from vllm.logger import init_logger
20+
from vllm.utils.mem_utils import DeviceMemoryProfiler, GiB_bytes
21+
22+
from vllm_omni.diffusion.cache.selector import get_cache_backend
23+
from vllm_omni.diffusion.compile import regionally_compile
24+
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
25+
from vllm_omni.diffusion.forward_context import set_forward_context
26+
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
27+
from vllm_omni.diffusion.offload import apply_offload_hooks
28+
from vllm_omni.diffusion.request import OmniDiffusionRequest
29+
30+
logger = init_logger(__name__)
31+
32+
33+
class GPUDiffusionModelRunner:
34+
"""
35+
Model runner that handles model loading and execution for diffusion models.
36+
37+
This class follows the AR pattern where the Runner handles all model-related
38+
operations including loading, compilation, offloading, caching, and execution.
39+
The Worker only handles infrastructure (device, distributed env).
40+
"""
41+
42+
def __init__(
43+
self,
44+
vllm_config,
45+
od_config: OmniDiffusionConfig,
46+
device: torch.device,
47+
):
48+
"""
49+
Initialize the diffusion model runner.
50+
51+
Args:
52+
vllm_config: vLLM configuration.
53+
od_config: OmniDiffusion configuration.
54+
device: The device to run on.
55+
"""
56+
self.vllm_config = vllm_config
57+
self.od_config = od_config
58+
self.device = device
59+
self.pipeline = None
60+
self.cache_backend = None
61+
62+
def load_model(
63+
self,
64+
memory_pool_context_fn: callable | None = None,
65+
) -> None:
66+
"""
67+
Load the diffusion model, apply compilation and offloading.
68+
69+
Args:
70+
memory_pool_context_fn: Optional function that returns a context manager
71+
for memory pool allocation (used for sleep mode).
72+
"""
73+
load_device = "cpu" if self.od_config.enable_cpu_offload else str(self.device)
74+
75+
def get_memory_context():
76+
if memory_pool_context_fn is not None:
77+
return memory_pool_context_fn(tag="weights")
78+
return nullcontext()
79+
80+
# Load model within forward context
81+
with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config):
82+
load_config = LoadConfig()
83+
model_loader = DiffusersPipelineLoader(load_config)
84+
time_before_load = time.perf_counter()
85+
86+
with get_memory_context():
87+
with DeviceMemoryProfiler() as m:
88+
self.pipeline = model_loader.load_model(
89+
od_config=self.od_config,
90+
load_device=load_device,
91+
)
92+
time_after_load = time.perf_counter()
93+
94+
logger.info(
95+
"Model loading took %.4f GiB and %.6f seconds",
96+
m.consumed_memory / GiB_bytes,
97+
time_after_load - time_before_load,
98+
)
99+
logger.info("Model runner: Model loaded successfully.")
100+
101+
# Apply CPU offloading (DiT <-> encoders mutual exclusion)
102+
if self.od_config.enable_cpu_offload:
103+
for name in ["vae"]:
104+
module = getattr(self.pipeline, name, None)
105+
if module is None:
106+
continue
107+
try:
108+
module.to(self.device, non_blocking=True)
109+
except Exception as exc:
110+
logger.debug("Failed to move %s to GPU: %s", name, exc)
111+
112+
apply_offload_hooks(self.pipeline, self.od_config, device=self.device)
113+
114+
# Apply torch.compile if not in eager mode
115+
if not self.od_config.enforce_eager:
116+
try:
117+
self.pipeline.transformer = regionally_compile(
118+
self.pipeline.transformer,
119+
dynamic=True,
120+
)
121+
logger.info("Model runner: Model compiled with torch.compile.")
122+
except Exception as e:
123+
logger.warning(f"Model runner: torch.compile failed with error: {e}. Using eager mode.")
124+
125+
# Setup cache backend
126+
self.cache_backend = get_cache_backend(self.od_config.cache_backend, self.od_config.cache_config)
127+
128+
if self.cache_backend is not None:
129+
self.cache_backend.enable(self.pipeline)
130+
131+
logger.info("Model runner: Initialization complete.")
132+
133+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
134+
"""Load weights into the pipeline."""
135+
return self.pipeline.load_weights(weights)
136+
137+
@torch.inference_mode()
138+
def execute_model(self, reqs: list[OmniDiffusionRequest]) -> DiffusionOutput:
139+
"""
140+
Execute a forward pass for the given requests.
141+
142+
Args:
143+
reqs: List of diffusion requests to process.
144+
145+
Returns:
146+
DiffusionOutput with generated results.
147+
"""
148+
assert self.pipeline is not None, "Model not loaded. Call load_model() first."
149+
if not reqs or len(reqs) == 0:
150+
raise ValueError("Cannot execute model with empty request list")
151+
152+
# TODO: dealing with first req for now
153+
req = reqs[0]
154+
155+
if req.generator is None and req.seed is not None:
156+
req.generator = torch.Generator(device=self.device).manual_seed(req.seed)
157+
158+
# Refresh cache context if needed
159+
if self.cache_backend is not None and self.cache_backend.is_enabled():
160+
self.cache_backend.refresh(self.pipeline, req.num_inference_steps)
161+
162+
with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config):
163+
output = self.pipeline.forward(req)
164+
165+
return output

0 commit comments

Comments
 (0)