Skip to content

Commit 763963a

Browse files
authored
set assume_32bit_indexing and pass unbacked hints (vllm-project#30459)
Signed-off-by: Laith Sakka <[email protected]>
1 parent 39cefbd commit 763963a

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

vllm/compilation/decorators.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from vllm.logger import init_logger
2929
from vllm.sequence import IntermediateTensors
3030
from vllm.utils.import_utils import resolve_obj_by_qualname
31-
from vllm.utils.torch_utils import supports_dynamo
31+
from vllm.utils.torch_utils import is_torch_equal_or_newer, supports_dynamo
3232

3333
from .monitor import start_monitoring_torch_compile
3434

@@ -316,7 +316,13 @@ def __init__(
316316
def _mark_dynamic_inputs(mod, type, *args, **kwargs):
317317
def mark_dynamic(arg, dims):
318318
if type == DynamicShapesType.UNBACKED:
319-
torch._dynamo.decorators.mark_unbacked(arg, dims)
319+
if is_torch_equal_or_newer("2.10.0.dev"):
320+
for dim in dims:
321+
torch._dynamo.decorators.mark_unbacked(
322+
arg, dim, hint_override=arg.size()[dim]
323+
)
324+
else:
325+
torch._dynamo.decorators.mark_unbacked(arg, dims)
320326
else:
321327
torch._dynamo.mark_dynamic(arg, dims)
322328

@@ -350,7 +356,13 @@ def mark_dynamic(arg, dims):
350356
if isinstance(arg, torch.Tensor):
351357
# In case dims is specified with negative indexing
352358
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
353-
torch._dynamo.decorators.mark_unbacked(arg, dims)
359+
if is_torch_equal_or_newer("2.10.0.dev"):
360+
for dim in dims:
361+
torch._dynamo.decorators.mark_unbacked(
362+
arg, dim, hint_override=arg.size()[dim]
363+
)
364+
else:
365+
torch._dynamo.decorators.mark_unbacked(arg, dims)
354366

355367
def __call__(self, *args, **kwargs):
356368
# torch.compiler.is_compiling() means we are inside the compilation
@@ -488,6 +500,12 @@ def patched_inline_call(self_):
488500
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
489501
fx_config_patches["backed_size_oblivious"] = True
490502

503+
# Prepare inductor config patches
504+
# assume_32bit_indexing is only available in torch 2.10.0.dev+
505+
inductor_config_patches = {}
506+
if is_torch_equal_or_newer("2.10.0.dev"):
507+
inductor_config_patches["assume_32bit_indexing"] = True
508+
491509
with (
492510
patch.object(
493511
InliningInstructionTranslator, "inline_call_", patched_inline_call
@@ -496,6 +514,7 @@ def patched_inline_call(self_):
496514
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
497515
torch.fx.experimental._config.patch(**fx_config_patches),
498516
_torch27_patch_tensor_subclasses(),
517+
torch._inductor.config.patch(**inductor_config_patches),
499518
):
500519
if envs.VLLM_USE_AOT_COMPILE:
501520
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)

0 commit comments

Comments
 (0)