Skip to content

Commit 5648e30

Browse files
committed
llava cgraph ok
1 parent 6854ad4 commit 5648e30

File tree

8 files changed

+535
-94
lines changed

8 files changed

+535
-94
lines changed

Makefile

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,7 @@ OBJ_LLAMA = \
926926
src/llama-vocab.o \
927927
src/llama-grammar.o \
928928
src/llama-sampling.o \
929+
src/llama-vision.o \
929930
src/unicode.o \
930931
src/unicode-data.o
931932

@@ -937,6 +938,7 @@ OBJ_COMMON = \
937938
common/ngram-cache.o \
938939
common/sampling.o \
939940
common/train.o \
941+
common/vision.o \
940942
common/build-info.o \
941943
common/json-schema-to-grammar.o
942944

@@ -1221,6 +1223,12 @@ common/ngram-cache.o: \
12211223
common/ngram-cache.h
12221224
$(CXX) $(CXXFLAGS) -c $< -o $@
12231225

1226+
common/vision.o: \
1227+
common/vision.cpp \
1228+
common/vision.h \
1229+
common/stb_image.h
1230+
$(CXX) $(CXXFLAGS) -c $< -o $@
1231+
12241232
$(LIB_COMMON): \
12251233
$(OBJ_COMMON) \
12261234
$(LIB_LLAMA) \

common/vision.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#include "vision.h"
2+
3+
#define STB_IMAGE_IMPLEMENTATION
4+
#include "stb_image.h"
5+
6+
#include <vector>
7+
#include <fstream>
8+
9+
llama_img * load_image_from_file(const char * fname) {
10+
std::ifstream file(fname, std::ios::binary);
11+
if (!file) {
12+
throw std::runtime_error("Unable to open file");
13+
}
14+
std::vector<char> image_bytes = std::vector<char>(
15+
std::istreambuf_iterator<char>(file),
16+
std::istreambuf_iterator<char>());
17+
// decode image to byte array
18+
int nx, ny, nc;
19+
auto * bytes = (unsigned char *) image_bytes.data();
20+
auto * img = stbi_load_from_memory(bytes, image_bytes.size(), &nx, &ny, &nc, 3);
21+
if (!img) {
22+
throw std::runtime_error("failed to decode image bytes");
23+
}
24+
// printf("nx=%d ny=%d nc=%d\n", nx, ny, nc);
25+
// GGML_ASSERT(nc == 3);
26+
// for (int y = 0; y < ny; y++) {
27+
// for (int x = 0; x < nx; x++) {
28+
// unsigned char * pix = img + x*nc + y*nc*nx;
29+
// printf("%02x%02x%02x ", pix[0], pix[1], pix[2]);
30+
// }
31+
// printf("\n");
32+
// }
33+
// printf("\n");
34+
llama_img * result = llama_img_alloc(nx, ny);
35+
memcpy(result->data, bytes, nx*ny*nc);
36+
return result;
37+
}

common/vision.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#pragma once
2+
3+
#include "llama.h"
4+
5+
#include <string>
6+
#include <vector>
7+
8+
llama_img * load_image_from_file(const char * fname);

examples/simple/simple.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "common.h"
33
#include "log.h"
44
#include "llama.h"
5+
#include "vision.h"
56

67
#include <vector>
78

@@ -61,6 +62,19 @@ int main(int argc, char ** argv) {
6162

6263
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
6364

65+
66+
67+
68+
// TODO: this is for testing; DELETE ME
69+
llama_img_batch ibatch;
70+
ibatch.n_imgs = 1;
71+
ibatch.imgs = (llama_img **) malloc(1024);
72+
ibatch.imgs[0] = load_image_from_file("media/llama0-logo.png");
73+
llama_vision_encode(ctx, &ibatch);
74+
return 0;
75+
76+
77+
6478
// tokenize the prompt
6579

6680
std::vector<llama_token> tokens_list;

include/llama.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ extern "C" {
234234

235235
// Input data for llama_vision_decode
236236
typedef struct llama_img_batch {
237-
int32_t n_imgs;
238-
llama_img * imgs;
237+
int32_t n_imgs;
238+
llama_img ** imgs;
239239
} llama_img_batch;
240240

241241
// Input data for llama_decode
@@ -893,6 +893,10 @@ extern "C" {
893893
// Vision
894894
//
895895

896+
// create new RGB image for input
897+
LLAMA_API llama_img * llama_img_alloc(int width, int height);
898+
LLAMA_API void llama_img_free(llama_img * img);
899+
896900
// encode image into embeddings
897901
LLAMA_API int32_t llama_vision_encode(struct llama_context * ctx, llama_img_batch * batch);
898902

0 commit comments

Comments
 (0)