1- import tables
1+ import std / [os, tables, json, strformat, strutils, sequtils, httpclient, asyncdispatch, algorithm]
22
33type
44 PrintType * = enum
@@ -388,13 +388,131 @@ proc chatllm_async_text_embedding*(obj: ptr chatllm_obj; utf8_str: cstring; purp
388388proc chatllm_async_qa_rank * (obj: ptr chatllm_obj; utf8_str_q: cstring ;
389389 utf8_str_a: cstring ): cint {.stdcall , dynlib : libName, importc .}
390390
391+ func is_same_command_option (a, b: string ): bool =
392+ if len (a) != len (b): return false
393+ for i in 0 ..< len (a):
394+ var c1 = a[i]
395+ var c2 = b[i]
396+ if c1 == '-' : c1 = '_'
397+ if c2 == '-' : c2 = '_'
398+ if c1 != c2: return false
399+ return true
400+
401+ func is_same_command_option (a: string , options: openArray [string ]): bool =
402+ for s in options:
403+ if a.is_same_command_option (s): return true
404+ return false
405+
406+ var all_models: JsonNode = nil
407+
408+ proc get_model (model_id: string ; storage_dir: string ): string =
409+ if not os.dirExists (storage_dir):
410+ os.createDir (storage_dir)
411+
412+ func calc_wer [T](ref_words, hyp_words: openArray [T]): float =
413+ var d = newSeq [seq [int ]](len (ref_words) + 1 )
414+ for i in 0 ..< len (d): d [i] = newSeq [int ](len (hyp_words) + 1 )
415+
416+ for i in 0 .. len (ref_words):
417+ d[i][0 ] = i
418+ for j in 0 .. len (hyp_words):
419+ d[0 ][j] = j
420+ for i in 1 .. len (ref_words):
421+ for j in 1 .. len (hyp_words):
422+ if ref_words[i - 1 ] == hyp_words[j - 1 ]:
423+ d[i][j] = d[i - 1 ][j - 1 ]
424+ else :
425+ let substitution = d[i - 1 ][j - 1 ] + 1
426+ let insertion = d[i ][j - 1 ] + 1
427+ let deletion = d[i - 1 ][j ] + 1
428+ d[i][j] = min ([substitution, insertion, deletion])
429+ let wer = d[len (ref_words)][len (hyp_words)] / len (ref_words)
430+ return wer
431+
432+ func calc_cer (ref_str, hyp: string ): float = calc_wer (ref_str.toSeq (), hyp.toSeq ())
433+
434+ func find_nearest_item (s: string ; candidates: openArray [string ]): seq [string ] =
435+ var l = candidates.sortedByIt (calc_cer (s, it))
436+ return l[0 ..< min (3 , len (l))]
437+
438+ proc get_model_url_on_modelscope (url: seq [string ]): string =
439+ let proj = url[0 ]
440+ let fn = url[1 ]
441+ let user = if len (url) >= 3 : url[2 ] else : " judd2024"
442+
443+ return fmt" https://modelscope.cn/api/v1/models/{ user} /{ proj} /repo?Revision=master&FilePath={ fn} "
444+
445+ proc print_progress_bar (iteration: BiggestInt , total: BiggestInt , prefix = " " , suffix = " " , decimals = 1 , length = 60 , fill = " █" , printEnd = " \r " , auto_nl = true ) =
446+ let percent = formatFloat (100.0 * (iteration.float / total.float ), ffDecimal, decimals)
447+ let filledLength = int (length.float * iteration.float / total.float )
448+ let bar = fill.repeat (filledLength) & '-' .repeat (length - filledLength)
449+ stdout.write (fmt" { printEnd} { prefix} |{ bar} | { percent} % { suffix} " )
450+ if iteration == total and auto_nl:
451+ echo " "
452+
453+ proc parse_model_id (model_id: string ): JsonNode =
454+ let parts = model_id.split (" :" )
455+ if all_models == nil :
456+ let fn = joinPath ([parentDir (paramStr (0 )), " ../scripts/models.json" ])
457+ const compiled_file = readFile (" ../scripts/models.json" )
458+ all_models = if fileExists (fn): json.parseFile (fn) else : json.parseJson (compiled_file)
459+
460+ let id = parts[0 ]
461+ if not all_models.contains (id):
462+ let guess = find_nearest_item (id, all_models.keys ().toSeq ())
463+ raise newException (ValueError , fmt""" `{ id} ` is recognized as a model id. Did you mean something like `{ guess.join (" , " )} `? """ )
464+ return nil
465+ let model = all_models[id]
466+ let variants = model[" variants" ]
467+ let variant = variants[if len (parts) >= 2 : parts[1 ] else : model[" default" ].getStr ()]
468+ let r = variant[" quantized" ][variant[" default" ].getStr ()].copy ()
469+ let url = r[" url" ].getStr ().split (" /" )
470+ r[" url" ] = json.newJString (get_model_url_on_modelscope (url))
471+ r[" fn" ] = json.newJString (url[1 ])
472+ return r
473+
474+ proc download_file (url: string , fn: string , prefix: string ) =
475+ echo fmt" Downloading { prefix} "
476+ let client = newAsyncHttpClient ()
477+ defer : client.close ()
478+
479+ proc onProgressChanged (total, progress, speed: BiggestInt ) {.async } =
480+ print_progress_bar (progress, total, prefix)
481+
482+ client.onProgressChanged = onProgressChanged
483+ client.downloadFile (url, fn).waitFor ()
484+
485+ let info = parse_model_id (model_id)
486+ assert info != nil , fmt" unknown model id { model_id} "
487+
488+ let fn = joinPath ([storage_dir, info[" fn" ].getStr ()])
489+ if os.fileExists (fn):
490+ if os.getFileSize (fn) == info[" size" ].getBiggestInt ():
491+ return fn
492+ else :
493+ echo (fmt" { fn} is incomplete, download again" )
494+
495+ download_file (info[" url" ].getStr (), fn, model_id)
496+ assert (os.fileExists (fn)) and (os.getFileSize (fn) == info[" size" ].getBiggestInt ())
497+ print_progress_bar (100 , 100 )
498+
499+ return fn
500+
391501# # Streamer in OOP style
392502type
503+ FrontendOptions * = object
504+ help* : bool = false
505+ interactive* : bool = false
506+ reversed_role* : bool = false
507+ use_multiple_lines* : bool = false
508+ prompt* : string
509+ sys_prompt* : string
510+
393511 StreamerMessageType = enum
394512 Done = 0 ,
395513 Chunk = 1 ,
396514 ThoughtChunk = 2 ,
397- Meta = 3 ,
515+ ThoughtDone = 3 ,
398516
399517 StreamerMessage = tuple [t: StreamerMessageType , chunk: string ]
400518
419537 result_token_ids* : string
420538 result_beam_search: seq [string ]
421539 model_info* : string
540+ fe_options* : FrontendOptions
422541 chan_output: Channel [StreamerMessage ]
423542
424543var streamer_dict = initTable [int , Streamer ]()
@@ -438,16 +557,16 @@ method on_error(streamer: Streamer, text: string) {.base.} =
438557method on_thought_completed (streamer: Streamer ) {.base .} =
439558 discard
440559
441- method on_async_completed (streamer: Streamer ) {.base .} =
442- streamer.chan_output. send ((t: StreamerMessageType . Done , chunk: " " ))
560+ method on_print_meta (streamer: Streamer , text: string ) {.base .} =
561+ discard
443562
444563proc streamer_on_print (user_data: pointer , print_type: cint , utf8_str: cstring ) {.cdecl .} =
445564 var streamer = get_streamer (user_data)
446565 case cast [PrintType ](print_type):
447566 of PrintType .PRINT_CHAT_CHUNK :
448567 streamer.chan_output.send ((t: StreamerMessageType .Chunk , chunk: $ utf8_str))
449568 of PrintType .PRINTLN_META :
450- streamer.chan_output. send ((t: StreamerMessageType . Meta , chunk: $ utf8_str))
569+ streamer.on_print_meta $ utf8_str
451570 of PrintType .PRINTLN_ERROR :
452571 on_error (streamer, $ utf8_str)
453572 of PrintType .PRINTLN_REF :
@@ -475,18 +594,24 @@ proc streamer_on_print(user_data: pointer, print_type: cint, utf8_str: cstring)
475594 of PrintType .PRINT_THOUGHT_CHUNK :
476595 streamer.chan_output.send ((t: StreamerMessageType .ThoughtChunk , chunk: $ utf8_str))
477596 of PrintType .PRINT_EVT_ASYNC_COMPLETED :
478- on_async_completed ( streamer)
597+ streamer.chan_output. send ((t: StreamerMessageType . Done , chunk: " " ) )
479598 of PrintType .PRINT_EVT_THOUGHT_COMPLETED :
480- on_thought_completed ( streamer)
599+ streamer.chan_output. send ((t: StreamerMessageType . ThoughtDone , chunk: " " ) )
481600
482601proc streamer_on_end (user_data: pointer ) {.cdecl .} =
483602 var streamer = get_streamer (user_data)
484603 streamer.is_generating = false
485604
486605proc initStreamer * (streamer: Streamer ; args: openArray [string ], auto_restart: bool = false ): bool =
606+ const candidates = [" -m" , " --model" , " --embedding_model" , " --reranker_model" ]
607+
608+ var storage_dir = getEnv (" CHATLLM_QUANTIZED_MODEL_PATH" )
609+ if storage_dir == " " :
610+ storage_dir = joinPath ([parentDir (paramStr (0 )), " ../quantized" ])
611+
487612 let id = streamer_dict.len + 1
488613 streamer_dict[id] = streamer
489- streamer.llm = chatllm_create ()
614+
490615 streamer.chan_output.open ()
491616 streamer.system_prompt = " "
492617 streamer.system_prompt_updating = false
@@ -499,7 +624,41 @@ proc initStreamer*(streamer: Streamer; args: openArray[string], auto_restart: bo
499624 streamer.result_ranking = " "
500625 streamer.result_token_ids = " "
501626 streamer.model_info = " "
502- for s in args:
627+
628+ var args_pp = newSeq [string ]()
629+ var i = 0
630+ while i < len (args):
631+ let s = args[i]
632+ if s.is_same_command_option ([" -h" , " --help" ]):
633+ streamer.fe_options.help = true
634+ break
635+ elif s.is_same_command_option ([" -i" , " --interactive" ]):
636+ streamer.fe_options.interactive = true
637+ elif s.is_same_command_option (" --reversed_role" ):
638+ streamer.fe_options.reversed_role = true
639+ elif s.is_same_command_option (" --multi" ):
640+ streamer.fe_options.use_multiple_lines = true
641+ elif s.is_same_command_option ([" -p" , " --prompt" ]):
642+ inc i
643+ if i < len (args): streamer.fe_options.prompt = args[i]
644+ elif s.is_same_command_option ([" -s" , " --system" ]):
645+ inc i
646+ if i < len (args): streamer.fe_options.sys_prompt = args[i]
647+ else :
648+ args_pp.add s
649+ if s.is_same_command_option (candidates):
650+ inc i
651+ if i >= len (args): break
652+ if args[i][0 ] == ':' :
653+ args_pp.add get_model (args[i][1 ..^ 1 ], storage_dir)
654+ else :
655+ args_pp.add args[i]
656+ inc i
657+
658+ if streamer.fe_options.help: return true
659+
660+ streamer.llm = chatllm_create ()
661+ for s in args_pp:
503662 chatllm_append_param (streamer.llm, s.cstring )
504663
505664 let r = chatllm_start (streamer.llm, streamer_on_print, streamer_on_end, cast [pointer ](id))
@@ -560,8 +719,8 @@ iterator chunks*(streamer: Streamer): tuple[t: ChunkType; chunk: string] =
560719 yield (t: ChunkType .Thought , chunk: msg.chunk)
561720 of StreamerMessageType .Done :
562721 break
563- of StreamerMessageType .Meta :
564- discard
722+ of StreamerMessageType .ThoughtDone :
723+ streamer. on_thought_completed ()
565724
566725proc set_max_gen_tokens * (streamer: Streamer , max_new_tokens: int ) =
567726 chatllm_set_gen_max_tokens (streamer.llm, cint (max_new_tokens))
0 commit comments