Skip to content

Commit 9d97ad5

Browse files
committed
(research) experiment with phi-4-multimodal
1 parent 7ab3643 commit 9d97ad5

File tree

5 files changed

+589
-4
lines changed

5 files changed

+589
-4
lines changed

convert_hf_to_gguf.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2398,9 +2398,23 @@ def set_gguf_parameters(self):
23982398
self.gguf_writer.add_add_bos_token(False)
23992399

24002400

2401-
@Model.register("Phi3ForCausalLM")
2401+
@Model.register("Phi3ForCausalLM", "Phi4MMForCausalLM")
24022402
class Phi3MiniModel(Model):
24032403
model_arch = gguf.MODEL_ARCH.PHI3
2404+
has_vision: bool = False
2405+
2406+
# we need to merge the text_config into the root level of hparams
2407+
def __init__(self, *args, **kwargs):
2408+
super().__init__(*args, **kwargs)
2409+
if "vision_lora" in self.hparams:
2410+
logger.info("Detected vision encoder, but it will be ignored")
2411+
self.has_vision = True
2412+
2413+
def write(self):
2414+
super().write()
2415+
if self.has_vision:
2416+
logger.info("NOTE: this script only convert the language model to GGUF")
2417+
logger.info(" for the vision model, please use phi4mm_convert_encoder_to_gguf.py")
24042418

24052419
def set_vocab(self):
24062420
# Phi-4 model uses GPT2Tokenizer
@@ -2409,7 +2423,7 @@ def set_vocab(self):
24092423
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
24102424
tokenizer_config_json = json.load(f)
24112425
tokenizer_class = tokenizer_config_json['tokenizer_class']
2412-
if tokenizer_class == 'GPT2Tokenizer':
2426+
if tokenizer_class == 'GPT2Tokenizer' or tokenizer_class == 'GPT2TokenizerFast':
24132427
return self._set_vocab_gpt2()
24142428

24152429
from sentencepiece import SentencePieceProcessor
@@ -2575,6 +2589,14 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
25752589
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32))
25762590
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
25772591

2592+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2593+
del bid # unused
2594+
if self.has_vision:
2595+
if name.startswith("model.embed_tokens_extend") or "lora_" in name:
2596+
return []
2597+
name = name.replace(".base_layer", "")
2598+
return [(self.map_tensor_name(name), data_torch)]
2599+
25782600

25792601
@Model.register("PhiMoEForCausalLM")
25802602
class PhiMoeModel(Phi3MiniModel):

examples/llava/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,10 @@ set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-clip-quantize
5757
install(TARGETS ${TARGET} RUNTIME)
5858
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
5959
target_compile_features(${TARGET} PRIVATE cxx_std_17)
60+
61+
set(TARGET llama-phi4mm-cli)
62+
add_executable(${TARGET} phi4mm-cli.cpp)
63+
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-phi4mm-cli)
64+
install(TARGETS ${TARGET} RUNTIME)
65+
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
66+
target_compile_features(${TARGET} PRIVATE cxx_std_17)

examples/llava/clip.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,24 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
878878
}
879879
}
880880

881+
// FIXME: phi-4, wrap this into an "if" condition
882+
int n_tokens = embeddings->ne[1];
883+
int n_tokens_sqrt = sqrtf(n_tokens);
884+
printf("embeddings shape: %d %d %d %d\n", embeddings->ne[0], embeddings->ne[1], embeddings->ne[2], embeddings->ne[3]);
885+
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
886+
embeddings = ggml_reshape_4d(ctx0, embeddings, n_tokens_sqrt, n_tokens_sqrt, hidden_size, batch_size);
887+
embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0);
888+
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, n_tokens / 4, batch_size);
889+
printf("embeddings shape: %d %d %d %d\n", embeddings->ne[0], embeddings->ne[1], embeddings->ne[2], embeddings->ne[3]);
890+
// mlp
891+
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
892+
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
893+
894+
embeddings = ggml_gelu(ctx0, embeddings);
895+
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
896+
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
897+
printf("embeddings shape: %d %d %d %d\n", embeddings->ne[0], embeddings->ne[1], embeddings->ne[2], embeddings->ne[3]);
898+
881899
// llava projector
882900
if (ctx->has_llava_projector) {
883901
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
@@ -2758,7 +2776,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
27582776
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
27592777
free(positions_data);
27602778

2761-
if (!ctx->has_glm_projector) {
2779+
/*if (!ctx->has_glm_projector) {
27622780
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
27632781
// The patches vector is used to get rows to index into the embeds with;
27642782
// we should skip dim 0 only if we have CLS to avoid going out of bounds
@@ -2770,7 +2788,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
27702788
}
27712789
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
27722790
free(patches_data);
2773-
}
2791+
}*/
27742792
}
27752793
}
27762794

