Skip to content

Commit 506fdf0

Browse files
declark1prashantgupta24
authored andcommitted
Apply rustfmt, clippy suggestions, update .gitignore
1 parent 4fa9a4d commit 506fdf0

File tree

4 files changed

+127
-95
lines changed

4 files changed

+127
-95
lines changed

fmaas-router/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
extern crate lazy_static;
21
extern crate core;
2+
extern crate lazy_static;
33

4+
mod pb;
45
pub mod server;
5-
mod pb;

fmaas-router/src/main.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
use clap::Parser;
21
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
3-
use tracing_subscriber::{EnvFilter};
2+
3+
use clap::Parser;
44
use fmaas_router::server;
5+
use tracing_subscriber::EnvFilter;
56

67
/// App Configuration
78
#[derive(Parser, Debug)]
@@ -27,7 +28,6 @@ struct Args {
2728
upstream_tls_ca_cert_path: Option<String>,
2829
}
2930

30-
3131
fn main() -> Result<(), std::io::Error> {
3232
//Get args
3333
let args = Args::parse();
@@ -66,22 +66,20 @@ fn main() -> Result<(), std::io::Error> {
6666
.block_on(async {
6767
//TODO initialize clients
6868

69-
let grpc_addr = SocketAddr::new(
70-
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), args.grpc_port
71-
);
72-
let http_addr = SocketAddr::new(
73-
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), args.probe_port
74-
);
69+
let grpc_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), args.grpc_port);
70+
let http_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), args.probe_port);
7571

7672
server::run(
7773
grpc_addr,
7874
http_addr,
79-
args.tls_cert_path.map(|cp| (cp, args.tls_key_path.unwrap())),
75+
args.tls_cert_path
76+
.map(|cp| (cp, args.tls_key_path.unwrap())),
8077
args.tls_client_ca_cert_path,
8178
args.default_upstream_port,
8279
args.upstream_tls,
83-
args.upstream_tls_ca_cert_path
84-
).await;
80+
args.upstream_tls_ca_cert_path,
81+
)
82+
.await;
8583

8684
Ok(())
8785
})

fmaas-router/src/server.rs

Lines changed: 113 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,24 @@
1-
use std::collections::HashMap;
2-
use std::net::SocketAddr;
3-
use std::time::Duration;
4-
use anyhow::Context;
1+
use std::{collections::HashMap, net::SocketAddr, time::Duration};
52

6-
use axum::Router;
7-
use axum::routing::get;
3+
use anyhow::Context;
4+
use axum::{routing::get, Router};
85
use futures::future::try_join_all;
96
use ginepro::LoadBalancedChannel;
107
use lazy_static::lazy_static;
11-
use tokio::fs::read;
12-
use tokio::signal;
13-
use tokio::time::sleep;
14-
use tonic::{Request, Response, Status, Streaming};
15-
use tonic::transport::{Certificate, ClientTlsConfig, Identity, Server, ServerTlsConfig};
8+
use tokio::{fs::read, signal, time::sleep};
9+
use tonic::{
10+
transport::{Certificate, ClientTlsConfig, Identity, Server, ServerTlsConfig},
11+
Request, Response, Status, Streaming,
12+
};
1613
use tracing::instrument;
14+
1715
use crate::pb::fmaas::{
18-
BatchedGenerationRequest, BatchedGenerationResponse, GenerationResponse,
19-
SingleGenerationRequest, BatchedTokenizeRequest, BatchedTokenizeResponse,
20-
ModelInfoRequest, ModelInfoResponse,
16+
generation_service_client::GenerationServiceClient,
17+
generation_service_server::{GenerationService, GenerationServiceServer},
18+
BatchedGenerationRequest, BatchedGenerationResponse, BatchedTokenizeRequest,
19+
BatchedTokenizeResponse, GenerationResponse, ModelInfoRequest, ModelInfoResponse,
20+
SingleGenerationRequest,
2121
};
22-
use crate::pb::fmaas::generation_service_client::GenerationServiceClient;
23-
use crate::pb::fmaas::generation_service_server::{GenerationService, GenerationServiceServer};
24-
2522

