Skip to content

Commit e3e34b8

Browse files
committed
Merge branch 'main' into add-bfloat16-support
2 parents e506c93 + cb9de7a commit e3e34b8

File tree

10 files changed

+154
-97
lines changed

10 files changed

+154
-97
lines changed

README.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,15 @@ To see all options to serve your models:
137137
$ text-embeddings-router --help
138138
Text Embedding Webserver
139139

140-
Usage: text-embeddings-router [OPTIONS]
140+
Usage: text-embeddings-router [OPTIONS] --model-id <MODEL_ID>
141141

142142
Options:
143143
--model-id <MODEL_ID>
144-
The name of the model to load. Can be a MODEL_ID as listed on <https://hf.co/models> like `BAAI/bge-large-en-v1.5`. Or it can be a local directory containing the necessary files as saved by `save_pretrained(...)` methods of transformers
144+
The Hugging Face model ID, can be any model listed on <https://huggingface.co/models> with the `text-embeddings-inference` tag (meaning it's compatible with Text Embeddings Inference).
145+
146+
Alternatively, the specified ID can also be a path to a local directory containing the necessary model files saved by the `save_pretrained(...)` methods of either Transformers or Sentence Transformers.
145147

146148
[env: MODEL_ID=]
147-
[default: BAAI/bge-large-en-v1.5]
148149

149150
--revision <REVISION>
150151
The actual revision of the model if you're referring to a model on the hub. You can use a specific commit id or a branch like `refs/pr/2`
@@ -162,6 +163,11 @@ Options:
162163
[env: DTYPE=]
163164
[possible values: float16, float32]
164165

166+
--served-model-name <SERVED_MODEL_NAME>
167+
The name of the model that is being served. If not specified, defaults to `--model-id`. It is only used for the OpenAI-compatible endpoints via HTTP
168+
169+
[env: SERVED_MODEL_NAME=]
170+
165171
--pooling <POOLING>
166172
Optionally control the pooling method for embedding models.
167173

@@ -238,10 +244,9 @@ Options:
238244

239245
Some embedding models require an extra `Dense` module which contains a single Linear layer and an activation function. By default, those `Dense` modules are stored under the `2_Dense` directory, but there might be cases where different `Dense` modules are provided, to convert the pooled embeddings into different dimensions, available as `2_Dense_<dims>` e.g. https://huggingface.co/NovaSearch/stella_en_400M_v5.
240246

241-
Note that this argument is optional, only required to be set if the path to the `Dense` module is other than `2_Dense`. And it also applies when leveraging the `candle` backend.
247+
Note that this argument is optional, only required to be set if there is no `modules.json` file or when you want to override a single Dense module path, only when running with the `candle` backend.
242248

243249
[env: DENSE_PATH=]
244-
[default: 2_Dense]
245250

246251
--hf-token <HF_TOKEN>
247252
Your Hugging Face Hub token

