2828from vllm .logger import init_logger
2929from vllm .sequence import IntermediateTensors
3030from vllm .utils .import_utils import resolve_obj_by_qualname
31- from vllm .utils .torch_utils import supports_dynamo
31+ from vllm .utils .torch_utils import is_torch_equal_or_newer , supports_dynamo
3232
3333from .monitor import start_monitoring_torch_compile
3434
@@ -316,7 +316,13 @@ def __init__(
316316 def _mark_dynamic_inputs (mod , type , * args , ** kwargs ):
317317 def mark_dynamic (arg , dims ):
318318 if type == DynamicShapesType .UNBACKED :
319- torch ._dynamo .decorators .mark_unbacked (arg , dims )
319+ if is_torch_equal_or_newer ("2.10.0.dev" ):
320+ for dim in dims :
321+ torch ._dynamo .decorators .mark_unbacked (
322+ arg , dim , hint_override = arg .size ()[dim ]
323+ )
324+ else :
325+ torch ._dynamo .decorators .mark_unbacked (arg , dims )
320326 else :
321327 torch ._dynamo .mark_dynamic (arg , dims )
322328
@@ -350,7 +356,13 @@ def mark_dynamic(arg, dims):
350356 if isinstance (arg , torch .Tensor ):
351357 # In case dims is specified with negative indexing
352358 dims = [arg .ndim + dim if dim < 0 else dim for dim in dims ]
353- torch ._dynamo .decorators .mark_unbacked (arg , dims )
359+ if is_torch_equal_or_newer ("2.10.0.dev" ):
360+ for dim in dims :
361+ torch ._dynamo .decorators .mark_unbacked (
362+ arg , dim , hint_override = arg .size ()[dim ]
363+ )
364+ else :
365+ torch ._dynamo .decorators .mark_unbacked (arg , dims )
354366
355367 def __call__ (self , * args , ** kwargs ):
356368 # torch.compiler.is_compiling() means we are inside the compilation
@@ -488,6 +500,12 @@ def patched_inline_call(self_):
488500 if ds_type == DynamicShapesType .BACKED_SIZE_OBLIVIOUS :
489501 fx_config_patches ["backed_size_oblivious" ] = True
490502
503+ # Prepare inductor config patches
504+ # assume_32bit_indexing is only available in torch 2.10.0.dev+
505+ inductor_config_patches = {}
506+ if is_torch_equal_or_newer ("2.10.0.dev" ):
507+ inductor_config_patches ["assume_32bit_indexing" ] = True
508+
491509 with (
492510 patch .object (
493511 InliningInstructionTranslator , "inline_call_" , patched_inline_call
@@ -496,6 +514,7 @@ def patched_inline_call(self_):
496514 maybe_use_cudagraph_partition_wrapper (self .vllm_config ),
497515 torch .fx .experimental ._config .patch (** fx_config_patches ),
498516 _torch27_patch_tensor_subclasses (),
517+ torch ._inductor .config .patch (** inductor_config_patches ),
499518 ):
500519 if envs .VLLM_USE_AOT_COMPILE :
501520 self .aot_compiled_fn = self .aot_compile (* args , ** kwargs )
0 commit comments