Skip to content

Commit e1547e2

Browse files
fix: expand discovery interface to support model types (#4090)
Signed-off-by: mohammedabdulwahhab <[email protected]>
1 parent 3c0763f commit e1547e2

File tree

3 files changed

+286
-38
lines changed

3 files changed

+286
-38
lines changed

lib/runtime/src/discovery/mock.rs

Lines changed: 78 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -46,37 +46,79 @@ impl MockDiscoveryClient {
4646
/// Helper function to check if an instance matches a discovery key query
4747
fn matches_key(instance: &DiscoveryInstance, key: &DiscoveryKey) -> bool {
4848
match (instance, key) {
49-
(DiscoveryInstance::Endpoint { .. }, DiscoveryKey::AllEndpoints) => true,
49+
// Endpoint matching
50+
(DiscoveryInstance::Endpoint(_), DiscoveryKey::AllEndpoints) => true,
51+
(DiscoveryInstance::Endpoint(inst), DiscoveryKey::NamespacedEndpoints { namespace }) => {
52+
&inst.namespace == namespace
53+
}
54+
(
55+
DiscoveryInstance::Endpoint(inst),
56+
DiscoveryKey::ComponentEndpoints {
57+
namespace,
58+
component,
59+
},
60+
) => &inst.namespace == namespace && &inst.component == component,
61+
(
62+
DiscoveryInstance::Endpoint(inst),
63+
DiscoveryKey::Endpoint {
64+
namespace,
65+
component,
66+
endpoint,
67+
},
68+
) => {
69+
&inst.namespace == namespace
70+
&& &inst.component == component
71+
&& &inst.endpoint == endpoint
72+
}
73+
74+
// ModelCard matching
75+
(DiscoveryInstance::ModelCard { .. }, DiscoveryKey::AllModelCards) => true,
5076
(
51-
DiscoveryInstance::Endpoint {
52-
namespace: ins_ns, ..
77+
DiscoveryInstance::ModelCard {
78+
namespace: inst_ns, ..
5379
},
54-
DiscoveryKey::NamespacedEndpoints { namespace },
55-
) => ins_ns == namespace,
80+
DiscoveryKey::NamespacedModelCards { namespace },
81+
) => inst_ns == namespace,
5682
(
57-
DiscoveryInstance::Endpoint {
58-
namespace: ins_ns,
59-
component: ins_comp,
83+
DiscoveryInstance::ModelCard {
84+
namespace: inst_ns,
85+
component: inst_comp,
6086
..
6187
},
62-
DiscoveryKey::ComponentEndpoints {
88+
DiscoveryKey::ComponentModelCards {
6389
namespace,
6490
component,
6591
},
66-
) => ins_ns == namespace && ins_comp == component,
92+
) => inst_ns == namespace && inst_comp == component,
6793
(
68-
DiscoveryInstance::Endpoint {
69-
namespace: ins_ns,
70-
component: ins_comp,
71-
endpoint: ins_ep,
94+
DiscoveryInstance::ModelCard {
95+
namespace: inst_ns,
96+
component: inst_comp,
97+
endpoint: inst_ep,
7298
..
7399
},
74-
DiscoveryKey::Endpoint {
100+
DiscoveryKey::EndpointModelCards {
75101
namespace,
76102
component,
77103
endpoint,
78104
},
79-
) => ins_ns == namespace && ins_comp == component && ins_ep == endpoint,
105+
) => inst_ns == namespace && inst_comp == component && inst_ep == endpoint,
106+
107+
// Cross-type matches return false
108+
(
109+
DiscoveryInstance::Endpoint(_),
110+
DiscoveryKey::AllModelCards
111+
| DiscoveryKey::NamespacedModelCards { .. }
112+
| DiscoveryKey::ComponentModelCards { .. }
113+
| DiscoveryKey::EndpointModelCards { .. },
114+
) => false,
115+
(
116+
DiscoveryInstance::ModelCard { .. },
117+
DiscoveryKey::AllEndpoints
118+
| DiscoveryKey::NamespacedEndpoints { .. }
119+
| DiscoveryKey::ComponentEndpoints { .. }
120+
| DiscoveryKey::Endpoint { .. },
121+
) => false,
80122
}
81123
}
82124

@@ -98,6 +140,15 @@ impl DiscoveryClient for MockDiscoveryClient {
98140
Ok(instance)
99141
}
100142

143+
async fn list(&self, key: DiscoveryKey) -> Result<Vec<DiscoveryInstance>> {
144+
let instances = self.registry.instances.lock().unwrap();
145+
Ok(instances
146+
.iter()
147+
.filter(|instance| matches_key(instance, &key))
148+
.cloned()
149+
.collect())
150+
}
151+
101152
async fn list_and_watch(&self, key: DiscoveryKey) -> Result<DiscoveryStream> {
102153
use std::collections::HashSet;
103154

@@ -118,14 +169,16 @@ impl DiscoveryClient for MockDiscoveryClient {
118169

119170
let current_ids: HashSet<_> = current.iter().map(|i| {
120171
match i {
121-
DiscoveryInstance::Endpoint { instance_id, .. } => *instance_id,
172+
DiscoveryInstance::Endpoint(inst) => inst.instance_id,
173+
DiscoveryInstance::ModelCard { instance_id, .. } => *instance_id,
122174
}
123175
}).collect();
124176

125177
// Emit Added events for new instances
126178
for instance in current {
127179
let id = match &instance {
128-
DiscoveryInstance::Endpoint { instance_id, .. } => *instance_id,
180+
DiscoveryInstance::Endpoint(inst) => inst.instance_id,
181+
DiscoveryInstance::ModelCard { instance_id, .. } => *instance_id,
129182
};
130183
if known_instances.insert(id) {
131184
yield Ok(DiscoveryEvent::Added(instance));
@@ -161,6 +214,7 @@ mod tests {
161214
namespace: "test-ns".to_string(),
162215
component: "test-comp".to_string(),
163216
endpoint: "test-ep".to_string(),
217+
transport: crate::component::TransportType::NatsTcp("test-subject".to_string()),
164218
};
165219

166220
let key = DiscoveryKey::Endpoint {
@@ -177,8 +231,8 @@ mod tests {
177231

178232
let event = stream.next().await.unwrap().unwrap();
179233
match event {
180-
DiscoveryEvent::Added(DiscoveryInstance::Endpoint { instance_id, .. }) => {
181-
assert_eq!(instance_id, 1);
234+
DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
235+
assert_eq!(inst.instance_id, 1);
182236
}
183237
_ => panic!("Expected Added event for instance-1"),
184238
}
@@ -188,15 +242,16 @@ mod tests {
188242

189243
let event = stream.next().await.unwrap().unwrap();
190244
match event {
191-
DiscoveryEvent::Added(DiscoveryInstance::Endpoint { instance_id, .. }) => {
192-
assert_eq!(instance_id, 2);
245+
DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
246+
assert_eq!(inst.instance_id, 2);
193247
}
194248
_ => panic!("Expected Added event for instance-2"),
195249
}
196250

197251
// Remove first instance
198252
registry.instances.lock().unwrap().retain(|i| match i {
199-
DiscoveryInstance::Endpoint { instance_id, .. } => *instance_id != 1,
253+
DiscoveryInstance::Endpoint(inst) => inst.instance_id != 1,
254+
DiscoveryInstance::ModelCard { instance_id, .. } => *instance_id != 1,
200255
});
201256

202257
let event = stream.next().await.unwrap().unwrap();

lib/runtime/src/discovery/mod.rs

Lines changed: 102 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// SPDX-License-Identifier: Apache-2.0
33

44
use crate::Result;
5+
use crate::component::TransportType;
56
use async_trait::async_trait;
67
use futures::Stream;
78
use serde::{Deserialize, Serialize};
@@ -10,14 +11,19 @@ use std::pin::Pin;
1011
mod mock;
1112
pub use mock::{MockDiscoveryClient, SharedMockRegistry};
1213

14+
pub mod utils;
15+
pub use utils::watch_and_extract_field;
16+
1317
/// Query key for prefix-based discovery queries
1418
/// Supports hierarchical queries from all endpoints down to specific endpoints
1519
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1620
pub enum DiscoveryKey {
1721
/// Query all endpoints in the system
1822
AllEndpoints,
1923
/// Query all endpoints in a specific namespace
20-
NamespacedEndpoints { namespace: String },
24+
NamespacedEndpoints {
25+
namespace: String,
26+
},
2127
/// Query all endpoints in a namespace/component
2228
ComponentEndpoints {
2329
namespace: String,
@@ -29,59 +35,136 @@ pub enum DiscoveryKey {
2935
component: String,
3036
endpoint: String,
3137
},
32-
// TODO: Extend to support ModelCard queries:
33-
// - AllModels
34-
// - NamespacedModels { namespace }
35-
// - ComponentModels { namespace, component }
36-
// - Model { namespace, component, model_name }
38+
AllModelCards,
39+
NamespacedModelCards {
40+
namespace: String,
41+
},
42+
ComponentModelCards {
43+
namespace: String,
44+
component: String,
45+
},
46+
EndpointModelCards {
47+
namespace: String,
48+
component: String,
49+
endpoint: String,
50+
},
3751
}
3852

3953
/// Specification for registering objects in the discovery plane
4054
/// Represents the input to the register() operation
41-
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
55+
#[derive(Debug, Clone, PartialEq, Eq)]
4256
pub enum DiscoverySpec {
4357
/// Endpoint specification for registration
4458
Endpoint {
4559
namespace: String,
4660
component: String,
4761
endpoint: String,
62+
/// Transport type and routing information
63+
transport: TransportType,
64+
},
65+
ModelCard {
66+
namespace: String,
67+
component: String,
68+
endpoint: String,
69+
/// ModelDeploymentCard serialized as JSON
70+
/// This allows lib/runtime to remain independent of lib/llm types
71+
/// DiscoverySpec.from_model_card() and DiscoveryInstance.deserialize_model_card() are ergonomic helpers to create and deserialize the model card.
72+
card_json: serde_json::Value,
4873
},
49-
// TODO: Add ModelCard variant:
50-
// - ModelCard { namespace, component, model_name, card: ModelDeploymentCard }
5174
}
5275

5376
impl DiscoverySpec {
77+
/// Creates a ModelCard discovery spec from a serializable type
78+
/// The card will be serialized to JSON to avoid cross-crate dependencies
79+
pub fn from_model_card<T>(
80+
namespace: String,
81+
component: String,
82+
endpoint: String,
83+
card: &T,
84+
) -> crate::Result<Self>
85+
where
86+
T: Serialize,
87+
{
88+
let card_json = serde_json::to_value(card)?;
89+
Ok(Self::ModelCard {
90+
namespace,
91+
component,
92+
endpoint,
93+
card_json,
94+
})
95+
}
96+
5497
/// Attaches an instance ID to create a DiscoveryInstance
5598
pub fn with_instance_id(self, instance_id: u64) -> DiscoveryInstance {
5699
match self {
57100
Self::Endpoint {
58101
namespace,
59102
component,
60103
endpoint,
61-
} => DiscoveryInstance::Endpoint {
104+
transport,
105+
} => DiscoveryInstance::Endpoint(crate::component::Instance {
106+
namespace,
107+
component,
108+
endpoint,
109+
instance_id,
110+
transport,
111+
}),
112+
Self::ModelCard {
113+
namespace,
114+
component,
115+
endpoint,
116+
card_json,
117+
} => DiscoveryInstance::ModelCard {
62118
namespace,
63119
component,
64120
endpoint,
65121
instance_id,
122+
card_json,
66123
},
67124
}
68125
}
69126
}
70127

71128
/// Registered instances in the discovery plane
72129
/// Represents objects that have been successfully registered with an instance ID
73-
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
130+
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
74131
#[serde(tag = "type")]
75132
pub enum DiscoveryInstance {
76-
/// Registered endpoint instance
77-
Endpoint {
133+
/// Registered endpoint instance - wraps the component::Instance directly
134+
Endpoint(crate::component::Instance),
135+
ModelCard {
78136
namespace: String,
79137
component: String,
80138
endpoint: String,
81139
instance_id: u64,
140+
/// ModelDeploymentCard serialized as JSON
141+
/// This allows lib/runtime to remain independent of lib/llm types
142+
card_json: serde_json::Value,
82143
},
83-
// TODO: Add ModelCard variant:
84-
// - ModelCard { namespace, component, model_name, instance_id, card: ModelDeploymentCard }
144+
}
145+
146+
impl DiscoveryInstance {
147+
/// Returns the instance ID for this discovery instance
148+
pub fn instance_id(&self) -> u64 {
149+
match self {
150+
Self::Endpoint(inst) => inst.instance_id,
151+
Self::ModelCard { instance_id, .. } => *instance_id,
152+
}
153+
}
154+
155+
/// Deserializes the model card JSON into the specified type T
156+
/// Returns an error if this is not a ModelCard instance or if deserialization fails
157+
pub fn deserialize_model_card<T>(&self) -> crate::Result<T>
158+
where
159+
T: for<'de> Deserialize<'de>,
160+
{
161+
match self {
162+
Self::ModelCard { card_json, .. } => Ok(serde_json::from_value(card_json.clone())?),
163+
Self::Endpoint(_) => {
164+
crate::raise!("Cannot deserialize model card from Endpoint instance")
165+
}
166+
}
167+
}
85168
}
86169

87170
/// Events emitted by the discovery client watch stream
@@ -106,6 +189,10 @@ pub trait DiscoveryClient: Send + Sync {
106189
/// Registers an object in the discovery plane with the instance id
107190
async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance>;
108191

192+
/// Returns a list of currently registered instances for the given discovery key
193+
/// This is a one-time snapshot without watching for changes
194+
async fn list(&self, key: DiscoveryKey) -> Result<Vec<DiscoveryInstance>>;
195+
109196
/// Returns a stream of discovery events (Added/Removed) for the given discovery key
110197
async fn list_and_watch(&self, key: DiscoveryKey) -> Result<DiscoveryStream>;
111198
}

0 commit comments

Comments
 (0)