Skip to content

Commit e3f2827

Browse files
author
Philip Sampaio
authored
Add basic HTTP client to download pretrained files (#21)
Close #20
1 parent 7ec2b83 commit e3f2827

File tree

12 files changed

+218
-1107
lines changed

12 files changed

+218
-1107
lines changed

.github/workflows/release.yml

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,13 @@ jobs:
3333
os: ubuntu-20.04,
3434
nif: "2.16",
3535
use-cross: true,
36-
features: "static_openssl",
3736
}
3837
- { target: aarch64-unknown-linux-gnu, os: ubuntu-20.04, nif: "2.16", use-cross: true }
39-
- { target: aarch64-unknown-linux-musl, os: ubuntu-20.04, nif: "2.16", use-cross: true, features: "static_openssl" }
38+
- { target: aarch64-unknown-linux-musl, os: ubuntu-20.04, nif: "2.16", use-cross: true }
4039
- { target: aarch64-apple-darwin, os: macos-11, nif: "2.16" }
4140
- { target: x86_64-apple-darwin, os: macos-11, nif: "2.16" }
4241
- { target: x86_64-unknown-linux-gnu, os: ubuntu-20.04, nif: "2.16" }
43-
- { target: x86_64-unknown-linux-musl, os: ubuntu-20.04, nif: "2.16", use-cross: true, features: "static_openssl" }
42+
- { target: x86_64-unknown-linux-musl, os: ubuntu-20.04, nif: "2.16", use-cross: true }
4443
- { target: x86_64-pc-windows-gnu, os: windows-2019, nif: "2.16" }
4544
- { target: x86_64-pc-windows-msvc, os: windows-2019, nif: "2.16" }
4645
# NIF version 2.15
@@ -49,14 +48,13 @@ jobs:
4948
os: ubuntu-20.04,
5049
nif: "2.15",
5150
use-cross: true,
52-
features: "static_openssl",
5351
}
5452
- { target: aarch64-unknown-linux-gnu, os: ubuntu-20.04, nif: "2.15", use-cross: true }
55-
- { target: aarch64-unknown-linux-musl, os: ubuntu-20.04, nif: "2.15", use-cross: true, features: "static_openssl" }
53+
- { target: aarch64-unknown-linux-musl, os: ubuntu-20.04, nif: "2.15", use-cross: true }
5654
- { target: aarch64-apple-darwin, os: macos-11, nif: "2.15" }
5755
- { target: x86_64-apple-darwin, os: macos-11, nif: "2.15" }
5856
- { target: x86_64-unknown-linux-gnu, os: ubuntu-20.04, nif: "2.15" }
59-
- { target: x86_64-unknown-linux-musl, os: ubuntu-20.04, nif: "2.15", use-cross: true, features: "static_openssl" }
57+
- { target: x86_64-unknown-linux-musl, os: ubuntu-20.04, nif: "2.15", use-cross: true }
6058
- { target: x86_64-pc-windows-gnu, os: windows-2019, nif: "2.15" }
6159
- { target: x86_64-pc-windows-msvc, os: windows-2019, nif: "2.15" }
6260
# NIF version 2.14
@@ -65,14 +63,13 @@ jobs:
6563
os: ubuntu-20.04,
6664
nif: "2.14",
6765
use-cross: true,
68-
features: "static_openssl",
6966
}
7067
- { target: aarch64-unknown-linux-gnu, os: ubuntu-20.04, nif: "2.14", use-cross: true }
71-
- { target: aarch64-unknown-linux-musl, os: ubuntu-20.04, nif: "2.14", use-cross: true, features: "static_openssl" }
68+
- { target: aarch64-unknown-linux-musl, os: ubuntu-20.04, nif: "2.14", use-cross: true }
7269
- { target: aarch64-apple-darwin, os: macos-11, nif: "2.14" }
7370
- { target: x86_64-apple-darwin, os: macos-11, nif: "2.14" }
7471
- { target: x86_64-unknown-linux-gnu, os: ubuntu-20.04, nif: "2.14" }
75-
- { target: x86_64-unknown-linux-musl, os: ubuntu-20.04, nif: "2.14", use-cross: true, features: "static_openssl" }
72+
- { target: x86_64-unknown-linux-musl, os: ubuntu-20.04, nif: "2.14", use-cross: true }
7673
- { target: x86_64-pc-windows-gnu, os: windows-2019, nif: "2.14" }
7774
- { target: x86_64-pc-windows-msvc, os: windows-2019, nif: "2.14" }
7875

