Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,7 @@ OBJ_LLAMA = \
src/llama-vocab.o \
src/llama-grammar.o \
src/llama-sampling.o \
src/llama-vision.o \
src/unicode.o \
src/unicode-data.o

Expand All @@ -937,6 +938,7 @@ OBJ_COMMON = \
common/ngram-cache.o \
common/sampling.o \
common/train.o \
common/vision.o \
common/build-info.o \
common/json-schema-to-grammar.o

Expand Down Expand Up @@ -1120,6 +1122,7 @@ src/llama.o: \
src/llama-vocab.h \
src/llama-grammar.h \
src/llama-sampling.h \
src/llama-vision.h \
src/unicode.h \
include/llama.h \
ggml/include/ggml-cuda.h \
Expand Down Expand Up @@ -1152,6 +1155,17 @@ src/llama-sampling.o: \
include/llama.h
$(CXX) $(CXXFLAGS) -c $< -o $@

src/llama-vision.o: \
src/llama-vision.cpp \
src/llama-vision.h \
include/llama.h \
ggml/include/ggml-cuda.h \
ggml/include/ggml-metal.h \
ggml/include/ggml.h \
ggml/include/ggml-alloc.h \
ggml/include/ggml-backend.h
$(CXX) $(CXXFLAGS) -c $< -o $@

$(LIB_LLAMA): \
$(OBJ_LLAMA) \
$(LIB_GGML)
Expand Down Expand Up @@ -1209,6 +1223,12 @@ common/ngram-cache.o: \
common/ngram-cache.h
$(CXX) $(CXXFLAGS) -c $< -o $@

common/vision.o: \
common/vision.cpp \
common/vision.h \
common/stb_image.h
$(CXX) $(CXXFLAGS) -c $< -o $@

$(LIB_COMMON): \
$(OBJ_COMMON) \
$(LIB_LLAMA) \
Expand Down Expand Up @@ -1457,7 +1477,6 @@ llama-server: \
examples/server/json-schema-to-grammar.mjs.hpp \
examples/server/loading.html.hpp \
common/json.hpp \
common/stb_image.h \
$(OBJ_ALL)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
Expand All @@ -1480,7 +1499,6 @@ libllava.a: examples/llava/llava.cpp \
examples/llava/llava.h \
examples/llava/clip.cpp \
examples/llava/clip.h \
common/stb_image.h \
common/base64.hpp \
$(OBJ_ALL)
$(CXX) $(CXXFLAGS) -static -fPIC -c $< -o $@ -Wno-cast-qual
Expand Down
38 changes: 38 additions & 0 deletions common/vision.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include "vision.h"

#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"

#include <vector>
#include <fstream>

llama_img * load_image_from_file(const char * fname) {
std::ifstream file(fname, std::ios::binary);
if (!file) {
throw std::runtime_error("Unable to open file");
}
std::vector<char> image_bytes = std::vector<char>(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>());
// decode image to byte array
int nx, ny, nc;
auto * bytes = (unsigned char *) image_bytes.data();
auto * img = stbi_load_from_memory(bytes, image_bytes.size(), &nx, &ny, &nc, 3);
if (!img) {
throw std::runtime_error("failed to decode image bytes");
}
// printf("nx=%d ny=%d nc=%d\n", nx, ny, nc);
// GGML_ASSERT(nc == 3);
// for (int y = 0; y < ny; y++) {
// for (int x = 0; x < nx; x++) {
// unsigned char * pix = img + x*nc + y*nc*nx;
// printf("%02x%02x%02x ", pix[0], pix[1], pix[2]);
// }
// printf("\n");
// }
// printf("\n");
llama_img * result = llama_img_alloc(nx, ny);
memcpy(result->data, img, nx*ny*3);
stbi_image_free(img);
return result;
}
8 changes: 8 additions & 0 deletions common/vision.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#pragma once

#include "llama.h"

#include <string>
#include <vector>

llama_img * load_image_from_file(const char * fname);
66 changes: 62 additions & 4 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import re
import sys
from transformers import AutoConfig
from enum import IntEnum
from pathlib import Path
from hashlib import sha256
Expand Down Expand Up @@ -66,6 +67,12 @@ class Model:
dir_model_card: Path
is_lora: bool

# for vision model
preprocessor_config: dict[str, Any] | None = None
vparams: dict[str, Any] | None = None
v_tensor_map: gguf.TensorNameMap
v_tensor_names: set[str] | None

# subclasses should define this!
model_arch: gguf.MODEL_ARCH

Expand Down Expand Up @@ -95,6 +102,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
self.is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)

# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
if self.ftype == gguf.LlamaFileType.GUESSED:
Expand Down Expand Up @@ -210,9 +218,13 @@ def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int |

def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
if new_name is None:
new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes)
if new_name is not None:
return new_name
elif new_name_vision is not None:
return new_name_vision
else:
raise ValueError(f"Can not map tensor {name!r}")
return new_name

