Skip to content

Commit 62c9d8c

Browse files
author
Judd
committed
add model downloader to main.nim
1 parent 49c2b41 commit 62c9d8c

File tree

1 file changed

+71
-3
lines changed

1 file changed

+71
-3
lines changed

bindings/main.nim

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,70 @@
1-
import strutils
2-
import os
1+
import strutils, std/strformat, std/httpclient, os, json, asyncdispatch
32
import libchatllm
43
import packages/docutils/highlite, terminal
54

65
import std/terminal
76
import std/[os, strutils]
87

8+
var all_models: JsonNode = nil
9+
10+
proc get_model_url_on_modelscope(url: seq[string]): string =
11+
let proj = url[0]
12+
let fn = url[1]
13+
let user = if len(url) >= 3: url[2] else: "judd2024"
14+
15+
return fmt"https://modelscope.cn/api/v1/models/{user}/{proj}/repo?Revision=master&FilePath={fn}"
16+
17+
proc parse_model_id(model_id: string): JsonNode =
18+
let parts = model_id.split(":")
19+
if all_models == nil:
20+
all_models = json.parseFile("../scripts/models.json")
21+
22+
let model = all_models[parts[0]]
23+
let variants = model["variants"]
24+
let variant = variants[if len(parts) >= 2: parts[1] else: model["default"].getStr()]
25+
let r = variant["quantized"][variant["default"].getStr()]
26+
let url = r["url"].getStr().split("/")
27+
r["url"] = json.newJString(get_model_url_on_modelscope(url))
28+
r["fn"] = json.newJString(url[1])
29+
return r
30+
31+
proc print_progress_bar(iteration: BiggestInt, total: BiggestInt, prefix = "", suffix = "", decimals = 1, length = 60, fill = "", printEnd = "\r", auto_nl = true) =
32+
let percent = formatFloat(100.0 * (iteration.float / total.float), ffDecimal, decimals)
33+
let filledLength = int(length.float * iteration.float / total.float)
34+
let bar = fill.repeat(filledLength) & '-'.repeat(length - filledLength)
35+
stdout.write(fmt"{printEnd}{prefix} |{bar}| {percent}% {suffix}")
36+
if iteration == total and auto_nl:
37+
echo ""
38+
39+
proc download_file(url: string, fn: string, prefix: string) =
40+
echo fmt"Downloading {prefix}"
41+
let client = newAsyncHttpClient()
42+
defer: client.close()
43+
44+
proc onProgressChanged(total, progress, speed: BiggestInt) {.async} =
45+
print_progress_bar(progress, total, prefix)
46+
47+
client.onProgressChanged = onProgressChanged
48+
client.downloadFile(url, fn).waitFor()
49+
50+
proc get_model(model_id: string; storage_dir: string): string =
51+
if not os.dirExists(storage_dir):
52+
os.createDir(storage_dir)
53+
54+
let info = parse_model_id(model_id)
55+
let fn = joinPath([storage_dir, info["fn"].getStr()])
56+
if os.fileExists(fn):
57+
if os.getFileSize(fn) == info["size"].getBiggestInt():
58+
return fn
59+
else:
60+
echo(fmt"{fn} is incomplete, download again")
61+
62+
download_file(info["url"].getStr(), fn, model_id)
63+
assert (os.fileExists(fn)) and (os.getFileSize(fn) == info["size"].getBiggestInt())
64+
print_progress_bar(100, 100)
65+
66+
return fn
67+
968
type
1069
highlighter = object
1170
line_acc: string
@@ -78,10 +137,19 @@ proc chatllm_print(user_data: pointer, print_type: cint, utf8_str: cstring) {.cd
78137
proc chatllm_end(user_data: pointer) {.cdecl.} =
79138
echo ""
80139

140+
const candidates = ["-m", "--model", "--embedding_model", "--reranker_model"]
141+
var storage_dir: string = "../quantized"
142+
81143
var ht = highlighter(line_acc: "", lang: langNone)
82144
let chat = chatllm_create()
145+
83146
for i in 1 .. paramCount():
84-
chatllm_append_param(chat, paramStr(i).cstring)
147+
if (i > 1) and (paramStr(i - 1) in candidates) and paramStr(i).startsWith(":"):
148+
var m = paramStr(i)
149+
m = m[1..<len(m)]
150+
chatllm_append_param(chat, get_model(m, storage_dir).cstring)
151+
else:
152+
chatllm_append_param(chat, paramStr(i).cstring)
85153

86154
let r = chatllm_start(chat, chatllm_print, chatllm_end, addr(ht))
87155
if r != 0:

0 commit comments

Comments
 (0)