Skip to content

Commit 7efa697

Browse files
feat: add /embed_sparse route (#191)
1 parent 2b8ad5f commit 7efa697

File tree

7 files changed

+724
-26
lines changed

7 files changed

+724
-26
lines changed

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ length of 512 tokens:
3535
- [Using a private or gated model](#using-a-private-or-gated-model)
3636
- [Using Re-rankers models](#using-re-rankers-models)
3737
- [Using Sequence Classification models](#using-sequence-classification-models)
38+
- [Using SPLADE pooling](#using-splade-pooling)
3839
- [Distributed Tracing](#distributed-tracing)
3940
- [gRPC](#grpc)
4041
- [Local Install](#local-install)
@@ -331,6 +332,26 @@ curl 127.0.0.1:8080/predict \
331332
-H 'Content-Type: application/json'
332333
```
333334

335+
### Using SPLADE pooling
336+
337+
You can choose to activate SPLADE pooling for Bert and Distilbert MaskedLM architectures:
338+
339+
```shell
340+
model=naver/efficient-splade-VI-BT-large-query
341+
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
342+
343+
docker run --gpus all -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingface/text-embeddings-inference:1.1 --model-id $model --pooling splade
344+
```
345+
346+
Once you have deployed the model you can use the `/embed_sparse` endpoint to get the sparse embedding:
347+
348+
```bash
349+
curl 127.0.0.1:8080/embed_sparse \
350+
-X POST \
351+
-d '{"inputs":"I like you."}' \
352+
-H 'Content-Type: application/json'
353+
```
354+
334355
### Distributed Tracing
335356

336357
`text-embeddings-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature

core/src/infer.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,54 @@ impl Infer {
144144
Ok(response)
145145
}
146146

147+
#[instrument(skip(self, permit))]
148+
pub async fn embed_sparse<I: Into<EncodingInput> + std::fmt::Debug>(
149+
&self,
150+
inputs: I,
151+
truncate: bool,
152+
permit: OwnedSemaphorePermit,
153+
) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> {
154+
let start_time = Instant::now();
155+
156+
if !self.is_splade() {
157+
metrics::increment_counter!("te_request_failure", "err" => "model_type");
158+
let message = "Model is not an embedding model with SPLADE pooling".to_string();
159+
tracing::error!("{message}");
160+
return Err(TextEmbeddingsError::Backend(BackendError::Inference(
161+
message,
162+
)));
163+
}
164+
165+
let results = self
166+
.embed(inputs, truncate, true, &start_time, permit)
167+
.await?;
168+
169+
let InferResult::PooledEmbedding(response) = results else {
170+
panic!("unexpected enum variant")
171+
};
172+
173+
// Timings
174+
let total_time = start_time.elapsed();
175+
176+
// Metrics
177+
metrics::increment_counter!("te_embed_success");
178+
metrics::histogram!("te_embed_duration", total_time.as_secs_f64());
179+
metrics::histogram!(
180+
"te_embed_tokenization_duration",
181+
response.metadata.tokenization.as_secs_f64()
182+
);
183+
metrics::histogram!(
184+
"te_embed_queue_duration",
185+
response.metadata.queue.as_secs_f64()
186+
);
187+
metrics::histogram!(
188+
"te_embed_inference_duration",
189+
response.metadata.inference.as_secs_f64()
190+
);
191+
192+
Ok(response)
193+
}
194+
147195
#[instrument(skip(self, permit))]
148196
pub async fn embed_pooled<I: Into<EncodingInput> + std::fmt::Debug>(
149197
&self,

docs/openapi.json

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,182 @@
100100
}
101101
}
102102
},
103+
"/embed_all": {
104+
"post": {
105+
"tags": [
106+
"Text Embeddings Inference"
107+
],
108+
"summary": "Get all Embeddings without Pooling.",
109+
"description": "Get all Embeddings without Pooling.\nReturns a 424 status code if the model is not an embedding model.",
110+
"operationId": "embed_all",
111+
"requestBody": {
112+
"content": {
113+
"application/json": {
114+
"schema": {
115+
"$ref": "#/components/schemas/EmbedAllRequest"
116+
}
117+
}
118+
},
119+
"required": true
120+
},
121+
"responses": {
122+
"200": {
123+
"description": "Embeddings",
124+
"content": {
125+
"application/json": {
126+
"schema": {
127+
"$ref": "#/components/schemas/EmbedAllResponse"
128+
}
129+
}
130+
}
131+
},
132+
"413": {
133+
"description": "Batch size error",
134+
"content": {
135+
"application/json": {
136+
"schema": {
137+
"$ref": "#/components/schemas/ErrorResponse"
138+
},
139+
"example": {
140+
"error": "Batch size error",
141+
"error_type": "validation"
142+
}
143+
}
144+
}
145+
},
146+
"422": {
147+
"description": "Tokenization error",
148+
"content": {
149+
"application/json": {
150+
"schema": {
151+
"$ref": "#/components/schemas/ErrorResponse"
152+
},
153+
"example": {
154+
"error": "Tokenization error",
155+
"error_type": "tokenizer"
156+
}
157+
}
158+
}
159+
},
160+
"424": {
161+
"description": "Embedding Error",
162+
"content": {
163+
"application/json": {
164+
"schema": {
165+
"$ref": "#/components/schemas/ErrorResponse"
166+
},
167+
"example": {
168+
"error": "Inference failed",
169+
"error_type": "backend"
170+
}
171+
}
172+
}
173+
},
174+
"429": {
175+
"description": "Model is overloaded",
176+
"content": {
177+
"application/json": {
178+
"schema": {
179+
"$ref": "#/components/schemas/ErrorResponse"
180+
},
181+
"example": {
182+
"error": "Model is overloaded",
183+
"error_type": "overloaded"
184+
}
185+
}
186+
}
187+
}
188+
}
189+
}
190+
},
191+
"/embed_sparse": {
192+
"post": {
193+
"tags": [
194+
"Text Embeddings Inference"
195+
],
196+
"summary": "Get Sparse Embeddings. Returns a 424 status code if the model is not an embedding model with SPLADE pooling.",
197+
"description": "Get Sparse Embeddings. Returns a 424 status code if the model is not an embedding model with SPLADE pooling.",
198+
"operationId": "embed_sparse",
199+
"requestBody": {
200+
"content": {
201+
"application/json": {
202+
"schema": {
203+
"$ref": "#/components/schemas/EmbedSparseRequest"
204+
}
205+
}
206+
},
207+
"required": true
208+
},
209+
"responses": {
210+
"200": {
211+
"description": "Embeddings",
212+
"content": {
213+
"application/json": {
214+
"schema": {
215+
"$ref": "#/components/schemas/EmbedSparseResponse"
216+
}
217+
}
218+
}
219+
},
220+
"413": {
221+
"description": "Batch size error",
222+
"content": {
223+
"application/json": {
224+
"schema": {
225+
"$ref": "#/components/schemas/ErrorResponse"
226+
},
227+
"example": {
228+
"error": "Batch size error",
229+
"error_type": "validation"
230+
}
231+
}
232+
}
233+
},
234+
"422": {
235+
"description": "Tokenization error",
236+
"content": {
237+
"application/json": {
238+
"schema": {
239+
"$ref": "#/components/schemas/ErrorResponse"
240+
},
241+
"example": {
242+
"error": "Tokenization error",
243+
"error_type": "tokenizer"
244+
}
245+
}
246+
}
247+
},
248+
"424": {
249+
"description": "Embedding Error",
250+
"content": {
251+
"application/json": {
252+
"schema": {
253+
"$ref": "#/components/schemas/ErrorResponse"
254+
},
255+
"example": {
256+
"error": "Inference failed",
257+
"error_type": "backend"
258+
}
259+
}
260+
}
261+
},
262+
"429": {
263+
"description": "Model is overloaded",
264+
"content": {
265+
"application/json": {
266+
"schema": {
267+
"$ref": "#/components/schemas/ErrorResponse"
268+
},
269+
"example": {
270+
"error": "Model is overloaded",
271+
"error_type": "overloaded"
272+
}
273+
}
274+
}
275+
}
276+
}
277+
}
278+
},
103279
"/embeddings": {
104280
"post": {
105281
"tags": [
@@ -514,6 +690,44 @@
514690
}
515691
}
516692
},
693+
"EmbedAllRequest": {
694+
"type": "object",
695+
"required": [
696+
"inputs"
697+
],
698+
"properties": {
699+
"inputs": {
700+
"$ref": "#/components/schemas/Input"
701+
},
702+
"truncate": {
703+
"type": "boolean",
704+
"default": "false",
705+
"example": "false"
706+
}
707+
}
708+
},
709+
"EmbedAllResponse": {
710+
"type": "array",
711+
"items": {
712+
"type": "array",
713+
"items": {
714+
"type": "array",
715+
"items": {
716+
"type": "number",
717+
"format": "float"
718+
}
719+
}
720+
},
721+
"example": [
722+
[
723+
[
724+
0.0,
725+
1.0,
726+
2.0
727+
]
728+
]
729+
]
730+
},
517731
"EmbedRequest": {
518732
"type": "object",
519733
"required": [
@@ -552,6 +766,31 @@
552766
]
553767
]
554768
},
769+
"EmbedSparseRequest": {
770+
"type": "object",
771+
"required": [
772+
"inputs"
773+
],
774+
"properties": {
775+
"inputs": {
776+
"$ref": "#/components/schemas/Input"
777+
},
778+
"truncate": {
779+
"type": "boolean",
780+
"default": "false",
781+
"example": "false"
782+
}
783+
}
784+
},
785+
"EmbedSparseResponse": {
786+
"type": "array",
787+
"items": {
788+
"type": "array",
789+
"items": {
790+
"$ref": "#/components/schemas/SparseValue"
791+
}
792+
}
793+
},
555794
"EmbeddingModel": {
556795
"type": "object",
557796
"required": [
@@ -1047,6 +1286,23 @@
10471286
}
10481287
}
10491288
},
1289+
"SparseValue": {
1290+
"type": "object",
1291+
"required": [
1292+
"index",
1293+
"value"
1294+
],
1295+
"properties": {
1296+
"index": {
1297+
"type": "integer",
1298+
"minimum": 0
1299+
},
1300+
"value": {
1301+
"type": "number",
1302+
"format": "float"
1303+
}
1304+
}
1305+
},
10501306
"TokenizeRequest": {
10511307
"type": "object",
10521308
"required": [

0 commit comments

Comments
 (0)