@@ -122,9 +119,9 @@ jobs:
122119
shell: bash
123120
run: |
124121
if [ "${{ matrix.job.use-cross }}" == "true" ]; then
125-
cross build --release --target=${{ matrix.job.target }} --features=${{ matrix.job.features }}
122+
cross build --release --target=${{ matrix.job.target }}
126123
else
127-
cargo build --release --target=${{ matrix.job.target }} --features=${{ matrix.job.features }}
124+
cargo build --release --target=${{ matrix.job.target }}
128125
fi
129126
130127
- name: Rename lib to the final name

lib/tokenizers/http_client.ex

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
defmodule Tokenizers.HTTPClient do
2+
@moduledoc """
3+
A simple implementation of an HTTP client.
4+
5+
This is using the built-in `:httpc` module, configured to use SSL.
6+
The `request/1` function is similar to `Req.request/1`.
7+
"""
8+
9+
@base_url "https://huggingface.io"
10+
11+
@doc """
12+
Make an HTTP(s) requests.
13+
14+
## Options
15+
16+
* `:method` - An HTTP method. By default it uses the `:get` method.
17+
18+
* `:base_url` - The base URL to make requests. By default is #{inspect(@base_url)}.
19+
20+
* `:url` - A path to a resource. By default is "".
21+
22+
* `:headers` - A list of tuples representing HTTP headers. By default it's empty.
23+
24+
"""
25+
def request(opts) when is_list(opts) do
26+
opts = Keyword.validate!(opts, base_url: @base_url, headers: [], method: :get, url: "")
27+
28+
url = Path.join([opts[:base_url], opts[:url]]) |> String.to_charlist()
29+
headers = Enum.map(opts[:headers], fn {key, value} -> {String.to_charlist(key), value} end)
30+
31+
{:ok, _} = Application.ensure_all_started(:inets)
32+
{:ok, _} = Application.ensure_all_started(:ssl)
33+
34+
if proxy = System.get_env("HTTP_PROXY") || System.get_env("http_proxy") do
35+
%{host: host, port: port} = URI.parse(proxy)
36+
37+
:httpc.set_options([{:proxy, {{String.to_charlist(host), port}, []}}])
38+
end
39+
40+
proxy = System.get_env("HTTPS_PROXY") || System.get_env("https_proxy")
41+
42+
with true <- is_binary(proxy),
43+
%{host: host, port: port} when is_binary(host) and is_integer(port) <- URI.parse(proxy) do
44+
:httpc.set_options([{:https_proxy, {{String.to_charlist(host), port}, []}}])
45+
end
46+
47+
# https://erlef.github.io/security-wg/secure_coding_and_deployment_hardening/inets
48+
cacertfile = CAStore.file_path() |> String.to_charlist()
49+
50+
http_options = [
51+
ssl: [
52+
verify: :verify_peer,
53+
cacertfile: cacertfile,
54+
depth: 3,
55+
customize_hostname_check: [
56+
match_fun: :public_key.pkix_verify_hostname_match_fun(:https)
57+
]
58+
]
59+
]
60+
61+
options = [body_format: :binary]
62+
63+
case :httpc.request(opts[:method], {url, headers}, http_options, options) do
64+
{:ok, {{_, status, _}, headers, body}} ->
65+
{:ok, %{status: status, headers: headers, body: body}}
66+
67+
{:ok, {status, body}} ->
68+
{:ok, %{status: status, body: body, headers: []}}
69+
70+
{:error, reason} ->
71+
{:error, "could not make request #{url}: #{inspect(reason)}"}
72+
end
73+
end
74+
end

