Skip to content

Commit 235340d

Browse files
committed
wip llava2
1 parent 2dabf75 commit 235340d

File tree

6 files changed

+350
-27
lines changed

6 files changed

+350
-27
lines changed

examples/llava/clip-impl.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "ggml.h"
22
#include "gguf.h"
33

4+
#include "clip.h"
5+
46
#include <climits>
57
#include <cstdarg>
68
#include <string>
@@ -120,6 +122,23 @@ static projector_type clip_projector_type_from_string(const std::string & str) {
120122
return PROJECTOR_TYPE_UNKNOWN;
121123
}
122124

125+
// RGB uint8 image
126+
struct clip_image_u8 {
127+
int nx;
128+
int ny;
129+
130+
std::vector<uint8_t> buf;
131+
};
132+
133+
// RGB float32 image (NHWC)
134+
// Memory layout: RGBRGBRGB...
135+
struct clip_image_f32 {
136+
int nx;
137+
int ny;
138+
139+
std::vector<float> buf;
140+
};
141+
123142
//
124143
// logging
125144
//
@@ -178,6 +197,28 @@ static void clip_log_internal(enum ggml_log_level level, const char * format, ..
178197
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
179198
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, __VA_ARGS__)
180199

200+
//
201+
// cpp wrappers
202+
//
203+
204+
struct clip_image_u8_deleter {
205+
void operator()(clip_image_u8 * val) { clip_image_u8_free(val); }
206+
};
207+
208+
struct clip_image_f32_deleter {
209+
void operator()(clip_image_f32 * val) { clip_image_f32_free(val); }
210+
};
211+
212+
struct clip_image_f32_batch_deleter {
213+
void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); }
214+
};
215+
216+
typedef std::unique_ptr<clip_image_u8, clip_image_u8_deleter> clip_image_u8_ptr;
217+
typedef std::unique_ptr<clip_image_f32, clip_image_f32_deleter> clip_image_f32_ptr;
218+
typedef std::unique_ptr<clip_image_f32_batch, clip_image_f32_batch_deleter> clip_image_f32_batch_ptr;
219+
220+
// TODO @ngxson : we're currently having a naming clash between struct clip_image_size and function clip_image_size()
221+
181222
//
182223
// common utils
183224
//
@@ -214,6 +255,20 @@ static void string_replace_all(std::string & s, const std::string & search, cons
214255
s = std::move(builder);
215256
}
216257

258+
// split string by a `std::string delim` instead of `char delim`
259+
static std::vector<std::string> string_split_str(std::string s, const std::string & delimiter) {
260+
std::vector<std::string> tokens;
261+
size_t pos = 0;
262+
std::string token;
263+
while ((pos = s.find(delimiter)) != std::string::npos) {
264+
token = s.substr(0, pos);
265+
tokens.push_back(token);
266+
s.erase(0, pos + delimiter.length());
267+
}
268+
tokens.push_back(s);
269+
return tokens;
270+
}
271+
217272
//
218273
// gguf utils
219274
//

examples/llava/clip.cpp

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,6 @@ struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callbac
3232

3333
//#define CLIP_DEBUG_FUNCTIONS
3434

35-
// RGB uint8 image
36-
struct clip_image_u8 {
37-
int nx;
38-
int ny;
39-
40-
std::vector<uint8_t> buf;
41-
};
42-
43-
// RGB float32 image (NHWC)
44-
// Memory layout: RGBRGBRGB...
45-
struct clip_image_f32 {
46-
int nx;
47-
int ny;
48-
49-
std::vector<float> buf;
50-
};
51-
5235
#ifdef CLIP_DEBUG_FUNCTIONS
5336
static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) {
5437
std::ofstream file(filename, std::ios::binary);
@@ -1618,6 +1601,12 @@ struct clip_image_f32 * clip_image_f32_init() {
16181601
return new clip_image_f32();
16191602
}
16201603

1604+
unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny) {
1605+
if (nx) *nx = img->nx;
1606+
if (ny) *ny = img->ny;
1607+
return img->buf.data();
1608+
}
1609+
16211610
void clip_image_u8_free(struct clip_image_u8 * img) { delete img; }
16221611
void clip_image_f32_free(struct clip_image_f32 * img) { delete img; }
16231612
void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) {

examples/llava/clip.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ CLIP_API struct clip_image_size * clip_image_size_init();
7777
CLIP_API struct clip_image_u8 * clip_image_u8_init ();
7878
CLIP_API struct clip_image_f32 * clip_image_f32_init();
7979

80+
// nx, ny are the output image dimensions
81+
CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
82+
8083
CLIP_API void clip_image_u8_free (struct clip_image_u8 * img);
8184
CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
8285
CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);

