Skip to content

Commit c784b3e

Browse files
authored
[core] add torch compile for diffusion (vllm-project#684)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
1 parent cddcc40 commit c784b3e

15 files changed

+134
-25
lines changed

vllm_omni/diffusion/compile.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Any
4+
5+
import torch.nn as nn
6+
from vllm.logger import init_logger
7+
8+
logger = init_logger(__name__)
9+
10+
11+
def regionally_compile(model: nn.Module, *compile_args: Any, **compile_kwargs: Any) -> nn.Module:
12+
"""
13+
Apply regional compilation to a PyTorch model.
14+
15+
Args:
16+
model: The PyTorch model instance to compile
17+
*compile_args: Positional arguments forwarded to torch.compile
18+
**compile_kwargs: Keyword arguments forwarded to torch.compile
19+
20+
Returns:
21+
The same model instance (modified in-place)
22+
"""
23+
# Get the list of repeated blocks from the model
24+
repeated_blocks = getattr(model, "_repeated_blocks", None)
25+
26+
if not repeated_blocks:
27+
logger.warning("Regional compilation skipped because the model does not define `_repeated_blocks`.")
28+
return model
29+
30+
# Check if we have modules with the specified class names
31+
has_compiled_region = False
32+
for submod in model.modules():
33+
if submod.__class__.__name__ in repeated_blocks:
34+
# Compile this submodule
35+
submod.compile(*compile_args, **compile_kwargs)
36+
has_compiled_region = True
37+
38+
if not has_compiled_region:
39+
logger.warning(f"Regional compilation skipped because {repeated_blocks} classes are not found in the model.")
40+
41+
return model

vllm_omni/diffusion/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ class OmniDiffusionConfig:
303303
skip_time_steps: int = 15
304304

305305
# Compilation
306-
enable_torch_compile: bool = False
306+
enforce_eager: bool = False
307307

308308
# Enable sleep mode
309309
enable_sleep_mode: bool = False

vllm_omni/diffusion/diffusion_engine.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,15 @@
88
from dataclasses import dataclass
99
from typing import Any
1010

11+
import PIL.Image
1112
from vllm.logger import init_logger
1213

1314
from vllm_omni.diffusion.data import SHUTDOWN_MESSAGE, OmniDiffusionConfig
14-
from vllm_omni.diffusion.registry import get_diffusion_post_process_func, get_diffusion_pre_process_func
15+
from vllm_omni.diffusion.registry import (
16+
DiffusionModelRegistry,
17+
get_diffusion_post_process_func,
18+
get_diffusion_pre_process_func,
19+
)
1520
from vllm_omni.diffusion.request import OmniDiffusionRequest
1621
from vllm_omni.diffusion.scheduler import Scheduler, scheduler
1722
from vllm_omni.outputs import OmniRequestOutput
@@ -20,6 +25,13 @@
2025
logger = init_logger(__name__)
2126

2227

28+
def supports_image_input(model_class_name: str) -> bool:
29+
model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name)
30+
if model_cls is None:
31+
return False
32+
return bool(getattr(model_cls, "support_image_input", False))
33+
34+
2335
@dataclass
2436
class BackgroundResources:
2537
"""
@@ -70,6 +82,12 @@ def __init__(self, od_config: OmniDiffusionConfig):
7082
self._processes: list[mp.Process] = []
7183
self._closed = False
7284
self._make_client()
85+
try:
86+
self._dummy_run()
87+
except Exception as e:
88+
logger.error(f"Dummy run failed: {e}")
89+
self.close()
90+
raise e
7391

7492
def step(self, requests: list[OmniDiffusionRequest]):
7593
try:
@@ -272,6 +290,30 @@ def _launch_workers(self, broadcast_handle):
272290
def add_req_and_wait_for_response(self, requests: list[OmniDiffusionRequest]):
273291
return scheduler.add_req(requests)
274292

293+
def _dummy_run(self):
294+
"""A dummy run to warm up the model."""
295+
prompt = "dummy run"
296+
num_inference_steps = 1
297+
height = 1024
298+
width = 1024
299+
if supports_image_input(self.od_config.model_class_name):
300+
# Provide a dummy image input if the model supports it
301+
302+
dummy_image = PIL.Image.new("RGB", (width, height), color=(0, 0, 0))
303+
else:
304+
dummy_image = None
305+
req = OmniDiffusionRequest(
306+
prompt=prompt,
307+
height=height,
308+
width=width,
309+
pil_image=dummy_image,
310+
num_inference_steps=num_inference_steps,
311+
num_outputs_per_prompt=1,
312+
)
313+
logger.info("dummy run to warm up the model")
314+
requests = self.pre_process_func([req]) if self.pre_process_func is not None else [req]
315+
self.add_req_and_wait_for_response(requests)
316+
275317
def collective_rpc(
276318
self,
277319
method: str | Callable,
@@ -343,22 +385,6 @@ def collective_rpc(
343385
logger.error(f"RPC call failed: {e}")
344386
raise
345387

346-
def _dummy_run(self):
347-
"""A dummy run to warm up the model."""
348-
prompt = "dummy run"
349-
num_inference_steps = 1
350-
height = 1024
351-
width = 1024
352-
req = OmniDiffusionRequest(
353-
prompt=prompt,
354-
height=height,
355-
width=width,
356-
num_inference_steps=num_inference_steps,
357-
num_outputs_per_prompt=1,
358-
)
359-
logger.info("dummy run to warm up the model")
360-
self.add_req_and_wait_for_response([req])
361-
362388
def close(self) -> None:
363389
self._finalizer()
364390

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import (
4+
ClassVar,
5+
Protocol,
6+
runtime_checkable,
7+
)
8+
9+
10+
@runtime_checkable
11+
class SupportImageInput(Protocol):
12+
support_image_input: ClassVar[bool] = True

vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,11 @@ class LongCatImageTransformer2DModel(nn.Module):
353353
The Transformer model introduced in Flux.
354354
"""
355355

