Skip to content

Commit f0540fc

Browse files
authored
Supporting adding special tokens when creating a tokenizer (#26)
1 parent ae34237 commit f0540fc

File tree

4 files changed

+61
-9
lines changed

4 files changed

+61
-9
lines changed

lib/tokenizers/native.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ defmodule Tokenizers.Native do
1515
def decode_batch(_tokenizer, _ids, _skip_special_tokens), do: err()
1616
def encode(_tokenizer, _input, _add_special_tokens), do: err()
1717
def encode_batch(_tokenizer, _input, _add_special_tokens), do: err()
18-
def from_file(_path), do: err()
18+
def from_file(_path, _additional_special_tokens), do: err()
1919
def get_attention_mask(_encoding), do: err()
2020
def get_type_ids(_encoding), do: err()
2121
def get_ids(_encoding), do: err()

lib/tokenizers/tokenizer.ex

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ defmodule Tokenizers.Tokenizer do
5555
even if `:use_cache` is false. By default it uses `:filename.basedir/3` to get
5656
a cache dir based in the "tokenizers_elixir" application name.
5757
58+
* `:additional_special_tokens` - A list of special tokens to append to the tokenizer.
59+
Defaults to `[]`.
5860
"""
5961
@spec from_pretrained(String.t(), Keyword.t()) :: {:ok, Tokenizer.t()} | {:error, term()}
6062
def from_pretrained(identifier, opts \\ []) do
@@ -63,7 +65,8 @@ defmodule Tokenizers.Tokenizer do
6365
revision: "main",
6466
use_cache: true,
6567
cache_dir: :filename.basedir(:user_cache, "tokenizers_elixir"),
66-
http_client: {Tokenizers.HTTPClient, []}
68+
http_client: {Tokenizers.HTTPClient, []},
69+
additional_special_tokens: []
6770
)
6871

6972
{http_client, http_opts} = opts[:http_client]
@@ -87,19 +90,21 @@ defmodule Tokenizers.Tokenizer do
8790
Path.join(cache_dir, entry_filename(url, etag))
8891
end
8992

93+
tokenizer_opts = Keyword.take(opts, [:additional_special_tokens])
94+
9095
if opts[:use_cache] do
9196
with {:ok, response} <- request(http_client, Keyword.put(http_opts, :method, :head)) do
9297
etag = fetch_etag(response.headers)
9398
file_path = file_path_fun.(etag)
9499

95100
if File.exists?(file_path) do
96-
from_file(file_path)
101+
from_file(file_path, tokenizer_opts)
97102
else
98103
with {:ok, response} <- request(http_client, http_opts) do
99104
File.mkdir_p!(cache_dir)
100105
File.write!(file_path, response.body)
101106

102-
from_file(file_path)
107+
from_file(file_path, tokenizer_opts)
103108
end
104109
end
105110
end
@@ -111,7 +116,7 @@ defmodule Tokenizers.Tokenizer do
111116
File.mkdir_p!(cache_dir)
112117
File.write!(file_path, response.body)
113118

114-
from_file(file_path)
119+
from_file(file_path, tokenizer_opts)
115120
end
116121
end
117122
end
@@ -156,9 +161,17 @@ defmodule Tokenizers.Tokenizer do
156161

157162
@doc """
158163
Instantiate a new tokenizer from the file at the given path.
164+
165+
## Options
166+
167+
* `:additional_special_tokens` - A list of special tokens to append to the tokenizer.
168+
Defaults to `[]`.
159169
"""
160-
@spec from_file(String.t()) :: {:ok, Tokenizer.t()} | {:error, term()}
161-
def from_file(path), do: Native.from_file(path)
170+
@spec from_file(String.t(), Keyword.t()) :: {:ok, Tokenizer.t()} | {:error, term()}
171+
def from_file(path, opts \\ []) do
172+
opts = Keyword.validate!(opts, additional_special_tokens: [])
173+
Native.from_file(path, opts[:additional_special_tokens])
174+
end
162175

163176
@doc """
164177
Save the tokenizer to the provided path.

native/ex_tokenizers/src/tokenizer.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::collections::HashMap;
22

33
use rustler::Term;
44

5+
use tokenizers::tokenizer::AddedToken;
56
use tokenizers::{EncodeInput, Tokenizer};
67

78
use crate::encoding::ExTokenizersEncoding;
@@ -31,8 +32,14 @@ impl ExTokenizersTokenizer {
3132
}
3233

3334
#[rustler::nif(schedule = "DirtyIo")]
34-
pub fn from_file(path: &str) -> Result<ExTokenizersTokenizer, ExTokenizersError> {
35-
let tokenizer = Tokenizer::from_file(path)?;
35+
pub fn from_file(
36+
path: &str,
37+
additional_special_tokens: Vec<String>,
38+
) -> Result<ExTokenizersTokenizer, ExTokenizersError> {
39+
let mut tokenizer = Tokenizer::from_file(path)?;
40+
for token in additional_special_tokens {
41+
tokenizer.add_special_tokens(&[AddedToken::from(token, true)]);
42+
}
3643
Ok(ExTokenizersTokenizer::new(tokenizer))
3744
}
3845

test/tokenizers/tokenizer_test.exs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,38 @@ defmodule Tokenizers.TokenizerTest do
2525
end
2626
end
2727

28+
describe "modify tokenizer" do
29+
test "can add special tokens" do
30+
special_tokens = ["<|test|>"]
31+
32+
{:ok, tokenizer} =
33+
Tokenizer.from_file("test/fixtures/bert-base-cased.json",
34+
additional_special_tokens: special_tokens
35+
)
36+
37+
assert Tokenizer.get_vocab_size(tokenizer) == 28997
38+
end
39+
40+
test "can decode special tokens" do
41+
text = ["This <|test|>is a test<|also|>", "<|test|>And so<|also|> is this<|test|>"]
42+
special_tokens = ["<|test|>", "<|also|>"]
43+
44+
{:ok, tokenizer} =
45+
Tokenizer.from_file("test/fixtures/bert-base-cased.json",
46+
additional_special_tokens: special_tokens
47+
)
48+
49+
{:ok, encodings} = Tokenizer.encode(tokenizer, text)
50+
51+
{:ok, decodings} =
52+
Tokenizer.decode(tokenizer, Enum.map(encodings, &Encoding.get_ids/1),
53+
skip_special_tokens: true
54+
)
55+
56+
assert ["This is a test", "And so is this"] == decodings
57+
end
58+
end
59+
2860
describe "from_pretrained/2" do
2961
defmodule SuccessHTTPClient do
3062
def request(opts) do

0 commit comments

Comments
 (0)