Skip to content

Commit aa913ea

Browse files
anijain2305epwalsh
authored andcommitted
[compile][startup] Disable C++ compilation of symbolic shapes (vllm-project#20836)
Signed-off-by: Animesh Jain <[email protected]>
1 parent 9eb4313 commit aa913ea

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

vllm/compilation/decorators.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,15 @@ def patched_inline_call(parent, func, args, kwargs):
267267
code.co_filename)
268268
return inline_call(parent, func, args, kwargs)
269269

270-
with patch.object(InliningInstructionTranslator, 'inline_call',
271-
patched_inline_call):
270+
# Disable the C++ compilation of symbolic shape guards. C++-fication
271+
# of symbolic shape guards can improve guard overhead. But, since
272+
# vllm skip guards anyways, setting this flag to False can improve
273+
# compile time.
274+
with torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards",
275+
False), patch.object(
276+
InliningInstructionTranslator,
277+
'inline_call',
278+
patched_inline_call):
272279
output = self.compiled_callable(*args, **kwargs)
273280
return output
274281

0 commit comments

Comments
 (0)