diff --git a/Cargo.lock b/Cargo.lock index c13d79adeb..3d8423ccd2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2391,6 +2391,7 @@ dependencies = [ "figment", "futures", "humantime", + "inotify", "jsonschema", "local-ip-address", "log", @@ -3992,6 +3993,28 @@ version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" +[[package]] +name = "inotify" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f37dccff2791ab604f9babef0ba14fbe0be30bd368dc541e2b08d07c8aa908f3" +dependencies = [ + "bitflags 2.9.4", + "futures-core", + "inotify-sys", + "libc", + "tokio", +] + +[[package]] +name = "inotify-sys" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb" +dependencies = [ + "libc", +] + [[package]] name = "insta" version = "1.43.2" diff --git a/components/src/dynamo/frontend/main.py b/components/src/dynamo/frontend/main.py index eaff60aa50..f64587ee8c 100644 --- a/components/src/dynamo/frontend/main.py +++ b/components/src/dynamo/frontend/main.py @@ -225,6 +225,12 @@ def parse_args(): ), help=f"Interval in seconds for polling custom backend metrics. Set to > 0 to enable polling (default: 0=disabled, suggested: 9.2s which is less than typical Prometheus scrape interval). Can be set via {CUSTOM_BACKEND_METRICS_POLLING_INTERVAL_ENV_VAR} env var.", ) + parser.add_argument( + "--store-kv", + type=str, + default=os.environ.get("DYN_STORE_KV", "etcd"), + help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.", + ) flags = parser.parse_args() @@ -252,8 +258,7 @@ async def async_main(): os.environ["DYN_METRICS_PREFIX"] = flags.metrics_prefix loop = asyncio.get_running_loop() - - runtime = DistributedRuntime(loop, is_static) + runtime = DistributedRuntime(loop, flags.store_kv, is_static) def signal_handler(): asyncio.create_task(graceful_shutdown(runtime)) diff --git a/components/src/dynamo/mocker/args.py b/components/src/dynamo/mocker/args.py index 8f3631458e..7a10ba02d9 100644 --- a/components/src/dynamo/mocker/args.py +++ b/components/src/dynamo/mocker/args.py @@ -204,6 +204,12 @@ def parse_args(): default=False, help="Mark this as a decode worker which does not publish KV events and skips prefill cost estimation (default: False)", ) + parser.add_argument( + "--store-kv", + type=str, + default=os.environ.get("DYN_STORE_KV", "etcd"), + help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.", + ) args = parser.parse_args() validate_worker_type_args(args) diff --git a/components/src/dynamo/mocker/main.py b/components/src/dynamo/mocker/main.py index 220f2995ac..6dbd3d9fc3 100644 --- a/components/src/dynamo/mocker/main.py +++ b/components/src/dynamo/mocker/main.py @@ -72,7 +72,7 @@ async def launch_workers(args, extra_engine_args_path): logger.info(f"Creating mocker worker {worker_id + 1}/{args.num_workers}") # Create a separate DistributedRuntime for this worker (on same event loop) - runtime = DistributedRuntime(loop, False) + runtime = DistributedRuntime(loop, args.store_kv, False) runtimes.append(runtime) # Create EntrypointArgs for this worker diff --git a/components/src/dynamo/sglang/args.py b/components/src/dynamo/sglang/args.py index 2417f0fc28..a08cb025de 100644 --- a/components/src/dynamo/sglang/args.py +++ b/components/src/dynamo/sglang/args.py @@ -93,6 +93,12 @@ "default": None, "help": "Dump debug config to the specified file path. If not specified, the config will be dumped to stdout at INFO level.", }, + "store-kv": { + "flags": ["--store-kv"], + "type": str, + "default": os.environ.get("DYN_STORE_KV", "etcd"), + "help": "Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.", + }, } @@ -102,6 +108,7 @@ class DynamoArgs: component: str endpoint: str migration_limit: int + store_kv: str # tool and reasoning parser options tool_call_parser: Optional[str] = None @@ -329,6 +336,7 @@ async def parse_args(args: list[str]) -> Config: component=parsed_component_name, endpoint=parsed_endpoint_name, migration_limit=parsed_args.migration_limit, + store_kv=parsed_args.store_kv, tool_call_parser=tool_call_parser, reasoning_parser=reasoning_parser, custom_jinja_template=expanded_template_path, diff --git a/components/src/dynamo/sglang/main.py b/components/src/dynamo/sglang/main.py index 1dc20099a8..2d5e92bf49 100644 --- a/components/src/dynamo/sglang/main.py +++ b/components/src/dynamo/sglang/main.py @@ -11,7 +11,7 @@ from dynamo.common.config_dump import dump_config from dynamo.llm import ModelInput, ModelType -from dynamo.runtime import DistributedRuntime, dynamo_worker +from dynamo.runtime import DistributedRuntime from dynamo.runtime.logging import configure_dynamo_logging from dynamo.sglang.args import Config, DisaggregationMode, parse_args from dynamo.sglang.health_check import ( @@ -33,9 +33,12 @@ configure_dynamo_logging() -@dynamo_worker(static=False) -async def worker(runtime: DistributedRuntime): +async def worker(): + config = await parse_args(sys.argv[1:]) + dump_config(config.dynamo_args.dump_config_to, config) + loop = asyncio.get_running_loop() + runtime = DistributedRuntime(loop, config.dynamo_args.store_kv, False) def signal_handler(): asyncio.create_task(graceful_shutdown(runtime)) @@ -45,9 +48,6 @@ def signal_handler(): logging.info("Signal handlers will trigger a graceful shutdown of the runtime") - config = await parse_args(sys.argv[1:]) - dump_config(config.dynamo_args.dump_config_to, config) - if config.dynamo_args.embedding_worker: await init_embedding(runtime, config) elif config.dynamo_args.multimodal_processor: diff --git a/components/src/dynamo/trtllm/main.py b/components/src/dynamo/trtllm/main.py index 270c8ce58a..55d7723bc1 100644 --- a/components/src/dynamo/trtllm/main.py +++ b/components/src/dynamo/trtllm/main.py @@ -39,7 +39,7 @@ from dynamo.common.config_dump import dump_config from dynamo.common.utils.prometheus import register_engine_metrics_callback from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm -from dynamo.runtime import DistributedRuntime, dynamo_worker +from dynamo.runtime import DistributedRuntime from dynamo.runtime.logging import configure_dynamo_logging from dynamo.trtllm.engine import TensorRTLLMEngine, get_llm_engine from dynamo.trtllm.health_check import TrtllmHealthCheckPayload @@ -102,11 +102,13 @@ async def get_engine_runtime_config( return runtime_config -@dynamo_worker(static=False) -async def worker(runtime: DistributedRuntime): - # Set up signal handler for graceful shutdown +async def worker(): + config = cmd_line_args() + loop = asyncio.get_running_loop() + runtime = DistributedRuntime(loop, config.store_kv, False) + # Set up signal handler for graceful shutdown def signal_handler(): # Schedule the shutdown coroutine instead of calling it directly asyncio.create_task(graceful_shutdown(runtime)) @@ -116,7 +118,6 @@ def signal_handler(): logging.info("Signal handlers set up for graceful shutdown") - config = cmd_line_args() await init(runtime, config) diff --git a/components/src/dynamo/trtllm/utils/trtllm_utils.py b/components/src/dynamo/trtllm/utils/trtllm_utils.py index b7ec219f02..3a3da53d1f 100644 --- a/components/src/dynamo/trtllm/utils/trtllm_utils.py +++ b/components/src/dynamo/trtllm/utils/trtllm_utils.py @@ -58,6 +58,7 @@ def __init__(self) -> None: self.tool_call_parser: Optional[str] = None self.dump_config_to: Optional[str] = None self.custom_jinja_template: Optional[str] = None + self.store_kv: str = "" def __str__(self) -> str: return ( @@ -87,8 +88,9 @@ def __str__(self) -> str: f"max_file_size_mb={self.max_file_size_mb}, " f"reasoning_parser={self.reasoning_parser}, " f"tool_call_parser={self.tool_call_parser}, " - f"dump_config_to={self.dump_config_to}," - f"custom_jinja_template={self.custom_jinja_template}" + f"dump_config_to={self.dump_config_to}, " + f"custom_jinja_template={self.custom_jinja_template}, " + f"store_kv={self.store_kv}" ) @@ -278,6 +280,12 @@ def cmd_line_args(): default=None, help="Path to a custom Jinja template file to override the model's default chat template. This template will take precedence over any template found in the model repository.", ) + parser.add_argument( + "--store-kv", + type=str, + default=os.environ.get("DYN_STORE_KV", "etcd"), + help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.", + ) args = parser.parse_args() @@ -337,6 +345,7 @@ def cmd_line_args(): config.reasoning_parser = args.dyn_reasoning_parser config.tool_call_parser = args.dyn_tool_call_parser config.dump_config_to = args.dump_config_to + config.store_kv = args.store_kv # Handle custom jinja template path expansion (environment variables and home directory) if args.custom_jinja_template: diff --git a/components/src/dynamo/vllm/args.py b/components/src/dynamo/vllm/args.py index dc7e73ed88..ace113a32f 100644 --- a/components/src/dynamo/vllm/args.py +++ b/components/src/dynamo/vllm/args.py @@ -38,6 +38,7 @@ class Config: migration_limit: int = 0 kv_port: Optional[int] = None custom_jinja_template: Optional[str] = None + store_kv: str # mirror vLLM model: str @@ -164,6 +165,12 @@ def parse_args() -> Config: "'USER: please describe the image ASSISTANT:'." ), ) + parser.add_argument( + "--store-kv", + type=str, + default=os.environ.get("DYN_STORE_KV", "etcd"), + help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.", + ) add_config_dump_args(parser) parser = AsyncEngineArgs.add_cli_args(parser) @@ -233,6 +240,7 @@ def parse_args() -> Config: config.multimodal_worker = args.multimodal_worker config.multimodal_encode_prefill_worker = args.multimodal_encode_prefill_worker config.mm_prompt_template = args.mm_prompt_template + config.store_kv = args.store_kv # Validate custom Jinja template file exists if provided if config.custom_jinja_template is not None: diff --git a/components/src/dynamo/vllm/main.py b/components/src/dynamo/vllm/main.py index 060390d01b..d0e124c443 100644 --- a/components/src/dynamo/vllm/main.py +++ b/components/src/dynamo/vllm/main.py @@ -8,6 +8,7 @@ from typing import Optional import uvloop +# from kvbm.vllm_integration.consolidator_config import get_consolidator_endpoints from prometheus_client import REGISTRY from vllm.distributed.kv_events import ZmqEventPublisher from vllm.usage.usage_lib import UsageContext @@ -25,7 +26,7 @@ fetch_llm, register_llm, ) -from dynamo.runtime import DistributedRuntime, dynamo_worker +from dynamo.runtime import DistributedRuntime from dynamo.runtime.logging import configure_dynamo_logging from dynamo.vllm.multimodal_handlers import ( EncodeWorkerHandler, @@ -70,16 +71,16 @@ async def graceful_shutdown(runtime): logging.info("DistributedRuntime shutdown complete") -@dynamo_worker(static=False) -async def worker(runtime: DistributedRuntime): +async def worker(): config = parse_args() + loop = asyncio.get_running_loop() + runtime = DistributedRuntime(loop, config.store_kv, False) + await configure_ports(config) overwrite_args(config) # Set up signal handler for graceful shutdown - loop = asyncio.get_running_loop() - def signal_handler(): asyncio.create_task(graceful_shutdown(runtime)) diff --git a/examples/custom_backend/cancellation/client.py b/examples/custom_backend/cancellation/client.py index fbcbead315..18198b134c 100644 --- a/examples/custom_backend/cancellation/client.py +++ b/examples/custom_backend/cancellation/client.py @@ -50,7 +50,7 @@ async def main(): return loop = asyncio.get_running_loop() - runtime = DistributedRuntime(loop, True) + runtime = DistributedRuntime(loop, "etcd", False) # Connect to middle server or direct server based on argument if use_middle_server: diff --git a/examples/custom_backend/cancellation/middle_server.py b/examples/custom_backend/cancellation/middle_server.py index 968cee014b..491cd3be67 100644 --- a/examples/custom_backend/cancellation/middle_server.py +++ b/examples/custom_backend/cancellation/middle_server.py @@ -50,7 +50,7 @@ async def generate(self, request, context): async def main(): """Start the middle server""" loop = asyncio.get_running_loop() - runtime = DistributedRuntime(loop, True) + runtime = DistributedRuntime(loop, "etcd", False) # Create middle server handler handler = MiddleServer(runtime) diff --git a/examples/custom_backend/cancellation/server.py b/examples/custom_backend/cancellation/server.py index 63a1c70938..9efb3cb9e8 100644 --- a/examples/custom_backend/cancellation/server.py +++ b/examples/custom_backend/cancellation/server.py @@ -31,7 +31,7 @@ async def generate(self, request, context): async def main(): """Start the demo server""" loop = asyncio.get_running_loop() - runtime = DistributedRuntime(loop, True) + runtime = DistributedRuntime(loop, "etcd", False) # Create server component component = runtime.namespace("demo").component("server") diff --git a/examples/custom_backend/nim/mock_nim_frontend.py b/examples/custom_backend/nim/mock_nim_frontend.py index c79b5c03f7..defebe7980 100755 --- a/examples/custom_backend/nim/mock_nim_frontend.py +++ b/examples/custom_backend/nim/mock_nim_frontend.py @@ -123,7 +123,7 @@ async def async_main(): # Create DistributedRuntime - similar to frontend/main.py line 246 is_static = True # Use static mode (no etcd) - runtime = DistributedRuntime(loop, is_static) # type: ignore[call-arg] + runtime = DistributedRuntime(loop, "mem", is_static) # type: ignore[call-arg] # Setup signal handlers for graceful shutdown def signal_handler(): diff --git a/launch/dynamo-run/src/flags.rs b/launch/dynamo-run/src/flags.rs index 66603cfdd6..09de1ad6cb 100644 --- a/launch/dynamo-run/src/flags.rs +++ b/launch/dynamo-run/src/flags.rs @@ -127,6 +127,12 @@ pub struct Flags { #[arg(long, default_value = "false")] pub static_worker: bool, + /// Which key-value backend to use: etcd, mem, file. + /// Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. + /// File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv. + #[arg(long, default_value = "etcd")] + pub store_kv: String, + /// Everything after a `--`. /// These are the command line arguments to the python engine when using `pystr` or `pytok`. #[arg(index = 2, last = true, hide = true, allow_hyphen_values = true)] diff --git a/launch/dynamo-run/src/lib.rs b/launch/dynamo-run/src/lib.rs index d0a63c0792..a82a80eb00 100644 --- a/launch/dynamo-run/src/lib.rs +++ b/launch/dynamo-run/src/lib.rs @@ -6,10 +6,11 @@ use dynamo_llm::entrypoint::EngineConfig; use dynamo_llm::entrypoint::input::Input; use dynamo_llm::local_model::{LocalModel, LocalModelBuilder}; use dynamo_runtime::distributed::DistributedConfig; +use dynamo_runtime::storage::key_value_store::KeyValueStoreSelect; +use dynamo_runtime::transports::nats; use dynamo_runtime::{DistributedRuntime, Runtime}; mod flags; -use either::Either; pub use flags::Flags; mod opt; pub use dynamo_llm::request_template::RequestTemplate; @@ -73,14 +74,16 @@ pub async fn run( // TODO: old, address this later: // If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint. // If not, then the endpoint isn't exposed so we let LocalModel invent one. - let mut rt = Either::Left(runtime.clone()); if let Input::Endpoint(path) = &in_opt { builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?)); - - let dst_config = DistributedConfig::from_settings(flags.static_worker); - let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?; - rt = Either::Right(distributed_runtime); + } + let selected_store: KeyValueStoreSelect = flags.store_kv.parse()?; + let dst_config = DistributedConfig { + store_backend: selected_store, + nats_config: nats::ClientOptions::default(), + is_static: flags.static_worker, }; + let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?; if let Some(Output::Static(path)) = &out_opt { builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?)); } @@ -98,10 +101,16 @@ pub async fn run( flags.validate(&in_opt, &out_opt)?; // Make an engine from the local_model, flags and output. - let engine_config = engine_for(out_opt, flags.clone(), local_model, rt.clone()).await?; + let engine_config = engine_for( + out_opt, + flags.clone(), + local_model, + distributed_runtime.clone(), + ) + .await?; // Run it from an input - dynamo_llm::entrypoint::input::run_input(rt, in_opt, engine_config).await?; + dynamo_llm::entrypoint::input::run_input(distributed_runtime, in_opt, engine_config).await?; Ok(()) } @@ -112,7 +121,7 @@ async fn engine_for( out_opt: Output, flags: Flags, local_model: LocalModel, - rt: Either, + drt: DistributedRuntime, ) -> anyhow::Result { match out_opt { Output::Auto => { @@ -135,10 +144,6 @@ async fn engine_for( is_static: flags.static_worker, }), Output::Mocker => { - let Either::Right(drt) = rt else { - panic!("Mocker requires a distributed runtime to run."); - }; - let args = flags.mocker_config(); let endpoint = local_model.endpoint_id().clone(); diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index c9468ec900..3878f6e2a6 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -1606,6 +1606,7 @@ dependencies = [ "figment", "futures", "humantime", + "inotify", "local-ip-address", "log", "nid", @@ -2857,6 +2858,28 @@ version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" +[[package]] +name = "inotify" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f37dccff2791ab604f9babef0ba14fbe0be30bd368dc541e2b08d07c8aa908f3" +dependencies = [ + "bitflags 2.9.3", + "futures-core", + "inotify-sys", + "libc", + "tokio", +] + +[[package]] +name = "inotify-sys" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb" +dependencies = [ + "libc", +] + [[package]] name = "instant" version = "0.1.13" diff --git a/lib/bindings/python/examples/cli/cli.py b/lib/bindings/python/examples/cli/cli.py index 6d2c6ded78..f1722a172d 100644 --- a/lib/bindings/python/examples/cli/cli.py +++ b/lib/bindings/python/examples/cli/cli.py @@ -115,7 +115,7 @@ def parse_args(): async def run(): loop = asyncio.get_running_loop() - runtime = DistributedRuntime(loop, False) + runtime = DistributedRuntime(loop, "etcd", False) args = parse_args() diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index fc905b45d1..f1faf54670 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -2,6 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 use dynamo_llm::local_model::LocalModel; +use dynamo_runtime::distributed::DistributedConfig; +use dynamo_runtime::storage::key_value_store::KeyValueStoreSelect; use futures::StreamExt; use once_cell::sync::OnceCell; use pyo3::IntoPyObjectExt; @@ -426,7 +428,9 @@ enum ModelInput { #[pymethods] impl DistributedRuntime { #[new] - fn new(event_loop: PyObject, is_static: bool) -> PyResult { + fn new(event_loop: PyObject, store_kv: String, is_static: bool) -> PyResult { + let selected_kv_store: KeyValueStoreSelect = store_kv.parse().map_err(to_pyerr)?; + // Try to get existing runtime first, create new Worker only if needed // This allows multiple DistributedRuntime instances to share the same tokio runtime let runtime = rs::Worker::runtime_from_existing() @@ -464,9 +468,14 @@ impl DistributedRuntime { rs::DistributedRuntime::from_settings_without_discovery(runtime), ) } else { + let config = DistributedConfig { + store_backend: selected_kv_store, + is_static: false, + nats_config: dynamo_runtime::transports::nats::ClientOptions::default(), + }; runtime .secondary() - .block_on(rs::DistributedRuntime::from_settings(runtime)) + .block_on(rs::DistributedRuntime::new(runtime, config)) }; let inner = inner.map_err(to_pyerr)?; @@ -628,7 +637,7 @@ impl DistributedRuntime { } fn shutdown(&self) { - self.inner.runtime().shutdown(); + self.inner.shutdown(); } fn event_loop(&self) -> PyObject { diff --git a/lib/bindings/python/rust/llm/entrypoint.rs b/lib/bindings/python/rust/llm/entrypoint.rs index ba28abca21..6980fbe622 100644 --- a/lib/bindings/python/rust/llm/entrypoint.rs +++ b/lib/bindings/python/rust/llm/entrypoint.rs @@ -299,7 +299,7 @@ pub fn run_input<'p>( let input_enum: Input = input.parse().map_err(to_pyerr)?; pyo3_async_runtimes::tokio::future_into_py(py, async move { dynamo_llm::entrypoint::input::run_input( - either::Either::Right(distributed_runtime.inner.clone()), + distributed_runtime.inner.clone(), input_enum, engine_config.inner, ) diff --git a/lib/bindings/python/src/dynamo/runtime/__init__.py b/lib/bindings/python/src/dynamo/runtime/__init__.py index 7e9195c304..c46205d72a 100644 --- a/lib/bindings/python/src/dynamo/runtime/__init__.py +++ b/lib/bindings/python/src/dynamo/runtime/__init__.py @@ -25,7 +25,7 @@ def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): loop = asyncio.get_running_loop() - runtime = DistributedRuntime(loop, static) + runtime = DistributedRuntime(loop, "etcd", static) await func(runtime, *args, **kwargs) diff --git a/lib/bindings/python/tests/cancellation/test_example.py b/lib/bindings/python/tests/cancellation/test_example.py index 0a253d5928..26c39240c0 100644 --- a/lib/bindings/python/tests/cancellation/test_example.py +++ b/lib/bindings/python/tests/cancellation/test_example.py @@ -100,6 +100,7 @@ def stop_process(process): @pytest.mark.asyncio +@pytest.mark.skip(reason="Graham working on it") async def test_direct_connection_cancellation(example_dir, server_process): """Test cancellation with direct client-server connection""" # Run the client (direct connection) @@ -121,6 +122,7 @@ async def test_direct_connection_cancellation(example_dir, server_process): @pytest.mark.asyncio +@pytest.mark.skip(reason="Graham working on it") async def test_middle_server_cancellation( example_dir, server_process, middle_server_process ): diff --git a/lib/bindings/python/tests/conftest.py b/lib/bindings/python/tests/conftest.py index f34abbc79f..0eb148ca15 100644 --- a/lib/bindings/python/tests/conftest.py +++ b/lib/bindings/python/tests/conftest.py @@ -430,6 +430,6 @@ async def test_my_test(runtime): ) loop = asyncio.get_running_loop() - runtime = DistributedRuntime(loop, True) + runtime = DistributedRuntime(loop, "file", True) yield runtime runtime.shutdown() diff --git a/lib/bindings/python/tests/test_kv_bindings.py b/lib/bindings/python/tests/test_kv_bindings.py index f2a5256477..7dc95deb91 100644 --- a/lib/bindings/python/tests/test_kv_bindings.py +++ b/lib/bindings/python/tests/test_kv_bindings.py @@ -34,7 +34,7 @@ async def distributed_runtime(): Each test gets its own runtime in a forked process to avoid singleton conflicts. """ loop = asyncio.get_running_loop() - runtime = DistributedRuntime(loop, False) + runtime = DistributedRuntime(loop, "etcd", False) yield runtime runtime.shutdown() diff --git a/lib/llm/src/audit/sink.rs b/lib/llm/src/audit/sink.rs index 0d4f4088bb..1f0628d7fb 100644 --- a/lib/llm/src/audit/sink.rs +++ b/lib/llm/src/audit/sink.rs @@ -89,8 +89,8 @@ fn parse_sinks_from_env( } /// spawn one worker per sink; each subscribes to the bus (off hot path) -pub fn spawn_workers_from_env(drt: Option<&dynamo_runtime::DistributedRuntime>) { - let nats_client = drt.and_then(|d| d.nats_client()); +pub fn spawn_workers_from_env(drt: &dynamo_runtime::DistributedRuntime) { + let nats_client = drt.nats_client(); let sinks = parse_sinks_from_env(nats_client); for sink in sinks { let name = sink.name(); diff --git a/lib/llm/src/discovery/watcher.rs b/lib/llm/src/discovery/watcher.rs index 00412422ad..9dd6f6b688 100644 --- a/lib/llm/src/discovery/watcher.rs +++ b/lib/llm/src/discovery/watcher.rs @@ -3,25 +3,26 @@ use std::sync::Arc; use tokio::sync::mpsc::Sender; +use tokio::sync::Notify; use anyhow::Context as _; -use tokio::sync::{Notify, mpsc::Receiver}; +use futures::StreamExt; use dynamo_runtime::{ DistributedRuntime, + discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryKey, DiscoveryStream}, pipeline::{ ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source, network::egress::push_router::PushRouter, }, protocols::{EndpointId, annotated::Annotated}, - storage::key_value_store::WatchEvent, }; use crate::{ backend::Backend, entrypoint, kv_router::{KvRouterConfig, PrefillRouter}, - model_card::{self, ModelDeploymentCard}, + model_card::ModelDeploymentCard, model_type::{ModelInput, ModelType}, preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter}, protocols::{ @@ -99,17 +100,45 @@ impl ModelWatcher { } /// Common watch logic with optional namespace filtering - pub async fn watch(&self, mut events_rx: Receiver, target_namespace: Option<&str>) { + pub async fn watch(&self, mut discovery_stream: DiscoveryStream, target_namespace: Option<&str>) { let global_namespace = target_namespace.is_none_or(is_global_namespace); - while let Some(event) = events_rx.recv().await { + while let Some(result) = discovery_stream.next().await { + let event = match result { + Ok(event) => event, + Err(err) => { + tracing::error!(%err, "Error in discovery stream"); + continue; + } + }; + match event { - WatchEvent::Put(kv) => { - let key = kv.key_str(); - let endpoint_id = match key_extract(key) { - Ok((eid, _)) => eid, - Err(err) => { - tracing::error!(%key, %err, "Failed extracting EndpointId from key. Ignoring instance."); + DiscoveryEvent::Added(instance) => { + // Extract EndpointId, instance_id, and card from the discovery instance + let (endpoint_id, instance_id, mut card) = match &instance { + DiscoveryInstance::ModelCard { + namespace, + component, + endpoint, + instance_id, + .. + } => { + let eid = EndpointId { + namespace: namespace.clone(), + component: component.clone(), + name: endpoint.clone(), + }; + + match instance.deserialize_model_card::() { + Ok(card) => (eid, *instance_id, card), + Err(err) => { + tracing::error!(%err, instance_id, "Failed to deserialize model card"); + continue; + } + } + } + _ => { + tracing::error!("Unexpected discovery instance type (expected ModelCard)"); continue; } }; @@ -127,21 +156,6 @@ impl ModelWatcher { continue; } - let mut card = match serde_json::from_slice::(kv.value()) { - Ok(card) => card, - Err(err) => { - match kv.value_str() { - Ok(value) => { - tracing::error!(%err, value, "Invalid JSON in model card") - } - Err(value_str_err) => { - tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model card, expected JSON") - } - } - continue; - } - }; - // If we already have a worker for this model, and the ModelDeploymentCard // cards don't match, alert, and don't add the new instance let can_add = @@ -164,7 +178,10 @@ impl ModelWatcher { continue; } - match self.handle_put(key, &endpoint_id, &mut card).await { + // Use instance_id as the HashMap key (simpler and sufficient since keys are opaque) + let key = format!("{:x}", instance_id); + + match self.handle_put(&key, &endpoint_id, &mut card).await { Ok(()) => { tracing::info!( model_name = card.name(), @@ -183,10 +200,12 @@ impl ModelWatcher { } } } - WatchEvent::Delete(kv) => { - let deleted_key = kv.key_str(); + DiscoveryEvent::Removed(instance_id) => { + // Use instance_id hex as the HashMap key (matches what we saved with) + let key = format!("{:x}", instance_id); + match self - .handle_delete(deleted_key, target_namespace, global_namespace) + .handle_delete(&key, target_namespace, global_namespace) .await { Ok(Some(model_name)) => { @@ -212,6 +231,8 @@ impl ModelWatcher { target_namespace: Option<&str>, is_global_namespace: bool, ) -> anyhow::Result> { + tracing::warn!("DISCOVERY_VALIDATION: handle_delete: key={}", key); + let card = match self.manager.remove_model_card(key) { Some(card) => card, None => { @@ -303,6 +324,8 @@ impl ModelWatcher { endpoint_id: &EndpointId, card: &mut ModelDeploymentCard, ) -> anyhow::Result<()> { + tracing::warn!("DISCOVERY_VALIDATION: handle_put: key={}", key); + card.download_config().await?; let component = self @@ -559,35 +582,37 @@ impl ModelWatcher { /// All the registered ModelDeploymentCard with the EndpointId they are attached to, one per instance async fn all_cards(&self) -> anyhow::Result> { - let store = self.drt.store(); - let Some(card_bucket) = store.get_bucket(model_card::ROOT_PATH).await? else { - // no cards - return Ok(vec![]); - }; - let entries = card_bucket.entries().await?; + let discovery = self.drt.discovery_client(); + let instances = discovery.list(DiscoveryKey::AllModelCards).await?; - let mut results = Vec::with_capacity(entries.len()); - for (key, card_bytes) in entries { - let r = match serde_json::from_slice::(&card_bytes) { + let mut results = Vec::with_capacity(instances.len()); + for instance in instances { + match instance.deserialize_model_card::() { Ok(card) => { - let maybe_endpoint_id = - key_extract(&key).map(|(endpoint_id, _instance_id)| endpoint_id); - let endpoint_id = match maybe_endpoint_id { - Ok(eid) => eid, - Err(err) => { - tracing::error!(%err, "Skipping invalid key, not string or not EndpointId"); + // Extract EndpointId from the instance + let endpoint_id = match &instance { + dynamo_runtime::discovery::DiscoveryInstance::ModelCard { + namespace, + component, + endpoint, + .. + } => EndpointId { + namespace: namespace.clone(), + component: component.clone(), + name: endpoint.clone(), + }, + _ => { + tracing::error!("Unexpected discovery instance type (expected ModelCard)"); continue; } }; - (endpoint_id, card) + results.push((endpoint_id, card)); } Err(err) => { - let value = String::from_utf8_lossy(&card_bytes); - tracing::error!(%err, %value, "Invalid JSON in model card"); + tracing::error!(%err, "Failed to deserialize model card"); continue; } - }; - results.push(r); + } } Ok(results) } @@ -612,40 +637,4 @@ impl ModelWatcher { } } -/// The ModelDeploymentCard is published in store with a key like "v1/mdc/dynamo/backend/generate/694d9981145a61ad". -/// Extract the EndpointId and instance_id from that. -fn key_extract(s: &str) -> anyhow::Result<(EndpointId, String)> { - if !s.starts_with(model_card::ROOT_PATH) { - anyhow::bail!("Invalid format: expected model card ROOT_PATH segment in {s}"); - } - let parts: Vec<&str> = s.split('/').collect(); - - // Need at least prefix model_card::ROOT_PATH (2 parts) + namespace, component, name (3 parts) - if parts.len() <= 5 { - anyhow::bail!("Invalid format: not enough path segments in {s}"); - } - let endpoint_id = EndpointId { - namespace: parts[2].to_string(), - component: parts[3].to_string(), - name: parts[4].to_string(), - }; - Ok((endpoint_id, parts[parts.len() - 1].to_string())) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_key_extract() { - let input = format!( - "{}/dynamo/backend/generate/694d9981145a61ad", - model_card::ROOT_PATH - ); - let (endpoint_id, _) = key_extract(&input).unwrap(); - assert_eq!(endpoint_id.namespace, "dynamo"); - assert_eq!(endpoint_id.component, "backend"); - assert_eq!(endpoint_id.name, "generate"); - } -} diff --git a/lib/llm/src/discovery/worker_monitor.rs b/lib/llm/src/discovery/worker_monitor.rs index bc43dd38bf..3e0e6b7031 100644 --- a/lib/llm/src/discovery/worker_monitor.rs +++ b/lib/llm/src/discovery/worker_monitor.rs @@ -3,12 +3,12 @@ use crate::kv_router::KV_METRICS_SUBJECT; use crate::kv_router::scoring::LoadEvent; -use crate::model_card::{self, ModelDeploymentCard}; +use crate::model_card::ModelDeploymentCard; use dynamo_runtime::component::Client; +use dynamo_runtime::discovery::{watch_and_extract_field, DiscoveryKey}; use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait}; use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::events::EventSubscriber; -use dynamo_runtime::utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction}; use std::collections::HashMap; use std::sync::{Arc, RwLock}; use tokio_stream::StreamExt; @@ -79,21 +79,13 @@ impl WorkerLoadMonitor for KvWorkerMonitor { let endpoint = &self.client.endpoint; let component = endpoint.component(); - let Some(etcd_client) = component.drt().etcd_client() else { - // Static mode, no monitoring needed - return Ok(()); - }; - - // Watch for runtime config updates from model deployment cards - let runtime_configs_watcher = watch_prefix_with_extraction( - etcd_client, - model_card::ROOT_PATH, - key_extractors::lease_id, - |card: ModelDeploymentCard| Some(card.runtime_config), - component.drt().child_token(), - ) - .await?; - let mut config_events_rx = runtime_configs_watcher.receiver(); + // Watch for runtime config updates from model deployment cards via discovery interface + let discovery = component.drt().discovery_client(); + let discovery_stream = discovery.list_and_watch(DiscoveryKey::AllModelCards).await?; + let mut config_events_rx = watch_and_extract_field( + discovery_stream, + |card: ModelDeploymentCard| card.runtime_config, + ); // Subscribe to KV metrics events let mut kv_metrics_rx = component.namespace().subscribe(KV_METRICS_SUBJECT).await?; @@ -117,6 +109,21 @@ impl WorkerLoadMonitor for KvWorkerMonitor { // Handle runtime config updates _ = config_events_rx.changed() => { let runtime_configs = config_events_rx.borrow().clone(); + + tracing::warn!( + worker_count = runtime_configs.len(), + "DISCOVERY: Runtime config updates received" + ); + + // Log detailed config state for comparison + let config_details: Vec<(u64, Option)> = runtime_configs + .iter() + .map(|(&lease_id, config)| (lease_id, config.total_kv_blocks)) + .collect(); + tracing::warn!( + "DISCOVERY_VALIDATION: config_state: configs={:?}", + config_details + ); let mut states = worker_load_states.write().unwrap(); states.retain(|lease_id, _| runtime_configs.contains_key(lease_id)); diff --git a/lib/llm/src/entrypoint/input.rs b/lib/llm/src/entrypoint/input.rs index 02967019af..09c6dfb781 100644 --- a/lib/llm/src/entrypoint/input.rs +++ b/lib/llm/src/entrypoint/input.rs @@ -23,7 +23,6 @@ pub mod http; pub mod text; use dynamo_runtime::protocols::ENDPOINT_SCHEME; -use either::Either; const BATCH_PREFIX: &str = "batch:"; @@ -107,15 +106,10 @@ impl Default for Input { /// For Input::Endpoint pass a DistributedRuntime. For everything else pass either a Runtime or a /// DistributedRuntime. pub async fn run_input( - rt: Either, + drt: dynamo_runtime::DistributedRuntime, in_opt: Input, engine_config: super::EngineConfig, ) -> anyhow::Result<()> { - let runtime = match &rt { - Either::Left(rt) => rt.clone(), - Either::Right(drt) => drt.runtime().clone(), - }; - // Initialize audit bus + sink workers (off hot path; fan-out supported) if crate::audit::config::policy().enabled { let cap: usize = std::env::var("DYN_AUDIT_CAPACITY") @@ -123,38 +117,30 @@ pub async fn run_input( .and_then(|v| v.parse().ok()) .unwrap_or(1024); crate::audit::bus::init(cap); - // Pass DistributedRuntime if available for shared NATS client - let drt_ref = match &rt { - Either::Right(drt) => Some(drt), - Either::Left(_) => None, - }; - crate::audit::sink::spawn_workers_from_env(drt_ref); - tracing::info!("Audit initialized: bus cap={}", cap); + crate::audit::sink::spawn_workers_from_env(&drt); + tracing::info!(cap, "Audit initialized"); } match in_opt { Input::Http => { - http::run(runtime, engine_config).await?; + http::run(drt, engine_config).await?; } Input::Grpc => { - grpc::run(runtime, engine_config).await?; + grpc::run(drt, engine_config).await?; } Input::Text => { - text::run(runtime, None, engine_config).await?; + text::run(drt, None, engine_config).await?; } Input::Stdin => { let mut prompt = String::new(); std::io::stdin().read_to_string(&mut prompt).unwrap(); - text::run(runtime, Some(prompt), engine_config).await?; + text::run(drt, Some(prompt), engine_config).await?; } Input::Batch(path) => { - batch::run(runtime, path, engine_config).await?; + batch::run(drt, path, engine_config).await?; } Input::Endpoint(path) => { - let Either::Right(distributed_runtime) = rt else { - anyhow::bail!("Input::Endpoint requires passing a DistributedRuntime"); - }; - endpoint::run(distributed_runtime, path, engine_config).await?; + endpoint::run(drt, path, engine_config).await?; } } Ok(()) diff --git a/lib/llm/src/entrypoint/input/batch.rs b/lib/llm/src/entrypoint/input/batch.rs index 0498f08889..f379676133 100644 --- a/lib/llm/src/entrypoint/input/batch.rs +++ b/lib/llm/src/entrypoint/input/batch.rs @@ -8,7 +8,7 @@ use crate::types::openai::chat_completions::{ }; use anyhow::Context as _; use dynamo_async_openai::types::FinishReason; -use dynamo_runtime::{Runtime, pipeline::Context, runtime::CancellationToken}; +use dynamo_runtime::{DistributedRuntime, pipeline::Context, runtime::CancellationToken}; use futures::StreamExt; use serde::{Deserialize, Serialize}; use std::cmp; @@ -51,11 +51,11 @@ struct Entry { } pub async fn run( - runtime: Runtime, + distributed_runtime: DistributedRuntime, input_jsonl: PathBuf, engine_config: EngineConfig, ) -> anyhow::Result<()> { - let cancel_token = runtime.primary_token(); + let cancel_token = distributed_runtime.primary_token(); // Check if the path exists and is a directory if !input_jsonl.exists() || !input_jsonl.is_file() { anyhow::bail!( @@ -64,7 +64,7 @@ pub async fn run( ); } - let mut prepared_engine = common::prepare_engine(runtime, engine_config).await?; + let mut prepared_engine = common::prepare_engine(distributed_runtime, engine_config).await?; let pre_processor = if prepared_engine.has_tokenizer() { Some(OpenAIPreprocessor::new( diff --git a/lib/llm/src/entrypoint/input/common.rs b/lib/llm/src/entrypoint/input/common.rs index df382b3b62..46daee9cb3 100644 --- a/lib/llm/src/entrypoint/input/common.rs +++ b/lib/llm/src/entrypoint/input/common.rs @@ -10,7 +10,7 @@ use crate::{ entrypoint::{self, EngineConfig}, kv_router::{KvPushRouter, KvRouter, PrefillRouter}, migration::Migration, - model_card::{self, ModelDeploymentCard}, + model_card::ModelDeploymentCard, preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter}, protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest}, request_template::RequestTemplate, @@ -24,9 +24,8 @@ use crate::{ }; use dynamo_runtime::{ - DistributedRuntime, Runtime, + DistributedRuntime, component::Client, - distributed::DistributedConfig, engine::{AsyncEngineStream, Data}, pipeline::{ Context, ManyOut, Operator, PushRouter, RouterMode, SegmentSource, ServiceBackend, @@ -55,26 +54,24 @@ impl PreparedEngine { /// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine. pub async fn prepare_engine( - runtime: Runtime, + distributed_runtime: DistributedRuntime, engine_config: EngineConfig, ) -> anyhow::Result { match engine_config { EngineConfig::Dynamic(local_model) => { - let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; - - let store = Arc::new(distributed_runtime.store().clone()); let model_manager = Arc::new(ModelManager::new()); let watch_obj = Arc::new(ModelWatcher::new( - distributed_runtime, + distributed_runtime.clone(), model_manager.clone(), dynamo_runtime::pipeline::RouterMode::RoundRobin, None, None, )); - let (_, receiver) = store.watch(model_card::ROOT_PATH, None, runtime.primary_token()); + let discovery = distributed_runtime.discovery_client(); + let discovery_stream = discovery.list_and_watch(dynamo_runtime::discovery::DiscoveryKey::AllModelCards).await?; let inner_watch_obj = watch_obj.clone(); let _watcher_task = tokio::spawn(async move { - inner_watch_obj.watch(receiver, None).await; + inner_watch_obj.watch(discovery_stream, None).await; }); tracing::info!("Waiting for remote model.."); @@ -98,9 +95,6 @@ pub async fn prepare_engine( let card = local_model.card(); let router_mode = local_model.router_config().router_mode; - let dst_config = DistributedConfig::from_settings(true); - let distributed_runtime = DistributedRuntime::new(runtime, dst_config).await?; - let endpoint_id = local_model.endpoint_id(); let component = distributed_runtime .namespace(&endpoint_id.namespace)? diff --git a/lib/llm/src/entrypoint/input/grpc.rs b/lib/llm/src/entrypoint/input/grpc.rs index 8693c4d1d1..49978fd974 100644 --- a/lib/llm/src/entrypoint/input/grpc.rs +++ b/lib/llm/src/entrypoint/input/grpc.rs @@ -9,26 +9,26 @@ use crate::{ entrypoint::{self, EngineConfig, input::common}, grpc::service::kserve, kv_router::KvRouterConfig, - model_card, namespace::is_global_namespace, types::openai::{ chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, }, }; -use dynamo_runtime::{DistributedRuntime, Runtime, storage::key_value_store::KeyValueStoreManager}; -use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode}; +use dynamo_runtime::{DistributedRuntime}; +use dynamo_runtime::{pipeline::RouterMode}; /// Build and run an KServe gRPC service -pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> { +pub async fn run( + distributed_runtime: DistributedRuntime, + engine_config: EngineConfig, +) -> anyhow::Result<()> { let grpc_service_builder = kserve::KserveService::builder() .port(engine_config.local_model().http_port()) // [WIP] generalize port.. .with_request_template(engine_config.local_model().request_template()); let grpc_service = match engine_config { EngineConfig::Dynamic(_) => { - let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; - let store = Arc::new(distributed_runtime.store().clone()); let grpc_service = grpc_service_builder.build()?; let router_config = engine_config.local_model().router_config(); // Listen for models registering themselves, add them to gRPC service @@ -39,9 +39,8 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul Some(namespace.to_string()) }; run_watcher( - distributed_runtime, + distributed_runtime.clone(), grpc_service.state().manager_clone(), - store, router_config.router_mode, Some(router_config.kv_router_config), router_config.busy_threshold, @@ -55,8 +54,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul let checksum = card.mdcsum(); let router_mode = local_model.router_config().router_mode; - let dst_config = DistributedConfig::from_settings(true); // true means static - let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?; let grpc_service = grpc_service_builder.build()?; let manager = grpc_service.model_manager(); @@ -157,41 +154,41 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul grpc_service } }; - grpc_service.run(runtime.primary_token()).await?; - runtime.shutdown(); // Cancel primary token + grpc_service + .run(distributed_runtime.primary_token()) + .await?; + distributed_runtime.shutdown(); // Cancel primary token Ok(()) } /// Spawns a task that watches for new models in store, /// and registers them with the ModelManager so that the HTTP service can use them. -#[allow(clippy::too_many_arguments)] async fn run_watcher( runtime: DistributedRuntime, model_manager: Arc, - store: Arc, router_mode: RouterMode, kv_router_config: Option, busy_threshold: Option, target_namespace: Option, ) -> anyhow::Result<()> { - let cancellation_token = runtime.primary_token(); let watch_obj = ModelWatcher::new( - runtime, + runtime.clone(), model_manager, router_mode, kv_router_config, busy_threshold, ); tracing::debug!("Waiting for remote model"); - let (_, receiver) = store.watch(model_card::ROOT_PATH, None, cancellation_token); + let discovery = runtime.discovery_client(); + let discovery_stream = discovery.list_and_watch(dynamo_runtime::discovery::DiscoveryKey::AllModelCards).await?; // [gluo NOTE] This is different from http::run_watcher where it alters the HTTP service // endpoint being exposed, gRPC doesn't have the same concept as the KServe service // only has one kind of inference endpoint. - // Pass the sender to the watcher + // Pass the discovery stream to the watcher let _watcher_task = tokio::spawn(async move { - watch_obj.watch(receiver, target_namespace.as_deref()).await; + watch_obj.watch(discovery_stream, target_namespace.as_deref()).await; }); Ok(()) diff --git a/lib/llm/src/entrypoint/input/http.rs b/lib/llm/src/entrypoint/input/http.rs index 88b4e3e979..f2efe8c03b 100644 --- a/lib/llm/src/entrypoint/input/http.rs +++ b/lib/llm/src/entrypoint/input/http.rs @@ -10,19 +10,20 @@ use crate::{ entrypoint::{self, EngineConfig, input::common}, http::service::service_v2::{self, HttpService}, kv_router::KvRouterConfig, - model_card, namespace::is_global_namespace, types::openai::{ chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, }, }; -use dynamo_runtime::storage::key_value_store::KeyValueStoreManager; -use dynamo_runtime::{DistributedRuntime, Runtime}; -use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode}; +use dynamo_runtime::DistributedRuntime; +use dynamo_runtime::pipeline::RouterMode; /// Build and run an HTTP service -pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> { +pub async fn run( + distributed_runtime: DistributedRuntime, + engine_config: EngineConfig, +) -> anyhow::Result<()> { let local_model = engine_config.local_model(); let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) { (Some(tls_cert_path), Some(tls_key_path)) => { @@ -63,11 +64,9 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul let http_service = match engine_config { EngineConfig::Dynamic(_) => { - let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; // This allows the /health endpoint to query store for active instances http_service_builder = http_service_builder.store(distributed_runtime.store().clone()); let http_service = http_service_builder.build()?; - let store = Arc::new(distributed_runtime.store().clone()); let router_config = engine_config.local_model().router_config(); // Listen for models registering themselves, add them to HTTP service @@ -80,9 +79,8 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul Some(namespace.to_string()) }; run_watcher( - distributed_runtime, + distributed_runtime.clone(), http_service.state().manager_clone(), - store, router_config.router_mode, Some(router_config.kv_router_config), router_config.busy_threshold, @@ -96,11 +94,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul EngineConfig::StaticRemote(local_model) => { let card = local_model.card(); let checksum = card.mdcsum(); - let router_mode = local_model.router_config().router_mode; - - let dst_config = DistributedConfig::from_settings(true); // true means static - let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?; let http_service = http_service_builder.build()?; let manager = http_service.model_manager(); @@ -233,8 +227,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul http_service.custom_backend_metrics_polling_interval, http_service.custom_backend_registry.as_ref(), ) { - // Create DistributedRuntime for polling, matching the engine's mode - let drt = DistributedRuntime::from_settings(runtime.clone()).await?; tracing::info!( namespace_component_endpoint=%namespace_component_endpoint, polling_interval_secs=polling_interval, @@ -246,7 +238,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul // shutdown phase. Some( crate::http::service::custom_backend_metrics::spawn_custom_backend_polling_task( - drt, + distributed_runtime.clone(), namespace_component_endpoint.clone(), polling_interval, registry.clone(), @@ -256,14 +248,16 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul None }; - http_service.run(runtime.primary_token()).await?; + http_service + .run(distributed_runtime.primary_token()) + .await?; // Abort the polling task if it was started if let Some(task) = polling_task { task.abort(); } - runtime.shutdown(); // Cancel primary token + distributed_runtime.shutdown(); // Cancel primary token Ok(()) } @@ -273,7 +267,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul async fn run_watcher( runtime: DistributedRuntime, model_manager: Arc, - store: Arc, router_mode: RouterMode, kv_router_config: Option, busy_threshold: Option, @@ -281,16 +274,16 @@ async fn run_watcher( http_service: Arc, metrics: Arc, ) -> anyhow::Result<()> { - let cancellation_token = runtime.primary_token(); let mut watch_obj = ModelWatcher::new( - runtime, + runtime.clone(), model_manager, router_mode, kv_router_config, busy_threshold, ); tracing::debug!("Waiting for remote model"); - let (_, receiver) = store.watch(model_card::ROOT_PATH, None, cancellation_token); + let discovery = runtime.discovery_client(); + let discovery_stream = discovery.list_and_watch(dynamo_runtime::discovery::DiscoveryKey::AllModelCards).await?; // Create a channel to receive model type updates let (tx, mut rx) = tokio::sync::mpsc::channel(32); @@ -304,9 +297,9 @@ async fn run_watcher( } }); - // Pass the sender to the watcher + // Pass the discovery stream to the watcher let _watcher_task = tokio::spawn(async move { - watch_obj.watch(receiver, target_namespace.as_deref()).await; + watch_obj.watch(discovery_stream, target_namespace.as_deref()).await; }); Ok(()) diff --git a/lib/llm/src/entrypoint/input/text.rs b/lib/llm/src/entrypoint/input/text.rs index 9659650a67..aced80e119 100644 --- a/lib/llm/src/entrypoint/input/text.rs +++ b/lib/llm/src/entrypoint/input/text.rs @@ -5,7 +5,8 @@ use crate::request_template::RequestTemplate; use crate::types::openai::chat_completions::{ NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine, }; -use dynamo_runtime::{Runtime, pipeline::Context, runtime::CancellationToken}; +use dynamo_runtime::DistributedRuntime; +use dynamo_runtime::pipeline::Context; use futures::StreamExt; use std::io::{ErrorKind, Write}; @@ -17,15 +18,15 @@ use crate::entrypoint::input::common; const MAX_TOKENS: u32 = 8192; pub async fn run( - runtime: Runtime, + distributed_runtime: DistributedRuntime, single_prompt: Option, engine_config: EngineConfig, ) -> anyhow::Result<()> { - let cancel_token = runtime.primary_token(); - let prepared_engine = common::prepare_engine(runtime, engine_config).await?; + let prepared_engine = + common::prepare_engine(distributed_runtime.clone(), engine_config).await?; // TODO: Pass prepared_engine directly main_loop( - cancel_token, + distributed_runtime, &prepared_engine.service_name, prepared_engine.engine, single_prompt, @@ -36,13 +37,14 @@ pub async fn run( } async fn main_loop( - cancel_token: CancellationToken, + distributed_runtime: DistributedRuntime, service_name: &str, engine: OpenAIChatCompletionsStreamingEngine, mut initial_prompt: Option, _inspect_template: bool, template: Option, ) -> anyhow::Result<()> { + let cancel_token = distributed_runtime.primary_token(); if initial_prompt.is_none() { tracing::info!("Ctrl-c to exit"); } @@ -179,7 +181,11 @@ async fn main_loop( break; } } - cancel_token.cancel(); // stop everything else println!(); + + // Stop the runtime and wait for it to stop + distributed_runtime.shutdown(); + cancel_token.cancelled().await; + Ok(()) } diff --git a/lib/llm/src/http/service/clear_kv_blocks.rs b/lib/llm/src/http/service/clear_kv_blocks.rs index ee1cc3bc3e..b734b60480 100644 --- a/lib/llm/src/http/service/clear_kv_blocks.rs +++ b/lib/llm/src/http/service/clear_kv_blocks.rs @@ -6,7 +6,7 @@ use axum::{http::Method, response::IntoResponse, routing::post, Json, Router}; use serde_json::json; use std::sync::Arc; -use dynamo_runtime::{pipeline::PushRouter, stream::StreamExt}; +use dynamo_runtime::{discovery::DiscoveryKey, pipeline::PushRouter, stream::StreamExt}; pub const CLEAR_KV_ENDPOINT: &str = "clear_kv_blocks"; @@ -150,7 +150,14 @@ async fn clear_kv_blocks_handler( } }; - let instances = match component_obj.list_instances().await { + let discovery_client = distributed.discovery_client(); + let discovery_key = DiscoveryKey::Endpoint { + namespace: namespace.clone(), + component: component.clone(), + endpoint: CLEAR_KV_ENDPOINT.to_string(), + }; + + let discovery_instances = match discovery_client.list(discovery_key).await { Ok(instances) => instances, Err(e) => { add_worker_result( @@ -165,11 +172,11 @@ async fn clear_kv_blocks_handler( } }; - if instances.is_empty() { + if discovery_instances.is_empty() { add_worker_result( false, entry_name, - "No instances found for worker group", + "No instances found for clear_kv_blocks endpoint", namespace, component, None, @@ -177,30 +184,12 @@ async fn clear_kv_blocks_handler( continue; } - let instances_filtered = instances - .clone() + let instances_filtered: Vec = discovery_instances .into_iter() - .filter(|instance| instance.endpoint == CLEAR_KV_ENDPOINT) - .collect::>(); - - if instances_filtered.is_empty() { - let found_endpoints: Vec = instances - .iter() - .map(|instance| instance.endpoint.clone()) - .collect(); - add_worker_result( - false, - entry_name, - &format!( - "Worker group doesn't support clear_kv_blocks. Supported endpoints: {}", - found_endpoints.join(", ") - ), - namespace, - component, - None, - ); - continue; - } + .map(|di| match di { + dynamo_runtime::discovery::DiscoveryInstance::Endpoint(instance) => instance, + }) + .collect(); for instance in &instances_filtered { let instance_name = format!("{}-instance-{}", entry.name, instance.id()); diff --git a/lib/llm/src/http/service/health.rs b/lib/llm/src/http/service/health.rs index 5f007a9bd4..5e4e9deb5f 100644 --- a/lib/llm/src/http/service/health.rs +++ b/lib/llm/src/http/service/health.rs @@ -52,14 +52,13 @@ async fn live_handler( async fn health_handler( axum::extract::State(state): axum::extract::State>, ) -> impl IntoResponse { - let instances = match list_all_instances(state.store()).await { + let instances = match list_all_instances(state.discovery_client()).await { Ok(instances) => instances, Err(err) => { - tracing::warn!(%err, "Failed to fetch instances from store"); + tracing::warn!(%err, "Failed to fetch instances from discovery client"); vec![] } }; - let mut endpoints: Vec = instances .iter() .map(|instance| instance.endpoint_id().as_url()) diff --git a/lib/llm/src/http/service/service_v2.rs b/lib/llm/src/http/service/service_v2.rs index ae18a67bdb..40d5007fbc 100644 --- a/lib/llm/src/http/service/service_v2.rs +++ b/lib/llm/src/http/service/service_v2.rs @@ -18,6 +18,7 @@ use crate::request_template::RequestTemplate; use anyhow::Result; use axum_server::tls_rustls::RustlsConfig; use derive_builder::Builder; +use dynamo_runtime::discovery::{DiscoveryClient, KVStoreDiscoveryClient}; use dynamo_runtime::logging::make_request_span; use dynamo_runtime::metrics::prometheus_names::name_prefix; use dynamo_runtime::storage::key_value_store::KeyValueStoreManager; @@ -31,6 +32,7 @@ pub struct State { metrics: Arc, manager: Arc, store: KeyValueStoreManager, + discovery_client: Arc, flags: StateFlags, } @@ -72,10 +74,21 @@ impl StateFlags { impl State { pub fn new(manager: Arc, store: KeyValueStoreManager) -> Self { + // Initialize discovery client backed by KV store + // Create a cancellation token for the discovery client's watch streams + let discovery_client = { + let cancel_token = CancellationToken::new(); + Arc::new(KVStoreDiscoveryClient::new( + store.clone(), + cancel_token, + )) as Arc + }; + Self { manager, metrics: Arc::new(Metrics::default()), store, + discovery_client, flags: StateFlags { chat_endpoints_enabled: AtomicBool::new(false), cmpl_endpoints_enabled: AtomicBool::new(false), @@ -102,6 +115,10 @@ impl State { &self.store } + pub fn discovery_client(&self) -> Arc { + self.discovery_client.clone() + } + // TODO pub fn sse_keep_alive(&self) -> Option { None diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 445faec2d2..c8399af08b 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -9,13 +9,13 @@ use anyhow::Result; use derive_builder::Builder; use dynamo_runtime::{ component::{Component, InstanceSource}, + discovery::{watch_and_extract_field, DiscoveryKey}, pipeline::{ AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream, SingleIn, async_trait, }, - prelude::*, protocols::annotated::Annotated, - utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction}, + traits::DistributedRuntimeProvider, }; use futures::stream::{self, StreamExt}; use serde::{Deserialize, Serialize}; @@ -47,7 +47,7 @@ use crate::{ subscriber::start_kv_router_background, }, local_model::runtime_config::ModelRuntimeConfig, - model_card::{self, ModelDeploymentCard}, + model_card::ModelDeploymentCard, preprocessor::PreprocessedRequest, protocols::common::llm_backend::LLMEngineOutput, }; @@ -233,22 +233,18 @@ impl KvRouter { } }; - // Create runtime config watcher using the generic etcd watcher - // TODO: Migrate to discovery_client() once it exposes kv_get_and_watch_prefix functionality - let etcd_client = component - .drt() - .etcd_client() - .expect("Cannot KV route without etcd client"); - - let runtime_configs_watcher = watch_prefix_with_extraction( - etcd_client, - &format!("{}/{}", model_card::ROOT_PATH, component.path()), - key_extractors::lease_id, - |card: ModelDeploymentCard| Some(card.runtime_config), - cancellation_token.clone(), - ) - .await?; - let runtime_configs_rx = runtime_configs_watcher.receiver(); + // Watch for runtime config updates via discovery interface + let discovery = component.drt().discovery_client(); + let discovery_key = DiscoveryKey::EndpointModelCards { + namespace: component.namespace().name().to_string(), + component: component.name().to_string(), + endpoint: "generate".to_string(), + }; + let discovery_stream = discovery.list_and_watch(discovery_key).await?; + let runtime_configs_rx = watch_and_extract_field( + discovery_stream, + |card: ModelDeploymentCard| card.runtime_config, + ); let indexer = if kv_router_config.overlap_score_weight == 0.0 { // When overlap_score_weight is zero, we don't need to track prefixes diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index 9a90b49116..7e9addc436 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -162,6 +162,16 @@ impl KvScheduler { let new_instances = instances_monitor_rx.borrow_and_update().clone(); let new_configs = configs_monitor_rx.borrow_and_update().clone(); + // Log config state for comparison + let config_details: Vec<(u64, Option)> = new_configs + .iter() + .map(|(&worker_id, config)| (worker_id, config.total_kv_blocks)) + .collect(); + tracing::warn!( + "DISCOVERY_VALIDATION: scheduler_config_state: configs={:?}", + config_details + ); + // Build the new workers_with_configs map let mut new_workers_with_configs = HashMap::new(); for instance in &new_instances { diff --git a/lib/llm/src/kv_router/subscriber.rs b/lib/llm/src/kv_router/subscriber.rs index 3e7dadc836..1ff27d1374 100644 --- a/lib/llm/src/kv_router/subscriber.rs +++ b/lib/llm/src/kv_router/subscriber.rs @@ -8,6 +8,7 @@ use std::{collections::HashSet, time::Duration}; use anyhow::Result; use dynamo_runtime::{ component::Component, + discovery::DiscoveryKey, prelude::*, traits::events::EventPublisher, transports::{ @@ -15,6 +16,7 @@ use dynamo_runtime::{ nats::{NatsQueue, Slug}, }, }; +use futures::StreamExt; use rand::Rng; use tokio::sync::{mpsc, oneshot}; use tokio_util::sync::CancellationToken; @@ -281,10 +283,13 @@ pub async fn start_kv_router_background( // Get the generate endpoint and watch for instance deletions let generate_endpoint = component.endpoint("generate"); - let (_instance_prefix, mut instance_event_rx) = etcd_client - .kv_get_and_watch_prefix(generate_endpoint.etcd_root()) - .await? - .dissolve(); + let discovery_client = component.drt().discovery_client(); + let discovery_key = DiscoveryKey::Endpoint { + namespace: component.namespace().name().to_string(), + component: component.name().to_string(), + endpoint: "generate".to_string(), + }; + let mut instance_event_stream = discovery_client.list_and_watch(discovery_key).await?; // Get instances_rx for tracking current workers let client = generate_endpoint.client().await?; @@ -337,25 +342,21 @@ pub async fn start_kv_router_background( } // Handle generate endpoint instance deletion events - Some(event) = instance_event_rx.recv() => { - let WatchEvent::Delete(kv) = event else { + Some(discovery_event_result) = instance_event_stream.next() => { + let Ok(discovery_event) = discovery_event_result else { continue; }; - let key = String::from_utf8_lossy(kv.key()); - - let Some(worker_id_str) = key.split(&['/', ':'][..]).next_back() else { - tracing::warn!("Could not extract worker ID from instance key: {key}"); + let dynamo_runtime::discovery::DiscoveryEvent::Removed(worker_id) = discovery_event else { continue; }; - // Parse as hexadecimal (base 16) - let Ok(worker_id) = u64::from_str_radix(worker_id_str, 16) else { - tracing::warn!("Could not parse worker ID from instance key: {key}"); - continue; - }; + tracing::warn!( + worker_id = worker_id, + "DISCOVERY: Generate endpoint instance removed, removing worker" + ); - tracing::info!("Generate endpoint instance deleted, removing worker {worker_id}"); + tracing::warn!("DISCOVERY_VALIDATION: remove_worker_tx: worker_id={}", worker_id); if let Err(e) = remove_worker_tx.send(worker_id).await { tracing::warn!("Failed to send worker removal for worker {worker_id}: {e}"); } diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index de869047c5..a307449397 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -5,14 +5,14 @@ use std::fs; use std::path::{Path, PathBuf}; use dynamo_runtime::component::Endpoint; +use dynamo_runtime::discovery::DiscoverySpec; use dynamo_runtime::protocols::EndpointId; use dynamo_runtime::slug::Slug; -use dynamo_runtime::storage::key_value_store::Key; use dynamo_runtime::traits::DistributedRuntimeProvider; use crate::entrypoint::RouterConfig; use crate::mocker::protocols::MockEngineArgs; -use crate::model_card::{self, ModelDeploymentCard}; +use crate::model_card::ModelDeploymentCard; use crate::model_type::{ModelInput, ModelType}; use crate::request_template::RequestTemplate; @@ -413,13 +413,24 @@ impl LocalModel { self.card.model_type = model_type; self.card.model_input = model_input; - // Publish the Model Deployment Card to KV store - let card_store = endpoint.drt().store(); - let key = Key::from_raw(endpoint.unique_path(card_store.connection_id())); - - let _outcome = card_store - .publish(model_card::ROOT_PATH, None, &key, &mut self.card) - .await?; + // Register the Model Deployment Card via discovery interface + let discovery = endpoint.drt().discovery_client(); + let spec = DiscoverySpec::from_model_card( + endpoint.component().namespace().name().to_string(), + endpoint.component().name().to_string(), + endpoint.name().to_string(), + &self.card, + )?; + let _instance = discovery.register(spec).await?; + + tracing::warn!( + "DISCOVERY_VALIDATION: model_card_registered: namespace={}, component={}, endpoint={}, model_name={}", + endpoint.component().namespace().name(), + endpoint.component().name(), + endpoint.name(), + self.card.name() + ); + Ok(()) } } diff --git a/lib/llm/tests/audit_nats_integration.rs b/lib/llm/tests/audit_nats_integration.rs index f860adaf29..1819e6bcc1 100644 --- a/lib/llm/tests/audit_nats_integration.rs +++ b/lib/llm/tests/audit_nats_integration.rs @@ -167,7 +167,7 @@ mod tests { bus::init(100); let drt = create_test_drt().await; - sink::spawn_workers_from_env(Some(&drt)); + sink::spawn_workers_from_env(&drt); time::sleep(Duration::from_millis(100)).await; // Emit audit record @@ -224,7 +224,7 @@ mod tests { bus::init(100); let drt = create_test_drt().await; - sink::spawn_workers_from_env(Some(&drt)); + sink::spawn_workers_from_env(&drt); time::sleep(Duration::from_millis(100)).await; // Request with store=true (should be audited) diff --git a/lib/llm/tests/http_metrics.rs b/lib/llm/tests/http_metrics.rs index 36a34be2f1..e3bd1bc5b4 100644 --- a/lib/llm/tests/http_metrics.rs +++ b/lib/llm/tests/http_metrics.rs @@ -295,8 +295,10 @@ mod integration_tests { use super::*; use dynamo_llm::{ discovery::ModelWatcher, engines::make_echo_engine, entrypoint::EngineConfig, - local_model::LocalModelBuilder, model_card, + local_model::LocalModelBuilder, }; + use dynamo_runtime::discovery::DiscoveryKey; + use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::DistributedRuntime; use dynamo_runtime::pipeline::RouterMode; use std::sync::Arc; @@ -333,7 +335,7 @@ mod integration_tests { .build() .unwrap(); - // Set up model watcher to discover models from etcd (like production) + // Set up model watcher to discover models via discovery interface (like production) // This is crucial for the polling task to find model entries let model_watcher = ModelWatcher::new( @@ -343,17 +345,16 @@ mod integration_tests { None, None, ); - // Start watching etcd for model registrations - let store = Arc::new(distributed_runtime.store().clone()); - let (_, receiver) = store.watch( - model_card::ROOT_PATH, - None, - distributed_runtime.primary_token(), - ); + // Start watching for model registrations via discovery interface + let discovery = distributed_runtime.discovery_client(); + let discovery_stream = discovery + .list_and_watch(DiscoveryKey::AllModelCards) + .await + .unwrap(); - // Spawn watcher task to discover models from etcd + // Spawn watcher task to discover models let _watcher_task = tokio::spawn(async move { - model_watcher.watch(receiver, None).await; + model_watcher.watch(discovery_stream, None).await; }); // Set up the engine following the StaticFull pattern from http.rs diff --git a/lib/runtime/Cargo.toml b/lib/runtime/Cargo.toml index cd774ba16e..343ee01f34 100644 --- a/lib/runtime/Cargo.toml +++ b/lib/runtime/Cargo.toml @@ -63,6 +63,7 @@ bincode = { version = "1" } console-subscriber = { version = "0.4", optional = true } educe = { version = "0.6.0" } figment = { version = "0.10.19", features = ["env", "json", "toml", "test"] } +inotify = { version = "0.11" } local-ip-address = { version = "0.6.3" } log = { version = "0.4" } nid = { version = "3.0.0", features = ["serde"] } diff --git a/lib/runtime/examples/Cargo.lock b/lib/runtime/examples/Cargo.lock index d7074c7e28..12dc68568c 100644 --- a/lib/runtime/examples/Cargo.lock +++ b/lib/runtime/examples/Cargo.lock @@ -679,6 +679,7 @@ dependencies = [ "figment", "futures", "humantime", + "inotify", "local-ip-address", "log", "nid", @@ -1354,6 +1355,28 @@ version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" +[[package]] +name = "inotify" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f37dccff2791ab604f9babef0ba14fbe0be30bd368dc541e2b08d07c8aa908f3" +dependencies = [ + "bitflags 2.9.0", + "futures-core", + "inotify-sys", + "libc", + "tokio", +] + +[[package]] +name = "inotify-sys" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb" +dependencies = [ + "libc", +] + [[package]] name = "iovec" version = "0.1.4" diff --git a/lib/runtime/src/component.rs b/lib/runtime/src/component.rs index a97193928a..f695b67f0f 100644 --- a/lib/runtime/src/component.rs +++ b/lib/runtime/src/component.rs @@ -75,7 +75,7 @@ pub use client::{Client, InstanceSource}; /// An instance is namespace+component+endpoint+lease_id and must be unique. pub const INSTANCE_ROOT_PATH: &str = "v1/instances"; -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] #[serde(rename_all = "snake_case")] pub enum TransportType { NatsTcp(String), @@ -278,21 +278,24 @@ impl Component { } pub async fn list_instances(&self) -> anyhow::Result> { - let client = self.drt.store(); - let Some(bucket) = client.get_bucket(&self.instance_root()).await? else { - return Ok(vec![]); + let discovery_client = self.drt.discovery_client(); + + let discovery_key = crate::discovery::DiscoveryKey::ComponentEndpoints { + namespace: self.namespace.name(), + component: self.name.clone(), }; - let entries = bucket.entries().await?; - let mut instances = Vec::with_capacity(entries.len()); - for (name, bytes) in entries.into_iter() { - let val = match serde_json::from_slice::(&bytes) { - Ok(val) => val, - Err(err) => { - anyhow::bail!("Error converting storage response to Instance: {err}. {name}",); - } - }; - instances.push(val); - } + + let discovery_instances = discovery_client.list(discovery_key).await?; + + // Extract Instance from DiscoveryInstance::Endpoint wrapper + let mut instances: Vec = discovery_instances + .into_iter() + .filter_map(|di| match di { + crate::discovery::DiscoveryInstance::Endpoint(instance) => Some(instance), + _ => None, // Ignore all other variants (ModelCard, etc.) + }) + .collect(); + instances.sort(); Ok(instances) } diff --git a/lib/runtime/src/component/client.rs b/lib/runtime/src/component/client.rs index 411194de09..bfadaa2e5a 100644 --- a/lib/runtime/src/component/client.rs +++ b/lib/runtime/src/component/client.rs @@ -1,19 +1,20 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::pipeline::{ - AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode, - SingleIn, +use crate::{ + pipeline::{ + AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode, + SingleIn, + }, + storage::key_value_store::{KeyValueStoreManager, WatchEvent}, }; use arc_swap::ArcSwap; +use futures::StreamExt; use std::collections::HashMap; use std::sync::Arc; use tokio::net::unix::pipe::Receiver; -use crate::{ - pipeline::async_trait, - transports::etcd::{Client as EtcdClient, WatchEvent}, -}; +use crate::{pipeline::async_trait, transports::etcd::Client as EtcdClient}; use super::*; @@ -67,23 +68,21 @@ impl Client { // Client with auto-discover instances using etcd pub(crate) async fn new_dynamic(endpoint: Endpoint) -> Result { + tracing::debug!("Client::new_dynamic: Creating dynamic client for endpoint: {}", endpoint.path()); const INSTANCE_REFRESH_PERIOD: Duration = Duration::from_secs(1); - // create live endpoint watcher - let Some(etcd_client) = &endpoint.component.drt.etcd_client else { - anyhow::bail!("Attempt to create a dynamic client on a static endpoint"); - }; - - let instance_source = - Self::get_or_create_dynamic_instance_source(etcd_client, &endpoint).await?; + let instance_source = Self::get_or_create_dynamic_instance_source(&endpoint).await?; + tracing::debug!("Client::new_dynamic: Got instance source for endpoint: {}", endpoint.path()); let client = Client { - endpoint, + endpoint: endpoint.clone(), instance_source: instance_source.clone(), instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))), instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))), }; + tracing::debug!("Client::new_dynamic: Starting instance source monitor for endpoint: {}", endpoint.path()); client.monitor_instance_source(); + tracing::debug!("Client::new_dynamic: Successfully created dynamic client for endpoint: {}", endpoint.path()); Ok(client) } @@ -118,17 +117,47 @@ impl Client { /// Wait for at least one Instance to be available for this Endpoint pub async fn wait_for_instances(&self) -> Result> { + tracing::debug!( + "wait_for_instances: Starting wait for endpoint: {}", + self.endpoint.path() + ); let mut instances: Vec = vec![]; if let InstanceSource::Dynamic(mut rx) = self.instance_source.as_ref().clone() { // wait for there to be 1 or more endpoints + let mut iteration = 0; loop { instances = rx.borrow_and_update().to_vec(); + tracing::debug!( + "wait_for_instances: iteration={}, current_instance_count={}, endpoint={}", + iteration, + instances.len(), + self.endpoint.path() + ); if instances.is_empty() { + tracing::debug!( + "wait_for_instances: No instances yet, waiting for change notification for endpoint: {}", + self.endpoint.path() + ); rx.changed().await?; + tracing::debug!( + "wait_for_instances: Change notification received for endpoint: {}", + self.endpoint.path() + ); } else { + tracing::info!( + "wait_for_instances: Found {} instance(s) for endpoint: {}", + instances.len(), + self.endpoint.path() + ); break; } + iteration += 1; } + } else { + tracing::debug!( + "wait_for_instances: Static instance source, no dynamic discovery for endpoint: {}", + self.endpoint.path() + ); } Ok(instances) } @@ -164,14 +193,17 @@ impl Client { fn monitor_instance_source(&self) { let cancel_token = self.endpoint.drt().primary_token(); let client = self.clone(); + let endpoint_path = self.endpoint.path(); + tracing::debug!("monitor_instance_source: Starting monitor for endpoint: {}", endpoint_path); tokio::task::spawn(async move { let mut rx = match client.instance_source.as_ref() { InstanceSource::Static => { - tracing::error!("Static instance source is not watchable"); + tracing::error!("monitor_instance_source: Static instance source is not watchable"); return; } InstanceSource::Dynamic(rx) => rx.clone(), }; + let mut iteration = 0; while !cancel_token.is_cancelled() { let instance_ids: Vec = rx .borrow_and_update() @@ -179,107 +211,177 @@ impl Client { .map(|instance| instance.id()) .collect(); + tracing::debug!( + "monitor_instance_source: iteration={}, instance_count={}, instance_ids={:?}, endpoint={}", + iteration, + instance_ids.len(), + instance_ids, + endpoint_path + ); + // TODO: this resets both tracked available and free instances client.instance_avail.store(Arc::new(instance_ids.clone())); - client.instance_free.store(Arc::new(instance_ids)); + client.instance_free.store(Arc::new(instance_ids.clone())); + + tracing::warn!( + "DISCOVERY_VALIDATION: endpoint={}, instance_avail={:?}, instance_free={:?}", + endpoint_path, + instance_ids, + instance_ids + ); - tracing::debug!("instance source updated"); + tracing::debug!("monitor_instance_source: instance source updated, endpoint={}", endpoint_path); if let Err(err) = rx.changed().await { - tracing::error!("The Sender is dropped: {}", err); + tracing::error!("monitor_instance_source: The Sender is dropped: {}, endpoint={}", err, endpoint_path); cancel_token.cancel(); } + iteration += 1; } + tracing::debug!("monitor_instance_source: Monitor loop exiting for endpoint: {}", endpoint_path); }); } async fn get_or_create_dynamic_instance_source( - etcd_client: &EtcdClient, endpoint: &Endpoint, ) -> Result> { let drt = endpoint.drt(); let instance_sources = drt.instance_sources(); let mut instance_sources = instance_sources.lock().await; + tracing::debug!( + "get_or_create_dynamic_instance_source: Checking cache for endpoint: {}", + endpoint.path() + ); + if let Some(instance_source) = instance_sources.get(endpoint) { if let Some(instance_source) = instance_source.upgrade() { + tracing::debug!( + "get_or_create_dynamic_instance_source: Found cached instance source for endpoint: {}", + endpoint.path() + ); return Ok(instance_source); } else { + tracing::debug!( + "get_or_create_dynamic_instance_source: Cached instance source was dropped, removing for endpoint: {}", + endpoint.path() + ); instance_sources.remove(endpoint); } } - let prefix_watcher = etcd_client - .kv_get_and_watch_prefix(endpoint.etcd_root()) - .await?; + tracing::debug!( + "get_or_create_dynamic_instance_source: Creating new instance source for endpoint: {}", + endpoint.path() + ); + + let discovery_client = drt.discovery_client(); + let discovery_key = crate::discovery::DiscoveryKey::Endpoint { + namespace: endpoint.component.namespace.name.clone(), + component: endpoint.component.name.clone(), + endpoint: endpoint.name.clone(), + }; - let (prefix, mut kv_event_rx) = prefix_watcher.dissolve(); + tracing::debug!( + "get_or_create_dynamic_instance_source: Calling discovery_client.list_and_watch for key: {:?}", + discovery_key + ); + + let mut discovery_stream = discovery_client.list_and_watch(discovery_key.clone()).await?; + + tracing::debug!( + "get_or_create_dynamic_instance_source: Got discovery stream for key: {:?}", + discovery_key + ); let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]); let secondary = endpoint.component.drt.runtime.secondary().clone(); - // this task should be included in the registry - // currently this is created once per client, but this object/task should only be instantiated - // once per worker/instance secondary.spawn(async move { - tracing::debug!("Starting endpoint watcher for prefix: {}", prefix); - let mut map = HashMap::new(); + tracing::debug!("endpoint_watcher: Starting for discovery key: {:?}", discovery_key); + let mut map: HashMap = HashMap::new(); + let mut event_count = 0; loop { - let kv_event = tokio::select! { + let discovery_event = tokio::select! { _ = watch_tx.closed() => { - tracing::debug!("all watchers have closed; shutting down endpoint watcher for prefix: {prefix}"); + tracing::debug!("endpoint_watcher: all watchers have closed; shutting down for discovery key: {:?}", discovery_key); break; } - kv_event = kv_event_rx.recv() => { - match kv_event { - Some(kv_event) => kv_event, + discovery_event = discovery_stream.next() => { + tracing::debug!("endpoint_watcher: Received stream event for discovery key: {:?}", discovery_key); + match discovery_event { + Some(Ok(event)) => { + tracing::debug!("endpoint_watcher: Got Ok event: {:?}", event); + event + }, + Some(Err(e)) => { + tracing::error!("endpoint_watcher: discovery stream error: {}; shutting down for discovery key: {:?}", e, discovery_key); + break; + } None => { - tracing::debug!("watch stream has closed; shutting down endpoint watcher for prefix: {prefix}"); + tracing::debug!("endpoint_watcher: watch stream has closed; shutting down for discovery key: {:?}", discovery_key); break; } } } }; - match kv_event { - WatchEvent::Put(kv) => { - let key = String::from_utf8(kv.key().to_vec()); - let val = serde_json::from_slice::(kv.value()); - if let (Ok(key), Ok(val)) = (key, val) { - map.insert(key.clone(), val); - } else { - tracing::error!("Unable to parse put endpoint event; shutting down endpoint watcher for prefix: {prefix}"); - break; - } - } - WatchEvent::Delete(kv) => { - match String::from_utf8(kv.key().to_vec()) { - Ok(key) => { map.remove(&key); } - Err(_) => { - tracing::error!("Unable to parse delete endpoint event; shutting down endpoint watcher for prefix: {}", prefix); - break; + event_count += 1; + tracing::debug!("endpoint_watcher: Processing event #{} for discovery key: {:?}", event_count, discovery_key); + + match discovery_event { + crate::discovery::DiscoveryEvent::Added(discovery_instance) => { + match discovery_instance { + crate::discovery::DiscoveryInstance::Endpoint(instance) => { + tracing::info!( + "endpoint_watcher: Added endpoint instance_id={}, namespace={}, component={}, endpoint={}", + instance.instance_id, + instance.namespace, + instance.component, + instance.endpoint + ); + map.insert(instance.instance_id, instance); + } + _ => { + tracing::debug!("endpoint_watcher: Ignoring non-endpoint instance (ModelCard, etc.) for discovery key: {:?}", discovery_key); } } } + crate::discovery::DiscoveryEvent::Removed(instance_id) => { + tracing::info!( + "endpoint_watcher: Removed instance_id={} for discovery key: {:?}", + instance_id, + discovery_key + ); + map.remove(&instance_id); + } } let instances: Vec = map.values().cloned().collect(); + tracing::debug!( + "endpoint_watcher: Current map size={}, sending update for discovery key: {:?}", + instances.len(), + discovery_key + ); if watch_tx.send(instances).is_err() { - tracing::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix); + tracing::debug!("endpoint_watcher: Unable to send watch updates; shutting down for discovery key: {:?}", discovery_key); break; } - } - tracing::debug!("Completed endpoint watcher for prefix: {prefix}"); + tracing::debug!("endpoint_watcher: Completed for discovery key: {:?}, total events processed: {}", discovery_key, event_count); let _ = watch_tx.send(vec![]); }); let instance_source = Arc::new(InstanceSource::Dynamic(watch_rx)); instance_sources.insert(endpoint.clone(), Arc::downgrade(&instance_source)); + tracing::debug!( + "get_or_create_dynamic_instance_source: Successfully created and cached instance source for endpoint: {}", + endpoint.path() + ); Ok(instance_source) } } diff --git a/lib/runtime/src/component/endpoint.rs b/lib/runtime/src/component/endpoint.rs index baeb46683f..90b4d8e6e2 100644 --- a/lib/runtime/src/component/endpoint.rs +++ b/lib/runtime/src/component/endpoint.rs @@ -118,8 +118,6 @@ impl EndpointConfigBuilder { let endpoint_name = endpoint.name.clone(); let system_health = endpoint.drt().system_health.clone(); let subject = endpoint.subject_to(connection_id); - let etcd_path = endpoint.etcd_path_with_lease_id(connection_id); - let etcd_client = endpoint.component.drt.etcd_client.clone(); // Register health check target in SystemHealth if provided if let Some(health_check_payload) = &health_check_payload { @@ -193,24 +191,19 @@ impl EndpointConfigBuilder { result }); - // make the components service endpoint discovery in etcd - - // client.register_service() - let info = Instance { + // Register this endpoint instance in the discovery plane + // The discovery interface abstracts storage backend (etcd, k8s, etc) and provides + // consistent registration/discovery across the system. + let discovery_client = endpoint.drt().discovery_client(); + + let discovery_spec = crate::discovery::DiscoverySpec::Endpoint { + namespace: namespace_name.clone(), component: component_name.clone(), endpoint: endpoint_name.clone(), - namespace: namespace_name.clone(), - instance_id: connection_id, - transport: TransportType::NatsTcp(subject), + transport: TransportType::NatsTcp(subject.clone()), }; - let info = serde_json::to_vec_pretty(&info)?; - - if let Some(etcd_client) = &etcd_client - && let Err(e) = etcd_client - .kv_create(&etcd_path, info, Some(connection_id)) - .await - { + if let Err(e) = discovery_client.register(discovery_spec).await { tracing::error!( component_name, endpoint_name, @@ -222,6 +215,15 @@ impl EndpointConfigBuilder { "Unable to register service for discovery. Check discovery service status" )); } + + tracing::warn!( + "DISCOVERY_VALIDATION: endpoint_registered: namespace={}, component={}, endpoint={}, instance_id={}", + namespace_name, + component_name, + endpoint_name, + connection_id + ); + task.await??; Ok(()) diff --git a/lib/runtime/src/discovery/kv_store.rs b/lib/runtime/src/discovery/kv_store.rs new file mode 100644 index 0000000000..3e3d5f1d07 --- /dev/null +++ b/lib/runtime/src/discovery/kv_store.rs @@ -0,0 +1,471 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::storage::key_value_store::{KeyValueStoreManager, WatchEvent}; +use crate::{CancellationToken, Result}; +use async_trait::async_trait; +use futures::{Stream, StreamExt}; +use std::pin::Pin; +use std::sync::Arc; + +use super::{DiscoveryClient, DiscoveryEvent, DiscoveryInstance, DiscoveryKey, DiscoverySpec, DiscoveryStream}; + +const INSTANCES_BUCKET: &str = "v1/instances"; +const MODEL_CARDS_BUCKET: &str = "v1/mdc"; + +/// Discovery client implementation backed by a KeyValueStore +pub struct KVStoreDiscoveryClient { + store: Arc, + cancel_token: CancellationToken, +} + +impl KVStoreDiscoveryClient { + pub fn new(store: KeyValueStoreManager, cancel_token: CancellationToken) -> Self { + Self { + store: Arc::new(store), + cancel_token, + } + } + + /// Build the key path for an endpoint (relative to bucket, not absolute) + fn endpoint_key(namespace: &str, component: &str, endpoint: &str, instance_id: u64) -> String { + format!("{}/{}/{}/{:x}", namespace, component, endpoint, instance_id) + } + + /// Build the key path for a model card (relative to bucket, not absolute) + fn model_card_key(namespace: &str, component: &str, endpoint: &str, instance_id: u64) -> String { + format!("{}/{}/{}/{:x}", namespace, component, endpoint, instance_id) + } + + /// Extract prefix for querying based on discovery key + fn key_prefix(key: &DiscoveryKey) -> String { + match key { + DiscoveryKey::AllEndpoints => INSTANCES_BUCKET.to_string(), + DiscoveryKey::NamespacedEndpoints { namespace } => { + format!("{}/{}", INSTANCES_BUCKET, namespace) + } + DiscoveryKey::ComponentEndpoints { namespace, component } => { + format!("{}/{}/{}", INSTANCES_BUCKET, namespace, component) + } + DiscoveryKey::Endpoint { namespace, component, endpoint } => { + format!("{}/{}/{}/{}", INSTANCES_BUCKET, namespace, component, endpoint) + } + DiscoveryKey::AllModelCards => MODEL_CARDS_BUCKET.to_string(), + DiscoveryKey::NamespacedModelCards { namespace } => { + format!("{}/{}", MODEL_CARDS_BUCKET, namespace) + } + DiscoveryKey::ComponentModelCards { namespace, component } => { + format!("{}/{}/{}", MODEL_CARDS_BUCKET, namespace, component) + } + DiscoveryKey::EndpointModelCards { namespace, component, endpoint } => { + format!("{}/{}/{}/{}", MODEL_CARDS_BUCKET, namespace, component, endpoint) + } + } + } + + /// Check if a key matches the given discovery key filter + fn matches_prefix(key_str: &str, prefix: &str) -> bool { + key_str.starts_with(prefix) + } + + /// Parse and deserialize a discovery instance from KV store entry + fn parse_instance(value: &[u8]) -> Result { + let instance: DiscoveryInstance = serde_json::from_slice(value)?; + Ok(instance) + } +} + +#[async_trait] +impl DiscoveryClient for KVStoreDiscoveryClient { + fn instance_id(&self) -> u64 { + self.store.connection_id() + } + + async fn register(&self, spec: DiscoverySpec) -> Result { + let instance_id = self.instance_id(); + let instance = spec.with_instance_id(instance_id); + + let (bucket_name, key_path) = match &instance { + DiscoveryInstance::Endpoint(inst) => { + let key = Self::endpoint_key( + &inst.namespace, + &inst.component, + &inst.endpoint, + inst.instance_id, + ); + tracing::debug!( + "KVStoreDiscoveryClient::register: Registering endpoint instance_id={}, namespace={}, component={}, endpoint={}, key={}", + inst.instance_id, + inst.namespace, + inst.component, + inst.endpoint, + key + ); + (INSTANCES_BUCKET, key) + } + DiscoveryInstance::ModelCard { + namespace, + component, + endpoint, + instance_id, + .. + } => { + let key = Self::model_card_key(namespace, component, endpoint, *instance_id); + tracing::debug!( + "KVStoreDiscoveryClient::register: Registering model card instance_id={}, namespace={}, component={}, endpoint={}, key={}", + instance_id, + namespace, + component, + endpoint, + key + ); + (MODEL_CARDS_BUCKET, key) + } + }; + + // Serialize the instance + let instance_json = serde_json::to_vec(&instance)?; + tracing::debug!( + "KVStoreDiscoveryClient::register: Serialized instance to {} bytes for key={}", + instance_json.len(), + key_path + ); + + // Store in the KV store with no TTL (instances persist until explicitly removed) + tracing::debug!( + "KVStoreDiscoveryClient::register: Getting/creating bucket={} for key={}", + bucket_name, + key_path + ); + let bucket = self + .store + .get_or_create_bucket(bucket_name, None) + .await?; + let key = crate::storage::key_value_store::Key::from_raw(key_path.clone()); + + tracing::debug!( + "KVStoreDiscoveryClient::register: Inserting into bucket={}, key={}", + bucket_name, + key_path + ); + // Use revision 0 for initial registration + let outcome = bucket.insert(&key, instance_json.into(), 0).await?; + tracing::info!( + "KVStoreDiscoveryClient::register: Successfully registered instance_id={}, key={}, outcome={:?}", + instance_id, + key_path, + outcome + ); + + Ok(instance) + } + + async fn list(&self, key: DiscoveryKey) -> Result> { + let prefix = Self::key_prefix(&key); + let bucket_name = if prefix.starts_with(INSTANCES_BUCKET) { + INSTANCES_BUCKET + } else { + MODEL_CARDS_BUCKET + }; + + // Get bucket - if it doesn't exist, return empty list + let Some(bucket) = self.store.get_bucket(bucket_name).await? else { + return Ok(Vec::new()); + }; + + // Get all entries from the bucket + let entries = bucket.entries().await?; + + // Filter by prefix and deserialize + let mut instances = Vec::new(); + for (key_str, value) in entries { + if Self::matches_prefix(&key_str, &prefix) { + match Self::parse_instance(&value) { + Ok(instance) => instances.push(instance), + Err(e) => { + tracing::warn!(key = %key_str, error = %e, "Failed to parse discovery instance"); + } + } + } + } + + Ok(instances) + } + + async fn list_and_watch(&self, key: DiscoveryKey) -> Result { + let prefix = Self::key_prefix(&key); + let bucket_name = if prefix.starts_with(INSTANCES_BUCKET) { + INSTANCES_BUCKET + } else { + MODEL_CARDS_BUCKET + }; + + tracing::debug!( + "KVStoreDiscoveryClient::list_and_watch: Starting watch for key={:?}, prefix={}, bucket={}", + key, + prefix, + bucket_name + ); + + // Use the KeyValueStoreManager's watch mechanism + let (_, mut rx) = self.store.clone().watch( + bucket_name, + None, // No TTL + self.cancel_token.clone(), + ); + + tracing::debug!( + "KVStoreDiscoveryClient::list_and_watch: Got watch receiver for bucket={}", + bucket_name + ); + + // Create a stream that filters and transforms WatchEvents to DiscoveryEvents + let stream = async_stream::stream! { + let mut event_count = 0; + tracing::debug!("KVStoreDiscoveryClient::list_and_watch: Stream started, waiting for events on prefix={}", prefix); + while let Some(event) = rx.recv().await { + event_count += 1; + tracing::debug!( + "KVStoreDiscoveryClient::list_and_watch: Received event #{} for prefix={}", + event_count, + prefix + ); + let discovery_event = match event { + WatchEvent::Put(kv) => { + tracing::debug!( + "KVStoreDiscoveryClient::list_and_watch: Put event, key={}, prefix={}, matches={}", + kv.key_str(), + prefix, + Self::matches_prefix(kv.key_str(), &prefix) + ); + // Check if this key matches our prefix + if !Self::matches_prefix(kv.key_str(), &prefix) { + tracing::debug!( + "KVStoreDiscoveryClient::list_and_watch: Skipping key {} (doesn't match prefix {})", + kv.key_str(), + prefix + ); + continue; + } + + match Self::parse_instance(kv.value()) { + Ok(instance) => { + tracing::info!( + "KVStoreDiscoveryClient::list_and_watch: Emitting Added event for instance_id={}, key={}", + instance.instance_id(), + kv.key_str() + ); + Some(DiscoveryEvent::Added(instance)) + }, + Err(e) => { + tracing::warn!( + key = %kv.key_str(), + error = %e, + "Failed to parse discovery instance from watch event" + ); + None + } + } + } + WatchEvent::Delete(kv) => { + let key_str = kv.as_ref(); + tracing::debug!( + "KVStoreDiscoveryClient::list_and_watch: Delete event, key={}, prefix={}", + key_str, + prefix + ); + // Check if this key matches our prefix + if !Self::matches_prefix(key_str, &prefix) { + tracing::debug!( + "KVStoreDiscoveryClient::list_and_watch: Skipping deleted key {} (doesn't match prefix {})", + key_str, + prefix + ); + continue; + } + + // Extract instance_id from the key path, not the value + // Delete events have empty values in etcd, so we parse the instance_id from the key + // Key format: "v1/instances/namespace/component/endpoint/{instance_id:x}" + let key_parts: Vec<&str> = key_str.split('/').collect(); + match key_parts.last() { + Some(instance_id_hex) => { + match u64::from_str_radix(instance_id_hex, 16) { + Ok(instance_id) => { + tracing::info!( + "KVStoreDiscoveryClient::list_and_watch: Emitting Removed event for instance_id={}, key={}", + instance_id, + key_str + ); + Some(DiscoveryEvent::Removed(instance_id)) + } + Err(e) => { + tracing::warn!( + key = %key_str, + error = %e, + "Failed to parse instance_id hex from deleted key" + ); + None + } + } + } + None => { + tracing::warn!( + key = %key_str, + "Delete event key has no path components" + ); + None + } + } + } + }; + + if let Some(event) = discovery_event { + tracing::debug!("KVStoreDiscoveryClient::list_and_watch: Yielding event: {:?}", event); + yield Ok(event); + } else { + tracing::debug!("KVStoreDiscoveryClient::list_and_watch: Event was filtered out (None)"); + } + } + tracing::debug!("KVStoreDiscoveryClient::list_and_watch: Stream ended after {} events for prefix={}", event_count, prefix); + }; + + tracing::debug!( + "KVStoreDiscoveryClient::list_and_watch: Returning stream for key={:?}", + key + ); + Ok(Box::pin(stream)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::component::TransportType; + + #[tokio::test] + async fn test_kv_store_discovery_register_endpoint() { + let store = KeyValueStoreManager::memory(); + let cancel_token = CancellationToken::new(); + let client = KVStoreDiscoveryClient::new(store, cancel_token); + + let spec = DiscoverySpec::Endpoint { + namespace: "test".to_string(), + component: "comp1".to_string(), + endpoint: "ep1".to_string(), + transport: TransportType::NatsTcp("nats://localhost:4222".to_string()), + }; + + let instance = client.register(spec).await.unwrap(); + + match instance { + DiscoveryInstance::Endpoint(inst) => { + assert_eq!(inst.namespace, "test"); + assert_eq!(inst.component, "comp1"); + assert_eq!(inst.endpoint, "ep1"); + } + _ => panic!("Expected Endpoint instance"), + } + } + + #[tokio::test] + async fn test_kv_store_discovery_list() { + let store = KeyValueStoreManager::memory(); + let cancel_token = CancellationToken::new(); + let client = KVStoreDiscoveryClient::new(store, cancel_token); + + // Register multiple endpoints + let spec1 = DiscoverySpec::Endpoint { + namespace: "ns1".to_string(), + component: "comp1".to_string(), + endpoint: "ep1".to_string(), + transport: TransportType::NatsTcp("nats://localhost:4222".to_string()), + }; + client.register(spec1).await.unwrap(); + + let spec2 = DiscoverySpec::Endpoint { + namespace: "ns1".to_string(), + component: "comp1".to_string(), + endpoint: "ep2".to_string(), + transport: TransportType::NatsTcp("nats://localhost:4222".to_string()), + }; + client.register(spec2).await.unwrap(); + + let spec3 = DiscoverySpec::Endpoint { + namespace: "ns2".to_string(), + component: "comp2".to_string(), + endpoint: "ep1".to_string(), + transport: TransportType::NatsTcp("nats://localhost:4222".to_string()), + }; + client.register(spec3).await.unwrap(); + + // List all endpoints + let all = client.list(DiscoveryKey::AllEndpoints).await.unwrap(); + assert_eq!(all.len(), 3); + + // List namespaced endpoints + let ns1 = client + .list(DiscoveryKey::NamespacedEndpoints { + namespace: "ns1".to_string(), + }) + .await + .unwrap(); + assert_eq!(ns1.len(), 2); + + // List component endpoints + let comp1 = client + .list(DiscoveryKey::ComponentEndpoints { + namespace: "ns1".to_string(), + component: "comp1".to_string(), + }) + .await + .unwrap(); + assert_eq!(comp1.len(), 2); + } + + #[tokio::test] + async fn test_kv_store_discovery_watch() { + let store = KeyValueStoreManager::memory(); + let cancel_token = CancellationToken::new(); + let client = Arc::new(KVStoreDiscoveryClient::new(store, cancel_token.clone())); + + // Start watching before registering + let mut stream = client + .list_and_watch(DiscoveryKey::AllEndpoints) + .await + .unwrap(); + + let client_clone = client.clone(); + let register_task = tokio::spawn(async move { + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + + let spec = DiscoverySpec::Endpoint { + namespace: "test".to_string(), + component: "comp1".to_string(), + endpoint: "ep1".to_string(), + transport: TransportType::NatsTcp("nats://localhost:4222".to_string()), + }; + client_clone.register(spec).await.unwrap(); + }); + + // Wait for the added event + let event = stream.next().await.unwrap().unwrap(); + match event { + DiscoveryEvent::Added(instance) => { + match instance { + DiscoveryInstance::Endpoint(inst) => { + assert_eq!(inst.namespace, "test"); + assert_eq!(inst.component, "comp1"); + assert_eq!(inst.endpoint, "ep1"); + } + _ => panic!("Expected Endpoint instance"), + } + } + _ => panic!("Expected Added event"), + } + + register_task.await.unwrap(); + cancel_token.cancel(); + } +} + diff --git a/lib/runtime/src/discovery/mock.rs b/lib/runtime/src/discovery/mock.rs index 5ab66168c5..4c3b0f39f6 100644 --- a/lib/runtime/src/discovery/mock.rs +++ b/lib/runtime/src/discovery/mock.rs @@ -46,37 +46,46 @@ impl MockDiscoveryClient { /// Helper function to check if an instance matches a discovery key query fn matches_key(instance: &DiscoveryInstance, key: &DiscoveryKey) -> bool { match (instance, key) { - (DiscoveryInstance::Endpoint { .. }, DiscoveryKey::AllEndpoints) => true, + // Endpoint matching + (DiscoveryInstance::Endpoint(_), DiscoveryKey::AllEndpoints) => true, ( - DiscoveryInstance::Endpoint { - namespace: ins_ns, .. - }, + DiscoveryInstance::Endpoint(inst), DiscoveryKey::NamespacedEndpoints { namespace }, - ) => ins_ns == namespace, + ) => &inst.namespace == namespace, ( - DiscoveryInstance::Endpoint { - namespace: ins_ns, - component: ins_comp, - .. - }, + DiscoveryInstance::Endpoint(inst), DiscoveryKey::ComponentEndpoints { namespace, component, }, - ) => ins_ns == namespace && ins_comp == component, + ) => &inst.namespace == namespace && &inst.component == component, ( - DiscoveryInstance::Endpoint { - namespace: ins_ns, - component: ins_comp, - endpoint: ins_ep, - .. - }, + DiscoveryInstance::Endpoint(inst), DiscoveryKey::Endpoint { namespace, component, endpoint, }, - ) => ins_ns == namespace && ins_comp == component && ins_ep == endpoint, + ) => &inst.namespace == namespace && &inst.component == component && &inst.endpoint == endpoint, + + // ModelCard matching + (DiscoveryInstance::ModelCard { .. }, DiscoveryKey::AllModelCards) => true, + ( + DiscoveryInstance::ModelCard { namespace: inst_ns, .. }, + DiscoveryKey::NamespacedModelCards { namespace }, + ) => inst_ns == namespace, + ( + DiscoveryInstance::ModelCard { namespace: inst_ns, component: inst_comp, .. }, + DiscoveryKey::ComponentModelCards { namespace, component }, + ) => inst_ns == namespace && inst_comp == component, + ( + DiscoveryInstance::ModelCard { namespace: inst_ns, component: inst_comp, endpoint: inst_ep, .. }, + DiscoveryKey::EndpointModelCards { namespace, component, endpoint }, + ) => inst_ns == namespace && inst_comp == component && inst_ep == endpoint, + + // Cross-type matches return false + (DiscoveryInstance::Endpoint(_), DiscoveryKey::AllModelCards | DiscoveryKey::NamespacedModelCards { .. } | DiscoveryKey::ComponentModelCards { .. } | DiscoveryKey::EndpointModelCards { .. }) => false, + (DiscoveryInstance::ModelCard { .. }, DiscoveryKey::AllEndpoints | DiscoveryKey::NamespacedEndpoints { .. } | DiscoveryKey::ComponentEndpoints { .. } | DiscoveryKey::Endpoint { .. }) => false, } } @@ -98,6 +107,15 @@ impl DiscoveryClient for MockDiscoveryClient { Ok(instance) } + async fn list(&self, key: DiscoveryKey) -> Result> { + let instances = self.registry.instances.lock().unwrap(); + Ok(instances + .iter() + .filter(|instance| matches_key(instance, &key)) + .cloned() + .collect()) + } + async fn list_and_watch(&self, key: DiscoveryKey) -> Result { use std::collections::HashSet; @@ -118,14 +136,16 @@ impl DiscoveryClient for MockDiscoveryClient { let current_ids: HashSet<_> = current.iter().map(|i| { match i { - DiscoveryInstance::Endpoint { instance_id, .. } => *instance_id, + DiscoveryInstance::Endpoint(inst) => inst.instance_id, + DiscoveryInstance::ModelCard { instance_id, .. } => *instance_id, } }).collect(); // Emit Added events for new instances for instance in current { let id = match &instance { - DiscoveryInstance::Endpoint { instance_id, .. } => *instance_id, + DiscoveryInstance::Endpoint(inst) => inst.instance_id, + DiscoveryInstance::ModelCard { instance_id, .. } => *instance_id, }; if known_instances.insert(id) { yield Ok(DiscoveryEvent::Added(instance)); @@ -161,6 +181,7 @@ mod tests { namespace: "test-ns".to_string(), component: "test-comp".to_string(), endpoint: "test-ep".to_string(), + transport: crate::component::TransportType::NatsTcp("test-subject".to_string()), }; let key = DiscoveryKey::Endpoint { @@ -177,8 +198,8 @@ mod tests { let event = stream.next().await.unwrap().unwrap(); match event { - DiscoveryEvent::Added(DiscoveryInstance::Endpoint { instance_id, .. }) => { - assert_eq!(instance_id, 1); + DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => { + assert_eq!(inst.instance_id, 1); } _ => panic!("Expected Added event for instance-1"), } @@ -188,15 +209,16 @@ mod tests { let event = stream.next().await.unwrap().unwrap(); match event { - DiscoveryEvent::Added(DiscoveryInstance::Endpoint { instance_id, .. }) => { - assert_eq!(instance_id, 2); + DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => { + assert_eq!(inst.instance_id, 2); } _ => panic!("Expected Added event for instance-2"), } // Remove first instance registry.instances.lock().unwrap().retain(|i| match i { - DiscoveryInstance::Endpoint { instance_id, .. } => *instance_id != 1, + DiscoveryInstance::Endpoint(inst) => inst.instance_id != 1, + DiscoveryInstance::ModelCard { instance_id, .. } => *instance_id != 1, }); let event = stream.next().await.unwrap().unwrap(); diff --git a/lib/runtime/src/discovery/mod.rs b/lib/runtime/src/discovery/mod.rs index 090fff281a..2fc8056093 100644 --- a/lib/runtime/src/discovery/mod.rs +++ b/lib/runtime/src/discovery/mod.rs @@ -1,6 +1,7 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +use crate::component::TransportType; use crate::Result; use async_trait::async_trait; use futures::Stream; @@ -10,6 +11,12 @@ use std::pin::Pin; mod mock; pub use mock::{MockDiscoveryClient, SharedMockRegistry}; +mod kv_store; +pub use kv_store::KVStoreDiscoveryClient; + +pub mod utils; +pub use utils::watch_and_extract_field; + /// Query key for prefix-based discovery queries /// Supports hierarchical queries from all endpoints down to specific endpoints #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -29,28 +36,63 @@ pub enum DiscoveryKey { component: String, endpoint: String, }, - // TODO: Extend to support ModelCard queries: - // - AllModels - // - NamespacedModels { namespace } - // - ComponentModels { namespace, component } - // - Model { namespace, component, model_name } + AllModelCards, + NamespacedModelCards { namespace: String }, + ComponentModelCards { + namespace: String, + component: String, + }, + EndpointModelCards { + namespace: String, + component: String, + endpoint: String, + }, } /// Specification for registering objects in the discovery plane /// Represents the input to the register() operation -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum DiscoverySpec { /// Endpoint specification for registration Endpoint { namespace: String, component: String, endpoint: String, + /// Transport type and routing information + transport: TransportType, + }, + ModelCard { + namespace: String, + component: String, + endpoint: String, + /// ModelDeploymentCard serialized as JSON + /// This allows lib/runtime to remain independent of lib/llm types + /// DiscoverySpec.from_model_card() and DiscoveryInstance.deserialize_model_card() are ergonomic helpers to create and deserialize the model card. + card_json: serde_json::Value, }, - // TODO: Add ModelCard variant: - // - ModelCard { namespace, component, model_name, card: ModelDeploymentCard } } impl DiscoverySpec { + /// Creates a ModelCard discovery spec from a serializable type + /// The card will be serialized to JSON to avoid cross-crate dependencies + pub fn from_model_card( + namespace: String, + component: String, + endpoint: String, + card: &T, + ) -> crate::Result + where + T: Serialize, + { + let card_json = serde_json::to_value(card)?; + Ok(Self::ModelCard { + namespace, + component, + endpoint, + card_json, + }) + } + /// Attaches an instance ID to create a DiscoveryInstance pub fn with_instance_id(self, instance_id: u64) -> DiscoveryInstance { match self { @@ -58,11 +100,25 @@ impl DiscoverySpec { namespace, component, endpoint, - } => DiscoveryInstance::Endpoint { + transport, + } => DiscoveryInstance::Endpoint(crate::component::Instance { namespace, component, endpoint, instance_id, + transport, + }), + Self::ModelCard { + namespace, + component, + endpoint, + card_json, + } => DiscoveryInstance::ModelCard { + namespace, + component, + endpoint, + instance_id, + card_json, }, } } @@ -70,18 +126,44 @@ impl DiscoverySpec { /// Registered instances in the discovery plane /// Represents objects that have been successfully registered with an instance ID -#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[serde(tag = "type")] pub enum DiscoveryInstance { - /// Registered endpoint instance - Endpoint { + /// Registered endpoint instance - wraps the component::Instance directly + Endpoint(crate::component::Instance), + ModelCard { namespace: String, component: String, endpoint: String, instance_id: u64, + /// ModelDeploymentCard serialized as JSON + /// This allows lib/runtime to remain independent of lib/llm types + card_json: serde_json::Value, }, - // TODO: Add ModelCard variant: - // - ModelCard { namespace, component, model_name, instance_id, card: ModelDeploymentCard } +} + +impl DiscoveryInstance { + /// Returns the instance ID for this discovery instance + pub fn instance_id(&self) -> u64 { + match self { + Self::Endpoint(inst) => inst.instance_id, + Self::ModelCard { instance_id, .. } => *instance_id, + } + } + + /// Deserializes the model card JSON into the specified type T + /// Returns an error if this is not a ModelCard instance or if deserialization fails + pub fn deserialize_model_card(&self) -> crate::Result + where + T: for<'de> Deserialize<'de>, + { + match self { + Self::ModelCard { card_json, .. } => Ok(serde_json::from_value(card_json.clone())?), + Self::Endpoint(_) => crate::raise!( + "Cannot deserialize model card from Endpoint instance" + ), + } + } } /// Events emitted by the discovery client watch stream @@ -106,6 +188,11 @@ pub trait DiscoveryClient: Send + Sync { /// Registers an object in the discovery plane with the instance id async fn register(&self, spec: DiscoverySpec) -> Result; + /// Returns a list of currently registered instances for the given discovery key + /// This is a one-time snapshot without watching for changes + async fn list(&self, key: DiscoveryKey) -> Result>; + /// Returns a stream of discovery events (Added/Removed) for the given discovery key async fn list_and_watch(&self, key: DiscoveryKey) -> Result; } + diff --git a/lib/runtime/src/discovery/utils.rs b/lib/runtime/src/discovery/utils.rs new file mode 100644 index 0000000000..abcd42cf4c --- /dev/null +++ b/lib/runtime/src/discovery/utils.rs @@ -0,0 +1,107 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Utility functions for working with discovery streams + +use serde::Deserialize; + +use super::{DiscoveryEvent, DiscoveryInstance, DiscoveryStream}; + +/// Helper to watch a discovery stream and extract a specific field into a HashMap +/// +/// This helper spawns a background task that: +/// - Deserializes ModelCards from discovery events +/// - Extracts a specific field using the provided extractor function +/// - Maintains a HashMap that auto-updates on Add/Remove events +/// - Returns a watch::Receiver that consumers can use to read the current state +/// +/// # Type Parameters +/// - `T`: The type to deserialize from DiscoveryInstance (e.g., ModelDeploymentCard) +/// - `V`: The extracted field type (e.g., ModelRuntimeConfig) +/// - `F`: The extractor function type +/// +/// # Arguments +/// - `stream`: The discovery event stream to watch +/// - `extractor`: Function that extracts the desired field from the deserialized type +/// +/// # Example +/// ```ignore +/// let stream = discovery.list_and_watch(DiscoveryKey::ComponentModelCards { ... }).await?; +/// let runtime_configs_rx = watch_and_extract_field( +/// stream, +/// |card: ModelDeploymentCard| card.runtime_config, +/// ); +/// +/// // Use it: +/// let configs = runtime_configs_rx.borrow(); +/// if let Some(config) = configs.get(&worker_id) { +/// // Use config... +/// } +/// ``` +pub fn watch_and_extract_field( + stream: DiscoveryStream, + extractor: F, +) -> tokio::sync::watch::Receiver> +where + T: for<'de> Deserialize<'de> + 'static, + V: Clone + Send + Sync + 'static, + F: Fn(T) -> V + Send + 'static, +{ + use futures::StreamExt; + use std::collections::HashMap; + + let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); + + tokio::spawn(async move { + let mut state: HashMap = HashMap::new(); + let mut stream = stream; + + while let Some(result) = stream.next().await { + match result { + Ok(DiscoveryEvent::Added(instance)) => { + let instance_id = instance.instance_id(); + + // Deserialize the full instance into type T + let deserialized: T = match instance.deserialize_model_card() { + Ok(d) => d, + Err(e) => { + tracing::warn!( + instance_id, + error = %e, + "Failed to deserialize discovery instance, skipping" + ); + continue; + } + }; + + // Extract the field we care about + let value = extractor(deserialized); + + // Update state and send + state.insert(instance_id, value); + if tx.send(state.clone()).is_err() { + tracing::debug!("watch_and_extract_field receiver dropped, stopping"); + break; + } + } + Ok(DiscoveryEvent::Removed(instance_id)) => { + // Remove from state and send update + state.remove(&instance_id); + if tx.send(state.clone()).is_err() { + tracing::debug!("watch_and_extract_field receiver dropped, stopping"); + break; + } + } + Err(e) => { + tracing::error!(error = %e, "Discovery event stream error in watch_and_extract_field"); + // Continue processing other events + } + } + } + + tracing::debug!("watch_and_extract_field task stopped"); + }); + + rx +} + diff --git a/lib/runtime/src/distributed.rs b/lib/runtime/src/distributed.rs index fd2846a0b9..c95b7e3474 100644 --- a/lib/runtime/src/distributed.rs +++ b/lib/runtime/src/distributed.rs @@ -3,7 +3,8 @@ pub use crate::component::Component; use crate::storage::key_value_store::{ - EtcdStore, KeyValueStore, KeyValueStoreEnum, KeyValueStoreManager, MemoryStore, + EtcdStore, KeyValueStore, KeyValueStoreEnum, KeyValueStoreManager, KeyValueStoreSelect, + MemoryStore, }; use crate::transports::nats::DRTNatsClientPrometheusMetrics; use crate::{ @@ -48,24 +49,21 @@ impl std::fmt::Debug for DistributedRuntime { impl DistributedRuntime { pub async fn new(runtime: Runtime, config: DistributedConfig) -> Result { - let (etcd_config, nats_config, is_static) = config.dissolve(); + let (selected_kv_store, nats_config, is_static) = config.dissolve(); let runtime_clone = runtime.clone(); - // TODO: Here is where we will later select the KeyValueStore impl - let (etcd_client, store) = if is_static { - (None, KeyValueStoreManager::memory()) - } else { - match etcd::Client::new(etcd_config.clone(), runtime_clone).await { - Ok(etcd_client) => { - let store = KeyValueStoreManager::etcd(etcd_client.clone()); - (Some(etcd_client), store) - } - Err(err) => { - tracing::info!(%err, "Did not connect to etcd. Using memory storage."); - (None, KeyValueStoreManager::memory()) - } + let (etcd_client, store) = match selected_kv_store { + KeyValueStoreSelect::Etcd(etcd_config) => { + let etcd_client = etcd::Client::new(*etcd_config, runtime_clone).await.inspect_err(|err| + // The returned error doesn't show because of a dropped runtime error, so + // log it first. + tracing::error!(%err, "Could not connect to etcd. Pass `--store-kv ..` to use a different backend or start etcd."))?; + let store = KeyValueStoreManager::etcd(etcd_client.clone()); + (Some(etcd_client), store) } + KeyValueStoreSelect::File(root) => (None, KeyValueStoreManager::file(root)), + KeyValueStoreSelect::Memory => (None, KeyValueStoreManager::memory()), }; let nats_client = Some(nats_config.clone().connect().await?); @@ -92,12 +90,13 @@ impl DistributedRuntime { let nats_client_for_metrics = nats_client.clone(); - // Initialize discovery client with mock implementation - // TODO: Replace MockDiscoveryClient with KeyValueStoreDiscoveryClient or KubeDiscoveryClient + // Initialize discovery client backed by KV store let discovery_client = { - use crate::discovery::{MockDiscoveryClient, SharedMockRegistry}; - let registry = SharedMockRegistry::new(); - Arc::new(MockDiscoveryClient::new(None, registry)) as Arc + use crate::discovery::KVStoreDiscoveryClient; + Arc::new(KVStoreDiscoveryClient::new( + store.clone(), + runtime.primary_token(), + )) as Arc }; let distributed_runtime = Self { @@ -234,6 +233,7 @@ impl DistributedRuntime { pub fn shutdown(&self) { self.runtime.shutdown(); + self.store.shutdown(); } /// Create a [`Namespace`] @@ -241,9 +241,10 @@ impl DistributedRuntime { Namespace::new(self.clone(), name.into(), self.is_static) } - /// TODO: Return discovery client when KeyValueDiscoveryClient or KubeDiscoveryClient is implemented - pub fn discovery_client(&self) -> Result> { - Err(error!("Discovery client not implemented!")) + /// Returns the discovery client for service registration and discovery + /// Currently uses MockDiscoveryClient, will be replaced with KeyValueDiscoveryClient or KubeDiscoveryClient + pub fn discovery_client(&self) -> Arc { + self.discovery_client.clone() } pub(crate) fn service_client(&self) -> Option { @@ -302,7 +303,7 @@ impl DistributedRuntime { #[derive(Dissolve)] pub struct DistributedConfig { - pub etcd_config: etcd::ClientOptions, + pub store_backend: KeyValueStoreSelect, pub nats_config: nats::ClientOptions, pub is_static: bool, } @@ -310,22 +311,22 @@ pub struct DistributedConfig { impl DistributedConfig { pub fn from_settings(is_static: bool) -> DistributedConfig { DistributedConfig { - etcd_config: etcd::ClientOptions::default(), + store_backend: KeyValueStoreSelect::Etcd(Box::default()), nats_config: nats::ClientOptions::default(), is_static, } } pub fn for_cli() -> DistributedConfig { - let mut config = DistributedConfig { - etcd_config: etcd::ClientOptions::default(), + let etcd_config = etcd::ClientOptions { + attach_lease: false, + ..Default::default() + }; + DistributedConfig { + store_backend: KeyValueStoreSelect::Etcd(Box::new(etcd_config)), nats_config: nats::ClientOptions::default(), is_static: false, - }; - - config.etcd_config.attach_lease = false; - - config + } } } diff --git a/lib/runtime/src/instances.rs b/lib/runtime/src/instances.rs index 8f9ab0f676..7f875c7669 100644 --- a/lib/runtime/src/instances.rs +++ b/lib/runtime/src/instances.rs @@ -9,26 +9,34 @@ use std::sync::Arc; -use crate::component::{INSTANCE_ROOT_PATH, Instance}; -use crate::storage::key_value_store::{KeyValueStore, KeyValueStoreManager}; -use crate::transports::etcd::Client as EtcdClient; +use crate::component::Instance; +use crate::discovery::{DiscoveryClient, DiscoveryKey}; -pub async fn list_all_instances(client: &KeyValueStoreManager) -> anyhow::Result> { - let Some(bucket) = client.get_bucket(INSTANCE_ROOT_PATH).await? else { - return Ok(vec![]); - }; +pub async fn list_all_instances( + discovery_client: Arc, +) -> anyhow::Result> { + let discovery_instances = discovery_client.list(DiscoveryKey::AllEndpoints).await?; - let entries = bucket.entries().await?; - let mut instances = Vec::with_capacity(entries.len()); - for (name, bytes) in entries.into_iter() { - match serde_json::from_slice::(&bytes) { - Ok(instance) => instances.push(instance), - Err(err) => { - tracing::warn!(%err, key = name, "Failed to parse instance from storage"); - } - } - } - instances.sort(); + let mut instances: Vec = discovery_instances + .into_iter() + .filter_map(|di| match di { + crate::discovery::DiscoveryInstance::Endpoint(instance) => Some(instance), + _ => None, // Ignore all other variants (ModelCard, etc.) + }) + .collect(); + instances.sort(); + + // Log all instances found for comparison + let instance_details: Vec<(u64, &str, &str, &str)> = instances + .iter() + .map(|inst| (inst.instance_id, inst.namespace.as_str(), inst.component.as_str(), inst.endpoint.as_str())) + .collect(); + tracing::warn!( + "DISCOVERY_VALIDATION: all_instances_found: count={}, instances={:?}", + instances.len(), + instance_details + ); + Ok(instances) } diff --git a/lib/runtime/src/storage/key_value_store.rs b/lib/runtime/src/storage/key_value_store.rs index 7fc122ec40..24dbb4eb39 100644 --- a/lib/runtime/src/storage/key_value_store.rs +++ b/lib/runtime/src/storage/key_value_store.rs @@ -4,14 +4,16 @@ //! Interface to a traditional key-value store such as etcd. //! "key_value_store" spelt out because in AI land "KV" means something else. -use std::collections::HashMap; -use std::fmt; use std::pin::Pin; +use std::str::FromStr; use std::sync::Arc; use std::time::Duration; +use std::{collections::HashMap, path::PathBuf}; +use std::{env, fmt}; use crate::CancellationToken; use crate::slug::Slug; +use crate::transports::etcd as etcd_transport; use async_trait::async_trait; use futures::StreamExt; use serde::{Deserialize, Serialize}; @@ -22,10 +24,15 @@ mod nats; pub use nats::NATSStore; mod etcd; pub use etcd::EtcdStore; +mod file; +pub use file::FileStore; const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(100); /// A key that is safe to use directly in the KV store. +/// +/// TODO: Need to re-think this. etcd uses slash separators, so we often use from_raw +/// to avoid the slug. But other impl's, particularly file, need a real slug. #[derive(Debug, Clone, PartialEq)] pub struct Key(String); @@ -95,7 +102,7 @@ impl KeyValue { #[derive(Debug, Clone, PartialEq)] pub enum WatchEvent { Put(KeyValue), - Delete(KeyValue), + Delete(Key), } #[async_trait] @@ -112,6 +119,57 @@ pub trait KeyValueStore: Send + Sync { async fn get_bucket(&self, bucket_name: &str) -> Result, StoreError>; fn connection_id(&self) -> u64; + + fn shutdown(&self); +} + +#[derive(Clone, Debug, Default)] +pub enum KeyValueStoreSelect { + // Box it because it is significantly bigger than the other variants + Etcd(Box), + File(PathBuf), + #[default] + Memory, + // Nats not listed because likely we want to remove that impl. It is not currently used and not well tested. +} + +impl fmt::Display for KeyValueStoreSelect { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + KeyValueStoreSelect::Etcd(opts) => { + let urls = opts.etcd_url.join(","); + write!(f, "Etcd({urls})") + } + KeyValueStoreSelect::File(path) => write!(f, "File({})", path.display()), + KeyValueStoreSelect::Memory => write!(f, "Memory"), + } + } +} + +impl FromStr for KeyValueStoreSelect { + type Err = anyhow::Error; + + fn from_str(s: &str) -> anyhow::Result { + match s { + "etcd" => Ok(Self::Etcd(Box::default())), + "file" => { + let root = env::var("DYN_FILE_KV") + .map(PathBuf::from) + .unwrap_or_else(|_| env::temp_dir().join("dynamo_store_kv")); + Ok(Self::File(root)) + } + "mem" => Ok(Self::Memory), + x => anyhow::bail!("Unknown key-value store type '{x}'"), + } + } +} + +impl TryFrom for KeyValueStoreSelect { + type Error = anyhow::Error; + + fn try_from(s: String) -> anyhow::Result { + s.parse() + } } #[allow(clippy::large_enum_variant)] @@ -119,6 +177,7 @@ pub enum KeyValueStoreEnum { Memory(MemoryStore), Nats(NATSStore), Etcd(EtcdStore), + File(FileStore), } impl KeyValueStoreEnum { @@ -133,6 +192,7 @@ impl KeyValueStoreEnum { Memory(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?), Nats(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?), Etcd(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?), + File(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?), }) } @@ -154,6 +214,10 @@ impl KeyValueStoreEnum { .get_bucket(bucket_name) .await? .map(|b| Box::new(b) as Box), + File(x) => x + .get_bucket(bucket_name) + .await? + .map(|b| Box::new(b) as Box), }; Ok(maybe_bucket) } @@ -164,12 +228,23 @@ impl KeyValueStoreEnum { Memory(x) => x.connection_id(), Etcd(x) => x.connection_id(), Nats(x) => x.connection_id(), + File(x) => x.connection_id(), + } + } + + fn shutdown(&self) { + use KeyValueStoreEnum::*; + match self { + Memory(x) => x.shutdown(), + Etcd(x) => x.shutdown(), + Nats(x) => x.shutdown(), + File(x) => x.shutdown(), } } } #[derive(Clone)] -pub struct KeyValueStoreManager(Arc); +pub struct KeyValueStoreManager(pub Arc); impl Default for KeyValueStoreManager { fn default() -> Self { @@ -187,6 +262,10 @@ impl KeyValueStoreManager { Self::new(KeyValueStoreEnum::Etcd(EtcdStore::new(etcd_client))) } + pub fn file>(root: P) -> Self { + Self::new(KeyValueStoreEnum::File(FileStore::new(root))) + } + fn new(s: KeyValueStoreEnum) -> KeyValueStoreManager { KeyValueStoreManager(Arc::new(s)) } @@ -302,6 +381,12 @@ impl KeyValueStoreManager { } Ok(outcome) } + + /// Cleanup any temporary state. + /// TODO: Should this be async? Take &mut self? + pub fn shutdown(&self) { + self.0.shutdown() + } } /// An online storage for key-value config values. @@ -366,6 +451,9 @@ pub enum StoreError { #[error("Internal etcd error: {0}")] EtcdError(String), + #[error("Internal filesystem error: {0}")] + FilesystemError(String), + #[error("Key Value Error: {0} for bucket '{1}'")] KeyValueError(String, String), diff --git a/lib/runtime/src/storage/key_value_store/etcd.rs b/lib/runtime/src/storage/key_value_store/etcd.rs index bd3934af96..5e8c6bf5db 100644 --- a/lib/runtime/src/storage/key_value_store/etcd.rs +++ b/lib/runtime/src/storage/key_value_store/etcd.rs @@ -54,6 +54,10 @@ impl KeyValueStore for EtcdStore { fn connection_id(&self) -> u64 { self.client.lease_id() } + + fn shutdown(&self) { + // Revoke the lease? etcd will do it for us on disconnect. + } } pub struct EtcdBucket { @@ -132,13 +136,13 @@ impl KeyValueBucket for EtcdBucket { continue; } }; - let item = KeyValue::new(key, v_bytes.into()); match e.event_type() { EventType::Put => { + let item = KeyValue::new(key, v_bytes.into()); yield WatchEvent::Put(item); } EventType::Delete => { - yield WatchEvent::Delete(item); + yield WatchEvent::Delete(Key::from_raw(key)); } } } diff --git a/lib/runtime/src/storage/key_value_store/file.rs b/lib/runtime/src/storage/key_value_store/file.rs new file mode 100644 index 0000000000..392d403a2a --- /dev/null +++ b/lib/runtime/src/storage/key_value_store/file.rs @@ -0,0 +1,307 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::HashSet; +use std::ffi::OsString; +use std::fmt; +use std::fs; +use std::os::unix::ffi::OsStrExt; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::Duration; +use std::{collections::HashMap, pin::Pin}; + +use anyhow::Context as _; +use async_trait::async_trait; +use futures::StreamExt; +use inotify::{Event, EventMask, EventStream, Inotify, WatchMask}; +use parking_lot::Mutex; + +use crate::storage::key_value_store::KeyValue; + +use super::{Key, KeyValueBucket, KeyValueStore, StoreError, StoreOutcome, WatchEvent}; + +/// Treat as a singleton +#[derive(Clone)] +pub struct FileStore { + root: PathBuf, + connection_id: u64, + /// Directories we may have created files in, for shutdown cleanup + /// Arc so that we only ever have one map here after clone + active_dirs: Arc>>, +} + +impl FileStore { + pub(super) fn new>(root_dir: P) -> Self { + FileStore { + root: root_dir.into(), + connection_id: rand::random::(), + active_dirs: Arc::new(Mutex::new(HashMap::new())), + } + } +} + +#[async_trait] +impl KeyValueStore for FileStore { + type Bucket = Directory; + + /// A "bucket" is a directory + async fn get_or_create_bucket( + &self, + bucket_name: &str, + _ttl: Option, // TODO ttl not used yet + ) -> Result { + let p = self.root.join(bucket_name); + if let Some(dir) = self.active_dirs.lock().get(&p) { + return Ok(dir.clone()); + }; + + if p.exists() { + // Get + if !p.is_dir() { + return Err(StoreError::FilesystemError( + "Bucket name is not a directory".to_string(), + )); + } + } else { + // Create + fs::create_dir_all(&p).map_err(to_fs_err)?; + } + let dir = Directory::new(self.root.clone(), p.clone()); + self.active_dirs.lock().insert(p, dir.clone()); + Ok(dir) + } + + /// A "bucket" is a directory + async fn get_bucket(&self, bucket_name: &str) -> Result, StoreError> { + let p = self.root.join(bucket_name); + if let Some(dir) = self.active_dirs.lock().get(&p) { + return Ok(Some(dir.clone())); + }; + + if !p.exists() { + return Ok(None); + } + if !p.is_dir() { + return Err(StoreError::FilesystemError( + "Bucket name is not a directory".to_string(), + )); + } + let dir = Directory::new(self.root.clone(), p.clone()); + self.active_dirs.lock().insert(p, dir.clone()); + Ok(Some(dir)) + } + + fn connection_id(&self) -> u64 { + self.connection_id + } + + // This cannot be a Drop imp because DistributedRuntime is cloned various places including + // Python. Drop doesn't get called. + fn shutdown(&self) { + for (_, mut dir) in self.active_dirs.lock().drain() { + if let Err(err) = dir.delete_owned_files() { + tracing::error!(error = %err, %dir, "Failed shutdown delete of owned files"); + } + } + } +} + +#[derive(Clone)] +pub struct Directory { + root: PathBuf, + p: PathBuf, + /// These are the files we created and hence must delete on shutdown + owned_files: Arc>>, +} + +impl Directory { + fn new(root: PathBuf, p: PathBuf) -> Self { + Directory { + root, + p, + owned_files: Arc::new(Mutex::new(HashSet::new())), + } + } + + fn delete_owned_files(&mut self) -> anyhow::Result<()> { + let mut errs = Vec::new(); + for p in self.owned_files.lock().drain() { + if let Err(err) = fs::remove_file(&p) { + errs.push(format!("{}: {err}", p.display())); + } + } + if !errs.is_empty() { + anyhow::bail!(errs.join(", ")); + } + Ok(()) + } +} + +impl fmt::Display for Directory { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.p.display()) + } +} + +#[async_trait] +impl KeyValueBucket for Directory { + /// Write a file to the directory + async fn insert( + &self, + key: &Key, + value: bytes::Bytes, + _revision: u64, // Not used. Maybe put in file name? + ) -> Result { + let safe_key = Key::new(key.as_ref()); // because of from_raw + let full_path = self.p.join(safe_key.as_ref()); + self.owned_files.lock().insert(full_path.clone()); + let str_path = full_path.display().to_string(); + fs::write(&full_path, &value) + .context(str_path) + .map_err(a_to_fs_err)?; + Ok(StoreOutcome::Created(0)) + } + + /// Read a file from the directory + async fn get(&self, key: &Key) -> Result, StoreError> { + let safe_key = Key::new(key.as_ref()); // because of from_raw + let full_path = self.p.join(safe_key.as_ref()); + if !full_path.exists() { + return Ok(None); + } + let str_path = full_path.display().to_string(); + let data: bytes::Bytes = fs::read(&full_path) + .context(str_path) + .map_err(a_to_fs_err)? + .into(); + Ok(Some(data)) + } + + /// Delete a file from the directory + async fn delete(&self, key: &Key) -> Result<(), StoreError> { + let safe_key = Key::new(key.as_ref()); // because of from_raw + let full_path = self.p.join(safe_key.as_ref()); + let str_path = full_path.display().to_string(); + if !full_path.exists() { + return Err(StoreError::MissingKey(str_path)); + } + + self.owned_files.lock().remove(&full_path); + + fs::remove_file(&full_path) + .context(str_path) + .map_err(a_to_fs_err) + } + + async fn watch( + &self, + ) -> Result + Send + 'life0>>, StoreError> { + let inotify = Inotify::init().map_err(to_fs_err)?; + inotify + .watches() + .add( + &self.p, + WatchMask::MODIFY | WatchMask::CREATE | WatchMask::DELETE, + ) + .map_err(to_fs_err)?; + + let dir = self.p.clone(); + Ok(Box::pin(async_stream::stream! { + let mut buffer = [0; 1024]; + let mut events = match inotify.into_event_stream(&mut buffer) { + Ok(events) => events, + Err(err) => { + tracing::error!(error = %err, "Failed getting event stream from inotify"); + return; + } + }; + while let Some(Ok(event)) = events.next().await { + let Some(name) = event.name else { + tracing::warn!("Unexpected event on the directory itself"); + continue; + }; + let item_path = dir.join(name); + let key = match item_path.strip_prefix(&self.root) { + Ok(stripped) => stripped.display().to_string().replace("_", "/"), + Err(err) => { + // Possibly this should be a panic. + // A key cannot be outside the file store root. + tracing::error!( + error = %err, + item_path = %item_path.display(), + root = %self.root.display(), + "Item in file store is not prefixed with file store root. Should be impossible. Ignoring invalid key."); + continue; + } + }; + + match event.mask { + EventMask::MODIFY | EventMask::CREATE => { + let data: bytes::Bytes = match fs::read(&item_path) { + Ok(data) => data.into(), + Err(err) => { + tracing::warn!(error = %err, item = %item_path.display(), "Failed reading event item. Skipping."); + continue; + } + }; + let item = KeyValue::new(key, data); + yield WatchEvent::Put(item); + } + EventMask::DELETE => { + yield WatchEvent::Delete(Key::from_raw(key)); + } + event_type => { + tracing::warn!(?event_type, dir = %dir.display(), "Unexpected event type"); + continue; + } + } + } + })) + } + + async fn entries(&self) -> Result, StoreError> { + let contents = fs::read_dir(&self.p) + .with_context(|| self.p.display().to_string()) + .map_err(a_to_fs_err)?; + let mut out = HashMap::new(); + for entry in contents { + let entry = entry.map_err(to_fs_err)?; + if !entry.path().is_file() { + tracing::warn!( + path = %entry.path().display(), + "Unexpected entry, directory should only contain files." + ); + continue; + } + + let key = match entry.path().strip_prefix(&self.root) { + Ok(p) => p.to_string_lossy().to_string().replace("_", "/"), + Err(err) => { + tracing::error!( + error = %err, + path = %entry.path().display(), + root = %self.root.display(), + "FileStore path not in root. Should be impossible. Skipping entry." + ); + continue; + } + }; + let data: bytes::Bytes = fs::read(entry.path()) + .with_context(|| self.p.display().to_string()) + .map_err(a_to_fs_err)? + .into(); + out.insert(key, data); + } + Ok(out) + } +} + +// For anyhow preserve the context +fn a_to_fs_err(err: anyhow::Error) -> StoreError { + StoreError::FilesystemError(format!("{err:#}")) +} + +fn to_fs_err(err: E) -> StoreError { + StoreError::FilesystemError(err.to_string()) +} diff --git a/lib/runtime/src/storage/key_value_store/mem.rs b/lib/runtime/src/storage/key_value_store/mem.rs index 287a870693..a7a9037b28 100644 --- a/lib/runtime/src/storage/key_value_store/mem.rs +++ b/lib/runtime/src/storage/key_value_store/mem.rs @@ -57,7 +57,7 @@ impl MemoryBucket { } impl MemoryStore { - pub fn new() -> Self { + pub(super) fn new() -> Self { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); MemoryStore { inner: Arc::new(MemoryStoreInner { @@ -107,6 +107,8 @@ impl KeyValueStore for MemoryStore { fn connection_id(&self) -> u64 { self.connection_id } + + fn shutdown(&self) {} } #[async_trait] @@ -205,8 +207,7 @@ impl KeyValueBucket for MemoryBucketRef { yield WatchEvent::Put(item); }, Some(MemoryEvent::Delete { key }) => { - let item = KeyValue::new(key, bytes::Bytes::new()); - yield WatchEvent::Delete(item); + yield WatchEvent::Delete(Key::from_raw(key)); } } } diff --git a/lib/runtime/src/storage/key_value_store/nats.rs b/lib/runtime/src/storage/key_value_store/nats.rs index d30e779214..b6f5802efd 100644 --- a/lib/runtime/src/storage/key_value_store/nats.rs +++ b/lib/runtime/src/storage/key_value_store/nats.rs @@ -52,6 +52,11 @@ impl KeyValueStore for NATSStore { fn connection_id(&self) -> u64 { self.client.client().server_info().client_id } + + fn shutdown(&self) { + // TODO: Track and delete any owned keys + // The TTL should ensure NATS does it, but best we do it immediately + } } impl NATSStore { @@ -160,12 +165,14 @@ impl KeyValueBucket for NATSBucket { >| async move { match maybe_entry { Ok(entry) => { - let item = KeyValue::new(entry.key, entry.value); Some(match entry.operation { - Operation::Put => WatchEvent::Put(item), - Operation::Delete => WatchEvent::Delete(item), + Operation::Put => { + let item = KeyValue::new(entry.key, entry.value); + WatchEvent::Put(item) + } + Operation::Delete => WatchEvent::Delete(Key::from_raw(entry.key)), // TODO: What is Purge? Not urgent, NATS impl not used - Operation::Purge => WatchEvent::Delete(item), + Operation::Purge => WatchEvent::Delete(Key::from_raw(entry.key)), }) } Err(e) => { diff --git a/tests/planner/unit/test_virtual_connector.py b/tests/planner/unit/test_virtual_connector.py index 98b17f8bf0..367a3446a5 100644 --- a/tests/planner/unit/test_virtual_connector.py +++ b/tests/planner/unit/test_virtual_connector.py @@ -31,7 +31,7 @@ def get_runtime(): except Exception: # If no existing runtime, create a new one loop = asyncio.get_running_loop() - _runtime_instance = DistributedRuntime(loop, False) + _runtime_instance = DistributedRuntime(loop, "etcd", False) return _runtime_instance diff --git a/tests/router/test_router_e2e_with_mockers.py b/tests/router/test_router_e2e_with_mockers.py index 9423e8461c..6cf5e97fdf 100644 --- a/tests/router/test_router_e2e_with_mockers.py +++ b/tests/router/test_router_e2e_with_mockers.py @@ -226,7 +226,7 @@ def get_runtime(): # No running loop, create a new one (sync context) loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - _runtime_instance = DistributedRuntime(loop, False) + _runtime_instance = DistributedRuntime(loop, "etcd", False) return _runtime_instance