Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 0a4806f

Browse files
authored
[plugin][torch.compile] allow to add custom compile backend (vllm-project#8445)
1 parent ecd7a1d commit 0a4806f

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

vllm/plugins/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from typing import Callable, Optional, Union
23

34
import vllm.envs as envs
45

@@ -29,3 +30,15 @@ def load_general_plugins():
2930
except Exception:
3031
logger.exception("Failed to load general plugin: %s",
3132
plugin.name)
33+
34+
35+
_torch_compile_backend: Optional[Union[Callable, str]] = None
36+
37+
38+
def set_torch_compile_backend(backend: Union[Callable, str]):
39+
global _torch_compile_backend
40+
_torch_compile_backend = backend
41+
42+
43+
def get_torch_compile_backend() -> Optional[Union[Callable, str]]:
44+
return _torch_compile_backend

vllm/worker/model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1064,10 +1064,12 @@ def load_model(self) -> None:
10641064
"This may lead to less accurate results!")
10651065

10661066
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
1067+
from vllm.plugins import get_torch_compile_backend
1068+
backend = get_torch_compile_backend() or "eager"
10671069
self.model = torch.compile(
10681070
self.model,
10691071
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
1070-
backend="eager")
1072+
backend=backend)
10711073

10721074
def save_sharded_state(
10731075
self,

0 commit comments

Comments
 (0)