|
26 | 26 | ) |
27 | 27 | from dynamo.runtime import DistributedRuntime, dynamo_worker |
28 | 28 | from dynamo.runtime.logging import configure_dynamo_logging |
| 29 | +from dynamo.vllm.multimodal_handlers import ( |
| 30 | + EncodeWorkerHandler, |
| 31 | + MultimodalPDWorkerHandler, |
| 32 | + ProcessorHandler, |
| 33 | +) |
29 | 34 |
|
30 | 35 | from .args import ENABLE_LMCACHE, Config, configure_ports, overwrite_args, parse_args |
31 | 36 | from .handlers import DecodeWorkerHandler, PrefillWorkerHandler |
@@ -92,7 +97,17 @@ def signal_handler(): |
92 | 97 | if not os.path.exists(config.model): |
93 | 98 | config.model = config.engine_args.model = await fetch_llm(config.model) |
94 | 99 |
|
95 | | - if config.is_prefill_worker: |
| 100 | + # Route to appropriate initialization based on config flags |
| 101 | + if config.multimodal_processor: |
| 102 | + await init_multimodal_processor(runtime, config) |
| 103 | + logger.debug("init_multimodal_processor completed") |
| 104 | + elif config.multimodal_encode_worker: |
| 105 | + await init_multimodal_encode_worker(runtime, config) |
| 106 | + logger.debug("init_multimodal_encode_worker completed") |
| 107 | + elif config.multimodal_worker: |
| 108 | + await init_multimodal_worker(runtime, config) |
| 109 | + logger.debug("init_multimodal_worker completed") |
| 110 | + elif config.is_prefill_worker: |
96 | 111 | await init_prefill(runtime, config) |
97 | 112 | logger.debug("init_prefill completed") |
98 | 113 | else: |
@@ -430,6 +445,147 @@ def get_engine_cache_info(engine: AsyncLLM): |
430 | 445 | raise |
431 | 446 |
|
432 | 447 |
|
| 448 | +async def init_multimodal_processor(runtime: DistributedRuntime, config: Config): |
| 449 | + """Initialize multimodal processor component""" |
| 450 | + component = runtime.namespace(config.namespace).component(config.component) |
| 451 | + await component.create_service() |
| 452 | + |
| 453 | + generate_endpoint = component.endpoint(config.endpoint) |
| 454 | + |
| 455 | + # Get encode worker client |
| 456 | + encode_worker_client = ( |
| 457 | + await runtime.namespace(config.namespace) |
| 458 | + .component("encoder") |
| 459 | + .endpoint("generate") |
| 460 | + .client() |
| 461 | + ) |
| 462 | + |
| 463 | + # Get prompt template from args (must be passed via environment or command line) |
| 464 | + mm_prompt_template = config.mm_prompt_template |
| 465 | + |
| 466 | + handler = ProcessorHandler( |
| 467 | + config.engine_args, |
| 468 | + encode_worker_client, |
| 469 | + mm_prompt_template, |
| 470 | + ) |
| 471 | + |
| 472 | + logger.info("Waiting for Encoder Worker Instances ...") |
| 473 | + await encode_worker_client.wait_for_instances() |
| 474 | + |
| 475 | + # Register the endpoint as entrypoint to a model |
| 476 | + await register_llm( |
| 477 | + ModelInput.Text, # Custom processor is used and this type bypasses SDK processor |
| 478 | + ModelType.Chat, |
| 479 | + generate_endpoint, |
| 480 | + config.model, |
| 481 | + config.served_model_name, |
| 482 | + kv_cache_block_size=config.engine_args.block_size, |
| 483 | + ) |
| 484 | + |
| 485 | + logger.info("Starting to serve the processor endpoint...") |
| 486 | + |
| 487 | + try: |
| 488 | + await asyncio.gather( |
| 489 | + generate_endpoint.serve_endpoint( |
| 490 | + handler.generate, metrics_labels=[("model", config.model)] |
| 491 | + ), |
| 492 | + ) |
| 493 | + except Exception as e: |
| 494 | + logger.error(f"Failed to serve endpoints: {e}") |
| 495 | + raise |
| 496 | + finally: |
| 497 | + handler.cleanup() |
| 498 | + |
| 499 | + |
| 500 | +async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Config): |
| 501 | + """Initialize multimodal encode worker component""" |
| 502 | + component = runtime.namespace(config.namespace).component(config.component) |
| 503 | + await component.create_service() |
| 504 | + |
| 505 | + generate_endpoint = component.endpoint(config.endpoint) |
| 506 | + |
| 507 | + # Get PD worker client |
| 508 | + # In multimodal mode, the PD worker always registers as "backend" |
| 509 | + # (even in disaggregated mode with prefill/decode split, we still connect to "backend") |
| 510 | + pd_worker_client = ( |
| 511 | + await runtime.namespace(config.namespace) |
| 512 | + .component("backend") |
| 513 | + .endpoint("generate") |
| 514 | + .client() |
| 515 | + ) |
| 516 | + |
| 517 | + handler = EncodeWorkerHandler( |
| 518 | + config.engine_args, |
| 519 | + pd_worker_client, |
| 520 | + ) |
| 521 | + await handler.async_init(runtime) |
| 522 | + logger.info("Waiting for PD Worker Instances ...") |
| 523 | + await pd_worker_client.wait_for_instances() |
| 524 | + logger.info("Starting to serve the encode worker endpoint...") |
| 525 | + |
| 526 | + try: |
| 527 | + await asyncio.gather( |
| 528 | + generate_endpoint.serve_endpoint( |
| 529 | + handler.generate, metrics_labels=[("model", config.model)] |
| 530 | + ), |
| 531 | + ) |
| 532 | + except Exception as e: |
| 533 | + logger.error(f"Failed to serve endpoints: {e}") |
| 534 | + raise |
| 535 | + finally: |
| 536 | + handler.cleanup() |
| 537 | + |
| 538 | + |
| 539 | +async def init_multimodal_worker(runtime: DistributedRuntime, config: Config): |
| 540 | + """Initialize multimodal worker component for aggregated or disaggregated mode""" |
| 541 | + |
| 542 | + component = runtime.namespace(config.namespace).component(config.component) |
| 543 | + await component.create_service() |
| 544 | + |
| 545 | + generate_endpoint = component.endpoint(config.endpoint) |
| 546 | + clear_endpoint = component.endpoint("clear_kv_blocks") |
| 547 | + |
| 548 | + engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config) |
| 549 | + |
| 550 | + # TODO: Support Disaggregated mode separately |
| 551 | + client = ( |
| 552 | + await runtime.namespace(config.namespace) |
| 553 | + .component("backend") |
| 554 | + .endpoint("generate") |
| 555 | + .client() |
| 556 | + ) |
| 557 | + |
| 558 | + handler = MultimodalPDWorkerHandler( |
| 559 | + runtime, component, engine_client, config, client |
| 560 | + ) |
| 561 | + |
| 562 | + await handler.async_init(runtime) |
| 563 | + |
| 564 | + # Set up KV event publisher for prefix caching if enabled |
| 565 | + kv_publisher = setup_kv_event_publisher( |
| 566 | + config, component, generate_endpoint, vllm_config |
| 567 | + ) |
| 568 | + if kv_publisher: |
| 569 | + handler.kv_publisher = kv_publisher |
| 570 | + |
| 571 | + metrics_labels = [("model", config.model)] |
| 572 | + |
| 573 | + try: |
| 574 | + await asyncio.gather( |
| 575 | + generate_endpoint.serve_endpoint( |
| 576 | + handler.generate, metrics_labels=metrics_labels |
| 577 | + ), |
| 578 | + clear_endpoint.serve_endpoint( |
| 579 | + handler.clear_kv_blocks, metrics_labels=metrics_labels |
| 580 | + ), |
| 581 | + ) |
| 582 | + except Exception as e: |
| 583 | + logger.error(f"Failed to serve endpoints: {e}") |
| 584 | + raise |
| 585 | + finally: |
| 586 | + handler.cleanup() |
| 587 | + |
| 588 | + |
433 | 589 | def main(): |
434 | 590 | uvloop.run(worker()) |
435 | 591 |
|
|
0 commit comments