diff --git a/README.md b/README.md index 76dcf8c7..24d43a57 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,11 @@ Options: Optionally control the number of tokenizer workers used for payload tokenization, validation and truncation. Default to the number of CPU cores on the machine [env: TOKENIZATION_WORKERS=] + + --served-model-name + The name of the model that is returned when serving OpenAI requests. If not specified, defaults to value in model-id. + + [env: SERVED_MODEL_NAME=] --dtype The dtype to be forced upon the model diff --git a/docs/openapi.json b/docs/openapi.json index b6484887..b57fa231 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1052,6 +1052,7 @@ "required": [ "model_id", "model_dtype", + "served_model_name", "model_type", "max_concurrent_requests", "max_input_length", @@ -1107,6 +1108,11 @@ "description": "Model info", "example": "thenlper/gte-base" }, + "served_model_name": { + "type": "string", + "description": "Model name specified by user", + "example": "thenlper/gte-base" + }, "model_sha": { "type": "string", "example": "fca14538aa9956a46526bd1d0d11d69e19b5a101", diff --git a/docs/source/en/cli_arguments.md b/docs/source/en/cli_arguments.md index 5bf025ee..f0d2158a 100644 --- a/docs/source/en/cli_arguments.md +++ b/docs/source/en/cli_arguments.md @@ -46,6 +46,11 @@ Options: [env: DTYPE=] [possible values: float16, float32] + + --served-model-name + The name of the model that is returned when serving OpenAI requests. If not specified, defaults to value in model-id. + + [env: SERVED_MODEL_NAME=] --pooling Optionally control the pooling method for embedding models. diff --git a/proto/tei.proto b/proto/tei.proto index aac6c2ba..53a0fd19 100644 --- a/proto/tei.proto +++ b/proto/tei.proto @@ -58,6 +58,7 @@ message InfoResponse { optional uint32 max_batch_requests = 11; uint32 max_client_batch_size = 12; uint32 tokenization_workers = 13; + optional string served_model_name = 14; } message Metadata { diff --git a/router/src/grpc/server.rs b/router/src/grpc/server.rs index 3c98f8b8..08e35ffe 100644 --- a/router/src/grpc/server.rs +++ b/router/src/grpc/server.rs @@ -572,6 +572,7 @@ impl grpc::info_server::Info for TextEmbeddingsService { model_id: self.info.model_id.clone(), model_sha: self.info.model_sha.clone(), model_dtype: self.info.model_dtype.clone(), + served_model_name: self.info.served_model_name.clone(), model_type: model_type.into(), max_concurrent_requests: self.info.max_concurrent_requests as u32, max_input_length: self.info.max_input_length as u32, diff --git a/router/src/http/server.rs b/router/src/http/server.rs index f805744a..3432b47f 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -1287,7 +1287,7 @@ async fn openai_embed( let response = OpenAICompatResponse { object: "list", data: embeddings, - model: info.model_id.clone(), + model: info.served_model_name.clone().unwrap_or_else(|| info.model_id.clone()), usage: OpenAICompatUsage { prompt_tokens: compute_tokens, total_tokens: compute_tokens, diff --git a/router/src/lib.rs b/router/src/lib.rs index f1b8ba26..b688ac87 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -46,6 +46,7 @@ pub async fn run( revision: Option, tokenization_workers: Option, dtype: Option, + served_model_name: Option, pooling: Option, max_concurrent_requests: usize, max_batch_tokens: usize, @@ -279,6 +280,7 @@ pub async fn run( model_id, model_sha: revision, model_dtype: dtype.to_string(), + served_model_name, model_type, max_concurrent_requests, max_input_length, @@ -493,6 +495,8 @@ pub struct Info { pub model_sha: Option, #[cfg_attr(feature = "http", schema(example = "float16"))] pub model_dtype: String, + #[cfg_attr(feature = "http", schema(example = "thenlper/gte-base"))] + pub served_model_name: Option, pub model_type: ModelType, /// Router Parameters #[cfg_attr(feature = "http", schema(example = "128"))] diff --git a/router/src/main.rs b/router/src/main.rs index 39b975d5..5b2d97ea 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -36,6 +36,11 @@ struct Args { #[clap(long, env, value_enum)] dtype: Option, + /// The name of the model that is being served. If not specified, defaults to + /// model-id. + #[clap(long, env)] + served_model_name: Option, + /// Optionally control the pooling method for embedding models. /// /// If `pooling` is not set, the pooling configuration will be parsed from the @@ -214,6 +219,7 @@ async fn main() -> Result<()> { args.revision, args.tokenization_workers, args.dtype, + args.served_model_name, args.pooling, args.max_concurrent_requests, args.max_batch_tokens, diff --git a/router/tests/common.rs b/router/tests/common.rs index 1ae4333b..5975cda9 100644 --- a/router/tests/common.rs +++ b/router/tests/common.rs @@ -51,6 +51,7 @@ pub async fn start_server(model_id: String, revision: Option, dtype: DTy Some(1), Some(dtype), None, + None, 4, 1024, None,