Skip to content

Commit 8efd754

Browse files
declark1njhill
authored andcommitted
Implement NlpService embeddings methods (#9)
Co-authored-by: Nick Hill <[email protected]> Signed-off-by: Prashant <[email protected]>
1 parent 506fdf0 commit 8efd754

File tree

90 files changed

+2088
-185
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+2088
-185
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
.DS_Store
12
.idea
23
target
34
fmaas-router/src/pb

Cargo.lock

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

fmaas-router/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ clap = { version = "^4.5.3", features = ["derive", "env"] }
2020
futures = "^0.3.30"
2121
tonic = { version = "=0.11.0", features = ["tls"] }
2222
ginepro = "=0.7.1"
23-
lazy_static = "^1.4.0"
2423
tokio = { version = "^1.36.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "fs"] }
2524
tracing = "0.1.40"
2625
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
2726
prost = "^0.12.3"
27+
prost-types = "^0.12.3"
2828
serde_yaml = "^0.9.33"
29+
serde = { version = "^1.0.197", features = ["derive"] }
2930

3031
mio = "^0.8.11" # Override to address CVE-2024-27308
3132
rustls-webpki = "^0.102.2" # Override to address WS-2023-0305, CVE-2018-16875

fmaas-router/build.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
77
.build_server(true)
88
.out_dir("src/pb")
99
.include_file("mod.rs")
10-
.compile(&["../proto/generation.proto"], &["../proto"])
10+
.compile(
11+
&["../proto/generation.proto", "../proto/nlpservice.proto"],
12+
&["../proto"],
13+
)
1114
.unwrap_or_else(|e| panic!("protobuf compilation failed: {}", e));
1215

1316
Ok(())

fmaas-router/src/lib.rs

Lines changed: 118 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,121 @@
1-
extern crate core;
2-
extern crate lazy_static;
1+
use std::{collections::HashMap, path::Path};
2+
use anyhow::Context;
3+
use futures::future::try_join_all;
4+
use ginepro::LoadBalancedChannel;
35

6+
use serde::{Deserialize, Deserializer};
7+
use tonic::transport::ClientTlsConfig;
8+
use tracing::info;
9+
10+
#[allow(clippy::enum_variant_names)]
411
mod pb;
12+
pub mod rpc;
513
pub mod server;
14+
15+
#[derive(Debug, Clone, Deserialize)]
16+
pub struct ServiceAddr {
17+
pub hostname: String,
18+
pub port: Option<u16>,
19+
}
20+
21+
/// Old format without top-level keys, generation models only.
22+
#[derive(Debug, Clone, Deserialize)]
23+
pub struct ModelMapV1(#[serde(deserialize_with = "de_service_addr")] HashMap<String, ServiceAddr>);
24+
25+
/// New format with top-level keys for generation and embeddings models.
26+
#[derive(Debug, Clone, Deserialize)]
27+
pub struct ModelMapV2 {
28+
#[serde(deserialize_with = "de_service_addr", default = "HashMap::default")]
29+
generation: HashMap<String, ServiceAddr>,
30+
#[serde(deserialize_with = "de_service_addr", default = "HashMap::default")]
31+
embeddings: HashMap<String, ServiceAddr>,
32+
}
33+
34+
/// Maps model names to service address.
35+
#[derive(Debug, Clone, Deserialize)]
36+
#[serde(untagged)]
37+
pub enum ModelMap {
38+
V1(ModelMapV1),
39+
V2(ModelMapV2),
40+
}
41+
42+
impl ModelMap {
43+
pub fn load(path: impl AsRef<Path>) -> Self {
44+
let s = std::fs::read_to_string(path).expect("Failed to load model map config");
45+
serde_yaml::from_str(&s).expect("Invalid model map config")
46+
}
47+
48+
pub fn generation(&self) -> Option<&HashMap<String, ServiceAddr>> {
49+
match self {
50+
ModelMap::V1(v1) => Some(&v1.0),
51+
ModelMap::V2(v2) => (!v2.generation.is_empty()).then_some(&v2.generation),
52+
}
53+
}
54+
55+
pub fn embeddings(&self) -> Option<&HashMap<String, ServiceAddr>> {
56+
match self {
57+
ModelMap::V1(_) => None,
58+
ModelMap::V2(v2) => (!v2.embeddings.is_empty()).then_some(&v2.embeddings),
59+
}
60+
}
61+
}
62+
63+
fn service_addr_from_str<'de, D>(deserializer: D) -> Result<ServiceAddr, D::Error>
64+
where
65+
D: Deserializer<'de>,
66+
{
67+
let s = String::deserialize(deserializer).map_err(serde::de::Error::custom)?;
68+
let mut parts = s.split(':');
69+
let hostname = parts.next().unwrap().to_string();
70+
let port = parts.next().map(|p| {
71+
p.parse::<u16>()
72+
.unwrap_or_else(|_| panic!("Invalid port in configured service name: {p}"))
73+
});
74+
if parts.next().is_some() {
75+
panic!("Configured service name contains more than one : character");
76+
}
77+
Ok(ServiceAddr { hostname, port })
78+
}
79+
80+
fn de_service_addr<'de, D>(deserializer: D) -> Result<HashMap<String, ServiceAddr>, D::Error>
81+
where
82+
D: Deserializer<'de>,
83+
{
84+
#[derive(Deserialize)]
85+
struct Wrapper(#[serde(deserialize_with = "service_addr_from_str")] ServiceAddr);
86+
87+
let v = HashMap::<String, Wrapper>::deserialize(deserializer)?;
88+
Ok(v.into_iter().map(|(k, Wrapper(v))| (k, v)).collect())
89+
}
90+
91+
async fn create_clients<C>(
92+
default_target_port: u16,
93+
client_tls: Option<&ClientTlsConfig>,
94+
model_map: &HashMap<String, ServiceAddr>,
95+
new: fn(LoadBalancedChannel) -> C,
96+
) -> HashMap<String, C> {
97+
let clients = model_map
98+
.iter()
99+
.map(|(name, service_addr)| async move {
100+
info!("Configuring client for model name: [{name}]");
101+
// Build a load-balanced channel given a service name and a port.
102+
let mut builder = LoadBalancedChannel::builder((
103+
service_addr.hostname.clone(),
104+
service_addr.port.unwrap_or(default_target_port),
105+
));
106+
if let Some(tls_config) = client_tls {
107+
builder = builder.with_tls(tls_config.clone());
108+
}
109+
let channel = builder
110+
.channel()
111+
.await
112+
.context(format!("Channel failed for service {name}"))?;
113+
Ok((name.clone(), new(channel))) as Result<(String, C), anyhow::Error>
114+
})
115+
.collect::<Vec<_>>();
116+
try_join_all(clients)
117+
.await
118+
.expect("Error creating upstream service clients")
119+
.into_iter()
120+
.collect()
121+
}

fmaas-router/src/main.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
22

33
use clap::Parser;
4-
use fmaas_router::server;
4+
use fmaas_router::{server, ModelMap};
55
use tracing_subscriber::EnvFilter;
66

77
/// App Configuration
@@ -17,6 +17,8 @@ struct Args {
1717
#[clap(long, env)]
1818
json_output: bool,
1919
#[clap(long, env)]
20+
model_map_config: String,
21+
#[clap(long, env)]
2022
tls_cert_path: Option<String>,
2123
#[clap(long, env)]
2224
tls_key_path: Option<String>,
@@ -53,19 +55,19 @@ fn main() -> Result<(), std::io::Error> {
5355
if args.tls_key_path.is_some() != args.tls_cert_path.is_some() {
5456
panic!("tls: must provide both cert and key")
5557
}
56-
5758
if args.tls_client_ca_cert_path.is_some() && args.tls_cert_path.is_none() {
5859
panic!("tls: cannot provide client ca cert without keypair")
5960
}
6061

62+
// Load model map config
63+
let model_map = ModelMap::load(args.model_map_config);
64+
6165
// Launch Tokio runtime
6266
tokio::runtime::Builder::new_multi_thread()
6367
.enable_all()
6468
.build()
6569
.unwrap()
6670
.block_on(async {
67-
//TODO initialize clients
68-
6971
let grpc_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), args.grpc_port);
7072
let http_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), args.probe_port);
7173

@@ -78,6 +80,7 @@ fn main() -> Result<(), std::io::Error> {
7880
args.default_upstream_port,
7981
args.upstream_tls,
8082
args.upstream_tls_ca_cert_path,
83+
model_map,
8184
)
8285
.await;
8386

fmaas-router/src/rpc.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pub mod generation;
2+
pub mod nlp;

fmaas-router/src/rpc/generation.rs

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
use std::collections::HashMap;
2+
3+
use ginepro::LoadBalancedChannel;
4+
use tonic::{transport::ClientTlsConfig, Request, Response, Status, Streaming};
5+
use tracing::{debug, instrument};
6+
7+
use crate::{pb::fmaas::{
8+
generation_service_client::GenerationServiceClient,
9+
generation_service_server::GenerationService, BatchedGenerationRequest,
10+
BatchedGenerationResponse, BatchedTokenizeRequest, BatchedTokenizeResponse,
11+
GenerationResponse, ModelInfoRequest, ModelInfoResponse, SingleGenerationRequest,
12+
}, create_clients, ServiceAddr};
13+
14+
#[derive(Debug, Default)]
15+
pub struct GenerationServicer {
16+
clients: HashMap<String, GenerationServiceClient<LoadBalancedChannel>>,
17+
}
18+
19+
impl GenerationServicer {
20+
pub async fn new(
21+
default_target_port: u16,
22+
client_tls: Option<&ClientTlsConfig>,
23+
model_map: &HashMap<String, ServiceAddr>,
24+
) -> Self {
25+
let clients = create_clients(
26+
default_target_port, client_tls, model_map, GenerationServiceClient::new
27+
).await;
28+
Self { clients }
29+
}
30+
31+
async fn client(
32+
&self,
33+
model_id: &str,
34+
) -> Result<GenerationServiceClient<LoadBalancedChannel>, Status> {
35+
Ok(self
36+
.clients
37+
.get(model_id)
38+
.ok_or_else(|| Status::not_found(format!("Unrecognized model_id: {model_id}")))?
39+
.clone())
40+
}
41+
}
42+
43+
#[tonic::async_trait]
44+
impl GenerationService for GenerationServicer {
45+
#[instrument(skip_all)]
46+
async fn generate(
47+
&self,
48+
request: Request<BatchedGenerationRequest>,
49+
) -> Result<Response<BatchedGenerationResponse>, Status> {
50+
let br = request.get_ref();
51+
if br.requests.is_empty() {
52+
return Ok(Response::new(BatchedGenerationResponse {
53+
responses: vec![],
54+
}));
55+
}
56+
debug!("Routing generation request for Model ID {}", &br.model_id);
57+
self.client(&br.model_id).await?.generate(request).await
58+
}
59+
60+
type GenerateStreamStream = Streaming<GenerationResponse>;
61+
62+
#[instrument(skip_all)]
63+
async fn generate_stream(
64+
&self,
65+
request: Request<SingleGenerationRequest>,
66+
) -> Result<Response<Self::GenerateStreamStream>, Status> {
67+
let sr = request.get_ref();
68+
if sr.request.is_none() {
69+
return Err(Status::invalid_argument("missing request"));
70+
}
71+
debug!(
72+
"Routing streaming generation request for Model ID {}",
73+
&sr.model_id
74+
);
75+
self.client(&sr.model_id)
76+
.await?
77+
.generate_stream(request)
78+
.await
79+
}
80+
81+
#[instrument(skip_all)]
82+
async fn tokenize(
83+
&self,
84+
request: Request<BatchedTokenizeRequest>,
85+
) -> Result<Response<BatchedTokenizeResponse>, Status> {
86+
let br = request.get_ref();
87+
if br.requests.is_empty() {
88+
return Ok(Response::new(BatchedTokenizeResponse { responses: vec![] }));
89+
}
90+
debug!("Routing tokenization request for Model ID {}", &br.model_id);
91+
self.client(&br.model_id).await?.tokenize(request).await
92+
}
93+
94+
#[instrument(skip_all)]
95+
async fn model_info(
96+
&self,
97+
request: Request<ModelInfoRequest>,
98+
) -> Result<Response<ModelInfoResponse>, Status> {
99+
debug!(
100+
"Routing model info request for Model ID {}",
101+
&request.get_ref().model_id
102+
);
103+
self.client(&request.get_ref().model_id)
104+
.await?
105+
.model_info(request)
106+
.await
107+
}
108+
}

0 commit comments

Comments
 (0)