55import os
66import pprint
77import time
8+ from collections .abc import Sequence
89from contextlib import ExitStack
9- from typing import Any , Callable , Dict , List , Optional , Sequence , Set , Tuple
10+ from typing import Any , Callable , Optional
1011from unittest .mock import patch
1112
1213import torch
@@ -56,7 +57,7 @@ class CompilerManager:
5657 """
5758
5859 def __init__ (self , compilation_config : CompilationConfig ):
59- self .cache : Dict [ Tuple [Optional [int ], int , str ], Any ] = dict ()
60+ self .cache : dict [ tuple [Optional [int ], int , str ], Any ] = dict ()
6061 self .is_cache_updated = False
6162 self .compilation_config = compilation_config
6263 self .compiler = make_compiler (compilation_config )
@@ -90,7 +91,7 @@ def save_to_file(self):
9091
9192 def load (self ,
9293 graph : fx .GraphModule ,
93- example_inputs : List [Any ],
94+ example_inputs : list [Any ],
9495 graph_index : int ,
9596 runtime_shape : Optional [int ] = None ) -> Optional [Callable ]:
9697 if (runtime_shape , graph_index , self .compiler .name ) not in self .cache :
@@ -186,7 +187,7 @@ class SplitItem:
186187
187188
188189def split_graph (graph : fx .GraphModule ,
189- ops : List [str ]) -> Tuple [fx .GraphModule , List [SplitItem ]]:
190+ ops : list [str ]) -> tuple [fx .GraphModule , list [SplitItem ]]:
190191 # split graph by ops
191192 subgraph_id = 0
192193 node_to_subgraph_id = {}
@@ -252,7 +253,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
252253 """
253254
254255 def __init__ (self , module : torch .fx .GraphModule ,
255- compile_submod_names : List [str ], vllm_config : VllmConfig ,
256+ compile_submod_names : list [str ], vllm_config : VllmConfig ,
256257 graph_pool , vllm_backend : "VllmBackend" ):
257258 super ().__init__ (module )
258259 from torch ._guards import detect_fake_mode
@@ -274,8 +275,8 @@ def run(self, *args):
274275 return super ().run (* fake_args )
275276
276277 def call_module (self , target : torch .fx .node .Target ,
277- args : Tuple [torch .fx .node .Argument ,
278- ...], kwargs : Dict [str , Any ]) -> Any :
278+ args : tuple [torch .fx .node .Argument ,
279+ ...], kwargs : dict [str , Any ]) -> Any :
279280 assert isinstance (target , str )
280281 output = super ().call_module (target , args , kwargs )
281282
@@ -326,12 +327,12 @@ class VllmBackend:
326327 graph : fx .GraphModule
327328 # the stiching graph module for all the piecewise graphs
328329 split_gm : fx .GraphModule
329- piecewise_graphs : List [SplitItem ]
330+ piecewise_graphs : list [SplitItem ]
330331 returned_callable : Callable
331332 # Inductor passes to run on the graph pre-defunctionalization
332333 post_grad_passes : Sequence [Callable ]
333- sym_tensor_indices : List [int ]
334- input_buffers : List [torch .Tensor ]
334+ sym_tensor_indices : list [int ]
335+ input_buffers : list [torch .Tensor ]
335336 compiler_manager : CompilerManager
336337
337338 def __init__ (
@@ -573,14 +574,14 @@ class ConcreteSizeEntry:
573574
574575 # for cudagraph debugging, track the input addresses
575576 # during capture, and check if they are the same during replay
576- input_addresses : Optional [List [int ]] = None
577+ input_addresses : Optional [list [int ]] = None
577578
578579
579580class PiecewiseBackend :
580581
581582 def __init__ (self , graph : fx .GraphModule , vllm_config : VllmConfig ,
582583 graph_pool : Any , piecewise_compile_index : int ,
583- total_piecewise_compiles : int , sym_shape_indices : List [int ],
584+ total_piecewise_compiles : int , sym_shape_indices : list [int ],
584585 compiled_graph_for_general_shape : Callable ,
585586 vllm_backend : VllmBackend ):
586587 """
@@ -608,9 +609,9 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
608609 self .is_last_graph = (
609610 piecewise_compile_index == total_piecewise_compiles - 1 )
610611
611- self .compile_sizes : Set [int ] = set (
612+ self .compile_sizes : set [int ] = set (
612613 self .compilation_config .compile_sizes )
613- self .cudagraph_capture_sizes : Set [int ] = set (
614+ self .cudagraph_capture_sizes : set [int ] = set (
614615 self .compilation_config .cudagraph_capture_sizes
615616 ) if self .compilation_config .use_cudagraph else set ()
616617
@@ -624,11 +625,11 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
624625
625626 # the entries for different shapes that we need to either
626627 # compile or capture cudagraph
627- self .concrete_size_entries : Dict [int , ConcreteSizeEntry ] = {}
628+ self .concrete_size_entries : dict [int , ConcreteSizeEntry ] = {}
628629
629630 # to_be_compiled_sizes tracks the remaining sizes to compile,
630631 # and updates during the compilation process, so we need to copy it
631- self .to_be_compiled_sizes : Set [int ] = self .compile_sizes .copy ()
632+ self .to_be_compiled_sizes : set [int ] = self .compile_sizes .copy ()
632633 for shape in self .compile_sizes .union (self .cudagraph_capture_sizes ):
633634 self .concrete_size_entries [shape ] = ConcreteSizeEntry (
634635 runtime_shape = shape ,
0 commit comments