99from ...._utils import mpi_rank , mpi_world_size
1010from ....bindings .executor import ExecutorConfig
1111from ....bindings .internal .batch_manager import CacheType
12+ from ....llmapi .llm_args import _AutoDeployLlmArgs
1213from ....mapping import Mapping
1314from ...distributed import MPIDist
1415from ...pyexecutor .config import PyTorchConfig
2728from ..models import ModelFactoryRegistry
2829from ..transformations .transform import InferenceOptimizer
2930from ..utils .logger import ad_logger
30- from .interface import AutoDeployConfig , CachedSequenceInterface , GetInferenceModel
31+ from .interface import CachedSequenceInterface , GetInferenceModel
3132
3233
3334class _CacheManagerWithFakePool (KVCacheManager ):
@@ -84,11 +85,11 @@ def _device(self) -> DeviceLikeType:
8485 def build_from_config (
8586 cls ,
8687 model : str ,
87- ad_config : AutoDeployConfig ,
88+ ad_config : _AutoDeployLlmArgs ,
8889 seq_info : SequenceInfo ,
8990 device : DeviceLikeType ,
9091 ):
91- """Build the ADEngine using the AutoDeployConfig that gets passed through from the LLM."""
92+ """Build the ADEngine using the _AutoDeployLlmArgs that gets passed through from the LLM."""
9293
9394 # update device to contain the current default device if it's in cuda
9495 device = torch .device (device )
@@ -245,7 +246,7 @@ def create_autodeploy_executor(
245246):
246247 """Create an AutoDeploy executor from the given configuration and checkpoint directory.
247248
248- This is the entrypoint API to the autodeploy backend.
249+ This is the entrypoint API to the _autodeploy backend.
249250 """
250251 # initialize process groups
251252 world_size = mpi_world_size ()
@@ -258,33 +259,34 @@ def create_autodeploy_executor(
258259 dist .initialize_or_skip (rank , world_size , port )
259260
260261 # some config
261- if executor_config .pytorch_backend_config is None :
262- executor_config .pytorch_backend_config = AutoDeployConfig (attn_backend = "FlashInfer" )
262+ msg = "pytorch_backend_config must be an _AutoDeployLlmArgs object"
263+ assert isinstance (executor_config .pytorch_backend_config , _AutoDeployLlmArgs ), msg
264+ ad_config : _AutoDeployLlmArgs = executor_config .pytorch_backend_config
263265
264- max_batch_size = executor_config .max_batch_size
265- max_seq_len = executor_config .max_seq_len
266- tokens_per_block = executor_config . tokens_per_block
267- max_num_tokens = executor_config .max_num_tokens
268- ad_logger .info (f"{ max_seq_len = } , { max_batch_size = } , { tokens_per_block = } , { max_num_tokens = } " )
266+ max_batch_size = ad_config .max_batch_size
267+ max_seq_len = ad_config .max_seq_len
268+ attn_page_size = ad_config . attn_page_size
269+ max_num_tokens = ad_config .max_num_tokens
270+ ad_logger .info (f"{ max_seq_len = } , { max_batch_size = } , { attn_page_size = } , { max_num_tokens = } " )
269271
270272 # initialize model engine
271273 engine = ADEngine .build_from_config (
272274 model = checkpoint_dir ,
273- ad_config = executor_config . pytorch_backend_config ,
275+ ad_config = ad_config ,
274276 seq_info = SequenceInfo (
275277 max_seq_len = max_seq_len ,
276278 max_batch_size = max_batch_size ,
277- page_size = tokens_per_block ,
279+ page_size = attn_page_size ,
278280 max_num_tokens = max_num_tokens ,
279281 ),
280282 device = "cuda" ,
281283 )
282284
283285 # resource managers
284286 kv_cache_manager = _CacheManagerWithFakePool (
285- executor_config .kv_cache_config ,
287+ ad_config .kv_cache_config ,
286288 num_blocks = engine .cache_seq_interface .info .num_pages ,
287- tokens_per_block = tokens_per_block ,
289+ tokens_per_block = attn_page_size ,
288290 max_seq_len = max_seq_len ,
289291 max_batch_size = max_batch_size ,
290292 )
@@ -302,18 +304,17 @@ def create_autodeploy_executor(
302304 sampler = TorchSampler (max_seq_len = max_seq_len )
303305
304306 # creating the executor object
305- py_config : PyTorchConfig = executor_config .pytorch_backend_config
306307 py_executor = PyExecutor (
307308 resource_manager ,
308309 scheduler ,
309310 model_engine = engine ,
310311 sampler = sampler ,
311312 dist = mpi_dist ,
312- disable_overlap_scheduler = py_config .disable_overlap_scheduler ,
313- max_input_len = executor_config .max_input_len ,
314- max_batch_size = executor_config .max_batch_size ,
315- max_draft_tokens = executor_config .speculative_config .max_draft_tokens
316- if executor_config .speculative_config is not None
313+ disable_overlap_scheduler = ad_config .disable_overlap_scheduler ,
314+ max_input_len = ad_config .max_input_len ,
315+ max_batch_size = ad_config .max_batch_size ,
316+ max_draft_tokens = ad_config .speculative_config .max_draft_tokens
317+ if ad_config .speculative_config is not None
317318 else 0 ,
318319 )
319320 return py_executor
0 commit comments