@@ -3,7 +3,7 @@ use crate::http::types::{
3
3
EmbedAllRequest , EmbedAllResponse , EmbedRequest , EmbedResponse , Input , OpenAICompatEmbedding ,
4
4
OpenAICompatErrorResponse , OpenAICompatRequest , OpenAICompatResponse , OpenAICompatUsage ,
5
5
PredictInput , PredictRequest , PredictResponse , Prediction , Rank , RerankRequest , RerankResponse ,
6
- Sequence , SimpleToken , TokenizeRequest , TokenizeResponse ,
6
+ Sequence , SimpleToken , TokenizeRequest , TokenizeResponse , VertexRequest ,
7
7
} ;
8
8
use crate :: {
9
9
shutdown, ClassifierModel , EmbeddingModel , ErrorResponse , ErrorType , Info , ModelType ,
@@ -992,6 +992,101 @@ async fn tokenize(
992
992
Ok ( Json ( TokenizeResponse ( tokens) ) )
993
993
}
994
994
995
+ /// Generate embeddings from a Vertex request
996
+ #[ utoipa:: path(
997
+ post,
998
+ tag = "Text Embeddings Inference" ,
999
+ path = "/vertex" ,
1000
+ request_body = VertexRequest ,
1001
+ responses(
1002
+ ( status = 200 , description = "Embeddings" , body = EmbedResponse ) ,
1003
+ ( status = 424 , description = "Embedding Error" , body = ErrorResponse ,
1004
+ example = json ! ( { "error" : "Inference failed" , "error_type" : "backend" } ) ) ,
1005
+ ( status = 429 , description = "Model is overloaded" , body = ErrorResponse ,
1006
+ example = json ! ( { "error" : "Model is overloaded" , "error_type" : "overloaded" } ) ) ,
1007
+ ( status = 422 , description = "Tokenization error" , body = ErrorResponse ,
1008
+ example = json ! ( { "error" : "Tokenization error" , "error_type" : "tokenizer" } ) ) ,
1009
+ ( status = 413 , description = "Batch size error" , body = ErrorResponse ,
1010
+ example = json ! ( { "error" : "Batch size error" , "error_type" : "validation" } ) ) ,
1011
+ )
1012
+ ) ]
1013
+ #[ instrument( skip_all) ]
1014
+ async fn vertex_compatibility (
1015
+ infer : Extension < Infer > ,
1016
+ info : Extension < Info > ,
1017
+ Json ( req) : Json < VertexRequest > ,
1018
+ ) -> Result < ( HeaderMap , Json < EmbedResponse > ) , ( StatusCode , Json < ErrorResponse > ) > {
1019
+ let span = tracing:: Span :: current ( ) ;
1020
+ let start_time = Instant :: now ( ) ;
1021
+
1022
+ let batch_size = req. instances . len ( ) ;
1023
+ if batch_size > info. max_client_batch_size {
1024
+ let message = format ! (
1025
+ "batch size {batch_size} > maximum allowed batch size {}" ,
1026
+ info. max_client_batch_size
1027
+ ) ;
1028
+ tracing:: error!( "{message}" ) ;
1029
+ let err = ErrorResponse {
1030
+ error : message,
1031
+ error_type : ErrorType :: Validation ,
1032
+ } ;
1033
+ metrics:: increment_counter!( "te_request_failure" , "err" => "batch_size" ) ;
1034
+ Err ( err) ?;
1035
+ }
1036
+
1037
+ let mut futures = Vec :: with_capacity ( batch_size) ;
1038
+ let mut compute_chars = 0 ;
1039
+
1040
+ for instance in req. instances . iter ( ) {
1041
+ let input = instance. inputs . clone ( ) ;
1042
+ compute_chars += input. chars ( ) . count ( ) ;
1043
+
1044
+ let local_infer = infer. clone ( ) ;
1045
+ futures. push ( async move {
1046
+ let permit = local_infer. acquire_permit ( ) . await ;
1047
+ local_infer
1048
+ . embed_pooled ( input, instance. truncate , instance. normalize , permit)
1049
+ . await
1050
+ } )
1051
+ }
1052
+ let results = join_all ( futures)
1053
+ . await
1054
+ . into_iter ( )
1055
+ . collect :: < Result < Vec < PooledEmbeddingsInferResponse > , TextEmbeddingsError > > ( )
1056
+ . map_err ( ErrorResponse :: from) ?;
1057
+
1058
+ let mut embeddings = Vec :: with_capacity ( batch_size) ;
1059
+ let mut total_tokenization_time = 0 ;
1060
+ let mut total_queue_time = 0 ;
1061
+ let mut total_inference_time = 0 ;
1062
+ let mut total_compute_tokens = 0 ;
1063
+
1064
+ for r in results {
1065
+ total_tokenization_time += r. metadata . tokenization . as_nanos ( ) as u64 ;
1066
+ total_queue_time += r. metadata . queue . as_nanos ( ) as u64 ;
1067
+ total_inference_time += r. metadata . inference . as_nanos ( ) as u64 ;
1068
+ total_compute_tokens += r. metadata . prompt_tokens ;
1069
+ embeddings. push ( r. results ) ;
1070
+ }
1071
+ let batch_size = batch_size as u64 ;
1072
+
1073
+ let response = EmbedResponse ( embeddings) ;
1074
+ let metadata = ResponseMetadata :: new (
1075
+ compute_chars,
1076
+ total_compute_tokens,
1077
+ start_time,
1078
+ Duration :: from_nanos ( total_tokenization_time / batch_size) ,
1079
+ Duration :: from_nanos ( total_queue_time / batch_size) ,
1080
+ Duration :: from_nanos ( total_inference_time / batch_size) ,
1081
+ ) ;
1082
+
1083
+ metadata. record_span ( & span) ;
1084
+ metadata. record_metrics ( ) ;
1085
+ tracing:: info!( "Success" ) ;
1086
+
1087
+ Ok ( ( HeaderMap :: from ( metadata) , Json ( response) ) )
1088
+ }
1089
+
995
1090
/// Prometheus metrics scrape endpoint
996
1091
#[ utoipa:: path(
997
1092
get,
@@ -1089,9 +1184,32 @@ pub async fn run(
1089
1184
. allow_headers ( [ http:: header:: CONTENT_TYPE ] )
1090
1185
. allow_origin ( allow_origin) ;
1091
1186
1187
+ // Define VertextApiDoc conditionally only if the "google" feature is enabled
1188
+ let doc = {
1189
+ // avoid `mut` if possible
1190
+ #[ cfg( feature = "google" ) ]
1191
+ {
1192
+ use crate :: http:: types:: VertexInstance ;
1193
+
1194
+ #[ derive( OpenApi ) ]
1195
+ #[ openapi(
1196
+ paths( vertex_compatibility) ,
1197
+ components( schemas( VertexInstance , VertexRequest ) )
1198
+ ) ]
1199
+ struct VertextApiDoc ;
1200
+
1201
+ // limiting mutability to the smallest scope necessary
1202
+ let mut doc = ApiDoc :: openapi ( ) ;
1203
+ doc. merge ( VertextApiDoc :: openapi ( ) ) ;
1204
+ doc
1205
+ }
1206
+ #[ cfg( not( feature = "google" ) ) ]
1207
+ ApiDoc :: openapi ( )
1208
+ } ;
1209
+
1092
1210
// Create router
1093
- let app = Router :: new ( )
1094
- . merge ( SwaggerUi :: new ( "/docs" ) . url ( "/api-doc/openapi.json" , ApiDoc :: openapi ( ) ) )
1211
+ let base_routes = Router :: new ( )
1212
+ . merge ( SwaggerUi :: new ( "/docs" ) . url ( "/api-doc/openapi.json" , doc ) )
1095
1213
// Base routes
1096
1214
. route ( "/info" , get ( get_model_info) )
1097
1215
. route ( "/embed" , post ( embed) )
@@ -1101,6 +1219,8 @@ pub async fn run(
1101
1219
. route ( "/tokenize" , post ( tokenize) )
1102
1220
// OpenAI compat route
1103
1221
. route ( "/embeddings" , post ( openai_embed) )
1222
+ // Vertex compat route
1223
+ . route ( "/vertex" , post ( vertex_compatibility) )
1104
1224
// Base Health route
1105
1225
. route ( "/health" , get ( health) )
1106
1226
// Inference API health route
@@ -1110,8 +1230,10 @@ pub async fn run(
1110
1230
// Prometheus metrics route
1111
1231
. route ( "/metrics" , get ( metrics) ) ;
1112
1232
1233
+ let mut app = Router :: new ( ) . merge ( base_routes) ;
1234
+
1113
1235
// Set default routes
1114
- let app = match & info. model_type {
1236
+ app = match & info. model_type {
1115
1237
ModelType :: Classifier ( _) => {
1116
1238
app. route ( "/" , post ( predict) )
1117
1239
// AWS Sagemaker route
@@ -1129,6 +1251,20 @@ pub async fn run(
1129
1251
}
1130
1252
} ;
1131
1253
1254
+ #[ cfg( feature = "google" ) ]
1255
+ {
1256
+ tracing:: info!( "Built with `google` feature" ) ;
1257
+ tracing:: info!(
1258
+ "Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."
1259
+ ) ;
1260
+ if let Ok ( env_predict_route) = std:: env:: var ( "AIP_PREDICT_ROUTE" ) {
1261
+ app = app. route ( & env_predict_route, post ( vertex_compatibility) ) ;
1262
+ }
1263
+ if let Ok ( env_health_route) = std:: env:: var ( "AIP_HEALTH_ROUTE" ) {
1264
+ app = app. route ( & env_health_route, get ( health) ) ;
1265
+ }
1266
+ }
1267
+
1132
1268
let app = app
1133
1269
. layer ( Extension ( infer) )
1134
1270
. layer ( Extension ( info) )
0 commit comments