Skip to content

Commit 6395a7a

Browse files
feat(router): add /tokenize route (#139)
1 parent 3d9cc20 commit 6395a7a

File tree

9 files changed

+556
-38
lines changed

9 files changed

+556
-38
lines changed

core/src/infer.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::queue::{Entry, Metadata, NextBatch, Queue};
2-
use crate::tokenization::{EncodingInput, Tokenization};
2+
use crate::tokenization::{EncodingInput, RawEncoding, Tokenization};
33
use crate::TextEmbeddingsError;
44
use std::sync::Arc;
55
use std::time::{Duration, Instant};
@@ -58,6 +58,22 @@ impl Infer {
5858
}
5959
}
6060

61+
#[instrument(skip(self))]
62+
pub async fn tokenize<I: Into<EncodingInput> + std::fmt::Debug>(
63+
&self,
64+
inputs: I,
65+
add_special_tokens: bool,
66+
) -> Result<RawEncoding, TextEmbeddingsError> {
67+
self.tokenization
68+
.tokenize(inputs.into(), add_special_tokens)
69+
.await
70+
.map_err(|err| {
71+
metrics::increment_counter!("te_request_failure", "err" => "tokenization");
72+
tracing::error!("{err}");
73+
err
74+
})
75+
}
76+
6177
#[instrument(skip(self))]
6278
pub fn try_acquire_permit(&self) -> Result<OwnedSemaphorePermit, TextEmbeddingsError> {
6379
// Limit concurrent requests by acquiring a permit from the semaphore

core/src/queue.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::infer::InferResponse;
2-
use crate::tokenization::Encoding;
2+
use crate::tokenization::ValidEncoding;
33
use std::cmp::max;
44
use std::collections::VecDeque;
55
use std::time::{Duration, Instant};
@@ -11,7 +11,7 @@ use tracing::{instrument, Span};
1111
#[derive(Debug)]
1212
pub struct Entry {
1313
/// Payload
14-
pub encoding: Encoding,
14+
pub encoding: ValidEncoding,
1515
/// Entry metadata
1616
pub metadata: Metadata,
1717
}

core/src/tokenization.rs

Lines changed: 104 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/// Payload tokenization logic
22
use crate::TextEmbeddingsError;
33
use tokenizers::tokenizer::Tokenizer;
4+
pub use tokenizers::Encoding as RawEncoding;
45
use tokenizers::{EncodeInput, TruncationDirection, TruncationParams, TruncationStrategy};
56
use tokio::sync::{mpsc, oneshot};
67
use tracing::{instrument, Span};
@@ -63,7 +64,7 @@ impl Tokenization {
6364
&self,
6465
inputs: EncodingInput,
6566
truncate: bool,
66-
) -> Result<Encoding, TextEmbeddingsError> {
67+
) -> Result<ValidEncoding, TextEmbeddingsError> {
6768
// Check if inputs is empty
6869
if inputs.is_empty() {
6970
return Err(TextEmbeddingsError::Validation(
@@ -76,7 +77,43 @@ impl Tokenization {
7677
// Send request to the background validation task
7778
// Unwrap is safe here
7879
self.sender
79-
.send((inputs, truncate, response_sender, Span::current()))
80+
.send(TokenizerRequest::Encode(
81+
inputs,
82+
truncate,
83+
response_sender,
84+
Span::current(),
85+
))
86+
.expect("Tokenization background task dropped the receiver. This is a bug.");
87+
88+
// Await on response channel
89+
// Unwrap is safe here
90+
response_receiver.await.expect("Tokenization background task dropped the sender without sending a response. This is a bug.")
91+
}
92+
93+
#[instrument(skip_all)]
94+
pub async fn tokenize(
95+
&self,
96+
inputs: EncodingInput,
97+
add_special_tokens: bool,
98+
) -> Result<RawEncoding, TextEmbeddingsError> {
99+
// Check if inputs is empty
100+
if inputs.is_empty() {
101+
return Err(TextEmbeddingsError::Validation(
102+
"`inputs` cannot be empty".to_string(),
103+
));
104+
}
105+
106+
// Create response channel
107+
let (response_sender, response_receiver) = oneshot::channel();
108+
// Send request to the background validation task
109+
// Unwrap is safe here
110+
self.sender
111+
.send(TokenizerRequest::Tokenize(
112+
inputs,
113+
add_special_tokens,
114+
response_sender,
115+
Span::current(),
116+
))
80117
.expect("Tokenization background task dropped the receiver. This is a bug.");
81118

82119
// Await on response channel
@@ -93,31 +130,65 @@ fn tokenizer_worker(
93130
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
94131
) {
95132
// Loop over requests
96-
while let Some((inputs, truncate, response_tx, parent_span)) = receiver.blocking_recv() {
97-
parent_span.in_scope(|| {
98-
if !response_tx.is_closed() {
99-
// It's possible that the user dropped its request resulting in a send error.
100-
// We just discard the error
101-
let _ = response_tx.send(encode_input(
102-
inputs,
103-
truncate,
104-
max_input_length,
105-
position_offset,
106-
&mut tokenizer,
107-
));
133+
while let Some(request) = receiver.blocking_recv() {
134+
match request {
135+
TokenizerRequest::Encode(inputs, truncate, response_tx, parent_span) => {
136+
parent_span.in_scope(|| {
137+
if !response_tx.is_closed() {
138+
// It's possible that the user dropped its request resulting in a send error.
139+
// We just discard the error
140+
let _ = response_tx.send(encode_input(
141+
inputs,
142+
truncate,
143+
max_input_length,
144+
position_offset,
145+
&mut tokenizer,
146+
));
147+
}
148+
})
149+
}
150+
TokenizerRequest::Tokenize(inputs, add_special_tokens, response_tx, parent_span) => {
151+
parent_span.in_scope(|| {
152+
if !response_tx.is_closed() {
153+
// It's possible that the user dropped its request resulting in a send error.
154+
// We just discard the error
155+
let _ = response_tx.send(tokenize_input(
156+
inputs,
157+
add_special_tokens,
158+
None,
159+
&mut tokenizer,
160+
));
161+
}
162+
})
108163
}
109-
})
164+
}
110165
}
111166
}
112167

168+
fn tokenize_input(
169+
inputs: EncodingInput,
170+
add_special_tokens: bool,
171+
truncate_params: Option<TruncationParams>,
172+
tokenizer: &mut Tokenizer,
173+
) -> Result<RawEncoding, TextEmbeddingsError> {
174+
let inputs: EncodeInput = match inputs {
175+
EncodingInput::Single(s) => s.into(),
176+
EncodingInput::Dual(s1, s2) => (s1, s2).into(),
177+
};
178+
179+
Ok(tokenizer
180+
.with_truncation(truncate_params)?
181+
.encode(inputs, add_special_tokens)?)
182+
}
183+
113184
/// Get input length and optionally truncate it
114185
fn encode_input(
115186
inputs: EncodingInput,
116187
truncate: bool,
117188
max_input_length: usize,
118189
position_offset: usize,
119190
tokenizer: &mut Tokenizer,
120-
) -> Result<Encoding, TextEmbeddingsError> {
191+
) -> Result<ValidEncoding, TextEmbeddingsError> {
121192
// Default truncation params
122193
let truncate_params = truncate.then_some(TruncationParams {
123194
direction: TruncationDirection::Right,
@@ -126,14 +197,7 @@ fn encode_input(
126197
stride: 0,
127198
});
128199

129-
let inputs: EncodeInput = match inputs {
130-
EncodingInput::Single(s) => s.into(),
131-
EncodingInput::Dual(s1, s2) => (s1, s2).into(),
132-
};
133-
134-
let encoding = tokenizer
135-
.with_truncation(truncate_params)?
136-
.encode(inputs, true)?;
200+
let encoding = tokenize_input(inputs, true, truncate_params, tokenizer)?;
137201
let seq_len = encoding.len();
138202

139203
if seq_len > max_input_length {
@@ -144,7 +208,7 @@ fn encode_input(
144208

145209
metrics::histogram!("te_request_input_length", seq_len as f64);
146210

147-
Ok(Encoding {
211+
Ok(ValidEncoding {
148212
input_ids: encoding.get_ids().to_vec(),
149213
token_type_ids: encoding.get_type_ids().to_vec(),
150214
position_ids: (position_offset as u32..(seq_len + position_offset) as u32)
@@ -153,7 +217,7 @@ fn encode_input(
153217
}
154218

155219
#[derive(Debug)]
156-
pub struct Encoding {
220+
pub struct ValidEncoding {
157221
pub input_ids: Vec<u32>,
158222
pub token_type_ids: Vec<u32>,
159223
pub position_ids: Vec<u32>,
@@ -186,9 +250,17 @@ impl From<(String, String)> for EncodingInput {
186250
}
187251
}
188252

189-
type TokenizerRequest = (
190-
EncodingInput,
191-
bool,
192-
oneshot::Sender<Result<Encoding, TextEmbeddingsError>>,
193-
Span,
194-
);
253+
enum TokenizerRequest {
254+
Encode(
255+
EncodingInput,
256+
bool,
257+
oneshot::Sender<Result<ValidEncoding, TextEmbeddingsError>>,
258+
Span,
259+
),
260+
Tokenize(
261+
EncodingInput,
262+
bool,
263+
oneshot::Sender<Result<RawEncoding, TextEmbeddingsError>>,
264+
Span,
265+
),
266+
}

docs/openapi.json

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,52 @@
436436
}
437437
}
438438
}
439+
},
440+
"/tokenize": {
441+
"post": {
442+
"tags": [
443+
"Text Embeddings Inference"
444+
],
445+
"summary": "Tokenize inputs",
446+
"description": "Tokenize inputs",
447+
"operationId": "tokenize",
448+
"requestBody": {
449+
"content": {
450+
"application/json": {
451+
"schema": {
452+
"$ref": "#/components/schemas/TokenizeRequest"
453+
}
454+
}
455+
},
456+
"required": true
457+
},
458+
"responses": {
459+
"200": {
460+
"description": "Tokenized ids",
461+
"content": {
462+
"application/json": {
463+
"schema": {
464+
"$ref": "#/components/schemas/TokenizeResponse"
465+
}
466+
}
467+
}
468+
},
469+
"422": {
470+
"description": "Tokenization error",
471+
"content": {
472+
"application/json": {
473+
"schema": {
474+
"$ref": "#/components/schemas/OpenAICompatErrorResponse"
475+
},
476+
"example": {
477+
"message": "Tokenization error",
478+
"type": "tokenizer"
479+
}
480+
}
481+
}
482+
}
483+
}
484+
}
439485
}
440486
},
441487
"components": {
@@ -660,6 +706,17 @@
660706
"$ref": "#/components/schemas/EmbeddingModel"
661707
}
662708
}
709+
},
710+
{
711+
"type": "object",
712+
"required": [
713+
"reranker"
714+
],
715+
"properties": {
716+
"reranker": {
717+
"$ref": "#/components/schemas/ClassifierModel"
718+
}
719+
}
663720
}
664721
]
665722
},
@@ -953,6 +1010,78 @@
9531010
"items": {
9541011
"$ref": "#/components/schemas/Rank"
9551012
}
1013+
},
1014+
"SimpleToken": {
1015+
"type": "object",
1016+
"required": [
1017+
"id",
1018+
"text",
1019+
"special"
1020+
],
1021+
"properties": {
1022+
"id": {
1023+
"type": "integer",
1024+
"format": "int32",
1025+
"example": 0,
1026+
"minimum": 0
1027+
},
1028+
"special": {
1029+
"type": "boolean",
1030+
"example": "false"
1031+
},
1032+
"start": {
1033+
"type": "integer",
1034+
"example": 0,
1035+
"nullable": true,
1036+
"minimum": 0
1037+
},
1038+
"stop": {
1039+
"type": "integer",
1040+
"example": 2,
1041+
"nullable": true,
1042+
"minimum": 0
1043+
},
1044+
"text": {
1045+
"type": "string",
1046+
"example": "test"
1047+
}
1048+
}
1049+
},
1050+
"TokenizeRequest": {
1051+
"type": "object",
1052+
"required": [
1053+
"inputs"
1054+
],
1055+
"properties": {
1056+
"add_special_tokens": {
1057+
"type": "boolean",
1058+
"default": "true",
1059+
"example": "true"
1060+
},
1061+
"inputs": {
1062+
"$ref": "#/components/schemas/Input"
1063+
}
1064+
}
1065+
},
1066+
"TokenizeResponse": {
1067+
"type": "array",
1068+
"items": {
1069+
"type": "array",
1070+
"items": {
1071+
"$ref": "#/components/schemas/SimpleToken"
1072+
}
1073+
},
1074+
"example": [
1075+
[
1076+
{
1077+
"id": 0,
1078+
"special": false,
1079+
"start": 0,
1080+
"stop": 2,
1081+
"text": "test"
1082+
}
1083+
]
1084+
]
9561085
}
9571086
}
9581087
},

0 commit comments

Comments
 (0)