Skip to content

Commit eeb3453

Browse files
Add binary variants for accessing encoding data (#32)
1 parent be0b12d commit eeb3453

File tree

6 files changed

+113
-3
lines changed

6 files changed

+113
-3
lines changed

lib/tokenizers/encoding.ex

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,52 @@ defmodule Tokenizers.Encoding do
2323
@spec get_ids(Encoding.t()) :: [integer()]
2424
def get_ids(encoding), do: encoding |> Native.get_ids() |> Shared.unwrap()
2525

26+
@doc """
27+
Same as `get_ids/1`, but returns binary with u32 values.
28+
"""
29+
@spec get_u32_ids(Encoding.t()) :: binary()
30+
def get_u32_ids(encoding), do: encoding |> Native.get_u32_ids() |> Shared.unwrap()
31+
2632
@doc """
2733
Get the attention mask from an encoding.
2834
"""
2935
@spec get_attention_mask(Encoding.t()) :: [integer()]
3036
def get_attention_mask(encoding), do: encoding |> Native.get_attention_mask() |> Shared.unwrap()
3137

38+
@doc """
39+
Same as `get_attention_mask/1`, but returns binary with u32 values.
40+
"""
41+
@spec get_u32_attention_mask(Encoding.t()) :: binary()
42+
def get_u32_attention_mask(encoding),
43+
do: encoding |> Native.get_u32_attention_mask() |> Shared.unwrap()
44+
3245
@doc """
3346
Get token type ids from an encoding.
3447
"""
3548
@spec get_type_ids(Encoding.t()) :: [integer()]
3649
def get_type_ids(encoding), do: encoding |> Native.get_type_ids() |> Shared.unwrap()
3750

51+
@doc """
52+
Same as `get_type_ids/1`, but returns binary with u32 values.
53+
"""
54+
@spec get_u32_type_ids(Encoding.t()) :: binary()
55+
def get_u32_type_ids(encoding),
56+
do: encoding |> Native.get_u32_type_ids() |> Shared.unwrap()
57+
3858
@doc """
3959
Get special tokens mask from an encoding.
4060
"""
4161
@spec get_special_tokens_mask(Encoding.t()) :: [integer()]
4262
def get_special_tokens_mask(encoding),
4363
do: encoding |> Native.get_special_tokens_mask() |> Shared.unwrap()
4464

65+
@doc """
66+
Same as `get_special_tokens_mask/1`, but returns binary with u32 values.
67+
"""
68+
@spec get_u32_special_tokens_mask(Encoding.t()) :: binary()
69+
def get_u32_special_tokens_mask(encoding),
70+
do: encoding |> Native.get_u32_special_tokens_mask() |> Shared.unwrap()
71+
4572
@doc """
4673
Get offsets from an encoding.
4774

lib/tokenizers/native.ex

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@ defmodule Tokenizers.Native do
1616
def encode_batch(_tokenizer, _input, _add_special_tokens), do: err()
1717
def from_file(_path, _additional_special_tokens), do: err()
1818
def get_attention_mask(_encoding), do: err()
19+
def get_u32_attention_mask(_encoding), do: err()
1920
def get_type_ids(_encoding), do: err()
21+
def get_u32_type_ids(_encoding), do: err()
2022
def get_ids(_encoding), do: err()
23+
def get_u32_ids(_encoding), do: err()
2124
def get_tokens(_encoding), do: err()
2225
def get_special_tokens_mask(_encoding), do: err()
26+
def get_u32_special_tokens_mask(_encoding), do: err()
2327
def get_offsets(_encoding), do: err()
2428
def get_vocab(_tokenizer, _with_added_tokens), do: err()
2529
def get_vocab_size(_tokenizer, _with_added_tokens), do: err()

native/ex_tokenizers/src/encoding.rs

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
use crate::error::ExTokenizersError;
2+
use rustler::resource::ResourceArc;
3+
use rustler::{Binary, Env};
24
use tokenizers::utils::padding::PaddingDirection;
35
use tokenizers::utils::truncation::TruncationDirection;
46
use tokenizers::Encoding;
@@ -8,7 +10,7 @@ pub struct ExTokenizersEncodingRef(pub Encoding);
810
#[derive(rustler::NifStruct)]
911
#[module = "Tokenizers.Encoding"]
1012
pub struct ExTokenizersEncoding {
11-
pub resource: rustler::resource::ResourceArc<ExTokenizersEncodingRef>,
13+
pub resource: ResourceArc<ExTokenizersEncodingRef>,
1214
}
1315

1416
impl ExTokenizersEncodingRef {
@@ -20,7 +22,7 @@ impl ExTokenizersEncodingRef {
2022
impl ExTokenizersEncoding {
2123
pub fn new(data: Encoding) -> Self {
2224
Self {
23-
resource: rustler::resource::ResourceArc::new(ExTokenizersEncodingRef::new(data)),
25+
resource: ResourceArc::new(ExTokenizersEncodingRef::new(data)),
2426
}
2527
}
2628
}
@@ -35,23 +37,60 @@ pub fn get_ids(encoding: ExTokenizersEncoding) -> Result<Vec<u32>, ExTokenizersE
3537
Ok(encoding.resource.0.get_ids().to_vec())
3638
}
3739

40+
#[rustler::nif]
41+
pub fn get_u32_ids(env: Env, encoding: ExTokenizersEncoding) -> Result<Binary, ExTokenizersError> {
42+
Ok(encoding
43+
.resource
44+
.make_binary(env, |r| slice_u32_to_u8(r.0.get_ids())))
45+
}
46+
3847
#[rustler::nif]
3948
pub fn get_attention_mask(encoding: ExTokenizersEncoding) -> Result<Vec<u32>, ExTokenizersError> {
4049
Ok(encoding.resource.0.get_attention_mask().to_vec())
4150
}
4251

52+
#[rustler::nif]
53+
pub fn get_u32_attention_mask(
54+
env: Env,
55+
encoding: ExTokenizersEncoding,
56+
) -> Result<Binary, ExTokenizersError> {
57+
Ok(encoding
58+
.resource
59+
.make_binary(env, |r| slice_u32_to_u8(r.0.get_attention_mask())))
60+
}
61+
4362
#[rustler::nif]
4463
pub fn get_type_ids(encoding: ExTokenizersEncoding) -> Result<Vec<u32>, ExTokenizersError> {
4564
Ok(encoding.resource.0.get_type_ids().to_vec())
4665
}
4766

67+
#[rustler::nif]
68+
pub fn get_u32_type_ids(
69+
env: Env,
70+
encoding: ExTokenizersEncoding,
71+
) -> Result<Binary, ExTokenizersError> {
72+
Ok(encoding
73+
.resource
74+
.make_binary(env, |r| slice_u32_to_u8(r.0.get_type_ids())))
75+
}
76+
4877
#[rustler::nif]
4978
pub fn get_special_tokens_mask(
5079
encoding: ExTokenizersEncoding,
5180
) -> Result<Vec<u32>, ExTokenizersError> {
5281
Ok(encoding.resource.0.get_special_tokens_mask().to_vec())
5382
}
5483

84+
#[rustler::nif]
85+
pub fn get_u32_special_tokens_mask(
86+
env: Env,
87+
encoding: ExTokenizersEncoding,
88+
) -> Result<Binary, ExTokenizersError> {
89+
Ok(encoding
90+
.resource
91+
.make_binary(env, |r| slice_u32_to_u8(r.0.get_special_tokens_mask())))
92+
}
93+
5594
#[rustler::nif]
5695
pub fn get_offsets(
5796
encoding: ExTokenizersEncoding,
@@ -99,3 +138,7 @@ pub fn pad(
99138
new_encoding.pad(target_length, pad_id, pad_type_id, pad_token, direction);
100139
Ok(ExTokenizersEncoding::new(new_encoding))
101140
}
141+
142+
fn slice_u32_to_u8(slice: &[u32]) -> &[u8] {
143+
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, slice.len() * 4) }
144+
}

native/ex_tokenizers/src/error.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@ pub enum ExTokenizersError {
2323

2424
impl Encoder for ExTokenizersError {
2525
fn encode<'b>(&self, env: Env<'b>) -> Term<'b> {
26-
format!("{:?}", self).encode(env)
26+
format!("{self:?}").encode(env)
2727
}
2828
}

native/ex_tokenizers/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,13 @@ rustler::init!(
2626
encode_batch,
2727
from_file,
2828
get_attention_mask,
29+
get_u32_attention_mask,
2930
get_type_ids,
31+
get_u32_type_ids,
3032
get_ids,
33+
get_u32_ids,
3134
get_special_tokens_mask,
35+
get_u32_special_tokens_mask,
3236
get_offsets,
3337
get_model,
3438
get_model_details,

test/tokenizers/tokenizer_test.exs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,43 @@ defmodule Tokenizers.TokenizerTest do
196196
ids = Enum.map(encodings, &Encoding.get_ids/1)
197197
{:ok, decoded} = Tokenizer.decode(tokenizer, ids)
198198
assert decoded == text
199+
200+
assert Enum.map(ids, &list_to_u32/1) == Enum.map(encodings, &Encoding.get_u32_ids/1)
199201
end
200202
end
201203

202204
describe "encode metadata" do
205+
test "can return attention mask", %{tokenizer: tokenizer} do
206+
text = ["Hello world", "Yes sir hello indeed"]
207+
{:ok, encodings} = Tokenizer.encode(tokenizer, text)
208+
209+
attention_mask = Enum.map(encodings, &Encoding.get_attention_mask/1)
210+
assert [[1, 1, 1, 1], [1, 1, 1, 1, 1, 1]] == attention_mask
211+
212+
assert Enum.map(attention_mask, &list_to_u32/1) ==
213+
Enum.map(encodings, &Encoding.get_u32_attention_mask/1)
214+
end
215+
216+
test "can return type ids", %{tokenizer: tokenizer} do
217+
text = [{"Hello", "world"}, {"Yes sir", "hello indeed"}]
218+
{:ok, encodings} = Tokenizer.encode(tokenizer, text)
219+
220+
type_ids = Enum.map(encodings, &Encoding.get_type_ids/1)
221+
assert [[0, 0, 0, 1, 1], [0, 0, 0, 0, 1, 1, 1]] == type_ids
222+
223+
assert Enum.map(type_ids, &list_to_u32/1) ==
224+
Enum.map(encodings, &Encoding.get_u32_type_ids/1)
225+
end
226+
203227
test "can return special tokens mask", %{tokenizer: tokenizer} do
204228
text = ["This is a test", "And so is this"]
205229
{:ok, encodings} = Tokenizer.encode(tokenizer, text)
230+
206231
special_tokens_mask = Enum.map(encodings, &Encoding.get_special_tokens_mask/1)
207232
assert [[1, 0, 0, 0, 0, 1], [1, 0, 0, 0, 0, 1]] == special_tokens_mask
233+
234+
assert Enum.map(special_tokens_mask, &list_to_u32/1) ==
235+
Enum.map(encodings, &Encoding.get_u32_special_tokens_mask/1)
208236
end
209237

210238
test "can return offsets", %{tokenizer: tokenizer} do
@@ -218,4 +246,8 @@ defmodule Tokenizers.TokenizerTest do
218246
] == offsets
219247
end
220248
end
249+
250+
defp list_to_u32(list) do
251+
for x <- list, into: <<>>, do: <<x::native-unsigned-32>>
252+
end
221253
end

0 commit comments

Comments
 (0)