Skip to content

Commit ec04b9d

Browse files
drbhOlivierDehaene
andauthored
feat: support vertex api endpoint (#184)
Co-authored-by: OlivierDehaene <[email protected]>
1 parent 9ab2f2c commit ec04b9d

File tree

4 files changed

+168
-4
lines changed

4 files changed

+168
-4
lines changed

router/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,4 @@ candle-cuda = ["candle", "text-embeddings-backend/flash-attn"]
7878
candle-cuda-turing = ["candle", "text-embeddings-backend/flash-attn-v1"]
7979
candle-cuda-volta = ["candle", "text-embeddings-backend/cuda"]
8080
static-linking = ["text-embeddings-backend/static-linking"]
81+
google = []

router/src/http/server.rs

Lines changed: 140 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::http::types::{
33
EmbedAllRequest, EmbedAllResponse, EmbedRequest, EmbedResponse, Input, OpenAICompatEmbedding,
44
OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage,
55
PredictInput, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse,
6-
Sequence, SimpleToken, TokenizeRequest, TokenizeResponse,
6+
Sequence, SimpleToken, TokenizeRequest, TokenizeResponse, VertexRequest,
77
};
88
use crate::{
99
shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType,
@@ -992,6 +992,101 @@ async fn tokenize(
992992
Ok(Json(TokenizeResponse(tokens)))
993993
}
994994

995+
/// Generate embeddings from a Vertex request
996+
#[utoipa::path(
997+
post,
998+
tag = "Text Embeddings Inference",
999+
path = "/vertex",
1000+
request_body = VertexRequest,
1001+
responses(
1002+
(status = 200, description = "Embeddings", body = EmbedResponse),
1003+
(status = 424, description = "Embedding Error", body = ErrorResponse,
1004+
example = json ! ({"error": "Inference failed", "error_type": "backend"})),
1005+
(status = 429, description = "Model is overloaded", body = ErrorResponse,
1006+
example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})),
1007+
(status = 422, description = "Tokenization error", body = ErrorResponse,
1008+
example = json ! ({"error": "Tokenization error", "error_type": "tokenizer"})),
1009+
(status = 413, description = "Batch size error", body = ErrorResponse,
1010+
example = json ! ({"error": "Batch size error", "error_type": "validation"})),
1011+
)
1012+
)]
1013+
#[instrument(skip_all)]
1014+
async fn vertex_compatibility(
1015+
infer: Extension<Infer>,
1016+
info: Extension<Info>,
1017+
Json(req): Json<VertexRequest>,
1018+
) -> Result<(HeaderMap, Json<EmbedResponse>), (StatusCode, Json<ErrorResponse>)> {
1019+
let span = tracing::Span::current();
1020+
let start_time = Instant::now();
1021+
1022+
let batch_size = req.instances.len();
1023+
if batch_size > info.max_client_batch_size {
1024+
let message = format!(
1025+
"batch size {batch_size} > maximum allowed batch size {}",
1026+
info.max_client_batch_size
1027+
);
1028+
tracing::error!("{message}");
1029+
let err = ErrorResponse {
1030+
error: message,
1031+
error_type: ErrorType::Validation,
1032+
};
1033+
metrics::increment_counter!("te_request_failure", "err" => "batch_size");
1034+
Err(err)?;
1035+
}
1036+
1037+
let mut futures = Vec::with_capacity(batch_size);
1038+
let mut compute_chars = 0;
1039+
1040+
for instance in req.instances.iter() {
1041+
let input = instance.inputs.clone();
1042+
compute_chars += input.chars().count();
1043+
1044+
let local_infer = infer.clone();
1045+
futures.push(async move {
1046+
let permit = local_infer.acquire_permit().await;
1047+
local_infer
1048+
.embed_pooled(input, instance.truncate, instance.normalize, permit)
1049+
.await
1050+
})
1051+
}
1052+
let results = join_all(futures)
1053+
.await
1054+
.into_iter()
1055+
.collect::<Result<Vec<PooledEmbeddingsInferResponse>, TextEmbeddingsError>>()
1056+
.map_err(ErrorResponse::from)?;
1057+
1058+
let mut embeddings = Vec::with_capacity(batch_size);
1059+
let mut total_tokenization_time = 0;
1060+
let mut total_queue_time = 0;
1061+
let mut total_inference_time = 0;
1062+
let mut total_compute_tokens = 0;
1063+
1064+
for r in results {
1065+
total_tokenization_time += r.metadata.tokenization.as_nanos() as u64;
1066+
total_queue_time += r.metadata.queue.as_nanos() as u64;
1067+
total_inference_time += r.metadata.inference.as_nanos() as u64;
1068+
total_compute_tokens += r.metadata.prompt_tokens;
1069+
embeddings.push(r.results);
1070+
}
1071+
let batch_size = batch_size as u64;
1072+
1073+
let response = EmbedResponse(embeddings);
1074+
let metadata = ResponseMetadata::new(
1075+
compute_chars,
1076+
total_compute_tokens,
1077+
start_time,
1078+
Duration::from_nanos(total_tokenization_time / batch_size),
1079+
Duration::from_nanos(total_queue_time / batch_size),
1080+
Duration::from_nanos(total_inference_time / batch_size),
1081+
);
1082+
1083+
metadata.record_span(&span);
1084+
metadata.record_metrics();
1085+
tracing::info!("Success");
1086+
1087+
Ok((HeaderMap::from(metadata), Json(response)))
1088+
}
1089+
9951090
/// Prometheus metrics scrape endpoint
9961091
#[utoipa::path(
9971092
get,
@@ -1089,9 +1184,32 @@ pub async fn run(
10891184
.allow_headers([http::header::CONTENT_TYPE])
10901185
.allow_origin(allow_origin);
10911186

1187+
// Define VertextApiDoc conditionally only if the "google" feature is enabled
1188+
let doc = {
1189+
// avoid `mut` if possible
1190+
#[cfg(feature = "google")]
1191+
{
1192+
use crate::http::types::VertexInstance;
1193+
1194+
#[derive(OpenApi)]
1195+
#[openapi(
1196+
paths(vertex_compatibility),
1197+
components(schemas(VertexInstance, VertexRequest))
1198+
)]
1199+
struct VertextApiDoc;
1200+
1201+
// limiting mutability to the smallest scope necessary
1202+
let mut doc = ApiDoc::openapi();
1203+
doc.merge(VertextApiDoc::openapi());
1204+
doc
1205+
}
1206+
#[cfg(not(feature = "google"))]
1207+
ApiDoc::openapi()
1208+
};
1209+
10921210
// Create router
1093-
let app = Router::new()
1094-
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
1211+
let base_routes = Router::new()
1212+
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc))
10951213
// Base routes
10961214
.route("/info", get(get_model_info))
10971215
.route("/embed", post(embed))
@@ -1101,6 +1219,8 @@ pub async fn run(
11011219
.route("/tokenize", post(tokenize))
11021220
// OpenAI compat route
11031221
.route("/embeddings", post(openai_embed))
1222+
// Vertex compat route
1223+
.route("/vertex", post(vertex_compatibility))
11041224
// Base Health route
11051225
.route("/health", get(health))
11061226
// Inference API health route
@@ -1110,8 +1230,10 @@ pub async fn run(
11101230
// Prometheus metrics route
11111231
.route("/metrics", get(metrics));
11121232

1233+
let mut app = Router::new().merge(base_routes);
1234+
11131235
// Set default routes
1114-
let app = match &info.model_type {
1236+
app = match &info.model_type {
11151237
ModelType::Classifier(_) => {
11161238
app.route("/", post(predict))
11171239
// AWS Sagemaker route
@@ -1129,6 +1251,20 @@ pub async fn run(
11291251
}
11301252
};
11311253

1254+
#[cfg(feature = "google")]
1255+
{
1256+
tracing::info!("Built with `google` feature");
1257+
tracing::info!(
1258+
"Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."
1259+
);
1260+
if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") {
1261+
app = app.route(&env_predict_route, post(vertex_compatibility));
1262+
}
1263+
if let Ok(env_health_route) = std::env::var("AIP_HEALTH_ROUTE") {
1264+
app = app.route(&env_health_route, get(health));
1265+
}
1266+
}
1267+
11321268
let app = app
11331269
.layer(Extension(infer))
11341270
.layer(Extension(info))

router/src/http/types.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,3 +364,20 @@ pub(crate) struct SimpleToken {
364364
#[derive(Serialize, ToSchema)]
365365
#[schema(example = json!([[{"id": 0, "text": "test", "special": false, "start": 0, "stop": 2}]]))]
366366
pub(crate) struct TokenizeResponse(pub Vec<Vec<SimpleToken>>);
367+
368+
#[derive(Clone, Deserialize, ToSchema)]
369+
pub(crate) struct VertexInstance {
370+
#[schema(example = "What is Deep Learning?")]
371+
pub inputs: String,
372+
#[serde(default)]
373+
#[schema(default = "false", example = "false")]
374+
pub truncate: bool,
375+
#[serde(default = "default_normalize")]
376+
#[schema(default = "true", example = "true")]
377+
pub normalize: bool,
378+
}
379+
380+
#[derive(Deserialize, ToSchema)]
381+
pub(crate) struct VertexRequest {
382+
pub instances: Vec<VertexInstance>,
383+
}

router/src/lib.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,16 @@ pub async fn run(
239239
docker_label: option_env!("DOCKER_LABEL"),
240240
};
241241

242+
// use AIP_HTTP_PORT if google feature is enabled
243+
let port = if cfg!(feature = "google") {
244+
std::env::var("AIP_HTTP_PORT")
245+
.ok()
246+
.and_then(|p| p.parse().ok())
247+
.expect("Invalid or unset AIP_HTTP_PORT")
248+
} else {
249+
port
250+
};
251+
242252
let addr = match hostname.unwrap_or("0.0.0.0".to_string()).parse() {
243253
Ok(ip) => SocketAddr::new(ip, port),
244254
Err(_) => {

0 commit comments

Comments
 (0)