Skip to content

Commit 94ad508

Browse files
fix: introduce service discovery interface (1/n) (#3937)
Signed-off-by: mohammedabdulwahhab <[email protected]>
1 parent bfb2574 commit 94ad508

File tree

4 files changed

+340
-0
lines changed

4 files changed

+340
-0
lines changed

lib/runtime/src/discovery/mock.rs

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
use super::{
5+
DiscoveryClient, DiscoveryEvent, DiscoveryInstance, DiscoveryKey, DiscoverySpec,
6+
DiscoveryStream,
7+
};
8+
use crate::Result;
9+
use async_trait::async_trait;
10+
use std::sync::{Arc, Mutex};
11+
12+
/// Shared in-memory registry for mock discovery
13+
#[derive(Clone, Default)]
14+
pub struct SharedMockRegistry {
15+
instances: Arc<Mutex<Vec<DiscoveryInstance>>>,
16+
}
17+
18+
impl SharedMockRegistry {
19+
pub fn new() -> Self {
20+
Self::default()
21+
}
22+
}
23+
24+
/// Mock implementation of DiscoveryClient for testing
25+
/// We can potentially remove this once we have KeyValueDiscoveryClient implemented
26+
pub struct MockDiscoveryClient {
27+
instance_id: u64,
28+
registry: SharedMockRegistry,
29+
}
30+
31+
impl MockDiscoveryClient {
32+
pub fn new(instance_id: Option<u64>, registry: SharedMockRegistry) -> Self {
33+
let instance_id = instance_id.unwrap_or_else(|| {
34+
use std::sync::atomic::{AtomicU64, Ordering};
35+
static COUNTER: AtomicU64 = AtomicU64::new(1);
36+
COUNTER.fetch_add(1, Ordering::SeqCst)
37+
});
38+
39+
Self {
40+
instance_id,
41+
registry,
42+
}
43+
}
44+
}
45+
46+
/// Helper function to check if an instance matches a discovery key query
47+
fn matches_key(instance: &DiscoveryInstance, key: &DiscoveryKey) -> bool {
48+
match (instance, key) {
49+
(DiscoveryInstance::Endpoint { .. }, DiscoveryKey::AllEndpoints) => true,
50+
(
51+
DiscoveryInstance::Endpoint {
52+
namespace: ins_ns, ..
53+
},
54+
DiscoveryKey::NamespacedEndpoints { namespace },
55+
) => ins_ns == namespace,
56+
(
57+
DiscoveryInstance::Endpoint {
58+
namespace: ins_ns,
59+
component: ins_comp,
60+
..
61+
},
62+
DiscoveryKey::ComponentEndpoints {
63+
namespace,
64+
component,
65+
},
66+
) => ins_ns == namespace && ins_comp == component,
67+
(
68+
DiscoveryInstance::Endpoint {
69+
namespace: ins_ns,
70+
component: ins_comp,
71+
endpoint: ins_ep,
72+
..
73+
},
74+
DiscoveryKey::Endpoint {
75+
namespace,
76+
component,
77+
endpoint,
78+
},
79+
) => ins_ns == namespace && ins_comp == component && ins_ep == endpoint,
80+
}
81+
}
82+
83+
#[async_trait]
84+
impl DiscoveryClient for MockDiscoveryClient {
85+
fn instance_id(&self) -> u64 {
86+
self.instance_id
87+
}
88+
89+
async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
90+
let instance = spec.with_instance_id(self.instance_id);
91+
92+
self.registry
93+
.instances
94+
.lock()
95+
.unwrap()
96+
.push(instance.clone());
97+
98+
Ok(instance)
99+
}
100+
101+
async fn list_and_watch(&self, key: DiscoveryKey) -> Result<DiscoveryStream> {
102+
use std::collections::HashSet;
103+
104+
let registry = self.registry.clone();
105+
106+
let stream = async_stream::stream! {
107+
let mut known_instances = HashSet::new();
108+
109+
loop {
110+
let current: Vec<_> = {
111+
let instances = registry.instances.lock().unwrap();
112+
instances
113+
.iter()
114+
.filter(|instance| matches_key(instance, &key))
115+
.cloned()
116+
.collect()
117+
};
118+
119+
let current_ids: HashSet<_> = current.iter().map(|i| {
120+
match i {
121+
DiscoveryInstance::Endpoint { instance_id, .. } => *instance_id,
122+
}
123+
}).collect();
124+
125+
// Emit Added events for new instances
126+
for instance in current {
127+
let id = match &instance {
128+
DiscoveryInstance::Endpoint { instance_id, .. } => *instance_id,
129+
};
130+
if known_instances.insert(id) {
131+
yield Ok(DiscoveryEvent::Added(instance));
132+
}
133+
}
134+
135+
// Emit Removed events for instances that are gone
136+
for id in known_instances.difference(&current_ids).cloned().collect::<Vec<_>>() {
137+
yield Ok(DiscoveryEvent::Removed(id));
138+
known_instances.remove(&id);
139+
}
140+
141+
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
142+
}
143+
};
144+
145+
Ok(Box::pin(stream))
146+
}
147+
}
148+
149+
#[cfg(test)]
150+
mod tests {
151+
use super::*;
152+
use futures::StreamExt;
153+
154+
#[tokio::test]
155+
async fn test_mock_discovery_add_and_remove() {
156+
let registry = SharedMockRegistry::new();
157+
let client1 = MockDiscoveryClient::new(Some(1), registry.clone());
158+
let client2 = MockDiscoveryClient::new(Some(2), registry.clone());
159+
160+
let spec = DiscoverySpec::Endpoint {
161+
namespace: "test-ns".to_string(),
162+
component: "test-comp".to_string(),
163+
endpoint: "test-ep".to_string(),
164+
};
165+
166+
let key = DiscoveryKey::Endpoint {
167+
namespace: "test-ns".to_string(),
168+
component: "test-comp".to_string(),
169+
endpoint: "test-ep".to_string(),
170+
};
171+
172+
// Start watching
173+
let mut stream = client1.list_and_watch(key.clone()).await.unwrap();
174+
175+
// Add first instance
176+
client1.register(spec.clone()).await.unwrap();
177+
178+
let event = stream.next().await.unwrap().unwrap();
179+
match event {
180+
DiscoveryEvent::Added(DiscoveryInstance::Endpoint { instance_id, .. }) => {
181+
assert_eq!(instance_id, 1);
182+
}
183+
_ => panic!("Expected Added event for instance-1"),
184+
}
185+
186+
// Add second instance
187+
client2.register(spec.clone()).await.unwrap();
188+
189+
let event = stream.next().await.unwrap().unwrap();
190+
match event {
191+
DiscoveryEvent::Added(DiscoveryInstance::Endpoint { instance_id, .. }) => {
192+
assert_eq!(instance_id, 2);
193+
}
194+
_ => panic!("Expected Added event for instance-2"),
195+
}
196+
197+
// Remove first instance
198+
registry.instances.lock().unwrap().retain(|i| match i {
199+
DiscoveryInstance::Endpoint { instance_id, .. } => *instance_id != 1,
200+
});
201+
202+
let event = stream.next().await.unwrap().unwrap();
203+
match event {
204+
DiscoveryEvent::Removed(instance_id) => {
205+
assert_eq!(instance_id, 1);
206+
}
207+
_ => panic!("Expected Removed event for instance-1"),
208+
}
209+
}
210+
}