356+
_repeated_blocks = [
357+
"LongCatImageTransformerBlock",
358+
"LongCatImageSingleTransformerBlock",
359+
]
360+
356361
def __init__(
357362
self,
358363
od_config: OmniDiffusionConfig,

vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
2929
from vllm_omni.diffusion.distributed.utils import get_local_device
3030
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
31+
from vllm_omni.diffusion.models.interface import SupportImageInput
3132
from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import (
3233
LongCatImageTransformer2DModel,
3334
)
@@ -196,7 +197,7 @@ def split_quotation(prompt, quote_pairs=None):
196197
return result
197198

198199

199-
class LongCatImageEditPipeline(nn.Module):
200+
class LongCatImageEditPipeline(nn.Module, SupportImageInput):
200201
def __init__(
201202
self,
202203
*,

vllm_omni/diffusion/models/ovis_image/ovis_image_transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,8 @@ class OvisImageTransformer2DModel(nn.Module):
365365
The dimensions to use for the rotary positional embeddings.
366366
"""
367367

368+
_repeated_blocks = ["OvisImageTransformerBlock", "OvisImageSingleTransformerBlock"]
369+
368370
def __init__(
369371
self,
370372
od_config: OmniDiffusionConfig,

vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from vllm_omni.diffusion.distributed.utils import get_local_device
3434
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
35+
from vllm_omni.diffusion.models.interface import SupportImageInput
3536
from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import calculate_shift
3637
from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import (
3738
QwenImageTransformer2DModel,
@@ -195,6 +196,7 @@ def retrieve_latents(
195196

196197
class QwenImageEditPipeline(
197198
nn.Module,
199+
SupportImageInput,
198200
):
199201
def __init__(
200202
self,

vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131
from vllm_omni.diffusion.distributed.utils import get_local_device
3232
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
33+
from vllm_omni.diffusion.models.interface import SupportImageInput
3334
from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import calculate_shift
3435
from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit import (
3536
calculate_dimensions,
@@ -156,9 +157,7 @@ def post_process_func(
156157
return post_process_func
157158

158159

159-
class QwenImageEditPlusPipeline(
160-
nn.Module,
161-
):
160+
class QwenImageEditPlusPipeline(nn.Module, SupportImageInput):
162161
def __init__(
163162
self,
164163
*,

vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
)
3030
from vllm_omni.diffusion.distributed.utils import get_local_device
3131
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
32+
from vllm_omni.diffusion.models.interface import SupportImageInput
3233
from vllm_omni.diffusion.models.qwen_image.autoencoder_kl_qwenimage import (
3334
AutoencoderKLQwenImage,
3435
)
@@ -170,9 +171,7 @@ def retrieve_latents(
170171
raise AttributeError("Could not access latents of provided encoder_output")
171172

172173

173-
class QwenImageLayeredPipeline(
174-
nn.Module,
175-
):
174+
class QwenImageLayeredPipeline(nn.Module, SupportImageInput):
176175
def __init__(
177176
self,
178177
*,

0 commit comments

Comments
 (0)