Skip to content

Commit 7eb7d6d

Browse files
author
Philip Sampaio
authored
Add option to use cache when download pretrained files (#23)
* Add option to use cache when download pretrained files This makes the usage simplier and faster. * Use etag to load cached file * Encode URL and Etag in the build of file name
1 parent 7b2a7ab commit 7eb7d6d

File tree

3 files changed

+112
-15
lines changed

3 files changed

+112
-15
lines changed

lib/tokenizers/http_client.ex

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ defmodule Tokenizers.HTTPClient do
6262

6363
case :httpc.request(opts[:method], {url, headers}, http_options, options) do
6464
{:ok, {{_, status, _}, headers, body}} ->
65-
{:ok, %{status: status, headers: headers, body: body}}
65+
{:ok, %{status: status, headers: normalize_headers(headers), body: body}}
6666

6767
{:ok, {status, body}} ->
6868
{:ok, %{status: status, body: body, headers: []}}
@@ -71,4 +71,10 @@ defmodule Tokenizers.HTTPClient do
7171
{:error, "could not make request #{url}: #{inspect(reason)}"}
7272
end
7373
end
74+
75+
defp normalize_headers(headers) do
76+
for {key, value} <- headers do
77+
{List.to_string(key), List.to_string(value)}
78+
end
79+
end
7480
end

lib/tokenizers/tokenizer.ex

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ defmodule Tokenizers.Tokenizer do
3131
@doc """
3232
Instantiate a new tokenizer from an existing file on the Hugging Face Hub.
3333
34-
This is going to download a tokenizer file, save to a file and load that file.
34+
This is going to download a tokenizer file, save it to disk and load that file.
3535
3636
## Options
3737
@@ -48,10 +48,23 @@ defmodule Tokenizers.Tokenizer do
4848
* `:revision` - The revision name that should be used for fetching the tokenizers
4949
from Hugging Face.
5050
51+
* `:use_cache` - Tells if it should read from cache when the file already exists.
52+
Defaults to `true`.
53+
54+
* `:cache_dir` - The directory where cache is saved. Files are written to cache
55+
even if `:use_cache` is false. By default it uses `:filename.basedir/3` to get
56+
a cache dir based in the "tokenizers_elixir" application name.
57+
5158
"""
5259
@spec from_pretrained(String.t(), Keyword.t()) :: {:ok, Tokenizer.t()} | {:error, term()}
5360
def from_pretrained(identifier, opts \\ []) do
54-
opts = Keyword.validate!(opts, revision: "main", http_client: {Tokenizers.HTTPClient, []})
61+
opts =
62+
Keyword.validate!(opts,
63+
revision: "main",
64+
use_cache: true,
65+
cache_dir: :filename.basedir(:user_cache, "tokenizers_elixir"),
66+
http_client: {Tokenizers.HTTPClient, []}
67+
)
5568

5669
{http_client, http_opts} = opts[:http_client]
5770

@@ -68,17 +81,53 @@ defmodule Tokenizers.Tokenizer do
6881
|> Keyword.put(:method, :get)
6982
|> Keyword.update(:headers, headers, fn existing -> existing ++ headers end)
7083

84+
cache_dir = opts[:cache_dir]
85+
86+
file_path_fun = fn etag ->
87+
Path.join(cache_dir, entry_filename(url, etag))
88+
end
89+
90+
if opts[:use_cache] do
91+
with {:ok, response} <- request(http_client, Keyword.put(http_opts, :method, :head)) do
92+
etag = fetch_etag(response.headers)
93+
file_path = file_path_fun.(etag)
94+
95+
if File.exists?(file_path) do
96+
from_file(file_path)
97+
else
98+
with {:ok, response} <- request(http_client, http_opts) do
99+
File.mkdir_p!(cache_dir)
100+
File.write!(file_path, response.body)
101+
102+
from_file(file_path)
103+
end
104+
end
105+
end
106+
else
107+
with {:ok, response} <- request(http_client, http_opts) do
108+
etag = fetch_etag(response.headers)
109+
file_path = file_path_fun.(etag)
110+
111+
File.mkdir_p!(cache_dir)
112+
File.write!(file_path, response.body)
113+
114+
from_file(file_path)
115+
end
116+
end
117+
end
118+
119+
defp fetch_etag(headers) do
120+
{_, etag} = List.keyfind!(headers, "etag", 0)
121+
122+
etag
123+
end
124+
125+
defp request(http_client, http_opts) do
71126
case http_client.request(http_opts) do
72127
{:ok, response} ->
73128
case response.status do
74129
status when status in 200..299 ->
75-
cache_dir = :filename.basedir(:user_cache, "tokenizers_elixir")
76-
:ok = File.mkdir_p(cache_dir)
77-
file_path = Path.join(cache_dir, "#{identifier}.json")
78-
79-
:ok = File.write(file_path, response.body)
80-
81-
from_file(file_path)
130+
{:ok, response}
82131

83132
404 ->
84133
{:error, :not_found}
@@ -93,6 +142,18 @@ defmodule Tokenizers.Tokenizer do
93142
end
94143
end
95144

145+
defp entry_filename(url, etag) do
146+
encode_url(url) <> "." <> encode_etag(etag)
147+
end
148+
149+
defp encode_url(url) do
150+
url |> :erlang.md5() |> Base.encode32(case: :lower, padding: false)
151+
end
152+
153+
defp encode_etag(etag) do
154+
Base.encode32(etag, case: :lower, padding: false)
155+
end
156+
96157
@doc """
97158
Instantiate a new tokenizer from the file at the given path.
98159
"""

test/tokenizers/tokenizer_test.exs

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,19 @@ defmodule Tokenizers.TokenizerTest do
3030
def request(opts) do
3131
send(self(), {:request, opts})
3232

33+
body =
34+
case opts[:method] do
35+
:get ->
36+
File.read!("test/fixtures/bert-base-cased.json")
37+
38+
:head ->
39+
""
40+
end
41+
3342
{:ok,
3443
%{
35-
body: File.read!("test/fixtures/bert-base-cased.json"),
36-
headers: [],
44+
body: body,
45+
headers: [{"etag", "test-etag"}],
3746
status: opts[:test_status]
3847
}}
3948
end
@@ -46,9 +55,12 @@ defmodule Tokenizers.TokenizerTest do
4655
end
4756
end
4857

49-
test "load from pretrained successfully" do
58+
@tag :tmp_dir
59+
test "load from pretrained successfully", %{tmp_dir: tmp_dir} do
5060
{:ok, tokenizer} =
5161
Tokenizer.from_pretrained("bert-base-cased",
62+
use_cache: false,
63+
cache_dir: tmp_dir,
5264
http_client: {SuccessHTTPClient, [test_status: 200, headers: [{"test-header", "42"}]]}
5365
)
5466

@@ -62,18 +74,36 @@ defmodule Tokenizers.TokenizerTest do
6274

6375
assert [{"test-header", "42"}, {"user-agent", "tokenizers-elixir/" <> _app_version}] =
6476
opts[:headers]
77+
78+
{:ok, tokenizer} =
79+
Tokenizer.from_pretrained("bert-base-cased",
80+
use_cache: true,
81+
cache_dir: tmp_dir,
82+
http_client: {SuccessHTTPClient, [test_status: 200]}
83+
)
84+
85+
assert Tokenizer.get_vocab_size(tokenizer) == 28996
86+
87+
assert_received {:request, opts}
88+
assert opts[:method] == :head
6589
end
6690

67-
test "returns error when status is not found" do
91+
@tag :tmp_dir
92+
test "returns error when status is not found", %{tmp_dir: tmp_dir} do
6893
assert {:error, :not_found} =
6994
Tokenizer.from_pretrained("bert-base-cased",
95+
use_cache: false,
96+
cache_dir: tmp_dir,
7097
http_client: {SuccessHTTPClient, [test_status: 404]}
7198
)
7299
end
73100

74-
test "returns error when request is not successful" do
101+
@tag :tmp_dir
102+
test "returns error when request is not successful", %{tmp_dir: tmp_dir} do
75103
assert {:error, error} =
76104
Tokenizer.from_pretrained("bert-base-cased",
105+
use_cache: false,
106+
cache_dir: tmp_dir,
77107
http_client: {ErrorHTTPClient, []}
78108
)
79109

0 commit comments

Comments
 (0)