Skip to content

Commit 65349f2

Browse files
tdakhranngxsonCISC
authored
model : support vision LiquidAI LFM2-VL family (#15347)
* wip lfm2 vision model * Fix conv weight * Implement dynamic resolution * Fix cuda * support LFM2-VL-450M * happy CI * Remove extra `ggml_conv` and put others into the right place Co-authored-by: Sigbjørn Skjæret <[email protected]> --------- Co-authored-by: Xuan Son Nguyen <[email protected]> Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 1fe0029 commit 65349f2

File tree

5 files changed

+171
-3
lines changed

5 files changed

+171
-3
lines changed

convert_hf_to_gguf.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8251,8 +8251,7 @@ def set_gguf_parameters(self):
82518251
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling.get("original_max_position_embeddings", 4096))
82528252

82538253

8254-
@ModelBase.register("Lfm2ForCausalLM")
8255-
@ModelBase.register("LFM2ForCausalLM")
8254+
@ModelBase.register("Lfm2ForCausalLM", "LFM2ForCausalLM")
82568255
class LFM2Model(TextModel):
82578256
model_arch = gguf.MODEL_ARCH.LFM2
82588257

@@ -8287,13 +8286,55 @@ def set_gguf_parameters(self):
82878286
self._add_feed_forward_length()
82888287

82898288
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
8289+
is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
8290+
if is_vision_tensor:
8291+
# skip vision tensors
8292+
return []
8293+
8294+
name = name.replace("language_model.", "")
8295+
82908296
# conv op requires 2d tensor
82918297
if 'conv.conv' in name:
82928298
data_torch = data_torch.squeeze(1)
82938299

82948300
return [(self.map_tensor_name(name), data_torch)]
82958301

82968302

8303+
@ModelBase.register("Lfm2VlForConditionalGeneration")
8304+
class LFM2VLModel(MmprojModel):
8305+
def __init__(self, *args, **kwargs):
8306+
super().__init__(*args, **kwargs)
8307+
assert self.hparams_vision is not None
8308+
# TODO(tarek): for dynamic resolution image_size is not specified, setting here for compatibility
8309+
self.hparams_vision["image_size"] = 256
8310+
8311+
def set_gguf_parameters(self):
8312+
super().set_gguf_parameters()
8313+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LFM2)
8314+
self.gguf_writer.add_vision_attention_layernorm_eps(self.find_vparam(["layer_norm_eps"]))
8315+
self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("downsample_factor", 2))
8316+
self.gguf_writer.add_vision_use_gelu(True)
8317+
# python notation, e.g. for vision_feature_layer == -1, we pick last layer -> vision_feature_layers_to_drop = 0
8318+
vision_feature_layers_to_drop = -(self.global_config.get("vision_feature_layer", -1) + 1)
8319+
self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys) - vision_feature_layers_to_drop)
8320+
8321+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
8322+
del bid # unused
8323+
is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
8324+
8325+
if is_vision_tensor:
8326+
# remove "model." prefix
8327+
name = name.replace("model.vision_tower.", "vision_tower.")
8328+
name = name.replace("model.multi_modal_projector.", "multi_modal_projector.")
8329+
8330+
if "patch_embedding.weight" in name:
8331+
data_torch = data_torch.view(data_torch.shape[0], 16, 16, 3).permute(0, 3, 1, 2)
8332+
8333+
return [(self.map_tensor_name(name), data_torch)]
8334+
8335+
return [] # skip other tensors
8336+
8337+
82978338
@ModelBase.register("SmallThinkerForCausalLM")
82988339
class SmallThinkerModel(TextModel):
82998340
model_arch = gguf.MODEL_ARCH.SMALLTHINKER

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2832,6 +2832,7 @@ class VisionProjectorType:
28322832
QWEN2A = "qwen2a" # audio
28332833
QWEN25O = "qwen2.5o" # omni
28342834
VOXTRAL = "voxtral"
2835+
LFM2 = "lfm2"
28352836

28362837

28372838
# Items here are (block size, type size)

gguf-py/gguf/tensor_mapping.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,6 +1272,7 @@ class TensorNameMap:
12721272

12731273
MODEL_TENSOR.V_MM_INP_NORM: (
12741274
"multi_modal_projector.norm",
1275+
"multi_modal_projector.layer_norm",
12751276
"pre_mm_projector_norm",
12761277
),
12771278