backends/candle/src/layers/linear.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ use serde::Deserialize;
55
#[derive(Debug, Deserialize, PartialEq, Clone)]
66
#[serde(rename_all = "lowercase")]
77
pub enum HiddenAct {
8-
#[serde(alias = "gelu_pytorch_tanh")]
8+
// NOTE: `GeluErf` is excluded due to incompatibility with cuBLASLt, as only GeLU + tanh
9+
// approximation is implemented due to efficiency, so GeLU is standardized to tanh approx. with
10+
// slight numerical deviation from GeLU erf (neglible on inference quality)
11+
#[serde(alias = "gelu_new", alias = "gelu_pytorch_tanh")]
912
Gelu,
1013
Relu,
1114
Silu,

backends/candle/src/lib.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,19 @@ impl CandleBackend {
180180
let config: String = std::fs::read_to_string(model_path.join("config.json"))
181181
.context("Unable to read config file")
182182
.map_err(|err| BackendError::Start(format!("{err:?}")))?;
183-
let config: Config = serde_json::from_str(&config)
183+
184+
let config_value: serde_json::Value = serde_json::from_str(&config)
185+
.context("Unable to parse config.json")
186+
.map_err(|err| BackendError::Start(format!("{err:?}")))?;
187+
188+
if let Some(hidden_act) = config_value.get("hidden_act").and_then(|v| v.as_str()) {
189+
if hidden_act == "gelu" {
190+
// NOTE: https://github.com/huggingface/text-embeddings-inference/pull/753
191+
tracing::warn!("The `config.json` contains `hidden_act=gelu` and GeLU + tanh approximation will be used instead of exact GeLU (aka. GeLU erf), which might lead to subtle differences with Transformers or Sentence Transformers outputs which use exact GeLU when `hidden_act=gelu`, unless specified otherwise. GeLU + tanh is more efficient and more consistent across devices (e.g., cuBLASLt comes with fused GeLU + tanh), and will have negligible impact on the inference quality.");
192+
}
193+
}
194+
195+
let config: Config = serde_json::from_value(config_value)
184196
.context("Model is not supported")
185197
.map_err(|err| BackendError::Start(format!("{err:?}")))?;
186198

backends/candle/tests/snapshots/test_bert__emotions_batch.snap

Lines changed: 82 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -2,87 +2,87 @@
22
source: backends/candle/tests/test_bert.rs
33
expression: predictions_batch
44
---
5-
- - -6.548559
6-
- -6.302024
7-
- -4.8671727
8-
- -3.9600255
9-
- -4.6329865
10-
- -6.2816987
11-
- -6.069644
12-
- -5.7742686
13-
- -6.9259467
14-
- -6.1909447
15-
- -5.67395
5+
- - -6.5485673
6+
- -6.3020196
7+
- -4.86717
8+
- -3.9600184
9+
- -4.632993
10+
- -6.2817054
11+
- -6.069636
12+
- -5.7742705
13+
- -6.925953
14+
- -6.190939
15+
- -5.6739373
1616
- -6.1698227
17-
- -7.513461
18-
- -6.865867
19-
- -7.186479
20-
- -7.128109
21-
- -8.210709
22-
- -7.0171394
23-
- -7.1321163
24-
- -8.533409
25-
- -6.2294865
26-
- -8.742306
27-
- -5.7792044
28-
- -8.657227
29-
- -8.258305
30-
- -6.64832
31-
- -7.4060283
32-
- 3.046496
33-
- - -5.8167515
34-
- -6.6119466
35-
- -5.2771955
36-
- -2.6306503
37-
- -4.6419163
38-
- -5.579778
39-
- -5.797174
40-
- -6.0305815
41-
- -5.8720746
42-
- 0.45377323
43-
- -3.0235887
44-
- -5.3944407
45-
- -5.186683
46-
- -6.2649117
47-
- -6.1962767
48-
- -6.97937
49-
- -5.5674877
50-
- -5.521044
51-
- -5.8899207
52-
- -4.8699703
53-
- -5.6259933
54-
- -7.6109924
55-
- -4.3881936
56-
- -6.039008
57-
- -4.934696
58-
- -0.6715916
59-
- -6.399376
60-
- -2.4499295
61-
- - -6.548559
62-
- -6.302024
63-
- -4.8671727
64-
- -3.9600255
65-
- -4.6329865
66-
- -6.2816987
67-
- -6.069644
68-
- -5.7742686
69-
- -6.9259467
70-
- -6.1909447
71-
- -5.67395
17+
- -7.5134573
18+
- -6.8658743
19+
- -7.1864815
20+
- -7.128115
21+
- -8.2107115
22+
- -7.017146
23+
- -7.132131
24+
- -8.533407
25+
- -6.229486
26+
- -8.742311
27+
- -5.7792006
28+
- -8.65723
29+
- -8.258308
30+
- -6.648321
31+
- -7.406026
32+
- 3.0464942
33+
- - -5.816747
34+
- -6.611947
35+
- -5.2771983
36+
- -2.6306484
37+
- -4.6419153
38+
- -5.5797825
39+
- -5.7971735
40+
- -6.030578
41+
- -5.872076
42+
- 0.45378062
43+
- -3.0235896
44+
- -5.3944383
45+
- -5.18668
46+
- -6.264913
47+
- -6.196284
48+
- -6.9793677
49+
- -5.567489
50+
- -5.5210495
51+
- -5.889915
52+
- -4.8699794
53+
- -5.625993
54+
- -7.6109934
55+
- -4.388194
56+
- -6.0390115
57+
- -4.934693
58+
- -0.6715966
59+
- -6.3993735
60+
- -2.4499245
61+
- - -6.5485673
62+
- -6.3020196
63+
- -4.86717
64+
- -3.9600184
65+
- -4.632993
66+
- -6.2817054
67+
- -6.069636
68+
- -5.7742705
69+
- -6.925953
70+
- -6.190939
71+
- -5.6739373
7272
- -6.1698227
73-
- -7.513461
74-
- -6.865867
75-
- -7.186479
76-
- -7.128109
77-
- -8.210709
78-
- -7.0171394
79-
- -7.1321163
80-
- -8.533409
81-
- -6.2294865
82-
- -8.742306
83-
- -5.7792044
84-
- -8.657227
85-
- -8.258305
86-
- -6.64832
87-
- -7.4060283
88-
- 3.046496
73+
- -7.5134573
74+
- -6.8658743
75+
- -7.1864815
76+
- -7.128115
77+
- -8.2107115
78+
- -7.017146
79+
- -7.132131
80+
- -8.533407
81+
- -6.229486
82+
- -8.742311
83+
- -5.7792006
84+
- -8.65723
85+
- -8.258308
86+
- -6.648321
87+
- -7.406026
88+
- 3.0464942

docs/openapi.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,7 @@
12151215
"required": [
12161216
"model_id",
12171217
"model_dtype",
1218+
"served_model_name",
12181219
"model_type",
12191220
"max_concurrent_requests",
12201221
"max_input_length",
@@ -1278,6 +1279,10 @@
12781279
"model_type": {
12791280
"$ref": "#/components/schemas/ModelType"
12801281
},
1282+
"served_model_name": {
1283+
"type": "string",
1284+
"example": "thenlper/gte-base"
1285+
},
12811286
"sha": {
12821287
"type": "string",
12831288
"example": "null",

docs/source/en/cli_arguments.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,15 @@ To see all options to serve your models, run the following:
2222
$ text-embeddings-router --help
2323
Text Embedding Webserver
2424

25-
Usage: text-embeddings-router [OPTIONS]
25+
Usage: text-embeddings-router [OPTIONS] --model-id <MODEL_ID>
2626

2727
Options:
2828
--model-id <MODEL_ID>
29-
The name of the model to load. Can be a MODEL_ID as listed on <https://hf.co/models> like `BAAI/bge-large-en-v1.5`. Or it can be a local directory containing the necessary files as saved by `save_pretrained(...)` methods of transformers
29+
The Hugging Face model ID, can be any model listed on <https://huggingface.co/models> with the `text-embeddings-inference` tag (meaning it's compatible with Text Embeddings Inference).
30+
31+
Alternatively, the specified ID can also be a path to a local directory containing the necessary model files saved by the `save_pretrained(...)` methods of either Transformers or Sentence Transformers.
3032

3133
[env: MODEL_ID=]
32-
[default: BAAI/bge-large-en-v1.5]
3334

3435
--revision <REVISION>
3536
The actual revision of the model if you're referring to a model on the hub. You can use a specific commit id or a branch like `refs/pr/2`
@@ -47,6 +48,11 @@ Options:
4748
[env: DTYPE=]
4849
[possible values: float16, float32]
4950

51+
--served-model-name <SERVED_MODEL_NAME>
52+
The name of the model that is being served. If not specified, defaults to `--model-id`. It is only used for the OpenAI-compatible endpoints via HTTP
53+
54+
[env: SERVED_MODEL_NAME=]
55+
5056
--pooling <POOLING>
5157
Optionally control the pooling method for embedding models.
5258

@@ -123,10 +129,9 @@ Options:
123129

124130
Some embedding models require an extra `Dense` module which contains a single Linear layer and an activation function. By default, those `Dense` modules are stored under the `2_Dense` directory, but there might be cases where different `Dense` modules are provided, to convert the pooled embeddings into different dimensions, available as `2_Dense_<dims>` e.g. https://huggingface.co/NovaSearch/stella_en_400M_v5.
125131

126-
Note that this argument is optional, only required to be set if the path to the `Dense` module is other than `2_Dense`. And it also applies when leveraging the `candle` backend.
132+
Note that this argument is optional, only required to be set if there is no `modules.json` file or when you want to override a single Dense module path, only when running with the `candle` backend.
127133

128134
[env: DENSE_PATH=]
129-
[default: 2_Dense]
130135

131136
--hf-token <HF_TOKEN>
132137
Your Hugging Face Hub token

router/src/http/server.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1153,6 +1153,18 @@ async fn openai_embed(
11531153
span.set_parent(context);
11541154
}
11551155

1156+
// NOTE: Validation of `model` won't fail for the time being given that Text Embeddings
1157+
// Inference can only serve a single model at a time so no need for the `model` parameter to
1158+
// differentiate one model from the other, but we at least raise a warning.
1159+
if let Some(requested_model) = &req.model {
1160+
if requested_model != &info.served_model_name {
1161+
tracing::warn!(
1162+
"The provided `model={}` has not been found, the `model` parameter should be provided either empty or with `model={}` instead.",
1163+
requested_model, info.served_model_name
1164+
);
1165+
}
1166+
}
1167+
11561168
let start_time = Instant::now();
11571169

11581170
let truncate = info.auto_truncate;
@@ -1308,7 +1320,7 @@ async fn openai_embed(
13081320
let response = OpenAICompatResponse {
13091321
object: "list",
13101322
data: embeddings,
1311-
model: info.model_id.clone(),
1323+
model: info.served_model_name.clone(),
13121324
usage: OpenAICompatUsage {
13131325
prompt_tokens: compute_tokens,
13141326
total_tokens: compute_tokens,

router/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pub async fn run(
4747
revision: Option<String>,
4848
tokenization_workers: Option<usize>,
4949
dtype: Option<DType>,
50+
served_model_name: String,
5051
pooling: Option<text_embeddings_backend::Pool>,
5152
max_concurrent_requests: usize,
5253
max_batch_tokens: usize,
@@ -332,6 +333,7 @@ pub async fn run(
332333
model_id,
333334
model_sha: revision,
334335
model_dtype: dtype.to_string(),
336+
served_model_name,
335337
model_type,
336338
max_concurrent_requests,
337339
max_input_length,
@@ -550,6 +552,8 @@ pub struct Info {
550552
pub model_sha: Option<String>,
551553
#[cfg_attr(feature = "http", schema(example = "float16"))]
552554
pub model_dtype: String,
555+
#[cfg_attr(feature = "http", schema(example = "thenlper/gte-base"))]
556+
pub served_model_name: String,
553557
pub model_type: ModelType,
554558
/// Router Parameters
555559
#[cfg_attr(feature = "http", schema(example = "128"))]

router/src/main.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
1414
struct Args {
1515
/// The Hugging Face model ID, can be any model listed on <https://huggingface.co/models> with
1616
/// the `text-embeddings-inference` tag (meaning it's compatible with Text Embeddings
17-
/// Inference)
17+
/// Inference).
1818
///
1919
/// Alternatively, the specified ID can also be a path to a local directory containing the
2020
/// necessary model files saved by the `save_pretrained(...)` methods of either Transformers or
@@ -40,6 +40,11 @@ struct Args {
4040
#[clap(long, env, value_enum)]
4141
dtype: Option<DType>,
4242

43+
/// The name of the model that is being served. If not specified, defaults to `--model-id`. It
44+
/// is only used for the OpenAI-compatible endpoints via HTTP.
45+
#[clap(long, env)]
46+
served_model_name: Option<String>,
47+
4348
/// Optionally control the pooling method for embedding models.
4449
///
4550
/// If `pooling` is not set, the pooling configuration will be parsed from the
@@ -227,11 +232,16 @@ async fn main() -> Result<()> {
227232
}
228233
let token = args.hf_token.or(args.hf_api_token);
229234

235+
let served_model_name = args
236+
.served_model_name
237+
.unwrap_or_else(|| args.model_id.clone());
238+
230239
text_embeddings_router::run(
231240
args.model_id,
232241
args.revision,
233242
args.tokenization_workers,
234243
args.dtype,
244+
served_model_name,
235245
args.pooling,
236246
args.max_concurrent_requests,
237247
args.max_batch_tokens,

0 commit comments

Comments
 (0)