2623
const MODEL_MAP_ENV_VAR_NAME: &str = "MODEL_MAP_CONFIG";
2724

@@ -40,8 +37,8 @@ lazy_static! {
4037
}
4138
};
4239
}
43-
let map: HashMap<&'static str, &'static str> = serde_yaml::from_str(&MODEL_MAP_STR)
44-
.expect("Failed to parse the model mapping config");
40+
let map: HashMap<&'static str, &'static str> =
41+
serde_yaml::from_str(&MODEL_MAP_STR).expect("Failed to parse the model mapping config");
4542
tracing::info!("{} model mappings configured", map.len());
4643
map
4744
};
@@ -54,7 +51,7 @@ pub async fn run(
5451
tls_client_ca_cert: Option<String>,
5552
default_target_port: u16,
5653
upstream_tls: bool,
57-
upstream_tls_ca_cert: Option<String>
54+
upstream_tls_ca_cert: Option<String>,
5855
) {
5956
let mut builder = Server::builder();
6057

@@ -82,40 +79,47 @@ pub async fn run(
8279
let ca_cert_pem = load_pem(ca_cert_path, "client ca cert").await;
8380
tls_config = tls_config.client_ca_root(Certificate::from_pem(ca_cert_pem));
8481
}
85-
builder = builder.tls_config(tls_config).expect("tls configuration error");
82+
builder = builder
83+
.tls_config(tls_config)
84+
.expect("tls configuration error");
8685
} else if upstream_tls {
8786
panic!("Upstream TLS enabled without any certificates");
8887
}
8988