tools/mtmd/clip-impl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
8383
#define TN_IMAGE_NEWLINE "model.image_newline"
8484
#define TN_MM_INP_NORM "mm.input_norm.weight"
85+
#define TN_MM_INP_NORM_B "mm.input_norm.bias"
8586
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
8687
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
8788
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
@@ -133,6 +134,7 @@ enum projector_type {
133134
PROJECTOR_TYPE_QWEN2A,
134135
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
135136
PROJECTOR_TYPE_VOXTRAL,
137+
PROJECTOR_TYPE_LFM2,
136138
PROJECTOR_TYPE_UNKNOWN,
137139
};
138140

@@ -153,6 +155,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
153155
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
154156
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
155157
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
158+
{ PROJECTOR_TYPE_LFM2, "lfm2"},
156159
};
157160

158161
static projector_type clip_projector_type_from_string(const std::string & str) {

tools/mtmd/clip.cpp

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ struct clip_model {
265265

266266
// LLaVA projection
267267
ggml_tensor * mm_input_norm_w = nullptr;
268+
ggml_tensor * mm_input_norm_b = nullptr;
268269
ggml_tensor * mm_0_w = nullptr;
269270
ggml_tensor * mm_0_b = nullptr;
270271
ggml_tensor * mm_2_w = nullptr;
@@ -488,11 +489,17 @@ struct clip_graph {
488489

489490
ggml_cgraph * build_siglip() {
490491
ggml_tensor * inp = build_inp();
492+
493+
ggml_tensor * learned_pos_embd = model.position_embeddings;
494+
if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
495+
learned_pos_embd = resize_position_embeddings();
496+
}
497+
491498
ggml_tensor * cur = build_vit(
492499
inp, n_patches,
493500
NORM_TYPE_NORMAL,
494501
hparams.ffn_op,
495-
model.position_embeddings,
502+
learned_pos_embd,
496503
nullptr);
497504

498505
if (ctx->proj_type() == PROJECTOR_TYPE_GEMMA3) {
@@ -542,6 +549,45 @@ struct clip_graph {
542549
bsz);
543550

544551
cur = ggml_mul_mat(ctx0, model.projection, cur);
552+
} else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
553+
// pixel unshuffle block
554+
const int scale_factor = model.hparams.proj_scale_factor;
555+
GGML_ASSERT(scale_factor > 1);
556+
557+
const int n_embd = cur->ne[0];
558+
int width = img.nx / patch_size;
559+
int height = img.ny / patch_size;
560+
561+
// pad width and height to factor
562+
const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width;
563+
const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height;
564+
cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height);
565+
if (pad_width || pad_height) {
566+
cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0);
567+
width += pad_width;
568+
height += pad_height;
569+
}
570+
571+
// unshuffle h
572+
cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height);
573+
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3));
574+
575+
// unshuffle w
576+
cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor);
577+
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3));
578+
579+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
580+
581+
// projection
582+
cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
583+
cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
584+
cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
585+
586+
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
587+
cur = ggml_add(ctx0, cur, model.mm_1_b);
588+
cur = ggml_gelu(ctx0, cur);
589+
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
590+
cur = ggml_add(ctx0, cur, model.mm_2_b);
545591
} else {
546592
GGML_ABORT("SigLIP: Unsupported projector type");
547593
}
@@ -1560,6 +1606,27 @@ struct clip_graph {
15601606
}
15611607
}
15621608

