Skip to content

Commit 6089b0a

Browse files
committed
simple example works
1 parent 4897ff6 commit 6089b0a

File tree

7 files changed

+111
-33
lines changed

7 files changed

+111
-33
lines changed

common/vision.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ llama_img * load_image_from_file(const char * fname) {
3232
// }
3333
// printf("\n");
3434
llama_img * result = llama_img_alloc(nx, ny);
35-
memcpy(result->data, bytes, nx*ny*nc);
35+
memcpy(result->data, img, nx*ny*3);
36+
stbi_image_free(img);
3637
return result;
3738
}

convert_hf_to_gguf.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import os
1212
import re
1313
import sys
14+
from transformers import AutoConfig
1415
from enum import IntEnum
1516
from pathlib import Path
1617
from hashlib import sha256
@@ -67,6 +68,7 @@ class Model:
6768
is_lora: bool
6869

6970
# for vision model
71+
preprocessor_config: dict[str, Any] | None = None
7072
vparams: dict[str, Any] | None = None
7173
v_tensor_map: gguf.TensorNameMap
7274
v_tensor_names: set[str] | None
@@ -100,6 +102,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
100102
self.model_name = model_name
101103
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
102104
self.is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py
105+
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
103106

104107
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
105108
if self.ftype == gguf.LlamaFileType.GUESSED:
@@ -463,8 +466,20 @@ def load_hparams(dir_model: Path):
463466
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
464467
hparams = json.load(f)
465468
if "text_config" in hparams:
466-
hparams = {**hparams, **hparams["text_config"]}
469+
text_config = hparams["text_config"]
470+
if "_name_or_path" in text_config:
471+
text_config = AutoConfig.from_pretrained(text_config["_name_or_path"]).to_dict()
472+
hparams = {**text_config, **hparams}
467473
return hparams
474+
475+
@staticmethod
476+
def load_preprocessor_config(dir_model: Path):
477+
file_path = dir_model / "preprocessor_config.json"
478+
if os.path.exists(file_path):
479+
with open(file_path, "r", encoding="utf-8") as f:
480+
return json.load(f)
481+
else:
482+
return None
468483

469484
@classmethod
470485
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
@@ -1574,7 +1589,7 @@ def set_gguf_parameters(self):
15741589
self.gguf_writer.add_add_bos_token(False)
15751590

15761591
# For vision model
1577-
if self.vparams is not None:
1592+
if self.vparams is not None and self.preprocessor_config is not None:
15781593
self.gguf_writer.add_vision_type("clip")
15791594
self.gguf_writer.add_vision_image_size(self.vparams["image_size"])
15801595
self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"])
@@ -1583,14 +1598,13 @@ def set_gguf_parameters(self):
15831598
self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"])
15841599
self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"])
15851600
self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"])
1601+
self.gguf_writer.add_vision_clip_image_mean(self.preprocessor_config["image_mean"])
1602+
self.gguf_writer.add_vision_clip_image_std(self.preprocessor_config["image_std"])
1603+
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
1604+
self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd)
15861605
# TODO: should not hardcode these, but they are currently missing from config.json
15871606
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP)
1588-
self.gguf_writer.add_vision_clip_max_position_embeddings(577)
15891607
self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05)
1590-
default_image_mean = [0.48145466, 0.4578275, 0.40821073]
1591-
default_image_std = [0.26862954, 0.26130258, 0.27577711]
1592-
self.gguf_writer.add_vision_clip_image_mean(default_image_mean)
1593-
self.gguf_writer.add_vision_clip_image_std(default_image_std)
15941608

15951609
@staticmethod
15961610
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
@@ -1606,8 +1620,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
16061620
n_head = self.hparams["num_attention_heads"]
16071621
n_kv_head = self.hparams.get("num_key_value_heads")
16081622

1623+
# For vision model
16091624
if name.startswith("language_model"):
16101625
name = name.replace("language_model.", "")
1626+
if "post_layernorm" in name:
1627+
return [] # skip post_layernorm
16111628

16121629
if name.endswith(("q_proj.weight", "q_proj.bias")):
16131630
data_torch = LlamaModel.permute(data_torch, n_head, n_head)

