Skip to content

Commit 7aa97b9

Browse files
committed
omni vlm add streaming
1 parent 5962b50 commit 7aa97b9

File tree

2 files changed

+133
-3
lines changed

2 files changed

+133
-3
lines changed

examples/omni-vlm/omni-vlm-wrapper.cpp

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
#include <vector>
1616
#include <string>
1717
#include <iostream>
18+
#include <memory>
1819

1920
#include "omni-vlm-wrapper.h"
2021

21-
2222
struct omnivlm_context {
2323
struct clip_ctx * ctx_clip = NULL;
2424
struct llama_context * ctx_llama = NULL;
@@ -30,6 +30,50 @@ void* internal_chars = nullptr;
3030
static struct common_params params;
3131
static struct llama_model* model;
3232
static struct omnivlm_context* ctx_omnivlm;
33+
static std::unique_ptr<struct omni_streaming_sample> g_oss = nullptr;
34+
35+
struct omni_streaming_sample {
36+
struct common_sampler * ctx_sampling_;
37+
std::string image_;
38+
std::string ret_str_;
39+
int32_t n_past_;
40+
int32_t dec_cnt_;
41+
42+
omni_streaming_sample() = delete;
43+
omni_streaming_sample(const std::string& image)
44+
:image_(image) {
45+
n_past_ = 0;
46+
dec_cnt_ = 0;
47+
params.sparams.top_k = 1;
48+
params.sparams.top_p = 1.0f;
49+
ctx_sampling_ = common_sampler_init(model, params.sparams);
50+
}
51+
52+
int32_t sample() {
53+
const llama_token id = common_sampler_sample(ctx_sampling_, ctx_omnivlm->ctx_llama, -1);
54+
common_sampler_accept(ctx_sampling_, id, true);
55+
if (llama_token_is_eog(llama_get_model(ctx_omnivlm->ctx_llama), id)) {
56+
ret_str_ = "</s>";
57+
} else {
58+
ret_str_ = common_token_to_piece(ctx_omnivlm->ctx_llama, id);
59+
}
60+
eval_id(ctx_omnivlm->ctx_llama, id, &n_past_);
61+
62+
++dec_cnt_;
63+
return id;
64+
}
65+
66+
~omni_streaming_sample() {
67+
common_sampler_free(ctx_sampling_);
68+
if(ctx_omnivlm != nullptr) {
69+
ctx_omnivlm->model = nullptr;
70+
omnivlm_free(ctx_omnivlm);
71+
free(ctx_omnivlm);
72+
ctx_omnivlm = nullptr;
73+
}
74+
}
75+
};
76+
3377

3478
static struct omni_image_embed * load_image(omnivlm_context * ctx_omnivlm, common_params * params, const std::string & fname) {
3579

@@ -286,3 +330,81 @@ void omnivlm_free() {
286330
}
287331
llama_free_model(model);
288332
}
333+
334+
335+
struct omni_streaming_sample* omnivlm_inference_streaming(const char *prompt, const char *imag_path) {
336+
if (g_oss) {
337+
g_oss.reset();
338+
}
339+
g_oss = std::make_unique<omni_streaming_sample>(std::string(imag_path));
340+
341+
ctx_omnivlm = omnivlm_init_context(&params, model);
342+
343+
params.prompt = prompt;
344+
345+
if (params.omni_vlm_version == "vlm-81-ocr") {
346+
params.prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n <|ocr_start|><|vision_start|><|image_pad|><|vision_end|><|ocr_end|><|im_end|>";
347+
} else if (params.omni_vlm_version == "vlm-81-instruct" || params.omni_vlm_version == "nano-vlm-instruct") {
348+
params.prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n\n<|vision_start|><|image_pad|><|vision_end|>" + params.prompt + "<|im_end|>";
349+
} else {
350+
LOG_ERR("%s : error: you set wrong vlm version info:'%s'.\n", __func__, params.omni_vlm_version.c_str());
351+
throw std::runtime_error("You set wrong vlm_version info strings.");
352+
}
353+
354+
return g_oss.get();
355+
}
356+
357+
int32_t sample(omni_streaming_sample* oss) {
358+
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
359+
int32_t ret_id;
360+
if(oss->n_past_ == 0) {
361+
auto * image_embed = load_image(ctx_omnivlm, &params, oss->image_);
362+
if (!image_embed) {
363+
LOG_ERR("%s: failed to load image %s. Terminating\n\n", __func__, oss->image_.c_str());
364+
throw std::runtime_error("failed to load image " + oss->image_);
365+
}
366+
367+
size_t image_pos = params.prompt.find("<|image_pad|>");
368+
std::string system_prompt, user_prompt;
369+
370+
system_prompt = params.prompt.substr(0, image_pos);
371+
user_prompt = params.prompt.substr(image_pos + std::string("<|image_pad|>").length());
372+
if (params.verbose_prompt) {
373+
auto tmp = ::common_tokenize(ctx_omnivlm->ctx_llama, system_prompt, true, true);
374+
for (int i = 0; i < (int) tmp.size(); i++) {
375+
LOG_ERR("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_omnivlm->ctx_llama, tmp[i]).c_str());
376+
}
377+
}
378+
if (params.verbose_prompt) {
379+
auto tmp = ::common_tokenize(ctx_omnivlm->ctx_llama, user_prompt, true, true);
380+
for (int i = 0; i < (int) tmp.size(); i++) {
381+
LOG_ERR("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_omnivlm->ctx_llama, tmp[i]).c_str());
382+
}
383+
}
384+
385+
eval_string(ctx_omnivlm->ctx_llama, system_prompt.c_str(), params.n_batch, &(oss->n_past_), true);
386+
omnivlm_eval_image_embed(ctx_omnivlm->ctx_llama, image_embed, params.n_batch, &(oss->n_past_));
387+
eval_string(ctx_omnivlm->ctx_llama, user_prompt.c_str(), params.n_batch, &(oss->n_past_), false);
388+
389+
omnivlm_image_embed_free(image_embed);
390+
391+
ret_id = oss->sample();
392+
if (oss->ret_str_ == "<|im_end|>" || oss->ret_str_ == "</s>" ) {
393+
ret_id = -1;
394+
}
395+
} else {
396+
if(oss->dec_cnt_ == max_tgt_len) {
397+
ret_id = -2;
398+
} else {
399+
ret_id = oss->sample();
400+
if (oss->ret_str_ == "<|im_end|>" || oss->ret_str_ == "</s>" ) {
401+
ret_id = -1;
402+
}
403+
}
404+
}
405+
return ret_id;
406+
}
407+
408+
const char* get_str(omni_streaming_sample* oss) {
409+
return oss->ret_str_.c_str();
410+
}

examples/omni-vlm/omni-vlm-wrapper.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
21
#ifndef OMNIVLMWRAPPER_H
32
#define OMNIVLMWRAPPER_H
3+
#include <stdint.h>
44

55
#ifdef LLAMA_SHARED
66
# if defined(_WIN32) && !defined(__MINGW32__)
@@ -20,14 +20,22 @@
2020
extern "C" {
2121
#endif
2222

23+
struct omni_streaming_sample;
24+
2325
OMNIVLM_API void omnivlm_init(const char* llm_model_path, const char* projector_model_path, const char* omni_vlm_version);
2426

2527
OMNIVLM_API const char* omnivlm_inference(const char* prompt, const char* imag_path);
2628

29+
OMNIVLM_API struct omni_streaming_sample* omnivlm_inference_streaming(const char* prompt, const char* imag_path);
30+
31+
OMNIVLM_API int32_t sample(struct omni_streaming_sample *);
32+
33+
OMNIVLM_API const char* get_str(struct omni_streaming_sample *);
34+
2735
OMNIVLM_API void omnivlm_free();
2836

2937
#ifdef __cplusplus
3038
}
3139
#endif
3240

33-
#endif
41+
#endif

0 commit comments

Comments
 (0)