Skip to content

Commit 42ff186

Browse files
authored
add test-tokenizers-remote
1 parent 2d2e059 commit 42ff186

File tree

2 files changed

+149
-0
lines changed

2 files changed

+149
-0
lines changed

tests/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ llama_test(test-tokenizer-0 NAME test-tokenizer-0-qwen2 ARGS ${CMAKE
9898
llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
9999
llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
100100

101+
if (LLAMA_CURL)
102+
llama_build_and_test(test-tokenizers-remote.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
103+
endif()
104+
101105
if (LLAMA_LLGUIDANCE)
102106
llama_build_and_test(test-grammar-llguidance.cpp ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf)
103107
endif ()

tests/test-tokenizers-remote.cpp

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#include "arg.h"
2+
#include "common.h"
3+
4+
#include <string>
5+
#include <fstream>
6+
#include <vector>
7+
#include <json.hpp>
8+
9+
using json = nlohmann::json;
10+
11+
#undef NDEBUG
12+
#include <cassert>
13+
14+
std::string endpoint = "https://huggingface.co/";
15+
std::string repo = "ggml-org/vocabs";
16+
17+
static void write_file(const std::string & fname, const std::string & content) {
18+
std::ofstream file(fname);
19+
if (file) {
20+
file << content;
21+
file.close();
22+
}
23+
}
24+
25+
static json get_hf_repo_dir(const std::string & hf_repo_with_branch, bool recursive, const std::string & repo_path, const std::string & bearer_token) {
26+
auto parts = string_split<std::string>(hf_repo_with_branch, ':');
27+
std::string branch = parts.size() > 1 ? parts.back() : "main";
28+
std::string hf_repo = parts[0];
29+
std::string url = endpoint + "api/models/" + hf_repo + "/tree/" + branch;
30+
std::string path = repo_path;
31+
32+
if (!path.empty()) {
33+
// FIXME: path should be properly url-encoded!
34+
string_replace_all(path, "/", "%2F");
35+
url += "/" + path;
36+
}
37+
38+
if (recursive) {
39+
url += "?recursive=true";
40+
}
41+
42+
// headers
43+
std::vector<std::string> headers;
44+
headers.push_back("Accept: application/json");
45+
if (!bearer_token.empty()) {
46+
headers.push_back("Authorization: Bearer " + bearer_token);
47+
}
48+
49+
// we use "=" to avoid clashing with other component, while still being allowed on windows
50+
std::string cached_response_fname = "tree=" + hf_repo + "/" + repo_path + "=" + branch + ".json";
51+
string_replace_all(cached_response_fname, "/", "_");
52+
std::string cached_response_path = fs_get_cache_file(cached_response_fname);
53+
54+
// make the request
55+
common_remote_params params;
56+
params.headers = headers;
57+
json res_data;
58+
try {
59+
// TODO: For pagination links we need response headers, which is not provided by common_remote_get_content()
60+
auto res = common_remote_get_content(url, params);
61+
long res_code = res.first;
62+
std::string res_str = std::string(res.second.data(), res.second.size());
63+
64+
if (res_code == 200) {
65+
write_file(cached_response_path, res_str);
66+
} else if (res_code == 401) {
67+
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
68+
} else {
69+
throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
70+
}
71+
} catch (const std::exception & e) {
72+
fprintf(stderr, "error: failed to get repo tree: %s\n", e.what());
73+
fprintf(stderr, "try reading from cache\n");
74+
}
75+
76+
// try to read from cache
77+
try {
78+
std::ifstream f(cached_response_path);
79+
res_data = json::parse(f);
80+
} catch (const std::exception & e) {
81+
fprintf(stderr, "error: failed to get repo tree (check your internet connection)\n");
82+
}
83+
84+
return res_data;
85+
}
86+
87+
int main(void) {
88+
if (common_has_curl()) {
89+
json tree = get_hf_repo_dir(repo, true, {}, {});
90+
91+
if (!tree.empty()) {
92+
std::vector<std::pair<std::string, std::string>> files;
93+
94+
for (const auto & item : tree) {
95+
if (item.at("type") == "file") {
96+
std::string path = item.at("path");
97+
98+
if (string_ends_with(path, ".gguf") || string_ends_with(path, ".gguf.inp") || string_ends_with(path, ".gguf.out")) {
99+
// this is to avoid different repo having same file name, or same file name in different subdirs
100+
std::string filepath = repo + "_" + path;
101+
// to make sure we don't have any slashes in the filename
102+
string_replace_all(filepath, "/", "_");
103+
// to make sure we don't have any quotes in the filename
104+
string_replace_all(filepath, "'", "_");
105+
filepath = fs_get_cache_file(filepath);
106+
107+
files.push_back({endpoint + repo + "/resolve/main/" + path, filepath});
108+
}
109+
}
110+
}
111+
112+
if (common_download_file_multiple(files, {}, false)) {
113+
std::string dir_sep(1, DIRECTORY_SEPARATOR);
114+
115+
for (auto const & item : files) {
116+
std::string filepath = item.second;
117+
118+
if (string_ends_with(filepath, ".gguf")) {
119+
std::string vocab_inp = filepath + ".inp";
120+
std::string vocab_out = filepath + ".out";
121+
auto matching_inp = std::find_if(files.begin(), files.end(), [&vocab_inp](const auto & p) {
122+
return p.second == vocab_inp;
123+
});
124+
auto matching_out = std::find_if(files.begin(), files.end(), [&vocab_out](const auto & p) {
125+
return p.second == vocab_out;
126+
});
127+
128+
if (matching_inp != files.end() && matching_out != files.end()) {
129+
std::string test_command = "." + dir_sep + "test-tokenizer-0 '" + filepath + "'";
130+
assert(std::system(test_command.c_str()) == 0);
131+
} else {
132+
printf("test-tokenizers-remote: %s found without .inp/out vocab files, skipping...\n", filepath.c_str());
133+
}
134+
}
135+
}
136+
} else {
137+
printf("test-tokenizers-remote: failed to download files, unable to perform tests...\n");
138+
}
139+
} else {
140+
printf("test-tokenizers-remote: failed to retrieve repository info, unable to perform tests...\n");
141+
}
142+
} else {
143+
printf("test-tokenizers-remote: no curl, unable to perform tests...\n");
144+
}
145+
}

0 commit comments

Comments
 (0)