examples/simple/simple.cpp

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,41 @@ int main(int argc, char ** argv) {
6666

6767

6868
// TODO: this is for testing; DELETE ME
69-
llama_img_batch ibatch;
70-
ibatch.n_imgs = 1;
71-
ibatch.imgs = (llama_img **) malloc(1024);
72-
ibatch.imgs[0] = load_image_from_file("media/llama0-logo.png");
73-
llama_vision_encode(ctx, &ibatch);
74-
return 0;
69+
int n_cur = 0;
70+
params.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:";
71+
{
72+
llama_img_batch ibatch;
73+
ibatch.n_imgs = 1;
74+
ibatch.imgs = (llama_img **) malloc(1024);
75+
ibatch.imgs[0] = load_image_from_file("../models/eiffel-tower-3349075_1280.jpg");
76+
llama_vision_encode(ctx, &ibatch);
77+
78+
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
79+
int n_imgs = ibatch.n_imgs;
80+
int n_embd = llama_n_embd(model);
81+
int n_patches = llama_vision_n_patches(ctx);
82+
printf("n_embd = %d ; n_patches = %d \n", n_embd, n_patches);
83+
float * output_img = llama_vision_get_embeddings(ctx, 0);
84+
85+
n_cur += tokens.size();
86+
llama_batch batch = llama_batch_init(512, 0, 1);
87+
llama_batch_clear(batch);
88+
for (auto t : tokens) { llama_batch_add(batch, t, n_cur, { 0 }, false); n_cur++; }
89+
if (llama_decode(ctx, batch) != 0) {
90+
LOG("%s: llama_decode() failed\n", __func__);
91+
return 1;
92+
}
93+
94+
// for (int k = 0; k < 10; k++) printf("%f\n", output_img[k]);
95+
llama_batch_clear(batch);
96+
batch = {int32_t(n_patches*n_imgs), nullptr, output_img, nullptr, nullptr, nullptr, nullptr, n_cur, 1, 0, };
97+
if (llama_decode(ctx, batch) != 0) {
98+
LOG("%s: llama_decode() failed\n", __func__);
99+
return 1;
100+
}
101+
n_cur += n_embd*n_imgs;
102+
}
103+
params.prompt = "\nwhat did you see?\nASSISTANT:";
75104

76105

77106

@@ -108,7 +137,10 @@ int main(int argc, char ** argv) {
108137

109138
// evaluate the initial prompt
110139
for (size_t i = 0; i < tokens_list.size(); i++) {
111-
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
140+
//llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
141+
if (i == 0) continue;
142+
llama_batch_add(batch, tokens_list[i], n_cur, { 0 }, false);
143+
n_cur++;
112144
}
113145

114146
// llama_decode will output logits only for the last token of the prompt
@@ -121,18 +153,18 @@ int main(int argc, char ** argv) {
121153

122154
// main loop
123155

124-
int n_cur = batch.n_tokens;
156+
//int n_cur = batch.n_tokens;
125157
int n_decode = 0;
126158

127159
const auto t_main_start = ggml_time_us();
128160

129-
while (n_cur <= n_predict) {
161+
for (int i = 0; i < n_predict; i++) {
130162
// sample the next token
131163
{
132164
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, -1);
133165

134166
// is it an end of generation?
135-
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
167+
if (llama_token_is_eog(model, new_token_id)) {
136168
LOG("\n");
137169

138170
break;

include/llama.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,8 @@ extern "C" {
903903
// get output embeddings, to be put into language batch
904904
LLAMA_API float * llama_vision_get_embeddings(struct llama_context * ctx, int32_t idx);
905905

906+
LLAMA_API int32_t llama_vision_n_patches(struct llama_context * ctx);
907+
906908
//
907909
// Vocab
908910
//

src/llama-vision.cpp

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
// export clip_image_u8 to bmp file for debugging
1616
// https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c
17-
static int bmp_export(const clip_image_u8 &img, const std::string &location);
17+
struct clip_image_size;
18+
static int bmp_export(const struct clip_image_u8 &img, const std::string &location);
1819
#endif
1920

2021
struct clip_image_size {
@@ -53,21 +54,21 @@ struct clip_image_f32 {
5354
using clip_image_f32_batch = std::vector<clip_image_f32>;
5455
using clip_image_f8_batch = std::vector<clip_image_u8>;
5556

56-
static int clip_n_patches(const clip_context & ctx) {
57+
int clip_n_patches(const clip_context & ctx) {
5758
auto & hparams = ctx.model->hparams;
5859
int n_patches = (hparams.image_size / hparams.patch_size) * (hparams.image_size / hparams.patch_size);
5960
return n_patches;
6061
}
6162

62-
static int clip_n_mmproj_embd(const clip_context & ctx) {
63+
int clip_n_mmproj_embd(const clip_context & ctx) {
6364
if (ctx.model->hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) {
6465
return ctx.model->mm_b_b->ne[0];
6566
} else {
6667
GGML_ASSERT(false && "invalid proj type");
6768
}
6869
}
6970

70-
static int clip_n_embd(const clip_context & ctx) {
71+
int clip_n_embd(const clip_context & ctx) {
7172
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx);
7273
}
7374

@@ -323,7 +324,7 @@ static bool clip_image_preprocess(const clip_context & ctx, const clip_image_u8
323324

324325
const int nx = temp.nx;
325326
const int ny = temp.ny;
326-
// clip_image_save_to_bmp(*temp, "resized_vanilla.bmp");
327+
// bmp_export(temp, "resized_vanilla.bmp");
327328

328329
const int nx2 = params.image_size;
329330
const int ny2 = params.image_size;
@@ -451,11 +452,11 @@ static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size,
451452
embeddings = ggml_norm(ctx0, embeddings, eps);
452453
ggml_set_name(embeddings, "pre_ln");
453454

454-
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_norm_w), model.pre_norm_w);
455+
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_norm_w), model.pre_norm_b);
455456
}
456457

457458
// loop over layers
458-
for (int il = 0; il < (int)hparams.n_layer - 1; il++) {
459+
for (int il = 0; il < (int)hparams.n_layer - 2; il++) {
459460
struct ggml_tensor * cur = embeddings;
460461

461462
// layernorm1
@@ -537,6 +538,14 @@ static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size,
537538
embeddings = cur;
538539
}
539540

541+
// post-layernorm
542+
if (model.post_norm_w) {
543+
embeddings = ggml_norm(ctx0, embeddings, eps);
544+
ggml_set_name(embeddings, "post_ln");
545+
546+
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_norm_w), model.post_norm_b);
547+
}
548+
540549
// llava projector
541550
{
542551
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
@@ -673,6 +682,7 @@ static int32_t encode_image_with_clip(clip_context & ctx, const llama_img img, s
673682
clip_image_u8 img_u8(img);
674683
clip_image_f32_batch img_res_v;
675684
auto & hparams = ctx.model->hparams;
685+
// bmp_export(img_u8, "test_inp.bmp");
676686

677687
if (!clip_image_preprocess(ctx, img_u8, img_res_v)) {
678688
LLAMA_LOG_ERROR("%s: unable to preprocess image\n", __func__);
@@ -724,7 +734,6 @@ int32_t llama_vision_encode_internal(clip_context & ctx, llama_img_batch * batch
724734
// copy output embeddings to result
725735
for (int k = 0; k < n_embd; k++) {
726736
ctx.output[n_embd*i + k] = output_single[k];
727-
// if (k<10) printf("%f\n", output_single[k]);
728737
}
729738
}
730739

@@ -735,10 +744,19 @@ int32_t llama_vision_encode_internal(clip_context & ctx, llama_img_batch * batch
735744
// for debugging
736745
#ifndef NDEBUG
737746

738-
static int bmp_export(const clip_image_u8 &img, const std::string &location) {
747+
static int bmp_export(const struct clip_image_u8 &img, const std::string &location) {
739748
const uint32_t width = img.nx;
740749
const uint32_t height = img.ny;
741-
const std::vector<uint8_t> &buffer = img.buf;
750+
// swap red and blue channel
751+
std::vector<uint8_t> buffer(width*height*3);
752+
for (uint32_t y = 0; y < height; y++) {
753+
for (uint32_t x = 0; x < width; x++) {
754+
size_t base = x*3 + y*3*width;
755+
buffer[base+2] = img.buf[base];
756+
buffer[base+1] = img.buf[base+1];
757+
buffer[base] = img.buf[base+2];
758+
}
759+
}
742760
const bool hasAlphaChannel = false;
743761

744762
std::ofstream fout(location, std::ios::out | std::ios::binary);

src/llama-vision.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
#include <array>
77

88
enum vision_arch {
9-
VISION_ARCH_LLAVA,
109
VISION_ARCH_UNKNOWN,
10+
VISION_ARCH_LLAVA,
1111
};
1212

1313
enum clip_projector_type {
@@ -112,4 +112,8 @@ struct clip_context {
112112
std::vector<float> output; // size == n_output * n_embd
113113
};
114114

115+
int clip_n_patches(const clip_context & ctx);
116+
int clip_n_mmproj_embd(const clip_context & ctx);
117+
int clip_n_embd(const clip_context & ctx);
118+
115119
int32_t llama_vision_encode_internal(clip_context & ctx, llama_img_batch * batch);

src/llama.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6239,7 +6239,7 @@ static void llm_load_hparams(
62396239
ml.get_key(LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, vparams.n_intermediate, true);
62406240
ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, vparams.n_head, true);
62416241
ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, vparams.eps, true);
6242-
ml.get_key(LLM_KV_VISION_CLIP_PROJECTOR_TYPE, proj_type, true);
6242+
ml.get_key(LLM_KV_VISION_CLIP_PROJECTOR_TYPE, proj_type, true);
62436243
if (proj_type == "mlp") {
62446244
vparams.proj_type = CLIP_PROJECTOR_TYPE_MLP;
62456245
} else {
@@ -8987,9 +8987,9 @@ static bool llm_load_tensors(
89878987
model.clip.position_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_POS, "weight"), {n_embd, max_pos_embd});
89888988

89898989
model.clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "weight"), {n_embd});
8990-
model.clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "bias" ), {n_embd});
8991-
model.clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "weight"), {n_embd});
8992-
model.clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "bias" ), {n_embd});
8990+
model.clip.pre_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "bias" ), {n_embd});
8991+
// model.clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "weight"), {n_embd});
8992+
// model.clip.post_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "bias" ), {n_embd});
89938993

89948994
for (int i = 0; i < n_layer; ++i) {
89958995
ggml_context * ctx_layer = ctx_for_layer(i);
@@ -21815,6 +21815,10 @@ float * llama_vision_get_embeddings(struct llama_context * ctx, int32_t idx) {
2181521815
return ctx->clip.output.data();
2181621816
}
2181721817

21818+
int32_t llama_vision_n_patches(struct llama_context * ctx) {
21819+
return clip_n_patches(ctx->clip);
21820+
}
21821+
2181821822
//
2181921823
// model split
2182021824
//

0 commit comments

Comments
 (0)