|
| 1 | +#include "arg.h" |
| 2 | +#include "common.h" |
| 3 | + |
| 4 | +#include <string> |
| 5 | +#include <fstream> |
| 6 | +#include <vector> |
| 7 | +#include <json.hpp> |
| 8 | + |
| 9 | +using json = nlohmann::json; |
| 10 | + |
| 11 | +#undef NDEBUG |
| 12 | +#include <cassert> |
| 13 | + |
| 14 | +std::string endpoint = "https://huggingface.co/"; |
| 15 | +std::string repo = "ggml-org/vocabs"; |
| 16 | + |
| 17 | +static void write_file(const std::string & fname, const std::string & content) { |
| 18 | + std::ofstream file(fname); |
| 19 | + if (file) { |
| 20 | + file << content; |
| 21 | + file.close(); |
| 22 | + } |
| 23 | +} |
| 24 | + |
| 25 | +static json get_hf_repo_dir(const std::string & hf_repo_with_branch, bool recursive, const std::string & repo_path, const std::string & bearer_token) { |
| 26 | + auto parts = string_split<std::string>(hf_repo_with_branch, ':'); |
| 27 | + std::string branch = parts.size() > 1 ? parts.back() : "main"; |
| 28 | + std::string hf_repo = parts[0]; |
| 29 | + std::string url = endpoint + "api/models/" + hf_repo + "/tree/" + branch; |
| 30 | + std::string path = repo_path; |
| 31 | + |
| 32 | + if (!path.empty()) { |
| 33 | + // FIXME: path should be properly url-encoded! |
| 34 | + string_replace_all(path, "/", "%2F"); |
| 35 | + url += "/" + path; |
| 36 | + } |
| 37 | + |
| 38 | + if (recursive) { |
| 39 | + url += "?recursive=true"; |
| 40 | + } |
| 41 | + |
| 42 | + // headers |
| 43 | + std::vector<std::string> headers; |
| 44 | + headers.push_back("Accept: application/json"); |
| 45 | + if (!bearer_token.empty()) { |
| 46 | + headers.push_back("Authorization: Bearer " + bearer_token); |
| 47 | + } |
| 48 | + |
| 49 | + // we use "=" to avoid clashing with other component, while still being allowed on windows |
| 50 | + std::string cached_response_fname = "tree=" + hf_repo + "/" + repo_path + "=" + branch + ".json"; |
| 51 | + string_replace_all(cached_response_fname, "/", "_"); |
| 52 | + std::string cached_response_path = fs_get_cache_file(cached_response_fname); |
| 53 | + |
| 54 | + // make the request |
| 55 | + common_remote_params params; |
| 56 | + params.headers = headers; |
| 57 | + json res_data; |
| 58 | + try { |
| 59 | + // TODO: For pagination links we need response headers, which is not provided by common_remote_get_content() |
| 60 | + auto res = common_remote_get_content(url, params); |
| 61 | + long res_code = res.first; |
| 62 | + std::string res_str = std::string(res.second.data(), res.second.size()); |
| 63 | + |
| 64 | + if (res_code == 200) { |
| 65 | + write_file(cached_response_path, res_str); |
| 66 | + } else if (res_code == 401) { |
| 67 | + throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); |
| 68 | + } else { |
| 69 | + throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str())); |
| 70 | + } |
| 71 | + } catch (const std::exception & e) { |
| 72 | + fprintf(stderr, "error: failed to get repo tree: %s\n", e.what()); |
| 73 | + fprintf(stderr, "try reading from cache\n"); |
| 74 | + } |
| 75 | + |
| 76 | + // try to read from cache |
| 77 | + try { |
| 78 | + std::ifstream f(cached_response_path); |
| 79 | + res_data = json::parse(f); |
| 80 | + } catch (const std::exception & e) { |
| 81 | + fprintf(stderr, "error: failed to get repo tree (check your internet connection)\n"); |
| 82 | + } |
| 83 | + |
| 84 | + return res_data; |
| 85 | +} |
| 86 | + |
| 87 | +int main(void) { |
| 88 | + if (common_has_curl()) { |
| 89 | + json tree = get_hf_repo_dir(repo, true, {}, {}); |
| 90 | + |
| 91 | + if (!tree.empty()) { |
| 92 | + std::vector<std::pair<std::string, std::string>> files; |
| 93 | + |
| 94 | + for (const auto & item : tree) { |
| 95 | + if (item.at("type") == "file") { |
| 96 | + std::string path = item.at("path"); |
| 97 | + |
| 98 | + if (string_ends_with(path, ".gguf") || string_ends_with(path, ".gguf.inp") || string_ends_with(path, ".gguf.out")) { |
| 99 | + // this is to avoid different repo having same file name, or same file name in different subdirs |
| 100 | + std::string filepath = repo + "_" + path; |
| 101 | + // to make sure we don't have any slashes in the filename |
| 102 | + string_replace_all(filepath, "/", "_"); |
| 103 | + // to make sure we don't have any quotes in the filename |
| 104 | + string_replace_all(filepath, "'", "_"); |
| 105 | + filepath = fs_get_cache_file(filepath); |
| 106 | + |
| 107 | + files.push_back({endpoint + repo + "/resolve/main/" + path, filepath}); |
| 108 | + } |
| 109 | + } |
| 110 | + } |
| 111 | + |
| 112 | + if (common_download_file_multiple(files, {}, false)) { |
| 113 | + std::string dir_sep(1, DIRECTORY_SEPARATOR); |
| 114 | + |
| 115 | + for (auto const & item : files) { |
| 116 | + std::string filepath = item.second; |
| 117 | + |
| 118 | + if (string_ends_with(filepath, ".gguf")) { |
| 119 | + std::string vocab_inp = filepath + ".inp"; |
| 120 | + std::string vocab_out = filepath + ".out"; |
| 121 | + auto matching_inp = std::find_if(files.begin(), files.end(), [&vocab_inp](const auto & p) { |
| 122 | + return p.second == vocab_inp; |
| 123 | + }); |
| 124 | + auto matching_out = std::find_if(files.begin(), files.end(), [&vocab_out](const auto & p) { |
| 125 | + return p.second == vocab_out; |
| 126 | + }); |
| 127 | + |
| 128 | + if (matching_inp != files.end() && matching_out != files.end()) { |
| 129 | + std::string test_command = "." + dir_sep + "test-tokenizer-0 '" + filepath + "'"; |
| 130 | + assert(std::system(test_command.c_str()) == 0); |
| 131 | + } else { |
| 132 | + printf("test-tokenizers-remote: %s found without .inp/out vocab files, skipping...\n", filepath.c_str()); |
| 133 | + } |
| 134 | + } |
| 135 | + } |
| 136 | + } else { |
| 137 | + printf("test-tokenizers-remote: failed to download files, unable to perform tests...\n"); |
| 138 | + } |
| 139 | + } else { |
| 140 | + printf("test-tokenizers-remote: failed to retrieve repository info, unable to perform tests...\n"); |
| 141 | + } |
| 142 | + } else { |
| 143 | + printf("test-tokenizers-remote: no curl, unable to perform tests...\n"); |
| 144 | + } |
| 145 | +} |
0 commit comments