9089
// Set up clients
91-
let clients = try_join_all(
92-
MODEL_MAP.iter().map(|(model, service)| async {
93-
tracing::info!("Configuring client for model name: [{}]", *model);
94-
// Parse hostname and optional port from target service name
95-
let mut service_parts = service.split(":");
96-
let hostname = service_parts.next().unwrap();
97-
let port = service_parts.next().map_or(
98-
default_target_port,
99-
|p| p.parse::<u16>().expect(
100-
&*format!("Invalid port in configured service name: {}", p)
101-
),
102-
);
103-
if service_parts.next().is_some() {
104-
panic!("Configured service name contains more than one : character");
105-
}
106-
// Build a load-balanced channel given a service name and a port.
107-
let mut builder = LoadBalancedChannel::builder((hostname, port));
108-
//.dns_probe_interval(Duration::from_secs(10))
109-
if let Some(tls_config) = &client_tls {
110-
builder = builder.with_tls(tls_config.clone());
111-
}
112-
let channel = builder.channel().await
113-
.context(format!("Channel failed for service {}", *service))?;
114-
Ok((*model, GenerationServiceClient::new(channel))) as Result<
115-
(&'static str, GenerationServiceClient<LoadBalancedChannel>), anyhow::Error
116-
>
117-
})).await.expect("Error creating upstream service clients").into_iter().collect();
118-
tracing::info!("{} upstream gRPC clients created successfully", grpc_addr.port());
90+
let clients = try_join_all(MODEL_MAP.iter().map(|(model, service)| async {
91+
tracing::info!("Configuring client for model name: [{}]", *model);
92+
// Parse hostname and optional port from target service name
93+
let mut service_parts = service.split(':');
94+
let hostname = service_parts.next().unwrap();
95+
let port = service_parts.next().map_or(default_target_port, |p| {
96+
p.parse::<u16>()
97+
.unwrap_or_else(|_| panic!("Invalid port in configured service name: {}", p))
98+
});
99+
if service_parts.next().is_some() {
100+
panic!("Configured service name contains more than one : character");
101+
}
102+
// Build a load-balanced channel given a service name and a port.
103+
let mut builder = LoadBalancedChannel::builder((hostname, port));
104+
//.dns_probe_interval(Duration::from_secs(10))
105+
if let Some(tls_config) = &client_tls {
106+
builder = builder.with_tls(tls_config.clone());
107+
}
108+
let channel = builder
109+
.channel()
110+
.await
111+
.context(format!("Channel failed for service {}", *service))?;
112+
Ok((*model, GenerationServiceClient::new(channel)))
113+
as Result<(&'static str, GenerationServiceClient<LoadBalancedChannel>), anyhow::Error>
114+
}))
115+
.await
116+
.expect("Error creating upstream service clients")
117+
.into_iter()
118+
.collect();
119+
tracing::info!(
120+
"{} upstream gRPC clients created successfully",
121+
grpc_addr.port()
122+
);
119123

120124
// Build and start gRPC server in background task
121125
let grpc_service = GenerationServicer { clients };
@@ -132,13 +136,15 @@ pub async fn run(
132136
// fail before starting
133137
sleep(Duration::from_secs(2)).await;
134138
if grpc_server_handle.is_finished() {
135-
grpc_server_handle.await.unwrap().expect("gRPC server startup failed");
139+
grpc_server_handle
140+
.await
141+
.unwrap()
142+
.expect("gRPC server startup failed");
136143
panic!(); // should not reach here
137144
}
138145

139146
// Build and await on the HTTP server
140-
let app = Router::new()
141-
.route("/health", get(health));
147+
let app = Router::new().route("/health", get(health));
142148

143149
let server = axum::Server::bind(&http_addr)
144150
.serve(app.into_make_service())
@@ -147,7 +153,10 @@ pub async fn run(
147153
tracing::info!("HTTP server started on port {}", http_addr.port());
148154
server.await.expect("HTTP server crashed!");
149155

150-
grpc_server_handle.await.unwrap().expect("gRPC server crashed");
156+
grpc_server_handle
157+
.await
158+
.unwrap()
159+
.expect("gRPC server crashed");
151160
}
152161

153162
async fn health() -> &'static str {
@@ -156,42 +165,47 @@ async fn health() -> &'static str {
156165
}
157166

158167
async fn load_pem(path: String, name: &str) -> Vec<u8> {
159-
read(&path).await.expect(&*format!("couldn't load {name} from {path}"))
168+
read(&path)
169+
.await
170+
.unwrap_or_else(|_| panic!("couldn't load {name} from {path}"))
160171
}
161172

162173
/*
163174
TODO:
164175
- Log errors/timings
165176
*/
166177

167-
168178
#[derive(Debug, Default)]
169179
pub struct GenerationServicer {
170-
clients: HashMap<&'static str, GenerationServiceClient<LoadBalancedChannel>>
180+
clients: HashMap<&'static str, GenerationServiceClient<LoadBalancedChannel>>,
171181
}
172182

173183
impl GenerationServicer {
174-
async fn client(&self, model_id: &String)
175-
-> Result<GenerationServiceClient<LoadBalancedChannel>, Status> {
176-
Ok(
177-
self.clients.get(&**model_id)
178-
.ok_or_else(|| Status::not_found(
179-
format!("Unrecognized model_id: {model_id}"))
180-
)?
181-
.clone()
182-
)
184+
async fn client(
185+
&self,
186+
model_id: &String,
187+
) -> Result<GenerationServiceClient<LoadBalancedChannel>, Status> {
188+
Ok(self
189+
.clients
190+
.get(&**model_id)
191+
.ok_or_else(|| Status::not_found(format!("Unrecognized model_id: {model_id}")))?
192+
.clone())
183193
}
184194
}
185195

186196
#[tonic::async_trait]
187197
impl GenerationService for GenerationServicer {
188198
#[instrument(skip_all)]
189-
async fn generate(&self, request: Request<BatchedGenerationRequest>)
190-
-> Result<Response<BatchedGenerationResponse>, Status> {
199+
async fn generate(
200+
&self,
201+
request: Request<BatchedGenerationRequest>,
202+
) -> Result<Response<BatchedGenerationResponse>, Status> {
191203
//let start_time = Instant::now();
192204
let br = request.get_ref();
193205
if br.requests.is_empty() {
194-
return Ok(Response::new(BatchedGenerationResponse{ responses: vec![] }));
206+
return Ok(Response::new(BatchedGenerationResponse {
207+
responses: vec![],
208+
}));
195209
}
196210
tracing::debug!("Routing generation request for Model ID {}", &br.model_id);
197211
self.client(&br.model_id).await?.generate(request).await
@@ -200,32 +214,50 @@ impl GenerationService for GenerationServicer {
200214
type GenerateStreamStream = Streaming<GenerationResponse>;
201215

202216
#[instrument(skip_all)]
203-
async fn generate_stream(&self, request: Request<SingleGenerationRequest>)
204-
-> Result<Response<Self::GenerateStreamStream>, Status> {
217+
async fn generate_stream(
218+
&self,
219+
request: Request<SingleGenerationRequest>,
220+
) -> Result<Response<Self::GenerateStreamStream>, Status> {
205221
let sr = request.get_ref();
206222
if sr.request.is_none() {
207223
return Err(Status::invalid_argument("missing request"));
208224
}
209-
tracing::debug!("Routing streaming generation request for Model ID {}", &sr.model_id);
210-
self.client(&sr.model_id).await?.generate_stream(request).await
225+
tracing::debug!(
226+
"Routing streaming generation request for Model ID {}",
227+
&sr.model_id
228+
);
229+
self.client(&sr.model_id)
230+
.await?
231+
.generate_stream(request)
232+
.await
211233
}
212234

213235
#[instrument(skip_all)]
214-
async fn tokenize(&self, request: Request<BatchedTokenizeRequest>)
215-
-> Result<Response<BatchedTokenizeResponse>, Status> {
236+
async fn tokenize(
237+
&self,
238+
request: Request<BatchedTokenizeRequest>,
239+
) -> Result<Response<BatchedTokenizeResponse>, Status> {
216240
let br = request.get_ref();
217241
if br.requests.is_empty() {
218-
return Ok(Response::new(BatchedTokenizeResponse{ responses: vec![] }));
242+
return Ok(Response::new(BatchedTokenizeResponse { responses: vec![] }));
219243
}
220244
tracing::debug!("Routing tokenization request for Model ID {}", &br.model_id);
221245
self.client(&br.model_id).await?.tokenize(request).await
222246
}
223247

224248
#[instrument(skip_all)]
225-
async fn model_info(&self, request: Request<ModelInfoRequest>)
226-
-> Result<Response<ModelInfoResponse>, Status> {
227-
tracing::debug!("Routing model info request for Model ID {}", &request.get_ref().model_id);
228-
self.client(&request.get_ref().model_id).await?.model_info(request).await
249+
async fn model_info(
250+
&self,
251+
request: Request<ModelInfoRequest>,
252+
) -> Result<Response<ModelInfoResponse>, Status> {
253+
tracing::debug!(
254+
"Routing model info request for Model ID {}",
255+
&request.get_ref().model_id
256+
);
257+
self.client(&request.get_ref().model_id)
258+
.await?
259+
.model_info(request)
260+
.await
229261
}
230262
}
231263

@@ -238,15 +270,15 @@ async fn shutdown_signal() {
238270
};
239271

240272
#[cfg(unix)]
241-
let terminate = async {
273+
let terminate = async {
242274
signal::unix::signal(signal::unix::SignalKind::terminate())
243275
.expect("failed to install signal handler")
244276
.recv()
245277
.await;
246278
};
247279

248280
#[cfg(not(unix))]
249-
let terminate = std::future::pending::<()>();
281+
let terminate = std::future::pending::<()>();
250282

251283
tokio::select! {
252284
_ = ctrl_c => {},

rustfmt.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
group_imports = "StdExternalCrate"
2+
imports_granularity = "Crate"

0 commit comments

Comments
 (0)