lib/tokenizers/native.ex

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ defmodule Tokenizers.Native do
1616
def encode(_tokenizer, _input, _add_special_tokens), do: err()
1717
def encode_batch(_tokenizer, _input, _add_special_tokens), do: err()
1818
def from_file(_path), do: err()
19-
def from_pretrained(_identifier), do: err()
2019
def get_attention_mask(_encoding), do: err()
2120
def get_type_ids(_encoding), do: err()
2221
def get_ids(_encoding), do: err()

lib/tokenizers/tokenizer.ex

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,68 @@ defmodule Tokenizers.Tokenizer do
3030

3131
@doc """
3232
Instantiate a new tokenizer from an existing file on the Hugging Face Hub.
33+
34+
This is going to download a tokenizer file, save to a file and load that file.
35+
36+
## Options
37+
38+
* `:http_client` - A tuple with a module and options. This module should implement
39+
the `request/1` function, accepting a keyword list with the options for a request.
40+
This is inspired by `Req.request/1`: https://hexdocs.pm/req/Req.html#request/1
41+
42+
The default HTTP client config is: `{Tokenizers.HTTPClient, []}`.
43+
Since it's inspired by `Req`, it's possible to use that client without any adjustments.
44+
45+
When making request, the options `:url` and `:method` are going to be overriden.
46+
`:headers` contains the "user-agent" set by default.
47+
48+
* `:revision` - The revision name that should be used for fetching the tokenizers
49+
from Hugging Face.
50+
3351
"""
34-
@spec from_pretrained(String.t()) :: {:ok, Tokenizer.t()} | {:error, term()}
35-
def from_pretrained(identifier), do: Native.from_pretrained(identifier)
52+
@spec from_pretrained(String.t(), Keyword.t()) :: {:ok, Tokenizer.t()} | {:error, term()}
53+
def from_pretrained(identifier, opts \\ []) do
54+
opts = Keyword.validate!(opts, revision: "main", http_client: {Tokenizers.HTTPClient, []})
55+
56+
{http_client, http_opts} = opts[:http_client]
57+
58+
{:ok, app_version} = :application.get_key(:tokenizers, :vsn)
59+
app_version = List.to_string(app_version)
60+
61+
headers = [{"user-agent", "tokenizers-elixir/#{app_version}"}]
62+
url = "/#{identifier}/resolve/#{opts[:revision]}/tokenizer.json"
63+
64+
http_opts =
65+
http_opts
66+
|> Keyword.put_new(:base_url, "https://huggingface.co")
67+
|> Keyword.put(:url, url)
68+
|> Keyword.put(:method, :get)
69+
|> Keyword.update(:headers, headers, fn existing -> existing ++ headers end)
70+
71+
case http_client.request(http_opts) do
72+
{:ok, response} ->
73+
case response.status do
74+
status when status in 200..299 ->
75+
cache_dir = :filename.basedir(:user_cache, "tokenizers_elixir")
76+
:ok = File.mkdir_p(cache_dir)
77+
file_path = Path.join(cache_dir, "#{identifier}.json")
78+
79+
:ok = File.write(file_path, response.body)
80+
81+
from_file(file_path)
82+
83+
404 ->
84+
{:error, :not_found}
85+
86+
other ->
87+
{:error,
88+
"download of pretrained file failed with status #{other}. Response: #{inspect(response.body)}"}
89+
end
90+
91+
{:error, _} = error ->
92+
error
93+
end
94+
end
3695

3796
@doc """
3897
Instantiate a new tokenizer from the file at the given path.

mix.exs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ defmodule Tokenizers.MixProject do
22
use Mix.Project
33

44
@source_url "https://github.com/elixir-nx/tokenizers"
5-
@version "0.1.2"
5+
@version "0.1.2-dev"
66

77
def project do
88
[
@@ -22,11 +22,14 @@ defmodule Tokenizers.MixProject do
2222
end
2323

2424
def application do
25-
[extra_applications: [:logger]]
25+
[
26+
extra_applications: [:logger, :inets, :public_key]
27+
]
2628
end
2729

2830
defp deps do
2931
[
32+
{:castore, "~> 0.1"},
3033
{:ex_doc, "~> 0.28", only: :docs, runtime: false},
3134
{:rustler, ">= 0.0.0", optional: true},
3235
{:rustler_precompiled, "~> 0.5"}

0 commit comments

Comments
 (0)