examples/llava/gemma3-cli.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
#include "log.h"
33
#include "common.h"
44
#include "sampling.h"
5-
#include "clip.h"
6-
#include "stb_image.h"
75
#include "llama.h"
86
#include "ggml.h"
97
#include "console.h"
8+
#include "llava2.h"
109

1110
#include <vector>
1211
#include <limits.h>
@@ -57,8 +56,8 @@ static void sigint_handler(int signo) {
5756
#endif
5857

5958
struct gemma3_context {
60-
struct clip_ctx * ctx_clip = NULL;
61-
common_init_result llama_init;
59+
llava2_context_ptr ctx_llava2;
60+
common_init_result llama_init;
6261

6362
llama_model * model;
6463
llama_context * lctx;
@@ -79,16 +78,16 @@ struct gemma3_context {
7978

8079
void init_clip_model(common_params & params) {
8180
const char * clip_path = params.mmproj.path.c_str();
82-
ctx_clip = clip_model_load(clip_path, GGML_LOG_LEVEL_INFO);
83-
if (!ctx_clip) {
81+
ctx_llava2 = llava2_init_from_file(clip_path, model, llava2_context_params{
82+
/* use_gpu */ true,
83+
/* n_threads */ params.cpuparams.n_threads,
84+
/* verbosity */ GGML_LOG_LEVEL_INFO,
85+
});
86+
if (!ctx_llava2.get()) {
8487
LOG_ERR("Failed to load CLIP model from %s\n", clip_path);
8588
exit(1);
8689
}
8790
}
88-
89-
~gemma3_context() {
90-
clip_free(ctx_clip);
91-
}
9291
};
9392

9493
struct decode_embd_batch {
@@ -271,6 +270,7 @@ int main(int argc, char ** argv) {
271270

272271
if (is_single_turn) {
273272
g_is_generating = true;
273+
std::string prompt = "<start_of_turn>user\n<image>" + params.prompt + "<end_of_turn><start_of_turn>model\n";
274274
if (eval_text(ctx, "<start_of_turn>user\n")) {
275275
return 1;
276276
}

examples/llava/llava2.cpp

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
#include "clip.h"
2+
#include "clip-impl.h"
3+
#include "llava2.h"
4+
5+
#include "llama.h"
6+
7+
#include <algorithm>
8+
#include <cerrno>
9+
#include <cstdio>
10+
#include <cstdlib>
11+
#include <cstring>
12+
#include <limits>
13+
#include <vector>
14+
15+
static const char * IMG_MARKER = "<image>";
16+
17+
struct llava2_context {
18+
struct clip_ctx * ctx_clip;
19+
const struct llama_model * text_model;
20+
std::vector<float> image_embd_v; // image embedding vector
21+
int n_threads;
22+
23+
llava2_context(const char * mmproj_fname,
24+
const struct llama_model * text_model,
25+
const struct llava2_context_params & ctx_params) : n_threads(ctx_params.n_threads) {
26+
clip_context_params ctx_clip_params;
27+
ctx_clip_params.use_gpu = ctx_params.use_gpu;
28+
ctx_clip_params.verbosity = ctx_params.verbosity;
29+
ctx_clip = clip_init(mmproj_fname, ctx_clip_params);
30+
if (!ctx_clip) {
31+
throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
32+
}
33+
this->text_model = text_model;
34+
}
35+
36+
~llava2_context() {
37+
clip_free(ctx_clip);
38+
}
39+
};
40+
41+
struct llava2_image_tokens_data {
42+
clip_image_f32_batch_ptr batch_f32; // preprocessed image patches
43+
};
44+
45+
llava2_context_ptr llava2_init_from_file(const char * mmproj_fname,
46+
const struct llama_model * text_model,
47+
const struct llava2_context_params ctx_params) {
48+
try {
49+
auto ctx = std::make_shared<llava2_context>(mmproj_fname, text_model, ctx_params);
50+
return ctx;
51+
} catch (const std::exception & e) {
52+
LOG_ERR("%s: error: %s\n", __func__, e.what());
53+
return nullptr;
54+
}
55+
}
56+
57+
int32_t llava2_bitmap_init_from_file(const char * fname, llava2_bitmap & output) {
58+
clip_image_u8_ptr img_u8(clip_image_u8_init());
59+
bool ok = clip_image_load_from_file(fname, img_u8.get());
60+
if (!ok) {
61+
LOG_ERR("Unable to load image %s\n", fname);
62+
return 1;
63+
}
64+
unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny);
65+
output.data.resize(output.nx * output.ny * 3);
66+
std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
67+
return 0;
68+
}
69+
70+
// copied from common_tokenize
71+
static std::vector<llama_token> llava2_tokenize_text_internal(
72+
const struct llama_vocab * vocab,
73+
const std::string & text,
74+
bool add_special,
75+
bool parse_special) {
76+
// upper limit for the number of tokens
77+
int n_tokens = text.length() + 2 * add_special;
78+
std::vector<llama_token> result(n_tokens);
79+
n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
80+
if (n_tokens < 0) {
81+
result.resize(-n_tokens);
82+
int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
83+
GGML_ASSERT(check == -n_tokens);
84+
} else {
85+
result.resize(n_tokens);
86+
}
87+
return result;
88+
}
89+
90+
int32_t llava2_tokenize(llava2_context_ptr & ctx,
91+
std::vector<llava2_input_chunk> & output,
92+
const std::string & prompt,
93+
bool add_special,
94+
bool parse_special,
95+
const std::vector<llava2_bitmap> & bitmaps) {
96+
auto vocab = llama_model_get_vocab(ctx->text_model);
97+
98+
std::vector<std::string> parts = string_split_str(prompt, IMG_MARKER);
99+
output.clear();
100+
output.reserve(parts.size());
101+
102+
size_t i_img = 0;
103+
104+
for (const auto & part : parts) {
105+
//printf("tokenizing part: %s\n", part.c_str());
106+
bool add_bos = &parts.front() == &part;
107+
auto tokens = llava2_tokenize_text_internal(vocab, part, add_special && add_bos, parse_special);
108+
if (tokens.empty()) {
109+
continue;
110+
}
111+
output.push_back({
112+
LLAVA2_INPUT_CHUNK_TYPE_TEXT,
113+
std::move(tokens),
114+
{},
115+
});
116+
117+
if (&parts.back() != &part) {
118+
// add image token to middle of 2 parts
119+
120+
if (i_img >= bitmaps.size()) {
121+
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
122+
return 2;
123+
}
124+
125+
// shim layer
126+
clip_image_u8_ptr img_u8(clip_image_u8_init());
127+
img_u8->nx = bitmaps[i_img].nx;
128+
img_u8->ny = bitmaps[i_img].ny;
129+
img_u8->buf.resize(bitmaps[i_img].data.size());
130+
std::memcpy(img_u8->buf.data(), bitmaps[i_img].data.data(), img_u8->nx * img_u8->ny * 3);
131+
132+
// preprocess image
133+
clip_image_f32_batch_ptr batch_f32;
134+
bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), batch_f32.get());
135+
if (!ok) {
136+
LOG_ERR("Unable to preprocess image\n");
137+
return 1;
138+
}
139+
140+
llava2_image_tokens image_tokens;
141+
//image_tokens.nx = ...;
142+
//image_tokens.ny = ...;
143+
image_tokens.n_tokens = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image
144+
image_tokens.data = std::unique_ptr<llava2_image_tokens_data>(
145+
new llava2_image_tokens_data{
146+
std::move(batch_f32),
147+
}
148+
);
149+
150+
output.push_back({
151+
LLAVA2_INPUT_CHUNK_TYPE_IMAGE,
152+
{},
153+
std::move(image_tokens),
154+
});
155+
i_img++;
156+
}
157+
}
158+
159+
return 0;
160+
}
161+
162+
LLAVA2_API int32_t llava2_encode(llava2_context_ptr & ctx,
163+
const llava2_image_tokens & image_tokens) {
164+
ctx->image_embd_v.reserve(image_tokens.n_tokens * clip_n_mmproj_embd(ctx->ctx_clip));
165+
return clip_image_batch_encode(
166+
ctx->ctx_clip,
167+
ctx->n_threads,
168+
image_tokens.data->batch_f32.get(),
169+
ctx->image_embd_v.data());
170+
}
171+
172+
LLAVA2_API float * llava2_get_output_embd(llava2_context_ptr & ctx) {
173+
return ctx->image_embd_v.data();
174+
}

0 commit comments

Comments
 (0)