examples/llava/phi4mm-cli.cpp

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
#include "arg.h"
2+
#include "log.h"
3+
#include "common.h"
4+
#include "sampling.h"
5+
#include "clip.h"
6+
#include "stb_image.h"
7+
#include "llama.h"
8+
#include "ggml.h"
9+
10+
#include <algorithm>
11+
#include <cstdio>
12+
#include <cstdlib>
13+
#include <cstring>
14+
#include <vector>
15+
#include <iostream>
16+
#include <fstream>
17+
18+
struct phi4mm_context {
19+
struct clip_ctx * ctx_clip = NULL;
20+
common_init_result llama_init;
21+
22+
llama_model * model;
23+
llama_context * lctx;
24+
llama_adapter_lora * vision_lora;
25+
26+
phi4mm_context(common_params & params) : llama_init(common_init_from_params(params)) {
27+
model = llama_init.model.get();
28+
lctx = llama_init.context.get();
29+
vision_lora = llama_init.lora[0].get();
30+
llama_clear_adapter_lora(lctx);
31+
init_clip_model(params);
32+
}
33+
34+
void init_clip_model(common_params & params) {
35+
const char * clip_path = params.mmproj.c_str();
36+
ctx_clip = clip_model_load(clip_path, params.verbosity > 1);
37+
}
38+
39+
~phi4mm_context() {
40+
clip_free(ctx_clip);
41+
}
42+
};
43+
44+
struct decode_embd_batch {
45+
std::vector<llama_pos> pos;
46+
std::vector<int32_t> n_seq_id;
47+
std::vector<llama_seq_id> seq_id_0;
48+
std::vector<llama_seq_id *> seq_ids;
49+
std::vector<int8_t> logits;
50+
llama_batch batch;
51+
decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
52+
pos .resize(n_tokens);
53+
n_seq_id.resize(n_tokens);
54+
seq_ids .resize(n_tokens + 1);
55+
logits .resize(n_tokens);
56+
seq_id_0.resize(1);
57+
seq_id_0[0] = seq_id;
58+
seq_ids [n_tokens] = nullptr;
59+
batch = {
60+
/*n_tokens =*/ n_tokens,
61+
/*tokens =*/ nullptr,
62+
/*embd =*/ embd,
63+
/*pos =*/ pos.data(),
64+
/*n_seq_id =*/ n_seq_id.data(),
65+
/*seq_id =*/ seq_ids.data(),
66+
/*logits =*/ logits.data(),
67+
};
68+
for (int i = 0; i < n_tokens; i++) {
69+
batch.pos [i] = pos_0 + i;
70+
batch.n_seq_id[i] = 1;
71+
batch.seq_id [i] = seq_id_0.data();
72+
batch.logits [i] = false;
73+
}
74+
}
75+
};
76+
77+
struct inp_bitmap {
78+
int nx;
79+
int ny;
80+
std::vector<unsigned char> data;
81+
};
82+
83+
static void show_additional_info(int /*argc*/, char ** argv) {
84+
GGML_UNUSED(argv);
85+
LOG("TODO\n");
86+
}
87+
88+
static void eval_text(phi4mm_context & ctx, int & n_past, std::string input, bool logits_last = false) {
89+
llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
90+
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
91+
for (llama_token & t : tokens) {
92+
common_batch_add(batch, t, n_past++, {0}, false);
93+
}
94+
if (logits_last) {
95+
batch.logits[batch.n_tokens - 1] = true;
96+
}
97+
LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
98+
if (llama_decode(ctx.lctx, batch)) {
99+
GGML_ABORT("Failed to decode\n");
100+
}
101+
}
102+
103+
int main(int argc, char ** argv) {
104+
ggml_time_init();
105+
106+
common_params params;
107+
108+
// default values
109+
params.prompt = "<|user|>$what did you see?<|end|><|assistant|>";
110+
params.n_predict = 64;
111+
params.sampling.temp = 0.0f;
112+
113+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) {
114+
return 1;
115+
}
116+
117+
common_init();
118+
119+
if (params.mmproj.empty() || (params.image.empty())) {
120+
show_additional_info(argc, argv);
121+
return 1;
122+
}
123+
124+
if (params.lora_adapters.empty()) {
125+
LOG_ERR("error: no vision lora adapters specified\n");
126+
return 1;
127+
}
128+
129+
phi4mm_context ctx(params);
130+
printf("%s: %s\n", __func__, params.model.c_str());
131+
132+
int n_threads = params.cpuparams.n_threads;
133+
int n_past = 0;
134+
135+
std::vector<std::string> prompt_parts = string_split<std::string>(params.prompt, '$');
136+
GGML_ASSERT(prompt_parts.size() == 2);
137+
eval_text(ctx, n_past, prompt_parts[0], false);
138+
139+
// process images
140+
for (auto & image : params.image) {
141+
//break;
142+
std::vector<float> image_embd_v;
143+
int n_embd = llama_model_n_embd(ctx.model);
144+
int n_tokens = 256;
145+
image_embd_v.resize(n_tokens * n_embd);
146+
147+
bool ok;
148+
struct clip_image_u8 * img_u8 = clip_image_u8_init();
149+
ok = clip_image_load_from_file(image.c_str(), img_u8);
150+
if (!ok) {
151+
LOG_ERR("Unable to load image %s\n", image.c_str());
152+
return 1;
153+
}
154+
155+
clip_image_f32_batch batch_f32;
156+
ok = clip_image_preprocess(ctx.ctx_clip, img_u8, &batch_f32);
157+
if (!ok) {
158+
LOG_ERR("Unable to preprocess image\n");
159+
return 1;
160+
}
161+
162+
LOG("Encoding image %s\n", image.c_str());
163+
ok = clip_image_batch_encode(ctx.ctx_clip, n_threads, &batch_f32, image_embd_v.data());
164+
if (!ok) {
165+
LOG_ERR("Unable to encode image\n");
166+
return 1;
167+
}
168+
169+
// debug
170+
// for (int i = 0; i < 10; i++) {
171+
// LOG("embd[%d] = %f, %f, %f\n", i, image_embd_v[i*n_embd], image_embd_v[i*n_embd+1], image_embd_v[i*n_embd+2]);
172+
// }
173+
174+
clip_image_f32_batch_free(&batch_f32);
175+
clip_image_u8_free(img_u8);
176+
177+
// decode image embeddings
178+
llama_set_adapter_lora(ctx.lctx, ctx.vision_lora, 1.0f);
179+
decode_embd_batch batch_img(image_embd_v.data(), n_tokens, n_past, 0);
180+
if (llama_decode(ctx.lctx, batch_img.batch)) {
181+
LOG_ERR("failed to decode image\n");
182+
return 1;
183+
}
184+
llama_clear_adapter_lora(ctx.lctx);
185+
n_past += n_tokens;
186+
}
187+
188+
eval_text(ctx, n_past, prompt_parts[1], true);
189+
190+
// generate text
191+
struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling);
192+
const llama_vocab * vocab = llama_model_get_vocab(ctx.model);
193+
int n_prompt = n_past;
194+
llama_batch batch = llama_batch_init(1, 0, 1);
195+
while (true) {
196+
int n_generated = n_past - n_prompt;
197+
if (n_generated > params.n_predict) {
198+
printf("\n");
199+
break;
200+
}
201+
202+
llama_token token_id = common_sampler_sample(smpl, ctx.lctx, -1);
203+
common_sampler_accept(smpl, token_id, true);
204+
printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
205+
fflush(stdout);
206+
207+
if (llama_vocab_is_eog(vocab, token_id)) {
208+
printf("\n");
209+
break;
210+
}
211+
212+
// eval the token
213+
common_batch_clear(batch);
214+
common_batch_add(batch, token_id, n_past++, {0}, true);
215+
if (llama_decode(ctx.lctx, batch)) {
216+
LOG_ERR("failed to decode token\n");
217+
break;
218+
}
219+
}
220+
221+
llama_batch_free(batch);
222+
223+
return 0;
224+
}

0 commit comments

Comments
 (0)