1
- use std:: collections:: HashMap ;
2
- use std:: net:: SocketAddr ;
3
- use std:: time:: Duration ;
4
- use anyhow:: Context ;
1
+ use std:: { collections:: HashMap , net:: SocketAddr , time:: Duration } ;
5
2
6
- use axum :: Router ;
7
- use axum:: routing:: get;
3
+ use anyhow :: Context ;
4
+ use axum:: { routing:: get, Router } ;
8
5
use futures:: future:: try_join_all;
9
6
use ginepro:: LoadBalancedChannel ;
10
7
use lazy_static:: lazy_static;
11
- use tokio:: fs:: read;
12
- use tokio :: signal ;
13
- use tokio :: time :: sleep ;
14
- use tonic :: { Request , Response , Status , Streaming } ;
15
- use tonic :: transport :: { Certificate , ClientTlsConfig , Identity , Server , ServerTlsConfig } ;
8
+ use tokio:: { fs:: read, signal , time :: sleep } ;
9
+ use tonic :: {
10
+ transport :: { Certificate , ClientTlsConfig , Identity , Server , ServerTlsConfig } ,
11
+ Request , Response , Status , Streaming ,
12
+ } ;
16
13
use tracing:: instrument;
14
+
17
15
use crate :: pb:: fmaas:: {
18
- BatchedGenerationRequest , BatchedGenerationResponse , GenerationResponse ,
19
- SingleGenerationRequest , BatchedTokenizeRequest , BatchedTokenizeResponse ,
20
- ModelInfoRequest , ModelInfoResponse ,
16
+ generation_service_client:: GenerationServiceClient ,
17
+ generation_service_server:: { GenerationService , GenerationServiceServer } ,
18
+ BatchedGenerationRequest , BatchedGenerationResponse , BatchedTokenizeRequest ,
19
+ BatchedTokenizeResponse , GenerationResponse , ModelInfoRequest , ModelInfoResponse ,
20
+ SingleGenerationRequest ,
21
21
} ;
22
- use crate :: pb:: fmaas:: generation_service_client:: GenerationServiceClient ;
23
- use crate :: pb:: fmaas:: generation_service_server:: { GenerationService , GenerationServiceServer } ;
24
-
25
22
26
23
const MODEL_MAP_ENV_VAR_NAME : & str = "MODEL_MAP_CONFIG" ;
27
24
@@ -40,8 +37,8 @@ lazy_static! {
40
37
}
41
38
} ;
42
39
}
43
- let map: HashMap <& ' static str , & ' static str > = serde_yaml :: from_str ( & MODEL_MAP_STR )
44
- . expect( "Failed to parse the model mapping config" ) ;
40
+ let map: HashMap <& ' static str , & ' static str > =
41
+ serde_yaml :: from_str ( & MODEL_MAP_STR ) . expect( "Failed to parse the model mapping config" ) ;
45
42
tracing:: info!( "{} model mappings configured" , map. len( ) ) ;
46
43
map
47
44
} ;
@@ -54,7 +51,7 @@ pub async fn run(
54
51
tls_client_ca_cert : Option < String > ,
55
52
default_target_port : u16 ,
56
53
upstream_tls : bool ,
57
- upstream_tls_ca_cert : Option < String >
54
+ upstream_tls_ca_cert : Option < String > ,
58
55
) {
59
56
let mut builder = Server :: builder ( ) ;
60
57
@@ -82,40 +79,47 @@ pub async fn run(
82
79
let ca_cert_pem = load_pem ( ca_cert_path, "client ca cert" ) . await ;
83
80
tls_config = tls_config. client_ca_root ( Certificate :: from_pem ( ca_cert_pem) ) ;
84
81
}
85
- builder = builder. tls_config ( tls_config) . expect ( "tls configuration error" ) ;
82
+ builder = builder
83
+ . tls_config ( tls_config)
84
+ . expect ( "tls configuration error" ) ;
86
85
} else if upstream_tls {
87
86
panic ! ( "Upstream TLS enabled without any certificates" ) ;
88
87
}
89
88
90
89
// Set up clients
91
- let clients = try_join_all (
92
- MODEL_MAP . iter ( ) . map ( |( model, service) | async {
93
- tracing:: info!( "Configuring client for model name: [{}]" , * model) ;
94
- // Parse hostname and optional port from target service name
95
- let mut service_parts = service. split ( ":" ) ;
96
- let hostname = service_parts. next ( ) . unwrap ( ) ;
97
- let port = service_parts. next ( ) . map_or (
98
- default_target_port,
99
- |p| p. parse :: < u16 > ( ) . expect (
100
- & * format ! ( "Invalid port in configured service name: {}" , p)
101
- ) ,
102
- ) ;
103
- if service_parts. next ( ) . is_some ( ) {
104
- panic ! ( "Configured service name contains more than one : character" ) ;
105
- }
106
- // Build a load-balanced channel given a service name and a port.
107
- let mut builder = LoadBalancedChannel :: builder ( ( hostname, port) ) ;
108
- //.dns_probe_interval(Duration::from_secs(10))
109
- if let Some ( tls_config) = & client_tls {
110
- builder = builder. with_tls ( tls_config. clone ( ) ) ;
111
- }
112
- let channel = builder. channel ( ) . await
113
- . context ( format ! ( "Channel failed for service {}" , * service) ) ?;
114
- Ok ( ( * model, GenerationServiceClient :: new ( channel) ) ) as Result <
115
- ( & ' static str , GenerationServiceClient < LoadBalancedChannel > ) , anyhow:: Error
116
- >
117
- } ) ) . await . expect ( "Error creating upstream service clients" ) . into_iter ( ) . collect ( ) ;
118
- tracing:: info!( "{} upstream gRPC clients created successfully" , grpc_addr. port( ) ) ;
90
+ let clients = try_join_all ( MODEL_MAP . iter ( ) . map ( |( model, service) | async {
91
+ tracing:: info!( "Configuring client for model name: [{}]" , * model) ;
92
+ // Parse hostname and optional port from target service name
93
+ let mut service_parts = service. split ( ':' ) ;
94
+ let hostname = service_parts. next ( ) . unwrap ( ) ;
95
+ let port = service_parts. next ( ) . map_or ( default_target_port, |p| {
96
+ p. parse :: < u16 > ( )
97
+ . unwrap_or_else ( |_| panic ! ( "Invalid port in configured service name: {}" , p) )
98
+ } ) ;
99
+ if service_parts. next ( ) . is_some ( ) {
100
+ panic ! ( "Configured service name contains more than one : character" ) ;
101
+ }
102
+ // Build a load-balanced channel given a service name and a port.
103
+ let mut builder = LoadBalancedChannel :: builder ( ( hostname, port) ) ;
104
+ //.dns_probe_interval(Duration::from_secs(10))
105
+ if let Some ( tls_config) = & client_tls {
106
+ builder = builder. with_tls ( tls_config. clone ( ) ) ;
107
+ }
108
+ let channel = builder
109
+ . channel ( )
110
+ . await
111
+ . context ( format ! ( "Channel failed for service {}" , * service) ) ?;
112
+ Ok ( ( * model, GenerationServiceClient :: new ( channel) ) )
113
+ as Result < ( & ' static str , GenerationServiceClient < LoadBalancedChannel > ) , anyhow:: Error >
114
+ } ) )
115
+ . await
116
+ . expect ( "Error creating upstream service clients" )
117
+ . into_iter ( )
118
+ . collect ( ) ;
119
+ tracing:: info!(
120
+ "{} upstream gRPC clients created successfully" ,
121
+ grpc_addr. port( )
122
+ ) ;
119
123
120
124
// Build and start gRPC server in background task
121
125
let grpc_service = GenerationServicer { clients } ;
@@ -132,13 +136,15 @@ pub async fn run(
132
136
// fail before starting
133
137
sleep ( Duration :: from_secs ( 2 ) ) . await ;
134
138
if grpc_server_handle. is_finished ( ) {
135
- grpc_server_handle. await . unwrap ( ) . expect ( "gRPC server startup failed" ) ;
139
+ grpc_server_handle
140
+ . await
141
+ . unwrap ( )
142
+ . expect ( "gRPC server startup failed" ) ;
136
143
panic ! ( ) ; // should not reach here
137
144
}
138
145
139
146
// Build and await on the HTTP server
140
- let app = Router :: new ( )
141
- . route ( "/health" , get ( health) ) ;
147
+ let app = Router :: new ( ) . route ( "/health" , get ( health) ) ;
142
148
143
149
let server = axum:: Server :: bind ( & http_addr)
144
150
. serve ( app. into_make_service ( ) )
@@ -147,7 +153,10 @@ pub async fn run(
147
153
tracing:: info!( "HTTP server started on port {}" , http_addr. port( ) ) ;
148
154
server. await . expect ( "HTTP server crashed!" ) ;
149
155
150
- grpc_server_handle. await . unwrap ( ) . expect ( "gRPC server crashed" ) ;
156
+ grpc_server_handle
157
+ . await
158
+ . unwrap ( )
159
+ . expect ( "gRPC server crashed" ) ;
151
160
}
152
161
153
162
async fn health ( ) -> & ' static str {
@@ -156,42 +165,47 @@ async fn health() -> &'static str {
156
165
}
157
166
158
167
async fn load_pem ( path : String , name : & str ) -> Vec < u8 > {
159
- read ( & path) . await . expect ( & * format ! ( "couldn't load {name} from {path}" ) )
168
+ read ( & path)
169
+ . await
170
+ . unwrap_or_else ( |_| panic ! ( "couldn't load {name} from {path}" ) )
160
171
}
161
172
162
173
/*
163
174
TODO:
164
175
- Log errors/timings
165
176
*/
166
177
167
-
168
178
#[ derive( Debug , Default ) ]
169
179
pub struct GenerationServicer {
170
- clients : HashMap < & ' static str , GenerationServiceClient < LoadBalancedChannel > >
180
+ clients : HashMap < & ' static str , GenerationServiceClient < LoadBalancedChannel > > ,
171
181
}
172
182
173
183
impl GenerationServicer {
174
- async fn client ( & self , model_id : & String )
175
- -> Result < GenerationServiceClient < LoadBalancedChannel > , Status > {
176
- Ok (
177
- self . clients . get ( & * * model_id )
178
- . ok_or_else ( || Status :: not_found (
179
- format ! ( "Unrecognized model_id: {model_id}" ) )
180
- ) ?
181
- . clone ( )
182
- )
184
+ async fn client (
185
+ & self ,
186
+ model_id : & String ,
187
+ ) -> Result < GenerationServiceClient < LoadBalancedChannel > , Status > {
188
+ Ok ( self
189
+ . clients
190
+ . get ( & * * model_id )
191
+ . ok_or_else ( || Status :: not_found ( format ! ( "Unrecognized model_id: {model_id}" ) ) ) ?
192
+ . clone ( ) )
183
193
}
184
194
}
185
195
186
196
#[ tonic:: async_trait]
187
197
impl GenerationService for GenerationServicer {
188
198
#[ instrument( skip_all) ]
189
- async fn generate ( & self , request : Request < BatchedGenerationRequest > )
190
- -> Result < Response < BatchedGenerationResponse > , Status > {
199
+ async fn generate (
200
+ & self ,
201
+ request : Request < BatchedGenerationRequest > ,
202
+ ) -> Result < Response < BatchedGenerationResponse > , Status > {
191
203
//let start_time = Instant::now();
192
204
let br = request. get_ref ( ) ;
193
205
if br. requests . is_empty ( ) {
194
- return Ok ( Response :: new ( BatchedGenerationResponse { responses : vec ! [ ] } ) ) ;
206
+ return Ok ( Response :: new ( BatchedGenerationResponse {
207
+ responses : vec ! [ ] ,
208
+ } ) ) ;
195
209
}
196
210
tracing:: debug!( "Routing generation request for Model ID {}" , & br. model_id) ;
197
211
self . client ( & br. model_id ) . await ?. generate ( request) . await
@@ -200,32 +214,50 @@ impl GenerationService for GenerationServicer {
200
214
type GenerateStreamStream = Streaming < GenerationResponse > ;
201
215
202
216
#[ instrument( skip_all) ]
203
- async fn generate_stream ( & self , request : Request < SingleGenerationRequest > )
204
- -> Result < Response < Self :: GenerateStreamStream > , Status > {
217
+ async fn generate_stream (
218
+ & self ,
219
+ request : Request < SingleGenerationRequest > ,
220
+ ) -> Result < Response < Self :: GenerateStreamStream > , Status > {
205
221
let sr = request. get_ref ( ) ;
206
222
if sr. request . is_none ( ) {
207
223
return Err ( Status :: invalid_argument ( "missing request" ) ) ;
208
224
}
209
- tracing:: debug!( "Routing streaming generation request for Model ID {}" , & sr. model_id) ;
210
- self . client ( & sr. model_id ) . await ?. generate_stream ( request) . await
225
+ tracing:: debug!(
226
+ "Routing streaming generation request for Model ID {}" ,
227
+ & sr. model_id
228
+ ) ;
229
+ self . client ( & sr. model_id )
230
+ . await ?
231
+ . generate_stream ( request)
232
+ . await
211
233
}
212
234
213
235
#[ instrument( skip_all) ]
214
- async fn tokenize ( & self , request : Request < BatchedTokenizeRequest > )
215
- -> Result < Response < BatchedTokenizeResponse > , Status > {
236
+ async fn tokenize (
237
+ & self ,
238
+ request : Request < BatchedTokenizeRequest > ,
239
+ ) -> Result < Response < BatchedTokenizeResponse > , Status > {
216
240
let br = request. get_ref ( ) ;
217
241
if br. requests . is_empty ( ) {
218
- return Ok ( Response :: new ( BatchedTokenizeResponse { responses : vec ! [ ] } ) ) ;
242
+ return Ok ( Response :: new ( BatchedTokenizeResponse { responses : vec ! [ ] } ) ) ;
219
243
}
220
244
tracing:: debug!( "Routing tokenization request for Model ID {}" , & br. model_id) ;
221
245
self . client ( & br. model_id ) . await ?. tokenize ( request) . await
222
246
}
223
247
224
248
#[ instrument( skip_all) ]
225
- async fn model_info ( & self , request : Request < ModelInfoRequest > )
226
- -> Result < Response < ModelInfoResponse > , Status > {
227
- tracing:: debug!( "Routing model info request for Model ID {}" , & request. get_ref( ) . model_id) ;
228
- self . client ( & request. get_ref ( ) . model_id ) . await ?. model_info ( request) . await
249
+ async fn model_info (
250
+ & self ,
251
+ request : Request < ModelInfoRequest > ,
252
+ ) -> Result < Response < ModelInfoResponse > , Status > {
253
+ tracing:: debug!(
254
+ "Routing model info request for Model ID {}" ,
255
+ & request. get_ref( ) . model_id
256
+ ) ;
257
+ self . client ( & request. get_ref ( ) . model_id )
258
+ . await ?
259
+ . model_info ( request)
260
+ . await
229
261
}
230
262
}
231
263
@@ -238,15 +270,15 @@ async fn shutdown_signal() {
238
270
} ;
239
271
240
272
#[ cfg( unix) ]
241
- let terminate = async {
273
+ let terminate = async {
242
274
signal:: unix:: signal ( signal:: unix:: SignalKind :: terminate ( ) )
243
275
. expect ( "failed to install signal handler" )
244
276
. recv ( )
245
277
. await ;
246
278
} ;
247
279
248
280
#[ cfg( not( unix) ) ]
249
- let terminate = std:: future:: pending :: < ( ) > ( ) ;
281
+ let terminate = std:: future:: pending :: < ( ) > ( ) ;
250
282
251
283
tokio:: select! {
252
284
_ = ctrl_c => { } ,
0 commit comments