1515#include < vector>
1616#include < string>
1717#include < iostream>
18+ #include < memory>
1819
1920#include " omni-vlm-wrapper.h"
2021
21-
2222struct omnivlm_context {
2323 struct clip_ctx * ctx_clip = NULL ;
2424 struct llama_context * ctx_llama = NULL ;
@@ -30,6 +30,50 @@ void* internal_chars = nullptr;
3030static struct common_params params;
3131static struct llama_model * model;
3232static struct omnivlm_context * ctx_omnivlm;
33+ static std::unique_ptr<struct omni_streaming_sample > g_oss = nullptr ;
34+
35+ struct omni_streaming_sample {
36+ struct common_sampler * ctx_sampling_;
37+ std::string image_;
38+ std::string ret_str_;
39+ int32_t n_past_;
40+ int32_t dec_cnt_;
41+
42+ omni_streaming_sample () = delete ;
43+ omni_streaming_sample (const std::string& image)
44+ :image_(image) {
45+ n_past_ = 0 ;
46+ dec_cnt_ = 0 ;
47+ params.sparams .top_k = 1 ;
48+ params.sparams .top_p = 1 .0f ;
49+ ctx_sampling_ = common_sampler_init (model, params.sparams );
50+ }
51+
52+ int32_t sample () {
53+ const llama_token id = common_sampler_sample (ctx_sampling_, ctx_omnivlm->ctx_llama , -1 );
54+ common_sampler_accept (ctx_sampling_, id, true );
55+ if (llama_token_is_eog (llama_get_model (ctx_omnivlm->ctx_llama ), id)) {
56+ ret_str_ = " </s>" ;
57+ } else {
58+ ret_str_ = common_token_to_piece (ctx_omnivlm->ctx_llama , id);
59+ }
60+ eval_id (ctx_omnivlm->ctx_llama , id, &n_past_);
61+
62+ ++dec_cnt_;
63+ return id;
64+ }
65+
66+ ~omni_streaming_sample () {
67+ common_sampler_free (ctx_sampling_);
68+ if (ctx_omnivlm != nullptr ) {
69+ ctx_omnivlm->model = nullptr ;
70+ omnivlm_free (ctx_omnivlm);
71+ free (ctx_omnivlm);
72+ ctx_omnivlm = nullptr ;
73+ }
74+ }
75+ };
76+
3377
3478static struct omni_image_embed * load_image (omnivlm_context * ctx_omnivlm, common_params * params, const std::string & fname) {
3579
@@ -286,3 +330,81 @@ void omnivlm_free() {
286330 }
287331 llama_free_model (model);
288332}
333+
334+
335+ struct omni_streaming_sample * omnivlm_inference_streaming (const char *prompt, const char *imag_path) {
336+ if (g_oss) {
337+ g_oss.reset ();
338+ }
339+ g_oss = std::make_unique<omni_streaming_sample>(std::string (imag_path));
340+
341+ ctx_omnivlm = omnivlm_init_context (¶ms, model);
342+
343+ params.prompt = prompt;
344+
345+ if (params.omni_vlm_version == " vlm-81-ocr" ) {
346+ params.prompt = " <|im_start|>system\n You are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n <|im_start|>user\n <|ocr_start|><|vision_start|><|image_pad|><|vision_end|><|ocr_end|><|im_end|>" ;
347+ } else if (params.omni_vlm_version == " vlm-81-instruct" || params.omni_vlm_version == " nano-vlm-instruct" ) {
348+ params.prompt = " <|im_start|>system\n You are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n <|im_start|>user\n\n <|vision_start|><|image_pad|><|vision_end|>" + params.prompt + " <|im_end|>" ;
349+ } else {
350+ LOG_ERR (" %s : error: you set wrong vlm version info:'%s'.\n " , __func__, params.omni_vlm_version .c_str ());
351+ throw std::runtime_error (" You set wrong vlm_version info strings." );
352+ }
353+
354+ return g_oss.get ();
355+ }
356+
357+ int32_t sample (omni_streaming_sample* oss) {
358+ const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict ;
359+ int32_t ret_id;
360+ if (oss->n_past_ == 0 ) {
361+ auto * image_embed = load_image (ctx_omnivlm, ¶ms, oss->image_ );
362+ if (!image_embed) {
363+ LOG_ERR (" %s: failed to load image %s. Terminating\n\n " , __func__, oss->image_ .c_str ());
364+ throw std::runtime_error (" failed to load image " + oss->image_ );
365+ }
366+
367+ size_t image_pos = params.prompt .find (" <|image_pad|>" );
368+ std::string system_prompt, user_prompt;
369+
370+ system_prompt = params.prompt .substr (0 , image_pos);
371+ user_prompt = params.prompt .substr (image_pos + std::string (" <|image_pad|>" ).length ());
372+ if (params.verbose_prompt ) {
373+ auto tmp = ::common_tokenize (ctx_omnivlm->ctx_llama , system_prompt, true , true );
374+ for (int i = 0 ; i < (int ) tmp.size (); i++) {
375+ LOG_ERR (" %6d -> '%s'\n " , tmp[i], common_token_to_piece (ctx_omnivlm->ctx_llama , tmp[i]).c_str ());
376+ }
377+ }
378+ if (params.verbose_prompt ) {
379+ auto tmp = ::common_tokenize (ctx_omnivlm->ctx_llama , user_prompt, true , true );
380+ for (int i = 0 ; i < (int ) tmp.size (); i++) {
381+ LOG_ERR (" %6d -> '%s'\n " , tmp[i], common_token_to_piece (ctx_omnivlm->ctx_llama , tmp[i]).c_str ());
382+ }
383+ }
384+
385+ eval_string (ctx_omnivlm->ctx_llama , system_prompt.c_str (), params.n_batch , &(oss->n_past_ ), true );
386+ omnivlm_eval_image_embed (ctx_omnivlm->ctx_llama , image_embed, params.n_batch , &(oss->n_past_ ));
387+ eval_string (ctx_omnivlm->ctx_llama , user_prompt.c_str (), params.n_batch , &(oss->n_past_ ), false );
388+
389+ omnivlm_image_embed_free (image_embed);
390+
391+ ret_id = oss->sample ();
392+ if (oss->ret_str_ == " <|im_end|>" || oss->ret_str_ == " </s>" ) {
393+ ret_id = -1 ;
394+ }
395+ } else {
396+ if (oss->dec_cnt_ == max_tgt_len) {
397+ ret_id = -2 ;
398+ } else {
399+ ret_id = oss->sample ();
400+ if (oss->ret_str_ == " <|im_end|>" || oss->ret_str_ == " </s>" ) {
401+ ret_id = -1 ;
402+ }
403+ }
404+ }
405+ return ret_id;
406+ }
407+
408+ const char * get_str (omni_streaming_sample* oss) {
409+ return oss->ret_str_ .c_str ();
410+ }
0 commit comments