def set_gguf_parameters(self):
self.gguf_writer.add_block_count(self.block_count)
Expand Down Expand Up @@ -452,7 +464,22 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]
@staticmethod
def load_hparams(dir_model: Path):
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
return json.load(f)
hparams = json.load(f)
if "text_config" in hparams:
text_config = hparams["text_config"]
if "_name_or_path" in text_config:
text_config = AutoConfig.from_pretrained(text_config["_name_or_path"]).to_dict()
hparams = {**text_config, **hparams}
return hparams

@staticmethod
def load_preprocessor_config(dir_model: Path):
file_path = dir_model / "preprocessor_config.json"
if os.path.exists(file_path):
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
else:
return None

@classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
Expand Down Expand Up @@ -1501,10 +1528,17 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed norms: {norms}")


@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration")
class LlamaModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if "vision_config" in self.hparams:
self.vparams = self.hparams["vision_config"]
if self.vparams is not None:
self.v_tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAVA_VISION, self.vparams["num_hidden_layers"])

def set_vocab(self):
try:
self._set_vocab_sentencepiece()
Expand Down Expand Up @@ -1554,6 +1588,24 @@ def set_gguf_parameters(self):
if self.hparams.get("vocab_size", 32000) == 49152:
self.gguf_writer.add_add_bos_token(False)

# For vision model
if self.vparams is not None and self.preprocessor_config is not None:
self.gguf_writer.add_vision_type("clip")
self.gguf_writer.add_vision_image_size(self.vparams["image_size"])
self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"])
self.gguf_writer.add_vision_clip_architecture("llava")
self.gguf_writer.add_vision_clip_block_count(self.vparams["num_hidden_layers"])
self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"])
self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"])
self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"])
self.gguf_writer.add_vision_clip_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_clip_image_std(self.preprocessor_config["image_std"])
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd)
# TODO: should not hardcode these, but they are currently missing from config.json
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP)
self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05)

@staticmethod
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
if n_head_kv is not None and n_head != n_head_kv:
Expand All @@ -1568,6 +1620,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")

# For vision model
if name.startswith("language_model"):
name = name.replace("language_model.", "")
if "post_layernorm" in name:
return [] # skip post_layernorm

if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
Expand Down
2 changes: 1 addition & 1 deletion examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include "ggml-vulkan.h"
#endif

#define STB_IMAGE_IMPLEMENTATION
#include "vision.h" // without this, we get duplicated symbol error
#include "stb_image.h"

#include <cassert>
Expand Down
54 changes: 50 additions & 4 deletions examples/simple/simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "vision.h"

#include <vector>

Expand Down Expand Up @@ -61,6 +62,48 @@ int main(int argc, char ** argv) {

llama_sampler_chain_add(smpl, llama_sampler_init_greedy());




// TODO: this is for testing; DELETE ME
int n_cur = 0;
params.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:";
{
llama_img_batch ibatch;
ibatch.n_imgs = 1;
ibatch.imgs = (llama_img **) malloc(1024);
ibatch.imgs[0] = load_image_from_file("../models/eiffel-tower-3349075_1280.jpg");
llama_vision_encode(ctx, &ibatch);

auto tokens = ::llama_tokenize(ctx, params.prompt, true);
int n_imgs = ibatch.n_imgs;
int n_embd = llama_n_embd(model);
int n_patches = llama_vision_n_patches(ctx);
printf("n_embd = %d ; n_patches = %d \n", n_embd, n_patches);
float * output_img = llama_vision_get_embeddings(ctx, 0);

n_cur += tokens.size();
llama_batch batch = llama_batch_init(512, 0, 1);
llama_batch_clear(batch);
for (auto t : tokens) { llama_batch_add(batch, t, n_cur, { 0 }, false); n_cur++; }
if (llama_decode(ctx, batch) != 0) {
LOG("%s: llama_decode() failed\n", __func__);
return 1;
}

// for (int k = 0; k < 10; k++) printf("%f\n", output_img[k]);
llama_batch_clear(batch);
batch = {int32_t(n_patches*n_imgs), nullptr, output_img, nullptr, nullptr, nullptr, nullptr, n_cur, 1, 0, };
if (llama_decode(ctx, batch) != 0) {
LOG("%s: llama_decode() failed\n", __func__);
return 1;
}
n_cur += n_embd*n_imgs;
}
params.prompt = "\nwhat did you see?\nASSISTANT:";



// tokenize the prompt

std::vector<llama_token> tokens_list;
Expand Down Expand Up @@ -94,7 +137,10 @@ int main(int argc, char ** argv) {

// evaluate the initial prompt
for (size_t i = 0; i < tokens_list.size(); i++) {
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
//llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
if (i == 0) continue;
llama_batch_add(batch, tokens_list[i], n_cur, { 0 }, false);
n_cur++;
}

// llama_decode will output logits only for the last token of the prompt
Expand All @@ -107,18 +153,18 @@ int main(int argc, char ** argv) {

// main loop

int n_cur = batch.n_tokens;
//int n_cur = batch.n_tokens;
int n_decode = 0;

const auto t_main_start = ggml_time_us();

while (n_cur <= n_predict) {
for (int i = 0; i < n_predict; i++) {
// sample the next token
{
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, -1);

// is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
if (llama_token_is_eog(model, new_token_id)) {
LOG("\n");

break;
Expand Down
Loading