Skip to content

Commit 26d864b

Browse files
Add option to disable padding and truncation when loading tokenizer (#46)
Co-authored-by: José Valim <[email protected]>
1 parent 20295cf commit 26d864b

File tree

4 files changed

+84
-46
lines changed

4 files changed

+84
-46
lines changed

lib/tokenizers/encoding.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ defmodule Tokenizers.Encoding do
179179
* `direction` (default `:right`) - The padding direction.
180180
* `pad_id` (default `0`) - The id corresponding to the padding token.
181181
* `pad_type_id` (default `0`) - The type ID corresponding to the padding token.
182-
* `pad_token` (default `[PDA]`) - The padding token to use.
182+
* `pad_token` (default `[PAD]`) - The padding token to use.
183183
184184
"""
185185
@type padding_opts :: [

lib/tokenizers/tokenizer.ex

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,30 @@ defmodule Tokenizers.Tokenizer do
6767
6868
* `:additional_special_tokens` - A list of special tokens to append to the tokenizer.
6969
Defaults to `[]`.
70+
71+
* `:padding` - Override for padding configuration. Currently the only supported
72+
value is `:none` to disable padding. By default the configuration is restored
73+
from the file.
74+
75+
* `:truncation` - Override for truncation configuration. Currently the only supported
76+
value is `:none` to disable truncation. By default the configuration is restored
77+
from the file.
78+
7079
"""
7180
@spec from_pretrained(String.t(), Keyword.t()) :: {:ok, t()} | {:error, term()}
7281
def from_pretrained(identifier, opts \\ []) do
7382
opts =
74-
Keyword.validate!(opts,
75-
revision: "main",
76-
use_cache: true,
77-
cache_dir: :filename.basedir(:user_cache, "tokenizers_elixir"),
78-
http_client: {Tokenizers.HTTPClient, []},
79-
additional_special_tokens: []
83+
Keyword.validate!(
84+
opts,
85+
[
86+
:padding,
87+
:truncation,
88+
revision: "main",
89+
use_cache: true,
90+
cache_dir: :filename.basedir(:user_cache, "tokenizers_elixir"),
91+
http_client: {Tokenizers.HTTPClient, []},
92+
additional_special_tokens: []
93+
]
8094
)
8195

8296
{http_client, http_opts} = opts[:http_client]
@@ -100,19 +114,21 @@ defmodule Tokenizers.Tokenizer do
100114
Path.join(cache_dir, entry_filename(url, etag))
101115
end
102116

117+
load_opts = Keyword.take(opts, [:additional_special_tokens, :padding, :truncation])
118+
103119
if opts[:use_cache] do
104120
with {:ok, response} <- request(http_client, Keyword.put(http_opts, :method, :head)) do
105121
etag = fetch_etag(response.headers)
106122
file_path = file_path_fun.(etag)
107123

108124
if File.exists?(file_path) do
109-
from_file(file_path, Keyword.take(opts, [:additional_special_tokens]))
125+
from_file(file_path, load_opts)
110126
else
111127
with {:ok, response} <- request(http_client, http_opts) do
112128
File.mkdir_p!(cache_dir)
113129
File.write!(file_path, response.body)
114130

115-
from_file(file_path, Keyword.take(opts, [:additional_special_tokens]))
131+
from_file(file_path, load_opts)
116132
end
117133
end
118134
end
@@ -124,7 +140,7 @@ defmodule Tokenizers.Tokenizer do
124140
File.mkdir_p!(cache_dir)
125141
File.write!(file_path, response.body)
126142

127-
from_file(file_path, Keyword.take(opts, [:additional_special_tokens]))
143+
from_file(file_path, load_opts)
128144
end
129145
end
130146
end
@@ -167,28 +183,40 @@ defmodule Tokenizers.Tokenizer do
167183
Base.encode32(etag, case: :lower, padding: false)
168184
end
169185

186+
@typedoc """
187+
Options to set on the loaded tokenizer.
188+
189+
* `:additional_special_tokens - a list of special tokens to append to the tokenizer.
190+
Defaults to `[]`.
191+
192+
* `:padding` - Override for padding configuration. Currently the only supported
193+
value is `:none` to disable padding. By default the configuration is restored
194+
from the file.
195+
196+
* `:truncation` - Override for truncation configuration. Currently the only supported
197+
value is `:none` to disable truncation. By default the configuration is restored
198+
from the file.
199+
200+
"""
201+
@type load_options ::
202+
[
203+
additional_special_tokens: [String.t() | Tokenizers.AddedToken.t()],
204+
padding: :none,
205+
truncation: :none
206+
]
207+
170208
@doc """
171209
Instantiate a new tokenizer from the file at the given path.
172-
You can specify a list of special tokens to append to the tokenizer.
173210
"""
174-
@spec from_file(
175-
path :: String.t(),
176-
options :: [additional_special_tokens :: [String.t() | Tokenizers.AddedToken.t()]]
177-
) ::
178-
{:ok, t()} | {:error, term()}
211+
@spec from_file(path :: String.t(), load_options()) :: {:ok, t()} | {:error, term()}
179212
defdelegate from_file(path, options \\ []),
180213
to: Tokenizers.Native,
181214
as: :tokenizer_from_file
182215

183216
@doc """
184217
Instantiate a new tokenizer from the buffer.
185-
You can specify a list of special tokens to append to the tokenizer.
186218
"""
187-
@spec from_buffer(
188-
data :: String.t(),
189-
options :: [additional_special_tokens :: [String.t() | Tokenizers.AddedToken.t()]]
190-
) ::
191-
{:ok, t()} | {:error, term()}
219+
@spec from_buffer(data :: String.t(), load_options()) :: {:ok, t()} | {:error, term()}
192220
defdelegate from_buffer(data, options \\ []),
193221
to: Tokenizers.Native,
194222
as: :tokenizer_from_buffer

native/ex_tokenizers/src/pre_tokenizers.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ pub fn pre_tokenizers_byte_level_alphabet() -> Vec<u32> {
164164

165165
#[rustler::nif]
166166
pub fn pre_tokenizers_whitespace() -> ExTokenizersPreTokenizer {
167-
ExTokenizersPreTokenizer::new(tokenizers::pre_tokenizers::whitespace::Whitespace::default())
167+
ExTokenizersPreTokenizer::new(tokenizers::pre_tokenizers::whitespace::Whitespace)
168168
}
169169

170170
#[rustler::nif]

native/ex_tokenizers/src/tokenizer.rs

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -60,35 +60,18 @@ pub fn tokenizer_init(
6060
#[derive(NifTaggedEnum)]
6161
pub enum LoadOption {
6262
AdditionalSpecialTokens(Vec<AddedSpecialTokenInput>),
63+
// Currently only :none is supported
64+
Padding(rustler::Atom),
65+
Truncation(rustler::Atom),
6366
}
6467

6568
#[rustler::nif(schedule = "DirtyIo")]
6669
pub fn tokenizer_from_file(
6770
path: &str,
6871
options: Vec<LoadOption>,
6972
) -> Result<ExTokenizersTokenizer, ExTokenizersError> {
70-
struct Opts {
71-
additional_special_tokens: Vec<AddedSpecialTokenInput>,
72-
}
73-
let mut opts = Opts {
74-
additional_special_tokens: vec![],
75-
};
76-
for opt in options {
77-
match opt {
78-
LoadOption::AdditionalSpecialTokens(tokens) => {
79-
opts.additional_special_tokens = tokens;
80-
}
81-
}
82-
}
83-
8473
let mut tokenizer = TokenizerImpl::from_file(path)?;
85-
tokenizer.add_special_tokens(
86-
opts.additional_special_tokens
87-
.iter()
88-
.map(|t| t.into())
89-
.collect::<Vec<_>>()
90-
.as_ref(),
91-
);
74+
tokenizer = apply_load_options(tokenizer, options);
9275
Ok(tokenizer.into())
9376
}
9477

@@ -97,28 +80,55 @@ pub fn tokenizer_from_buffer(
9780
data: String,
9881
options: Vec<LoadOption>,
9982
) -> Result<ExTokenizersTokenizer, ExTokenizersError> {
83+
let mut tokenizer: ExTokenizerImpl = data.parse()?;
84+
tokenizer = apply_load_options(tokenizer, options);
85+
Ok(tokenizer.into())
86+
}
87+
88+
fn apply_load_options(mut tokenizer: ExTokenizerImpl, options: Vec<LoadOption>) -> ExTokenizerImpl {
10089
struct Opts {
10190
additional_special_tokens: Vec<AddedSpecialTokenInput>,
91+
disable_padding: bool,
92+
disable_truncation: bool,
10293
}
94+
10395
let mut opts = Opts {
10496
additional_special_tokens: vec![],
97+
disable_padding: false,
98+
disable_truncation: false,
10599
};
100+
106101
for opt in options {
107102
match opt {
108103
LoadOption::AdditionalSpecialTokens(tokens) => {
109104
opts.additional_special_tokens = tokens;
110105
}
106+
LoadOption::Padding(_) => {
107+
opts.disable_padding = true;
108+
}
109+
LoadOption::Truncation(_) => {
110+
opts.disable_truncation = true;
111+
}
111112
}
112113
}
113-
let mut tokenizer: ExTokenizerImpl = data.parse()?;
114+
114115
tokenizer.add_special_tokens(
115116
opts.additional_special_tokens
116117
.iter()
117118
.map(|t| t.into())
118119
.collect::<Vec<_>>()
119120
.as_ref(),
120121
);
121-
Ok(tokenizer.into())
122+
123+
if opts.disable_padding {
124+
tokenizer.with_padding(None);
125+
}
126+
127+
if opts.disable_truncation {
128+
tokenizer.with_padding(None);
129+
}
130+
131+
tokenizer
122132
}
123133

124134
#[derive(NifTaggedEnum)]

0 commit comments

Comments
 (0)