Skip to content

Commit 8c88371

Browse files
authored
Merge branch 'ggml-org:master' into tmp
2 parents d256aa0 + 401af80 commit 8c88371

File tree

14 files changed

+585
-16
lines changed

14 files changed

+585
-16
lines changed
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Granite Vision
2+
3+
Download the model and point your `GRANITE_MODEL` environment variable to the path.
4+
5+
```bash
6+
$ git clone https://huggingface.co/ibm-granite/granite-vision-3.1-2b-preview
7+
$ export GRANITE_MODEL=./granite-vision-3.1-2b-preview
8+
```
9+
10+
11+
### 1. Running llava surgery v2.
12+
First, we need to run the llava surgery script as shown below:
13+
14+
`python llava_surgery_v2.py -C -m $GRANITE_MODEL`
15+
16+
You should see two new files (`llava.clip` and `llava.projector`) written into your model's directory, as shown below.
17+
18+
```bash
19+
$ ls $GRANITE_MODEL | grep -i llava
20+
llava.clip
21+
llava.projector
22+
```
23+
24+
We should see that the projector and visual encoder get split out into the llava files. Quick check to make sure they aren't empty:
25+
```python
26+
import os
27+
import torch
28+
29+
MODEL_PATH = os.getenv("GRANITE_MODEL")
30+
if not MODEL_PATH:
31+
raise ValueError("env var GRANITE_MODEL is unset!")
32+
33+
encoder_tensors = torch.load(os.path.join(MODEL_PATH, "llava.clip"))
34+
projector_tensors = torch.load(os.path.join(MODEL_PATH, "llava.projector"))
35+
36+
assert len(encoder_tensors) > 0
37+
assert len(projector_tensors) > 0
38+
```
39+
40+
If you actually inspect the `.keys()` of the loaded tensors, you should see a lot of `vision_model` tensors in the `encoder_tensors`, and 5 tensors (`'multi_modal_projector.linear_1.bias'`, `'multi_modal_projector.linear_1.weight'`, `'multi_modal_projector.linear_2.bias'`, `'multi_modal_projector.linear_2.weight'`, `'image_newline'`) in the multimodal `projector_tensors`.
41+
42+
43+
### 2. Creating the Visual Component GGUF
44+
To create the GGUF for the visual components, we need to write a config for the visual encoder; make sure the config contains the correct `image_grid_pinpoints`
45+
46+
47+
Note: we refer to this file as `$VISION_CONFIG` later on.
48+
```json
49+
{
50+
"_name_or_path": "siglip-model",
51+
"architectures": [
52+
"SiglipVisionModel"
53+
],
54+
"image_grid_pinpoints": [
55+
[384,768],
56+
[384,1152],
57+
[384,1536],
58+
[384,1920],
59+
[384,2304],
60+
[384,2688],
61+
[384,3072],
62+
[384,3456],
63+
[384,3840],
64+
[768,384],
65+
[768,768],
66+
[768,1152],
67+
[768,1536],
68+
[768,1920],
69+
[1152,384],
70+
[1152,768],
71+
[1152,1152],
72+
[1536,384],
73+
[1536,768],
74+
[1920,384],
75+
[1920,768],
76+
[2304,384],
77+
[2688,384],
78+
[3072,384],
79+
[3456,384],
80+
[3840,384]
81+
],
82+
"mm_patch_merge_type": "spatial_unpad",
83+
"hidden_size": 1152,
84+
"image_size": 384,
85+
"intermediate_size": 4304,
86+
"model_type": "siglip_vision_model",
87+
"num_attention_heads": 16,
88+
"num_hidden_layers": 27,
89+
"patch_size": 14,
90+
"layer_norm_eps": 1e-6,
91+
"hidden_act": "gelu_pytorch_tanh",
92+
"projection_dim": 0,
93+
"vision_feature_layer": [-24, -20, -12, -1]
94+
}
95+
```
96+
97+
Create a new directory to hold the visual components, and copy the llava.clip/projector files, as well as the vision config into it.
98+
99+
```bash
100+
$ ENCODER_PATH=$PWD/visual_encoder
101+
$ mkdir $ENCODER_PATH
102+
103+
$ cp $GRANITE_MODEL/llava.clip $ENCODER_PATH/pytorch_model.bin
104+
$ cp $GRANITE_MODEL/llava.projector $ENCODER_PATH/
105+
$ cp $VISION_CONFIG $ENCODER_PATH/config.json
106+
```
107+
108+
At which point you should have something like this:
109+
```bash
110+
$ ls $ENCODER_PATH
111+
config.json llava.projector pytorch_model.bin
112+
```
113+
114+
Now convert the components to GGUF; Note that we also override the image mean/std dev to `[.5,.5,.5]` since we use the siglip visual encoder - in the transformers model, you can find these numbers in the [preprocessor_config.json](https://huggingface.co/ibm-granite/granite-vision-3.1-2b-preview/blob/main/preprocessor_config.json).
115+
```bash
116+
$ python convert_image_encoder_to_gguf.py \
117+
-m $ENCODER_PATH \
118+
--llava-projector $ENCODER_PATH/llava.projector \
119+
--output-dir $ENCODER_PATH \
120+
--clip-model-is-vision \
121+
--clip-model-is-siglip \
122+
--image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5
123+
```
124+
125+
this will create the first GGUF file at `$ENCODER_PATH/mmproj-model-f16.gguf`; we will refer to the abs path of this file as the `$VISUAL_GGUF_PATH.`
126+
127+
128+
### 3. Creating the LLM GGUF.
129+
The granite vision model contains a granite LLM as its language model. For now, the easiest way to get the GGUF for LLM is by loading the composite model in `transformers` and exporting the LLM so that it can be directly converted with the normal conversion path.
130+
131+
First, set the `LLM_EXPORT_PATH` to the path to export the `transformers` LLM to.
132+
```
133+
$ export LLM_EXPORT_PATH=$PWD/granite_vision_llm
134+
```
135+
136+
```python
137+
import os
138+
import transformers
139+
140+
MODEL_PATH = os.getenv("GRANITE_MODEL")
141+
if not MODEL_PATH:
142+
raise ValueError("env var GRANITE_MODEL is unset!")
143+
144+
LLM_EXPORT_PATH = os.getenv("LLM_EXPORT_PATH")
145+
if not MODEL_PATH:
146+
raise ValueError("env var LLM_EXPORT_PATH is unset!")
147+
148+
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_PATH)
149+
150+
# NOTE: granite vision support was added to transformers very recently (4.49);
151+
# if you get size mismatches, your version is too old.
152+
# If you are running with an older version, set `ignore_mismatched_sizes=True`
153+
# as shown below; it won't be loaded correctly, but the LLM part of the model that
154+
# we are exporting will be loaded correctly.
155+
model = transformers.AutoModelForImageTextToText.from_pretrained(MODEL_PATH, ignore_mismatched_sizes=True)
156+
157+
tokenizer.save_pretrained(LLM_EXPORT_PATH)
158+
model.language_model.save_pretrained(LLM_EXPORT_PATH)
159+
```
160+
161+
Now you can convert the exported LLM to GGUF with the normal converter in the root of the llama cpp project.
162+
```bash
163+
$ LLM_GGUF_PATH=$LLM_EXPORT_PATH/granite_llm.gguf
164+
...
165+
$ python convert_hf_to_gguf.py --outfile $LLM_GGUF_PATH $LLM_EXPORT_PATH
166+
```
167+
168+
169+
### 4. Running the Model in Llama cpp
170+
Build llama cpp normally; you should have a target binary named `llama-llava-cli`, which you can pass two binaries to. Sample usage:
171+
172+
Note - the test image shown below can be found [here](https://github-production-user-asset-6210df.s3.amazonaws.com/10740300/415512792-d90d5562-8844-4f34-a0a5-77f62d5a58b5.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20250221%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250221T054145Z&X-Amz-Expires=300&X-Amz-Signature=86c60be490aa49ef7d53f25d6c973580a8273904fed11ed2453d0a38240ee40a&X-Amz-SignedHeaders=host).
173+
174+
```bash
175+
$ ./build/bin/llama-llava-cli -m $LLM_GGUF_PATH \
176+
--mmproj $VISUAL_GGUF_PATH \
177+
--image cherry_blossom.jpg \
178+
-c 16384 \
179+
-p "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|user|>\n\<image>\nWhat type of flowers are in this picture?\n<|assistant|>\n" \
180+
--temp 0
181+
```
182+
183+
Sample response: `The flowers in the picture are cherry blossoms, which are known for their delicate pink petals and are often associated with the beauty of spring.`

examples/server/utils.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,8 +521,13 @@ static json oaicompat_completion_params_parse(const json & body) {
521521
throw std::runtime_error("Only one completion choice is allowed");
522522
}
523523

524+
// Handle "echo" field
525+
if (json_value(body, "echo", false)) {
526+
throw std::runtime_error("Only no echo is supported");
527+
}
528+
524529
// Params supported by OAI but unsupported by llama.cpp
525-
static const std::vector<std::string> unsupported_params { "best_of", "echo", "suffix" };
530+
static const std::vector<std::string> unsupported_params { "best_of", "suffix" };
526531
for (const auto & param : unsupported_params) {
527532
if (body.contains(param)) {
528533
throw std::runtime_error("Unsupported param: " + param);
@@ -598,7 +603,7 @@ static json oaicompat_completion_params_parse(
598603
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
599604
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
600605
inputs.grammar = grammar;
601-
inputs.add_generation_prompt = true;
606+
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
602607
inputs.use_jinja = use_jinja;
603608
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
604609
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5265,6 +5265,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
52655265

52665266
#if defined(__ARM_FEATURE_SVE)
52675267

5268+
uint32_t aux[3];
52685269
uint32_t utmp[4];
52695270

52705271
const int8_t m32 = 32;
@@ -5276,7 +5277,6 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
52765277
const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1);
52775278
const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2);
52785279
const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3);
5279-
svbool_t pred_s32 = svnot_b_z (svptrue_b32(), svptrue_pat_b32(SV_VL4));
52805280

52815281
float sum = 0;
52825282

@@ -5289,7 +5289,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
52895289
const int8_t * restrict q8_sv = y[i].qs;
52905290

52915291
// Set up scales
5292-
uint32_t * aux = &x[i].scales;
5292+
memcpy(aux, x[i].scales, 12);
52935293
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
52945294
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
52955295
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,16 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
407407
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
408408
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
409409
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
410+
GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
411+
GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,
412+
GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
413+
GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,
414+
GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
415+
GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,
416+
GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
417+
GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,
418+
GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
419+
GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
410420
GGML_METAL_KERNEL_TYPE_CONCAT,
411421
GGML_METAL_KERNEL_TYPE_SQR,
412422
GGML_METAL_KERNEL_TYPE_SQRT,
@@ -1012,6 +1022,16 @@ @implementation GGMLMetalClass
10121022
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
10131023
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
10141024
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
1025+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true);
1026+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true);
1027+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true);
1028+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true);
1029+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true);
1030+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true);
1031+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true);
1032+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true);
1033+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true);
1034+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true);
10151035
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
10161036
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
10171037
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
@@ -1287,6 +1307,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
12871307
default:
12881308
return false;
12891309
}
1310+
case GGML_TYPE_Q4_0:
1311+
case GGML_TYPE_Q4_1:
1312+
case GGML_TYPE_Q5_0:
1313+
case GGML_TYPE_Q5_1:
1314+
case GGML_TYPE_Q8_0:
1315+
switch (op->type) {
1316+
case GGML_TYPE_F32:
1317+
case GGML_TYPE_F16:
1318+
return true;
1319+
default:
1320+
return false;
1321+
}
12901322
default:
12911323
return false;
12921324
};
@@ -3899,10 +3931,6 @@ static void ggml_metal_encode_node(
38993931
case GGML_OP_CPY:
39003932
case GGML_OP_CONT:
39013933
{
3902-
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
3903-
3904-
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
3905-
39063934
id<MTLComputePipelineState> pipeline = nil;
39073935

39083936
switch (src0t) {
@@ -3936,7 +3964,47 @@ static void ggml_metal_encode_node(
39363964
switch (dstt) {
39373965
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
39383966
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
3939-
default: GGML_ASSERT(false && "not implemented");
3967+
default: GGML_ABORT("not implemented");
3968+
};
3969+
} break;
3970+
case GGML_TYPE_Q4_0:
3971+
{
3972+
switch (dstt) {
3973+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
3974+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break;
3975+
default: GGML_ABORT("not implemented");
3976+
};
3977+
} break;
3978+
case GGML_TYPE_Q4_1:
3979+
{
3980+
switch (dstt) {
3981+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
3982+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break;
3983+
default: GGML_ABORT("not implemented");
3984+
};
3985+
} break;
3986+
case GGML_TYPE_Q5_0:
3987+
{
3988+
switch (dstt) {
3989+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
3990+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break;
3991+
default: GGML_ABORT("not implemented");
3992+
};
3993+
} break;
3994+
case GGML_TYPE_Q5_1:
3995+
{
3996+
switch (dstt) {
3997+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
3998+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break;
3999+
default: GGML_ABORT("not implemented");
4000+
};
4001+
} break;
4002+
case GGML_TYPE_Q8_0:
4003+
{
4004+
switch (dstt) {
4005+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
4006+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break;
4007+
default: GGML_ABORT("not implemented");
39404008
};
39414009
} break;
39424010
default: GGML_ABORT("not implemented");
@@ -3966,7 +4034,11 @@ static void ggml_metal_encode_node(
39664034
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
39674035
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
39684036

4037+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
4038+
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
4039+
39694040
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4041+
39704042
} break;
39714043
case GGML_OP_SET:
39724044
{

0 commit comments

Comments
 (0)