Skip to content

Commit 0aae295

Browse files
Źmićer Rubinštejnjonatanklosko
andauthored
Add bulk transformations for encoding (#49)
Co-authored-by: Jonatan Kłosko <[email protected]>
1 parent 13e12b9 commit 0aae295

File tree

10 files changed

+241
-66
lines changed

10 files changed

+241
-66
lines changed

lib/tokenizers/encoding.ex

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -180,46 +180,54 @@ defmodule Tokenizers.Encoding do
180180
to: Tokenizers.Native,
181181
as: :encoding_char_to_word
182182

183-
@doc """
184-
Pad the encoding to the given length.
183+
@typedoc """
184+
Padding configuration.
185185
186-
## Options
186+
* `:direction` - the padding direction. Defaults to `:right`
187187
188-
* `direction` (default `:right`) - the padding direction
188+
* `:pad_id` - the id corresponding to the padding token. Defaults
189+
to `0`
189190
190-
* `pad_id` (default `0`) - the id corresponding to the padding
191-
token
191+
* `:pad_type_id` - the type ID corresponding to the padding token.
192+
Defaults to `0`
192193
193-
* `pad_type_id` (default `0`) - the type ID corresponding to the
194-
padding token
194+
* `:pad_token` - the padding token to use. Defaults to `"[PAD]"`
195195
196-
* `pad_token` (default `[PAD]`) - the padding token to use
196+
"""
197+
@type padding_opts :: [
198+
pad_id: non_neg_integer(),
199+
pad_type_id: non_neg_integer(),
200+
pad_token: String.t(),
201+
direction: :left | :right
202+
]
197203

204+
@doc """
205+
Pad the encoding to the given length.
206+
207+
For available options see `t:padding_opts/0`.
198208
"""
199-
@spec pad(t(), non_neg_integer(), opts) :: t()
200-
when opts: [
201-
pad_id: non_neg_integer(),
202-
pad_type_id: non_neg_integer(),
203-
pad_token: String.t(),
204-
direction: :left | :right
205-
]
209+
@spec pad(t(), non_neg_integer(), opts :: padding_opts()) :: t()
206210
defdelegate pad(encoding, target_length, opts \\ []),
207211
to: Tokenizers.Native,
208212
as: :encoding_pad
209213

210-
@doc """
211-
Truncate the encoding to the given length.
214+
@typedoc """
215+
Truncation configuration.
212216
213-
## Options
217+
* `:stride` - the length of previous content to be included in each
218+
overflowing piece. Defaults to `0`
214219
215-
* `stride` (default `0`) - the length of previous content to be
216-
included in each overflowing piece
220+
* `:direction` - the truncation direction. Defaults to `:right`
217221
218-
* `direction` (default `:right`) - the truncation direction
222+
"""
223+
@type truncation_opts :: [stride: non_neg_integer(), direction: :left | :right]
219224

225+
@doc """
226+
Truncate the encoding to the given length.
227+
228+
For available options see `t:truncation_opts/0`.
220229
"""
221-
@spec truncate(t(), non_neg_integer(), opts) :: t()
222-
when opts: [stride: non_neg_integer(), direction: :left | :right]
230+
@spec truncate(t(), non_neg_integer(), opts :: truncation_opts()) :: t()
223231
defdelegate truncate(encoding, max_length, opts \\ []),
224232
to: Tokenizers.Native,
225233
as: :encoding_truncate
@@ -229,6 +237,20 @@ defmodule Tokenizers.Encoding do
229237
"""
230238
@spec n_tokens(encoding :: t()) :: non_neg_integer()
231239
defdelegate n_tokens(encoding), to: Tokenizers.Native, as: :encoding_get_length
240+
241+
@doc """
242+
Performs set of transformations to given encoding, creating a new one.
243+
Transformations are applied in order they are given.
244+
245+
While all these transformations can be done one by one, this function
246+
is more efficient as it avoids multiple allocations and Garbage Collection
247+
for intermediate encodings.
248+
249+
Check the module `Tokenizers.Encoding.Transformation` for handy functions,
250+
that can be used to build the transformations list.
251+
Also, you can build this list manually, as long as it follows the format.
252+
"""
253+
defdelegate transform(encoding, transformations), to: Tokenizers.Native, as: :encoding_transform
232254
end
233255

234256
defimpl Inspect, for: Tokenizers.Encoding do
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
defmodule Tokenizers.Encoding.Transformation do
2+
@moduledoc """
3+
Module containing handy functions to build the transformations list.
4+
5+
This list is aplied to an encoding using `Tokenizers.Encoding.transform/2`.
6+
"""
7+
8+
@type t :: [
9+
{:pad, {non_neg_integer(), Tokenizers.Encoding.padding_opts()}},
10+
{:truncate, {non_neg_integer(), Tokenizers.Encoding.truncation_opts()}},
11+
{:set_sequence_id, non_neg_integer()}
12+
]
13+
14+
@doc """
15+
Generates the padding transformation.
16+
17+
Check `Tokenizers.Encoding.pad/3` for more information.
18+
"""
19+
@spec pad(non_neg_integer(), Tokenizers.Encoding.padding_opts()) ::
20+
{:pad, {non_neg_integer(), Tokenizers.Encoding.padding_opts()}}
21+
def pad(target_length, opts \\ []) do
22+
{:pad, {target_length, opts}}
23+
end
24+
25+
@doc """
26+
Generates the truncation transformation.
27+
28+
Check `Tokenizers.Encoding.truncate/3` for more information.
29+
"""
30+
@spec truncate(non_neg_integer(), Tokenizers.Encoding.truncation_opts()) ::
31+
{:truncate, {non_neg_integer(), Tokenizers.Encoding.truncation_opts()}}
32+
def truncate(max_length, opts \\ []) do
33+
{:truncate, {max_length, opts}}
34+
end
35+
36+
@doc """
37+
Generates the set_sequence_id transformation.
38+
39+
Check `Tokenizers.Encoding.set_sequence_id/2` for more information.
40+
"""
41+
@spec set_sequence_id(non_neg_integer()) ::
42+
{:set_sequence_id, non_neg_integer()}
43+
def set_sequence_id(id) do
44+
{:set_sequence_id, id}
45+
end
46+
end

lib/tokenizers/native.ex

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ defmodule Tokenizers.Native do
5959
def encoding_char_to_word(_encoding, _position, _seq_id), do: err()
6060
def encoding_pad(_encoding, _target_length, _opts), do: err()
6161
def encoding_truncate(_encoding, _max_length, _opts), do: err()
62+
#
63+
def encoding_transform(_encoding, _transformers), do: err()
6264

6365
# Models
6466
def models_save(_model, _folder, _opts), do: err()

lib/tokenizers/tokenizer.ex

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,10 @@ defmodule Tokenizers.Tokenizer do
452452
* `:add_special_tokens` - whether to add special tokens to the
453453
sequence. Defaults to `true`
454454
455+
* `:encoding_transformations` - a list of `t:Tokenizers.Encoding.Transformation.t/0`
456+
to apply to the encoding. Check `Tokenizers.Encoding.transform/2`
457+
for more information. Defaults to `[]`
458+
455459
"""
456460
@doc type: :inference
457461
@spec encode(t(), encode_input(), keyword()) :: {:ok, Encoding.t()} | {:error, term()}

native/ex_tokenizers/src/encoding.rs

Lines changed: 87 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -187,33 +187,38 @@ pub enum PadOption {
187187
Direction(Direction),
188188
}
189189

190-
#[rustler::nif]
191-
pub fn encoding_pad(
192-
encoding: ExTokenizersEncoding,
193-
target_length: usize,
194-
opts: Vec<PadOption>,
195-
) -> ExTokenizersEncoding {
196-
struct Padding {
197-
pad_id: u32,
198-
pad_type_id: u32,
199-
pad_token: String,
200-
direction: Direction,
201-
}
190+
struct Padding {
191+
pad_id: u32,
192+
pad_type_id: u32,
193+
pad_token: String,
194+
direction: Direction,
195+
}
196+
197+
fn parse_pad_options(opts: &Vec<PadOption>) -> Padding {
202198
let mut default = Padding {
203199
pad_id: 0,
204200
pad_type_id: 0,
205201
pad_token: "[PAD]".to_string(),
206202
direction: Direction::Right,
207203
};
208-
209204
for opt in opts {
210205
match opt {
211-
PadOption::PadId(id) => default.pad_id = id,
212-
PadOption::PadTypeId(id) => default.pad_type_id = id,
213-
PadOption::PadToken(token) => default.pad_token = token,
214-
PadOption::Direction(direction) => default.direction = direction,
206+
PadOption::PadId(id) => default.pad_id = *id,
207+
PadOption::PadTypeId(id) => default.pad_type_id = *id,
208+
PadOption::PadToken(token) => default.pad_token = token.clone(),
209+
PadOption::Direction(direction) => default.direction = direction.clone(),
215210
}
216211
}
212+
default
213+
}
214+
215+
#[rustler::nif]
216+
pub fn encoding_pad(
217+
encoding: ExTokenizersEncoding,
218+
target_length: usize,
219+
opts: Vec<PadOption>,
220+
) -> ExTokenizersEncoding {
221+
let default = parse_pad_options(&opts);
217222

218223
let mut encoding = encoding.resource.0.clone();
219224
encoding.pad(
@@ -232,27 +237,33 @@ pub enum TruncationOption {
232237
Direction(Direction),
233238
}
234239

235-
#[rustler::nif]
236-
pub fn encoding_truncate(
237-
encoding: ExTokenizersEncoding,
238-
max_len: usize,
239-
opts: Vec<TruncationOption>,
240-
) -> ExTokenizersEncoding {
241-
struct Truncation {
242-
stride: usize,
243-
direction: Direction,
244-
}
240+
struct Truncation {
241+
stride: usize,
242+
direction: Direction,
243+
}
244+
245+
fn parse_truncation_options(opts: &Vec<TruncationOption>) -> Truncation {
245246
let mut default = Truncation {
246247
stride: 0,
247248
direction: Direction::Right,
248249
};
249250

250251
for opt in opts {
251252
match opt {
252-
TruncationOption::Stride(stride) => default.stride = stride,
253-
TruncationOption::Direction(direction) => default.direction = direction,
253+
TruncationOption::Stride(stride) => default.stride = *stride,
254+
TruncationOption::Direction(direction) => default.direction = direction.clone(),
254255
}
255256
}
257+
default
258+
}
259+
260+
#[rustler::nif]
261+
pub fn encoding_truncate(
262+
encoding: ExTokenizersEncoding,
263+
max_len: usize,
264+
opts: Vec<TruncationOption>,
265+
) -> ExTokenizersEncoding {
266+
let default = parse_truncation_options(&opts);
256267

257268
let mut encoding = encoding.resource.0.clone();
258269

@@ -263,3 +274,50 @@ pub fn encoding_truncate(
263274
fn slice_u32_to_u8(slice: &[u32]) -> &[u8] {
264275
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, slice.len() * 4) }
265276
}
277+
278+
///////////////////////////////////////////////////////////////////////////////
279+
/// Encoding transformations
280+
///////////////////////////////////////////////////////////////////////////////
281+
282+
#[derive(NifTaggedEnum)]
283+
pub enum TransformationElement {
284+
Pad((usize, Vec<PadOption>)), // {:pad, {target_length, opts}}
285+
Truncate((usize, Vec<TruncationOption>)), // {:truncate, {max_len, opts}}
286+
SetSequenceId(usize), // {:set_sequence_id, seq_id}
287+
}
288+
289+
#[rustler::nif]
290+
pub fn encoding_transform(
291+
encoding: ExTokenizersEncoding,
292+
transformations: Vec<TransformationElement>,
293+
) -> ExTokenizersEncoding {
294+
let mut encoding = encoding.resource.0.clone();
295+
apply_transformations(&mut encoding, &transformations);
296+
encoding.into()
297+
}
298+
299+
pub fn apply_transformations(
300+
encoding: &mut Encoding,
301+
transformations: &Vec<TransformationElement>,
302+
) {
303+
for transformation in transformations {
304+
match transformation {
305+
TransformationElement::Pad((target_length, opts)) => {
306+
let default = parse_pad_options(opts);
307+
308+
encoding.pad(
309+
*target_length,
310+
default.pad_id,
311+
default.pad_type_id,
312+
&default.pad_token,
313+
default.direction.into(),
314+
)
315+
}
316+
TransformationElement::Truncate((max_len, opts)) => {
317+
let default = parse_truncation_options(opts);
318+
encoding.truncate(*max_len, default.stride, default.direction.into())
319+
}
320+
TransformationElement::SetSequenceId(seq_id) => encoding.set_sequence_id(*seq_id),
321+
}
322+
}
323+
}

native/ex_tokenizers/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ rustler::init!(
8484
encoding_char_to_word,
8585
encoding_pad,
8686
encoding_truncate,
87+
//
88+
encoding_transform,
8789
// Models
8890
models_save,
8991
//

0 commit comments

Comments
 (0)