Skip to content

Commit 8052b96

Browse files
author
Źmićer Rubinštejn
authored
Complete encoding API (#44)
1 parent 9210206 commit 8052b96

File tree

10 files changed

+1431
-357
lines changed

10 files changed

+1431
-357
lines changed

lib/tokenizers/decoder.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ defmodule Tokenizers.Decoder do
9797

9898
@doc """
9999
Creates new Strip decoder.
100-
100+
101101
It expects a character and the number of times to strip the
102102
character on `left` and `right` sides.
103103
"""

lib/tokenizers/encoding.ex

Lines changed: 159 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -5,123 +5,219 @@ defmodule Tokenizers.Encoding do
55
Use these functions to retrieve the inputs needed for a natural language processing machine learning model.
66
"""
77

8-
@type t :: %__MODULE__{resource: binary(), reference: reference()}
9-
defstruct resource: nil, reference: nil
8+
@type t :: %__MODULE__{resource: reference()}
9+
defstruct resource: nil
1010

11-
alias Tokenizers.Native
12-
alias Tokenizers.Shared
11+
@doc """
12+
Get the number of tokens in an encoding.
13+
"""
14+
@spec get_length(Encoding.t()) :: non_neg_integer()
15+
defdelegate get_length(encoding), to: Tokenizers.Native, as: :encoding_get_length
1316

1417
@doc """
15-
Get the tokens from an encoding.
18+
Return the number of sequences combined in this Encoding
1619
"""
17-
@spec get_tokens(Encoding.t()) :: [binary()]
18-
def get_tokens(encoding), do: encoding |> Native.get_tokens() |> Shared.unwrap()
20+
@spec get_n_sequences(Encoding.t()) :: non_neg_integer()
21+
defdelegate get_n_sequences(encoding), to: Tokenizers.Native, as: :encoding_get_n_sequences
22+
23+
@doc """
24+
Set the given sequence id for the whole range of tokens contained in this Encoding.
25+
"""
26+
@spec set_sequence_id(Encoding.t(), non_neg_integer()) :: Encoding.t()
27+
defdelegate set_sequence_id(encoding, id), to: Tokenizers.Native, as: :encoding_set_sequence_id
1928

2029
@doc """
2130
Get the ids from an encoding.
2231
"""
2332
@spec get_ids(Encoding.t()) :: [integer()]
24-
def get_ids(encoding), do: encoding |> Native.get_ids() |> Shared.unwrap()
33+
defdelegate get_ids(encoding), to: Tokenizers.Native, as: :encoding_get_ids
2534

2635
@doc """
2736
Same as `get_ids/1`, but returns binary with u32 values.
2837
"""
2938
@spec get_u32_ids(Encoding.t()) :: binary()
30-
def get_u32_ids(encoding), do: encoding |> Native.get_u32_ids() |> Shared.unwrap()
39+
defdelegate get_u32_ids(encoding), to: Tokenizers.Native, as: :encoding_get_u32_ids
40+
41+
@doc """
42+
Get token type ids from an encoding.
43+
"""
44+
@spec get_type_ids(Encoding.t()) :: [integer()]
45+
defdelegate get_type_ids(encoding), to: Tokenizers.Native, as: :encoding_get_type_ids
46+
47+
@doc """
48+
Same as `get_type_ids/1`, but returns binary with u32 values.
49+
"""
50+
@spec get_u32_type_ids(Encoding.t()) :: binary()
51+
defdelegate get_u32_type_ids(encoding), to: Tokenizers.Native, as: :encoding_get_u32_type_ids
3152

3253
@doc """
3354
Get the attention mask from an encoding.
3455
"""
3556
@spec get_attention_mask(Encoding.t()) :: [integer()]
36-
def get_attention_mask(encoding), do: encoding |> Native.get_attention_mask() |> Shared.unwrap()
57+
defdelegate get_attention_mask(encoding),
58+
to: Tokenizers.Native,
59+
as: :encoding_get_attention_mask
3760

3861
@doc """
3962
Same as `get_attention_mask/1`, but returns binary with u32 values.
4063
"""
4164
@spec get_u32_attention_mask(Encoding.t()) :: binary()
42-
def get_u32_attention_mask(encoding),
43-
do: encoding |> Native.get_u32_attention_mask() |> Shared.unwrap()
65+
defdelegate get_u32_attention_mask(encoding),
66+
to: Tokenizers.Native,
67+
as: :encoding_get_u32_attention_mask
4468

4569
@doc """
46-
Get token type ids from an encoding.
70+
Get the special tokens mask from an encoding.
4771
"""
48-
@spec get_type_ids(Encoding.t()) :: [integer()]
49-
def get_type_ids(encoding), do: encoding |> Native.get_type_ids() |> Shared.unwrap()
72+
@spec get_special_tokens_mask(Encoding.t()) :: [integer()]
73+
defdelegate get_special_tokens_mask(encoding),
74+
to: Tokenizers.Native,
75+
as: :encoding_get_special_tokens_mask
5076

5177
@doc """
52-
Same as `get_type_ids/1`, but returns binary with u32 values.
78+
Same as `get_special_tokens_mask/1`, but returns binary with u32 values.
5379
"""
54-
@spec get_u32_type_ids(Encoding.t()) :: binary()
55-
def get_u32_type_ids(encoding),
56-
do: encoding |> Native.get_u32_type_ids() |> Shared.unwrap()
80+
@spec get_u32_special_tokens_mask(Encoding.t()) :: binary()
81+
defdelegate get_u32_special_tokens_mask(encoding),
82+
to: Tokenizers.Native,
83+
as: :encoding_get_u32_special_tokens_mask
84+
85+
@doc """
86+
Get the tokens from an encoding.
87+
"""
88+
@spec get_tokens(Encoding.t()) :: [binary()]
89+
defdelegate get_tokens(encoding), to: Tokenizers.Native, as: :encoding_get_tokens
5790

5891
@doc """
59-
Get special tokens mask from an encoding.
92+
Get word ids from an encoding.
6093
"""
61-
@spec get_special_tokens_mask(Encoding.t()) :: [integer()]
62-
def get_special_tokens_mask(encoding),
63-
do: encoding |> Native.get_special_tokens_mask() |> Shared.unwrap()
94+
@spec get_word_ids(Encoding.t()) :: [non_neg_integer() | nil]
95+
defdelegate get_word_ids(encoding), to: Tokenizers.Native, as: :encoding_get_word_ids
6496

6597
@doc """
66-
Same as `get_special_tokens_mask/1`, but returns binary with u32 values.
98+
Get sequence ids from an encoding.
6799
"""
68-
@spec get_u32_special_tokens_mask(Encoding.t()) :: binary()
69-
def get_u32_special_tokens_mask(encoding),
70-
do: encoding |> Native.get_u32_special_tokens_mask() |> Shared.unwrap()
100+
@spec get_sequence_ids(Encoding.t()) :: [non_neg_integer() | nil]
101+
defdelegate get_sequence_ids(encoding), to: Tokenizers.Native, as: :encoding_get_sequence_ids
71102

72103
@doc """
73104
Get offsets from an encoding.
74105
75106
The offsets are expressed in terms of UTF-8 bytes.
76107
"""
77108
@spec get_offsets(Encoding.t()) :: [{integer(), integer()}]
78-
def get_offsets(encoding), do: encoding |> Native.get_offsets() |> Shared.unwrap()
109+
defdelegate get_offsets(encoding), to: Tokenizers.Native, as: :encoding_get_offsets
79110

80111
@doc """
81-
Truncate the encoding to the given length.
112+
Get the overflow from an encoding.
113+
"""
114+
@spec get_overflowing(Encoding.t()) :: [Encoding.t()]
115+
defdelegate get_overflowing(encoding), to: Tokenizers.Native, as: :encoding_get_overflowing
82116

83-
## Options
84-
* `direction` - The truncation direction. Can be `:right` or `:left`. Default: `:right`.
85-
* `stride` - The length of previous content to be included in each overflowing piece. Default: `0`.
117+
@doc """
118+
Get the encoded tokens corresponding to the word at the given index in the input sequence,
119+
with the form (start_token, end_token + 1)
86120
"""
87-
@spec truncate(encoding :: Encoding.t(), length :: integer(), opts :: Keyword.t()) ::
88-
Encoding.t()
89-
def truncate(encoding, max_len, opts \\ []) do
90-
opts = Keyword.validate!(opts, direction: :right, stride: 0)
91-
encoding |> Native.truncate(max_len, opts[:stride], "#{opts[:direction]}") |> Shared.unwrap()
92-
end
121+
@spec word_to_tokens(Encoding.t(), non_neg_integer(), non_neg_integer()) ::
122+
{non_neg_integer(), non_neg_integer()} | nil
123+
defdelegate word_to_tokens(encoding, word, seq_id),
124+
to: Tokenizers.Native,
125+
as: :encoding_word_to_tokens
126+
127+
@doc """
128+
Get the offsets of the word at the given index in the input sequence.
129+
"""
130+
@spec word_to_chars(Encoding.t(), non_neg_integer(), non_neg_integer()) ::
131+
{non_neg_integer(), non_neg_integer()} | nil
132+
defdelegate word_to_chars(encoding, word, seq_id),
133+
to: Tokenizers.Native,
134+
as: :encoding_word_to_chars
135+
136+
@doc """
137+
Returns the index of the sequence containing the given token
138+
"""
139+
@spec token_to_sequence(Encoding.t(), non_neg_integer()) :: non_neg_integer() | nil
140+
defdelegate token_to_sequence(encoding, token),
141+
to: Tokenizers.Native,
142+
as: :encoding_token_to_sequence
143+
144+
@doc """
145+
Get the offsets of the token at the given index.
146+
"""
147+
@spec token_to_chars(Encoding.t(), non_neg_integer()) ::
148+
{non_neg_integer(), {non_neg_integer(), non_neg_integer()}} | nil
149+
defdelegate token_to_chars(encoding, token), to: Tokenizers.Native, as: :encoding_token_to_chars
150+
151+
@doc """
152+
Get the word that contains the token at the given index.
153+
"""
154+
@spec token_to_word(Encoding.t(), non_neg_integer()) ::
155+
{non_neg_integer(), non_neg_integer()} | nil
156+
defdelegate token_to_word(encoding, token), to: Tokenizers.Native, as: :encoding_token_to_word
157+
158+
@doc """
159+
Get the token that contains the given char.
160+
"""
161+
@spec char_to_token(Encoding.t(), non_neg_integer(), non_neg_integer()) ::
162+
non_neg_integer() | nil
163+
defdelegate char_to_token(encoding, position, seq_id),
164+
to: Tokenizers.Native,
165+
as: :encoding_char_to_token
166+
167+
@doc """
168+
Get the word that contains the given char.
169+
"""
170+
@spec char_to_word(Encoding.t(), non_neg_integer(), non_neg_integer()) ::
171+
non_neg_integer() | nil
172+
defdelegate char_to_word(encoding, position, seq_id),
173+
to: Tokenizers.Native,
174+
as: :encoding_char_to_word
175+
176+
@typedoc """
177+
Options for padding. All options can be ommited.
178+
179+
* `direction` (default `:right`) - The padding direction.
180+
* `pad_id` (default `0`) - The id corresponding to the padding token.
181+
* `pad_type_id` (default `0`) - The type ID corresponding to the padding token.
182+
* `pad_token` (default `[PDA]`) - The padding token to use.
183+
184+
"""
185+
@type padding_opts :: [
186+
pad_id: non_neg_integer(),
187+
pad_type_id: non_neg_integer(),
188+
pad_token: String.t(),
189+
direction: :left | :right
190+
]
93191

94192
@doc """
95193
Pad the encoding to the given length.
194+
"""
195+
@spec pad(Encoding.t(), non_neg_integer(), padding_opts()) :: Encoding.t()
196+
defdelegate pad(encoding, target_length, opts \\ []),
197+
to: Tokenizers.Native,
198+
as: :encoding_pad
96199

97-
## Options
98-
* `direction` - The padding direction. Can be `:right` or `:left`. Default: `:right`.
99-
* `pad_id` - The id corresponding to the padding token. Default: `0`.
100-
* `pad_token` - The padding token to use. Default: `"[PAD]"`.
101-
* `pad_type_id` - The type ID corresponding to the padding token. Default: `0`.
102-
"""
103-
@spec pad(encoding :: Encoding.t(), length :: pos_integer(), opts :: Keyword.t()) ::
104-
Encoding.t()
105-
def pad(encoding, length, opts \\ []) do
106-
opts =
107-
Keyword.validate!(opts, direction: :right, pad_id: 0, pad_type_id: 0, pad_token: "[PAD]")
108-
109-
encoding
110-
|> Native.pad(
111-
length,
112-
opts[:pad_id],
113-
opts[:pad_type_id],
114-
opts[:pad_token],
115-
"#{opts[:direction]}"
116-
)
117-
|> Shared.unwrap()
118-
end
200+
@typedoc """
201+
Options for truncation. All options can be ommited.
202+
203+
* `stride` (default `0`) - The length of previous content to be included in each overflowing piece.
204+
* `direction` (default `:right`) - The truncation direction.
205+
"""
206+
@type truncation_opts :: [stride: non_neg_integer(), direction: :left | :right]
207+
208+
@doc """
209+
Truncate the encoding to the given length.
210+
"""
211+
@spec truncate(Encoding.t(), non_neg_integer(), truncation_opts()) :: Encoding.t()
212+
defdelegate truncate(encoding, max_length, opts \\ []),
213+
to: Tokenizers.Native,
214+
as: :encoding_truncate
119215

120216
@doc """
121217
Returns the number of tokens in an `Encoding.t()`.
122218
"""
123219
@spec n_tokens(encoding :: Encoding.t()) :: non_neg_integer()
124-
def n_tokens(encoding), do: encoding |> Native.n_tokens() |> Shared.unwrap()
220+
defdelegate n_tokens(encoding), to: Tokenizers.Native, as: :encoding_get_length
125221
end
126222

127223
defimpl Inspect, for: Tokenizers.Encoding do
@@ -131,10 +227,10 @@ defimpl Inspect, for: Tokenizers.Encoding do
131227

132228
def inspect(encoding, opts) do
133229
attrs = [
134-
n_tokens: Encoding.n_tokens(encoding),
230+
length: Encoding.get_length(encoding),
135231
ids: Encoding.get_ids(encoding)
136232
]
137233

138-
concat(["#Tokenizers.Tokenizer<", to_doc(attrs, opts), ">"])
234+
concat(["#Tokenizers.Encoding<", to_doc(attrs, opts), ">"])
139235
end
140236
end

0 commit comments

Comments
 (0)