Skip to content

Commit 0ee543a

Browse files
committed
update nim binding; upload some models.
1 parent 6ce7897 commit 0ee543a

File tree

3 files changed

+291
-149
lines changed

3 files changed

+291
-149
lines changed

bindings/libchatllm.nim

Lines changed: 170 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import tables
1+
import std/[os, tables, json, strformat, strutils, sequtils, httpclient, asyncdispatch, algorithm]
22

33
type
44
PrintType* = enum
@@ -388,13 +388,131 @@ proc chatllm_async_text_embedding*(obj: ptr chatllm_obj; utf8_str: cstring; purp
388388
proc chatllm_async_qa_rank*(obj: ptr chatllm_obj; utf8_str_q: cstring;
389389
utf8_str_a: cstring): cint {.stdcall, dynlib: libName, importc.}
390390

391+
func is_same_command_option(a, b: string): bool =
392+
if len(a) != len(b): return false
393+
for i in 0 ..< len(a):
394+
var c1 = a[i]
395+
var c2 = b[i]
396+
if c1 == '-': c1 = '_'
397+
if c2 == '-': c2 = '_'
398+
if c1 != c2: return false
399+
return true
400+
401+
func is_same_command_option(a: string, options: openArray[string]): bool =
402+
for s in options:
403+
if a.is_same_command_option(s): return true
404+
return false
405+
406+
var all_models: JsonNode = nil
407+
408+
proc get_model(model_id: string; storage_dir: string): string =
409+
if not os.dirExists(storage_dir):
410+
os.createDir(storage_dir)
411+
412+
func calc_wer[T](ref_words, hyp_words: openArray[T]): float =
413+
var d = newSeq[seq[int]](len(ref_words) + 1)
414+
for i in 0 ..< len(d): d[i] = newSeq[int](len(hyp_words) + 1)
415+
416+
for i in 0..len(ref_words):
417+
d[i][0] = i
418+
for j in 0..len(hyp_words):
419+
d[0][j] = j
420+
for i in 1..len(ref_words):
421+
for j in 1..len(hyp_words):
422+
if ref_words[i - 1] == hyp_words[j - 1]:
423+
d[i][j] = d[i - 1][j - 1]
424+
else:
425+
let substitution = d[i - 1][j - 1] + 1
426+
let insertion = d[i ][j - 1] + 1
427+
let deletion = d[i - 1][j ] + 1
428+
d[i][j] = min([substitution, insertion, deletion])
429+
let wer = d[len(ref_words)][len(hyp_words)] / len(ref_words)
430+
return wer
431+
432+
func calc_cer(ref_str, hyp: string): float = calc_wer(ref_str.toSeq(), hyp.toSeq())
433+
434+
func find_nearest_item(s: string; candidates: openArray[string]): seq[string] =
435+
var l = candidates.sortedByIt(calc_cer(s, it))
436+
return l[0 ..< min(3, len(l))]
437+
438+
proc get_model_url_on_modelscope(url: seq[string]): string =
439+
let proj = url[0]
440+
let fn = url[1]
441+
let user = if len(url) >= 3: url[2] else: "judd2024"
442+
443+
return fmt"https://modelscope.cn/api/v1/models/{user}/{proj}/repo?Revision=master&FilePath={fn}"
444+
445+
proc print_progress_bar(iteration: BiggestInt, total: BiggestInt, prefix = "", suffix = "", decimals = 1, length = 60, fill = "", printEnd = "\r", auto_nl = true) =
446+
let percent = formatFloat(100.0 * (iteration.float / total.float), ffDecimal, decimals)
447+
let filledLength = int(length.float * iteration.float / total.float)
448+
let bar = fill.repeat(filledLength) & '-'.repeat(length - filledLength)
449+
stdout.write(fmt"{printEnd}{prefix} |{bar}| {percent}% {suffix}")
450+
if iteration == total and auto_nl:
451+
echo ""
452+
453+
proc parse_model_id(model_id: string): JsonNode =
454+
let parts = model_id.split(":")
455+
if all_models == nil:
456+
let fn = joinPath([parentDir(paramStr(0)), "../scripts/models.json"])
457+
const compiled_file = readFile("../scripts/models.json")
458+
all_models = if fileExists(fn): json.parseFile(fn) else: json.parseJson(compiled_file)
459+
460+
let id = parts[0]
461+
if not all_models.contains(id):
462+
let guess = find_nearest_item(id, all_models.keys().toSeq())
463+
raise newException(ValueError, fmt"""`{id}` is recognized as a model id. Did you mean something like `{guess.join(", ")}`?""")
464+
return nil
465+
let model = all_models[id]
466+
let variants = model["variants"]
467+
let variant = variants[if len(parts) >= 2: parts[1] else: model["default"].getStr()]
468+
let r = variant["quantized"][variant["default"].getStr()].copy()
469+
let url = r["url"].getStr().split("/")
470+
r["url"] = json.newJString(get_model_url_on_modelscope(url))
471+
r["fn"] = json.newJString(url[1])
472+
return r
473+
474+
proc download_file(url: string, fn: string, prefix: string) =
475+
echo fmt"Downloading {prefix}"
476+
let client = newAsyncHttpClient()
477+
defer: client.close()
478+
479+
proc onProgressChanged(total, progress, speed: BiggestInt) {.async} =
480+
print_progress_bar(progress, total, prefix)
481+
482+
client.onProgressChanged = onProgressChanged
483+
client.downloadFile(url, fn).waitFor()
484+
485+
let info = parse_model_id(model_id)
486+
assert info != nil, fmt"unknown model id {model_id}"
487+
488+
let fn = joinPath([storage_dir, info["fn"].getStr()])
489+
if os.fileExists(fn):
490+
if os.getFileSize(fn) == info["size"].getBiggestInt():
491+
return fn
492+
else:
493+
echo(fmt"{fn} is incomplete, download again")
494+
495+
download_file(info["url"].getStr(), fn, model_id)
496+
assert (os.fileExists(fn)) and (os.getFileSize(fn) == info["size"].getBiggestInt())
497+
print_progress_bar(100, 100)
498+
499+
return fn
500+
391501
## Streamer in OOP style
392502
type
503+
FrontendOptions* = object
504+
help*: bool = false
505+
interactive*: bool = false
506+
reversed_role*: bool = false
507+
use_multiple_lines*: bool = false
508+
prompt*: string
509+
sys_prompt*: string
510+
393511
StreamerMessageType = enum
394512
Done = 0,
395513
Chunk = 1,
396514
ThoughtChunk = 2,
397-
Meta = 3,
515+
ThoughtDone = 3,
398516

399517
StreamerMessage = tuple[t: StreamerMessageType, chunk: string]
400518

@@ -419,6 +537,7 @@ type
419537
result_token_ids*: string
420538
result_beam_search: seq[string]
421539
model_info*: string
540+
fe_options*: FrontendOptions
422541
chan_output: Channel[StreamerMessage]
423542

424543
var streamer_dict = initTable[int, Streamer]()
@@ -438,16 +557,16 @@ method on_error(streamer: Streamer, text: string) {.base.} =
438557
method on_thought_completed(streamer: Streamer) {.base.} =
439558
discard
440559

441-
method on_async_completed(streamer: Streamer) {.base.} =
442-
streamer.chan_output.send((t: StreamerMessageType.Done, chunk: ""))
560+
method on_print_meta(streamer: Streamer, text: string) {.base.} =
561+
discard
443562

444563
proc streamer_on_print(user_data: pointer, print_type: cint, utf8_str: cstring) {.cdecl.} =
445564
var streamer = get_streamer(user_data)
446565
case cast[PrintType](print_type):
447566
of PrintType.PRINT_CHAT_CHUNK:
448567
streamer.chan_output.send((t: StreamerMessageType.Chunk, chunk: $utf8_str))
449568
of PrintType.PRINTLN_META:
450-
streamer.chan_output.send((t: StreamerMessageType.Meta, chunk: $utf8_str))
569+
streamer.on_print_meta $utf8_str
451570
of PrintType.PRINTLN_ERROR:
452571
on_error(streamer, $utf8_str)
453572
of PrintType.PRINTLN_REF:
@@ -475,18 +594,24 @@ proc streamer_on_print(user_data: pointer, print_type: cint, utf8_str: cstring)
475594
of PrintType.PRINT_THOUGHT_CHUNK:
476595
streamer.chan_output.send((t: StreamerMessageType.ThoughtChunk, chunk: $utf8_str))
477596
of PrintType.PRINT_EVT_ASYNC_COMPLETED:
478-
on_async_completed(streamer)
597+
streamer.chan_output.send((t: StreamerMessageType.Done, chunk: ""))
479598
of PrintType.PRINT_EVT_THOUGHT_COMPLETED:
480-
on_thought_completed(streamer)
599+
streamer.chan_output.send((t: StreamerMessageType.ThoughtDone, chunk: ""))
481600

482601
proc streamer_on_end(user_data: pointer) {.cdecl.} =
483602
var streamer = get_streamer(user_data)
484603
streamer.is_generating = false
485604

486605
proc initStreamer*(streamer: Streamer; args: openArray[string], auto_restart: bool = false): bool =
606+
const candidates = ["-m", "--model", "--embedding_model", "--reranker_model"]
607+
608+
var storage_dir = getEnv("CHATLLM_QUANTIZED_MODEL_PATH")
609+
if storage_dir == "":
610+
storage_dir = joinPath([parentDir(paramStr(0)), "../quantized"])
611+
487612
let id = streamer_dict.len + 1
488613
streamer_dict[id] = streamer
489-
streamer.llm = chatllm_create()
614+
490615
streamer.chan_output.open()
491616
streamer.system_prompt = ""
492617
streamer.system_prompt_updating = false
@@ -499,7 +624,41 @@ proc initStreamer*(streamer: Streamer; args: openArray[string], auto_restart: bo
499624
streamer.result_ranking = ""
500625
streamer.result_token_ids = ""
501626
streamer.model_info = ""
502-
for s in args:
627+
628+
var args_pp = newSeq[string]()
629+
var i = 0
630+
while i < len(args):
631+
let s = args[i]
632+
if s.is_same_command_option(["-h", "--help"]):
633+
streamer.fe_options.help = true
634+
break
635+
elif s.is_same_command_option(["-i", "--interactive"]):
636+
streamer.fe_options.interactive = true
637+
elif s.is_same_command_option("--reversed_role"):
638+
streamer.fe_options.reversed_role = true
639+
elif s.is_same_command_option("--multi"):
640+
streamer.fe_options.use_multiple_lines = true
641+
elif s.is_same_command_option(["-p", "--prompt"]):
642+
inc i
643+
if i < len(args): streamer.fe_options.prompt = args[i]
644+
elif s.is_same_command_option(["-s", "--system"]):
645+
inc i
646+
if i < len(args): streamer.fe_options.sys_prompt = args[i]
647+
else:
648+
args_pp.add s
649+
if s.is_same_command_option(candidates):
650+
inc i
651+
if i >= len(args): break
652+
if args[i][0] == ':':
653+
args_pp.add get_model(args[i][1..^1], storage_dir)
654+
else:
655+
args_pp.add args[i]
656+
inc i
657+
658+
if streamer.fe_options.help: return true
659+
660+
streamer.llm = chatllm_create()
661+
for s in args_pp:
503662
chatllm_append_param(streamer.llm, s.cstring)
504663

505664
let r = chatllm_start(streamer.llm, streamer_on_print, streamer_on_end, cast[pointer](id))
@@ -560,8 +719,8 @@ iterator chunks*(streamer: Streamer): tuple[t: ChunkType; chunk: string] =
560719
yield (t: ChunkType.Thought, chunk: msg.chunk)
561720
of StreamerMessageType.Done:
562721
break
563-
of StreamerMessageType.Meta:
564-
discard
722+
of StreamerMessageType.ThoughtDone:
723+
streamer.on_thought_completed()
565724

566725
proc set_max_gen_tokens*(streamer: Streamer, max_new_tokens: int) =
567726
chatllm_set_gen_max_tokens(streamer.llm, cint(max_new_tokens))

0 commit comments

Comments
 (0)