|
1 | | -import strutils |
2 | | -import os |
| 1 | +import strutils, std/strformat, std/httpclient, os, json, asyncdispatch |
3 | 2 | import libchatllm |
4 | 3 | import packages/docutils/highlite, terminal |
5 | 4 |
|
6 | 5 | import std/terminal |
7 | 6 | import std/[os, strutils] |
8 | 7 |
|
| 8 | +var all_models: JsonNode = nil |
| 9 | + |
| 10 | +proc get_model_url_on_modelscope(url: seq[string]): string = |
| 11 | + let proj = url[0] |
| 12 | + let fn = url[1] |
| 13 | + let user = if len(url) >= 3: url[2] else: "judd2024" |
| 14 | + |
| 15 | + return fmt"https://modelscope.cn/api/v1/models/{user}/{proj}/repo?Revision=master&FilePath={fn}" |
| 16 | + |
| 17 | +proc parse_model_id(model_id: string): JsonNode = |
| 18 | + let parts = model_id.split(":") |
| 19 | + if all_models == nil: |
| 20 | + all_models = json.parseFile("../scripts/models.json") |
| 21 | + |
| 22 | + let model = all_models[parts[0]] |
| 23 | + let variants = model["variants"] |
| 24 | + let variant = variants[if len(parts) >= 2: parts[1] else: model["default"].getStr()] |
| 25 | + let r = variant["quantized"][variant["default"].getStr()] |
| 26 | + let url = r["url"].getStr().split("/") |
| 27 | + r["url"] = json.newJString(get_model_url_on_modelscope(url)) |
| 28 | + r["fn"] = json.newJString(url[1]) |
| 29 | + return r |
| 30 | + |
| 31 | +proc print_progress_bar(iteration: BiggestInt, total: BiggestInt, prefix = "", suffix = "", decimals = 1, length = 60, fill = "█", printEnd = "\r", auto_nl = true) = |
| 32 | + let percent = formatFloat(100.0 * (iteration.float / total.float), ffDecimal, decimals) |
| 33 | + let filledLength = int(length.float * iteration.float / total.float) |
| 34 | + let bar = fill.repeat(filledLength) & '-'.repeat(length - filledLength) |
| 35 | + stdout.write(fmt"{printEnd}{prefix} |{bar}| {percent}% {suffix}") |
| 36 | + if iteration == total and auto_nl: |
| 37 | + echo "" |
| 38 | + |
| 39 | +proc download_file(url: string, fn: string, prefix: string) = |
| 40 | + echo fmt"Downloading {prefix}" |
| 41 | + let client = newAsyncHttpClient() |
| 42 | + defer: client.close() |
| 43 | + |
| 44 | + proc onProgressChanged(total, progress, speed: BiggestInt) {.async} = |
| 45 | + print_progress_bar(progress, total, prefix) |
| 46 | + |
| 47 | + client.onProgressChanged = onProgressChanged |
| 48 | + client.downloadFile(url, fn).waitFor() |
| 49 | + |
| 50 | +proc get_model(model_id: string; storage_dir: string): string = |
| 51 | + if not os.dirExists(storage_dir): |
| 52 | + os.createDir(storage_dir) |
| 53 | + |
| 54 | + let info = parse_model_id(model_id) |
| 55 | + let fn = joinPath([storage_dir, info["fn"].getStr()]) |
| 56 | + if os.fileExists(fn): |
| 57 | + if os.getFileSize(fn) == info["size"].getBiggestInt(): |
| 58 | + return fn |
| 59 | + else: |
| 60 | + echo(fmt"{fn} is incomplete, download again") |
| 61 | + |
| 62 | + download_file(info["url"].getStr(), fn, model_id) |
| 63 | + assert (os.fileExists(fn)) and (os.getFileSize(fn) == info["size"].getBiggestInt()) |
| 64 | + print_progress_bar(100, 100) |
| 65 | + |
| 66 | + return fn |
| 67 | + |
9 | 68 | type |
10 | 69 | highlighter = object |
11 | 70 | line_acc: string |
@@ -78,10 +137,19 @@ proc chatllm_print(user_data: pointer, print_type: cint, utf8_str: cstring) {.cd |
78 | 137 | proc chatllm_end(user_data: pointer) {.cdecl.} = |
79 | 138 | echo "" |
80 | 139 |
|
| 140 | +const candidates = ["-m", "--model", "--embedding_model", "--reranker_model"] |
| 141 | +var storage_dir: string = "../quantized" |
| 142 | + |
81 | 143 | var ht = highlighter(line_acc: "", lang: langNone) |
82 | 144 | let chat = chatllm_create() |
| 145 | + |
83 | 146 | for i in 1 .. paramCount(): |
84 | | - chatllm_append_param(chat, paramStr(i).cstring) |
| 147 | + if (i > 1) and (paramStr(i - 1) in candidates) and paramStr(i).startsWith(":"): |
| 148 | + var m = paramStr(i) |
| 149 | + m = m[1..<len(m)] |
| 150 | + chatllm_append_param(chat, get_model(m, storage_dir).cstring) |
| 151 | + else: |
| 152 | + chatllm_append_param(chat, paramStr(i).cstring) |
85 | 153 |
|
86 | 154 | let r = chatllm_start(chat, chatllm_print, chatllm_end, addr(ht)) |
87 | 155 | if r != 0: |
|
0 commit comments