Skip to content

Commit 8b1d615

Browse files
authored
Merge pull request #34 from JJJYmmm/add_qwen3vl
fix rope fail
2 parents 0537774 + 0518b0a commit 8b1d615

File tree

7 files changed

+21
-13
lines changed

7 files changed

+21
-13
lines changed

convert_hf_to_gguf.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4085,7 +4085,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
40854085
# Skip text model tensors - they go in the text model file
40864086
if name.startswith("model.language_model.") or name.startswith("lm_head."):
40874087
return []
4088-
4088+
40894089
if name.startswith("model.visual."):
40904090
name = name.replace("model.visual.", "visual.", 1)
40914091

@@ -4174,7 +4174,7 @@ def set_gguf_parameters(self):
41744174
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
41754175

41764176
logger.info(f"MRoPE sections: {mrope_section[:4]}")
4177-
4177+
41784178
vision_config = self.hparams.get("vision_config", {})
41794179
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
41804180
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
@@ -4183,7 +4183,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
41834183
# Skip vision tensors - they go in the mmproj file
41844184
if name.startswith("model.visual."):
41854185
return []
4186-
4186+
41874187
return super().modify_tensors(data_torch, name, bid)
41884188

41894189

@@ -4217,7 +4217,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
42174217
# Skip vision tensors - they go in the mmproj file
42184218
if name.startswith("model.visual."):
42194219
return []
4220-
4220+
42214221
return super().modify_tensors(data_torch, name, bid)
42224222

42234223

ggml/src/ggml-cpu/ops.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5514,6 +5514,8 @@ static void ggml_mrope_cache_init(
55145514
theta = theta_h;
55155515
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
55165516
theta = theta_w;
5517+
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
5518+
theta = theta_t;
55175519
} else {
55185520
theta = theta_e;
55195521
}
@@ -5599,7 +5601,7 @@ static void ggml_compute_forward_rope_f32(
55995601

56005602
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
56015603
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
5602-
const bool is_imrope = mode & GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
5604+
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
56035605
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
56045606

56055607
if (is_mrope) {
@@ -5786,7 +5788,7 @@ static void ggml_compute_forward_rope_f16(
57865788

57875789
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
57885790
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
5789-
const bool is_imrope = mode & GGML_ROPE_TYPE_IMROPE;
5791+
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
57905792
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
57915793

57925794
if (is_mrope) {

ggml/src/ggml-cuda/rope.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,10 @@ static __global__ void rope_multi(
157157
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
158158
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
159159
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
160-
} else { // t
160+
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
161161
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
162+
} else {
163+
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
162164
}
163165
} else {
164166
if (sector < sections.v[0]) {
@@ -379,7 +381,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
379381

380382
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
381383
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
382-
const bool is_imrope = mode & GGML_ROPE_TYPE_IMROPE;
384+
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
383385
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
384386

385387
if (is_mrope) {

gguf-py/gguf/gguf_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1073,7 +1073,7 @@ def add_vision_projector_scale_factor(self, value: int) -> None:
10731073

10741074
def add_vision_n_wa_pattern(self, value: int) -> None:
10751075
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
1076-
1076+
10771077
def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
10781078
self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
10791079

tests/test-backend-ops.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7076,6 +7076,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
70767076
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
70777077
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
70787078
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
7079+
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,imrope (qwen3vl 2B)
7080+
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,imrope (qwen3vl 7B)
7081+
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
7082+
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
70797083
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
70807084
}
70817085

@@ -7092,7 +7096,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
70927096

70937097
// single inplace test per type/mode/ff
70947098
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
7095-
for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION}) {
7099+
for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_IMROPE, GGML_ROPE_TYPE_VISION}) {
70967100
for (bool ff : {false, true}) {
70977101
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 0, true, true));
70987102
}

tests/test-rope.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ int main(int /*argc*/, const char ** /*argv*/) {
138138
struct ggml_tensor * x;
139139

140140
// rope f32
141-
for (int m = 0; m < 5; ++m) {
141+
for (int m = 0; m < 6; ++m) {
142142
const int ndims = 4;
143143

144144
const int64_t n_rot = 128;
@@ -180,7 +180,7 @@ int main(int /*argc*/, const char ** /*argv*/) {
180180
struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2] * 4);
181181

182182
int sections[4] = {16, 24, 24, 0};
183-
mode = (m == 3) ? GGML_ROPE_TYPE_MROPE : GGML_ROPE_TYPE_VISION;
183+
mode = (m == 3) ? GGML_ROPE_TYPE_MROPE : (m == 4) ? GGML_ROPE_TYPE_VISION : GGML_ROPE_TYPE_IMROPE;
184184

185185
for (int i = 0; i < ne[2]; ++i) {
186186
for (int j = 0; j < 4; ++j) {

tools/mtmd/clip.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ struct clip_graph {
986986
// residual 2
987987
cur = ggml_add(ctx0, inpL, cur);
988988
cb(cur, "layer_out", il);
989-
989+
990990
if (layer.has_deepstack()) {
991991
ggml_tensor * feat = ggml_reshape_3d(ctx0, cur, n_embd * merge_factor, n_pos / merge_factor, batch_size);
992992
feat = build_norm(feat, layer.deepstack_norm_w, layer.deepstack_norm_b, norm_t, eps, il);

0 commit comments

Comments
 (0)