lib/runtime/src/discovery/mod.rs

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
use crate::Result;
5+
use async_trait::async_trait;
6+
use futures::Stream;
7+
use serde::{Deserialize, Serialize};
8+
use std::pin::Pin;
9+
10+
mod mock;
11+
pub use mock::{MockDiscoveryClient, SharedMockRegistry};
12+
13+
/// Query key for prefix-based discovery queries
14+
/// Supports hierarchical queries from all endpoints down to specific endpoints
15+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16+
pub enum DiscoveryKey {
17+
/// Query all endpoints in the system
18+
AllEndpoints,
19+
/// Query all endpoints in a specific namespace
20+
NamespacedEndpoints { namespace: String },
21+
/// Query all endpoints in a namespace/component
22+
ComponentEndpoints {
23+
namespace: String,
24+
component: String,
25+
},
26+
/// Query a specific endpoint
27+
Endpoint {
28+
namespace: String,
29+
component: String,
30+
endpoint: String,
31+
},
32+
// TODO: Extend to support ModelCard queries:
33+
// - AllModels
34+
// - NamespacedModels { namespace }
35+
// - ComponentModels { namespace, component }
36+
// - Model { namespace, component, model_name }
37+
}
38+
39+
/// Specification for registering objects in the discovery plane
40+
/// Represents the input to the register() operation
41+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42+
pub enum DiscoverySpec {
43+
/// Endpoint specification for registration
44+
Endpoint {
45+
namespace: String,
46+
component: String,
47+
endpoint: String,
48+
},
49+
// TODO: Add ModelCard variant:
50+
// - ModelCard { namespace, component, model_name, card: ModelDeploymentCard }
51+
}
52+
53+
impl DiscoverySpec {
54+
/// Attaches an instance ID to create a DiscoveryInstance
55+
pub fn with_instance_id(self, instance_id: u64) -> DiscoveryInstance {
56+
match self {
57+
Self::Endpoint {
58+
namespace,
59+
component,
60+
endpoint,
61+
} => DiscoveryInstance::Endpoint {
62+
namespace,
63+
component,
64+
endpoint,
65+
instance_id,
66+
},
67+
}
68+
}
69+
}
70+
71+
/// Registered instances in the discovery plane
72+
/// Represents objects that have been successfully registered with an instance ID
73+
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
74+
#[serde(tag = "type")]
75+
pub enum DiscoveryInstance {
76+
/// Registered endpoint instance
77+
Endpoint {
78+
namespace: String,
79+
component: String,
80+
endpoint: String,
81+
instance_id: u64,
82+
},
83+
// TODO: Add ModelCard variant:
84+
// - ModelCard { namespace, component, model_name, instance_id, card: ModelDeploymentCard }
85+
}
86+
87+
/// Events emitted by the discovery client watch stream
88+
#[derive(Debug, Clone, PartialEq, Eq)]
89+
pub enum DiscoveryEvent {
90+
/// A new instance was added
91+
Added(DiscoveryInstance),
92+
/// An instance was removed (identified by instance_id)
93+
Removed(u64),
94+
}
95+
96+
/// Stream type for discovery events
97+
pub type DiscoveryStream = Pin<Box<dyn Stream<Item = Result<DiscoveryEvent>> + Send>>;
98+
99+
/// Discovery client trait for service discovery across different backends
100+
#[async_trait]
101+
pub trait DiscoveryClient: Send + Sync {
102+
/// Returns a unique identifier for this worker (e.g lease id if using etcd or generated id for memory store)
103+
/// Discovery objects created by this worker will be associated with this id.
104+
fn instance_id(&self) -> u64;
105+
106+
/// Registers an object in the discovery plane with the instance id
107+
async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance>;
108+
109+
/// Returns a stream of discovery events (Added/Removed) for the given discovery key
110+
async fn list_and_watch(&self, key: DiscoveryKey) -> Result<DiscoveryStream>;
111+
}