1609+
// siglip2 naflex
1610+
ggml_tensor * resize_position_embeddings() {
1611+
ggml_tensor * pos_embd = model.position_embeddings;
1612+
const int height = img.ny / patch_size;
1613+
const int width = img.nx / patch_size;
1614+
1615+
if (!pos_embd || height * width == pos_embd->ne[1]) {
1616+
return pos_embd;
1617+
}
1618+
1619+
const int n_pos_embd = std::sqrt(pos_embd->ne[1]);
1620+
pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_pos_embd, n_pos_embd); // -> (n_embd, n_pos_embd, n_pos_embd)
1621+
pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3); // -> (n_pos_embd, n_pos_embd, n_embd)
1622+
pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, 1); // -> (width, height, n_embd)
1623+
pos_embd = ggml_reshape_2d(ctx0, pos_embd, height * width, n_embd); // -> (height * width, n_embd)
1624+
pos_embd = ggml_transpose(ctx0, pos_embd); // -> (n_embd, height * width)
1625+
pos_embd = ggml_cont(ctx0, pos_embd);
1626+
1627+
return pos_embd;
1628+
}
1629+
15631630
// build vision transformer (ViT) cgraph
15641631
// this function should cover most of the models
15651632
// if your model has specific features, you should probably duplicate this function
@@ -1966,6 +2033,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
19662033
switch (ctx->proj_type()) {
19672034
case PROJECTOR_TYPE_GEMMA3:
19682035
case PROJECTOR_TYPE_IDEFICS3:
2036+
case PROJECTOR_TYPE_LFM2:
19692037
{
19702038
res = graph.build_siglip();
19712039
} break;
@@ -2230,6 +2298,7 @@ struct clip_model_loader {
22302298
}
22312299
} break;
22322300
case PROJECTOR_TYPE_IDEFICS3:
2301+
case PROJECTOR_TYPE_LFM2:
22332302
case PROJECTOR_TYPE_INTERNVL:
22342303
{
22352304
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
@@ -2533,6 +2602,15 @@ struct clip_model_loader {
25332602
{
25342603
model.projection = get_tensor(TN_MM_PROJECTOR);
25352604
} break;
2605+
case PROJECTOR_TYPE_LFM2:
2606+
{
2607+
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
2608+
model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
2609+
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
2610+
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
2611+
model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
2612+
model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
2613+
} break;
25362614
case PROJECTOR_TYPE_PIXTRAL:
25372615
{
25382616
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
@@ -3428,6 +3506,43 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
34283506
res_imgs->grid_y = inst.grid_size.height;
34293507
return true;
34303508

3509+
} else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
3510+
GGML_ASSERT(params.proj_scale_factor);
3511+
3512+
// smart resize
3513+
const int width = img->nx;
3514+
const int height = img->ny;
3515+
const int total_factor = params.patch_size * params.proj_scale_factor;
3516+
constexpr int min_image_tokens = 64;
3517+
constexpr int max_image_tokens = 256;
3518+
const float min_pixels = min_image_tokens * total_factor * total_factor;
3519+
const float max_pixels = max_image_tokens * total_factor * total_factor;
3520+
3521+
auto round_by_factor = [f = total_factor](float x) { return static_cast<int>(std::nearbyintf(x / static_cast<float>(f))) * f; };
3522+
auto ceil_by_factor = [f = total_factor](float x) { return static_cast<int>(std::ceil(x / static_cast<float>(f))) * f; };
3523+
auto floor_by_factor = [f = total_factor](float x) { return static_cast<int>(std::floor(x / static_cast<float>(f))) * f; };
3524+
3525+
int h_bar = std::max(total_factor, round_by_factor(height));
3526+
int w_bar = std::max(total_factor, round_by_factor(width));
3527+
3528+
if (h_bar * w_bar > max_pixels) {
3529+
const auto beta = std::sqrt((height * width) / max_pixels);
3530+
h_bar = std::max(total_factor, floor_by_factor(height / beta));
3531+
w_bar = std::max(total_factor, floor_by_factor(width / beta));
3532+
} else if (h_bar * w_bar < min_pixels) {
3533+
const auto beta = std::sqrt(min_pixels / (height * width));
3534+
h_bar = ceil_by_factor(height * beta);
3535+
w_bar = ceil_by_factor(width * beta);
3536+
}
3537+
3538+
const std::array<uint8_t, 3> pad_color = {122, 116, 104};
3539+
3540+
clip_image_u8 resized_img;
3541+
image_manipulation::resize_and_pad_image(*img, resized_img, clip_image_size{w_bar, h_bar}, pad_color);
3542+
clip_image_f32_ptr res(clip_image_f32_init());
3543+
normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std);
3544+
res_imgs->entries.push_back(std::move(res));
3545+
return true;
34313546
}
34323547

34333548
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
@@ -3630,6 +3745,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
36303745
n_patches_sq /= 2;
36313746
}
36323747
} break;
3748+
case PROJECTOR_TYPE_LFM2:
3749+
{
3750+
n_patches_sq = (img->nx / (params.patch_size * params.proj_scale_factor)) * (img->ny / (params.patch_size * params.proj_scale_factor));
3751+
} break;
36333752
default:
36343753
GGML_ABORT("unsupported projector type");
36353754
}
@@ -4034,6 +4153,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
40344153
case PROJECTOR_TYPE_INTERNVL:
40354154
case PROJECTOR_TYPE_QWEN2A:
40364155
case PROJECTOR_TYPE_ULTRAVOX:
4156+
case PROJECTOR_TYPE_LFM2:
40374157
case PROJECTOR_TYPE_VOXTRAL:
40384158
{
40394159
// do nothing
@@ -4135,6 +4255,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
41354255
return ctx->model.mm_model_proj->ne[1];
41364256
case PROJECTOR_TYPE_QWEN2A:
41374257
return ctx->model.mm_fc_w->ne[1];
4258+
case PROJECTOR_TYPE_LFM2:
4259+
return ctx->model.mm_2_w->ne[1];
41384260
default:
41394261
GGML_ABORT("Unknown projector type");
41404262
}

0 commit comments

Comments
 (0)