1
1
/// Payload tokenization logic
2
2
use crate :: TextEmbeddingsError ;
3
3
use tokenizers:: tokenizer:: Tokenizer ;
4
+ pub use tokenizers:: Encoding as RawEncoding ;
4
5
use tokenizers:: { EncodeInput , TruncationDirection , TruncationParams , TruncationStrategy } ;
5
6
use tokio:: sync:: { mpsc, oneshot} ;
6
7
use tracing:: { instrument, Span } ;
@@ -63,7 +64,7 @@ impl Tokenization {
63
64
& self ,
64
65
inputs : EncodingInput ,
65
66
truncate : bool ,
66
- ) -> Result < Encoding , TextEmbeddingsError > {
67
+ ) -> Result < ValidEncoding , TextEmbeddingsError > {
67
68
// Check if inputs is empty
68
69
if inputs. is_empty ( ) {
69
70
return Err ( TextEmbeddingsError :: Validation (
@@ -76,7 +77,43 @@ impl Tokenization {
76
77
// Send request to the background validation task
77
78
// Unwrap is safe here
78
79
self . sender
79
- . send ( ( inputs, truncate, response_sender, Span :: current ( ) ) )
80
+ . send ( TokenizerRequest :: Encode (
81
+ inputs,
82
+ truncate,
83
+ response_sender,
84
+ Span :: current ( ) ,
85
+ ) )
86
+ . expect ( "Tokenization background task dropped the receiver. This is a bug." ) ;
87
+
88
+ // Await on response channel
89
+ // Unwrap is safe here
90
+ response_receiver. await . expect ( "Tokenization background task dropped the sender without sending a response. This is a bug." )
91
+ }
92
+
93
+ #[ instrument( skip_all) ]
94
+ pub async fn tokenize (
95
+ & self ,
96
+ inputs : EncodingInput ,
97
+ add_special_tokens : bool ,
98
+ ) -> Result < RawEncoding , TextEmbeddingsError > {
99
+ // Check if inputs is empty
100
+ if inputs. is_empty ( ) {
101
+ return Err ( TextEmbeddingsError :: Validation (
102
+ "`inputs` cannot be empty" . to_string ( ) ,
103
+ ) ) ;
104
+ }
105
+
106
+ // Create response channel
107
+ let ( response_sender, response_receiver) = oneshot:: channel ( ) ;
108
+ // Send request to the background validation task
109
+ // Unwrap is safe here
110
+ self . sender
111
+ . send ( TokenizerRequest :: Tokenize (
112
+ inputs,
113
+ add_special_tokens,
114
+ response_sender,
115
+ Span :: current ( ) ,
116
+ ) )
80
117
. expect ( "Tokenization background task dropped the receiver. This is a bug." ) ;
81
118
82
119
// Await on response channel
@@ -93,31 +130,65 @@ fn tokenizer_worker(
93
130
mut receiver : mpsc:: UnboundedReceiver < TokenizerRequest > ,
94
131
) {
95
132
// Loop over requests
96
- while let Some ( ( inputs, truncate, response_tx, parent_span) ) = receiver. blocking_recv ( ) {
97
- parent_span. in_scope ( || {
98
- if !response_tx. is_closed ( ) {
99
- // It's possible that the user dropped its request resulting in a send error.
100
- // We just discard the error
101
- let _ = response_tx. send ( encode_input (
102
- inputs,
103
- truncate,
104
- max_input_length,
105
- position_offset,
106
- & mut tokenizer,
107
- ) ) ;
133
+ while let Some ( request) = receiver. blocking_recv ( ) {
134
+ match request {
135
+ TokenizerRequest :: Encode ( inputs, truncate, response_tx, parent_span) => {
136
+ parent_span. in_scope ( || {
137
+ if !response_tx. is_closed ( ) {
138
+ // It's possible that the user dropped its request resulting in a send error.
139
+ // We just discard the error
140
+ let _ = response_tx. send ( encode_input (
141
+ inputs,
142
+ truncate,
143
+ max_input_length,
144
+ position_offset,
145
+ & mut tokenizer,
146
+ ) ) ;
147
+ }
148
+ } )
149
+ }
150
+ TokenizerRequest :: Tokenize ( inputs, add_special_tokens, response_tx, parent_span) => {
151
+ parent_span. in_scope ( || {
152
+ if !response_tx. is_closed ( ) {
153
+ // It's possible that the user dropped its request resulting in a send error.
154
+ // We just discard the error
155
+ let _ = response_tx. send ( tokenize_input (
156
+ inputs,
157
+ add_special_tokens,
158
+ None ,
159
+ & mut tokenizer,
160
+ ) ) ;
161
+ }
162
+ } )
108
163
}
109
- } )
164
+ }
110
165
}
111
166
}
112
167
168
+ fn tokenize_input (
169
+ inputs : EncodingInput ,
170
+ add_special_tokens : bool ,
171
+ truncate_params : Option < TruncationParams > ,
172
+ tokenizer : & mut Tokenizer ,
173
+ ) -> Result < RawEncoding , TextEmbeddingsError > {
174
+ let inputs: EncodeInput = match inputs {
175
+ EncodingInput :: Single ( s) => s. into ( ) ,
176
+ EncodingInput :: Dual ( s1, s2) => ( s1, s2) . into ( ) ,
177
+ } ;
178
+
179
+ Ok ( tokenizer
180
+ . with_truncation ( truncate_params) ?
181
+ . encode ( inputs, add_special_tokens) ?)
182
+ }
183
+
113
184
/// Get input length and optionally truncate it
114
185
fn encode_input (
115
186
inputs : EncodingInput ,
116
187
truncate : bool ,
117
188
max_input_length : usize ,
118
189
position_offset : usize ,
119
190
tokenizer : & mut Tokenizer ,
120
- ) -> Result < Encoding , TextEmbeddingsError > {
191
+ ) -> Result < ValidEncoding , TextEmbeddingsError > {
121
192
// Default truncation params
122
193
let truncate_params = truncate. then_some ( TruncationParams {
123
194
direction : TruncationDirection :: Right ,
@@ -126,14 +197,7 @@ fn encode_input(
126
197
stride : 0 ,
127
198
} ) ;
128
199
129
- let inputs: EncodeInput = match inputs {
130
- EncodingInput :: Single ( s) => s. into ( ) ,
131
- EncodingInput :: Dual ( s1, s2) => ( s1, s2) . into ( ) ,
132
- } ;
133
-
134
- let encoding = tokenizer
135
- . with_truncation ( truncate_params) ?
136
- . encode ( inputs, true ) ?;
200
+ let encoding = tokenize_input ( inputs, true , truncate_params, tokenizer) ?;
137
201
let seq_len = encoding. len ( ) ;
138
202
139
203
if seq_len > max_input_length {
@@ -144,7 +208,7 @@ fn encode_input(
144
208
145
209
metrics:: histogram!( "te_request_input_length" , seq_len as f64 ) ;
146
210
147
- Ok ( Encoding {
211
+ Ok ( ValidEncoding {
148
212
input_ids : encoding. get_ids ( ) . to_vec ( ) ,
149
213
token_type_ids : encoding. get_type_ids ( ) . to_vec ( ) ,
150
214
position_ids : ( position_offset as u32 ..( seq_len + position_offset) as u32 )
@@ -153,7 +217,7 @@ fn encode_input(
153
217
}
154
218
155
219
#[ derive( Debug ) ]
156
- pub struct Encoding {
220
+ pub struct ValidEncoding {
157
221
pub input_ids : Vec < u32 > ,
158
222
pub token_type_ids : Vec < u32 > ,
159
223
pub position_ids : Vec < u32 > ,
@@ -186,9 +250,17 @@ impl From<(String, String)> for EncodingInput {
186
250
}
187
251
}
188
252
189
- type TokenizerRequest = (
190
- EncodingInput ,
191
- bool ,
192
- oneshot:: Sender < Result < Encoding , TextEmbeddingsError > > ,
193
- Span ,
194
- ) ;
253
+ enum TokenizerRequest {
254
+ Encode (
255
+ EncodingInput ,
256
+ bool ,
257
+ oneshot:: Sender < Result < ValidEncoding , TextEmbeddingsError > > ,
258
+ Span ,
259
+ ) ,
260
+ Tokenize (
261
+ EncodingInput ,
262
+ bool ,
263
+ oneshot:: Sender < Result < RawEncoding , TextEmbeddingsError > > ,
264
+ Span ,
265
+ ) ,
266
+ }
0 commit comments