lib/runtime/src/distributed.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::transports::nats::DRTNatsClientPrometheusMetrics;
99
use crate::{
1010
ErrorContext,
1111
component::{self, ComponentBuilder, Endpoint, InstanceSource, Namespace},
12+
discovery::DiscoveryClient,
1213
metrics::PrometheusUpdateCallback,
1314
metrics::{MetricsHierarchy, MetricsRegistry},
1415
service::ServiceClient,
@@ -83,13 +84,22 @@ impl DistributedRuntime {
8384

8485
let nats_client_for_metrics = nats_client.clone();
8586

87+
// Initialize discovery client with mock implementation
88+
// TODO: Replace MockDiscoveryClient with KeyValueStoreDiscoveryClient or KubeDiscoveryClient
89+
let discovery_client = {
90+
use crate::discovery::{MockDiscoveryClient, SharedMockRegistry};
91+
let registry = SharedMockRegistry::new();
92+
Arc::new(MockDiscoveryClient::new(None, registry)) as Arc<dyn DiscoveryClient>
93+
};
94+
8695
let distributed_runtime = Self {
8796
runtime,
8897
etcd_client,
8998
store,
9099
nats_client,
91100
tcp_server: Arc::new(OnceCell::new()),
92101
system_status_server: Arc::new(OnceLock::new()),
102+
discovery_client,
93103
component_registry: component::Registry::new(),
94104
is_static,
95105
instance_sources: Arc::new(Mutex::new(HashMap::new())),
@@ -223,6 +233,11 @@ impl DistributedRuntime {
223233
Namespace::new(self.clone(), name.into(), self.is_static)
224234
}
225235

236+
/// TODO: Return discovery client when KeyValueDiscoveryClient or KubeDiscoveryClient is implemented
237+
pub fn discovery_client(&self) -> Result<Arc<dyn DiscoveryClient>> {
238+
Err(error!("Discovery client not implemented!"))
239+
}
240+
226241
pub(crate) fn service_client(&self) -> Option<ServiceClient> {
227242
self.nats_client().map(|nc| ServiceClient::new(nc.clone()))
228243
}

lib/runtime/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub use config::RuntimeConfig;
2222

2323
pub mod component;
2424
pub mod compute;
25+
pub mod discovery;
2526
pub mod engine;
2627
pub mod health_check;
2728
pub mod system_status_server;
@@ -95,6 +96,9 @@ pub struct DistributedRuntime {
9596
tcp_server: Arc<OnceCell<Arc<transports::tcp::server::TcpStreamServer>>>,
9697
system_status_server: Arc<OnceLock<Arc<system_status_server::SystemStatusServerInfo>>>,
9798

99+
// Service discovery client
100+
discovery_client: Arc<dyn discovery::DiscoveryClient>,
101+
98102
// local registry for components
99103
// the registry allows us to use share runtime resources across instances of the same component object.
100104
// take for example two instances of a client to the same remote component. The registry allows us to use

0 commit comments

Comments
 (0)