Skip to content

Commit 6a3b05a

Browse files
HimariOggerganov
andcommitted
llama : add Qwen2VL support + multimodal RoPE (llama/10361)
* Barebone Qwen2VL LLM convertor * Add Qwen2VL cli entrypoint * [WIP] add qwen2vl arch * Verify m-rope output * Add vl-rope/2d-rope support for qwen2vl ViT * update qwen2vl cli tool * update 5D tensor op workaround * [WIP] qwen2vl vision model * make batch and clip utils compatible with qwen2vl * [WIP] create inference workflow, gguf convert script but fix * correcting vision-rope behavior, add the missing last layer back to ViT * add arg parser to qwen2vl_surgery * replace variable size array with vector * cuda-gdb cmake preset * add fp32 mrope, vision rope kernel * add fp16 support for qwen2vl and m-rope * add `GGML_ROPE_TYPE_MROPE`, `GGML_ROPE_TYPE_VISION` * fix rope op mode switching, out dated func args * update `llama_hparams` * update to keep up stream changes * resolve linter, test errors * add makefile entry, update speical image padding token * add mrope unit test, fix few compiler warnings * rename `mrope` related function, params * minor updates on debug util, bug fixs * add `m-rope` testcase to `test-backend-ops` * Apply suggestions from code review Co-authored-by: Georgi Gerganov <[email protected]> * fix traililng whitespce * store `llama_hparams.rope_sections` with fixed size array * update position id tensor size check in GGML_OP_ROPE * minor updates * update `ggml_backend_*_supports_op` of unsupported backends * remote old `rope_section` compare operator --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 5538f04 commit 6a3b05a

File tree

9 files changed

+564
-42
lines changed

9 files changed

+564
-42
lines changed

ggml/include/ggml.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,9 @@
237237
#define GGML_EXIT_SUCCESS 0
238238
#define GGML_EXIT_ABORTED 1
239239

240-
#define GGML_ROPE_TYPE_NEOX 2
240+
#define GGML_ROPE_TYPE_NEOX 2
241+
#define GGML_ROPE_TYPE_MROPE 8
242+
#define GGML_ROPE_TYPE_VISION 24
241243

242244
#define GGUF_MAGIC "GGUF"
243245

@@ -1443,6 +1445,22 @@ extern "C" {
14431445
float beta_fast,
14441446
float beta_slow);
14451447

1448+
GGML_API struct ggml_tensor * ggml_rope_multi(
1449+
struct ggml_context * ctx,
1450+
struct ggml_tensor * a,
1451+
struct ggml_tensor * b,
1452+
struct ggml_tensor * c,
1453+
int n_dims,
1454+
int sections[4],
1455+
int mode,
1456+
int n_ctx_orig,
1457+
float freq_base,
1458+
float freq_scale,
1459+
float ext_factor,
1460+
float attn_factor,
1461+
float beta_fast,
1462+
float beta_slow);
1463+
14461464
// in-place, returns view(a)
14471465
GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
14481466
struct ggml_context * ctx,

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,6 +1747,15 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
17471747
if (*ext_factor != 0) {
17481748
return false;
17491749
}
1750+
1751+
const int mode = ((const int32_t *) op->op_params)[2];
1752+
if (mode & GGML_ROPE_TYPE_MROPE) {
1753+
return false;
1754+
}
1755+
if (mode & GGML_ROPE_TYPE_VISION) {
1756+
return false;
1757+
}
1758+
17501759
return true;
17511760
}
17521761
case GGML_OP_UPSCALE: {

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 211 additions & 33 deletions
Large diffs are not rendered by default.

ggml/src/ggml-cuda/rope.cu

Lines changed: 229 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ struct rope_corr_dims {
44
float v[2];
55
};
66

7+
8+
struct mrope_sections {
9+
int v[4];
10+
};
11+
712
static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
813
const float y = (i0 / 2 - low) / max(0.001f, high - low);
914
return 1.0f - min(1.0f, max(0.0f, y));
@@ -108,6 +113,105 @@ static __global__ void rope_neox(
108113
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
109114
}
110115

116+
template<typename T, bool has_ff>
117+
static __global__ void rope_multi(
118+
const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
119+
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
120+
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
121+
122+
if (i0 >= ne0) {
123+
return;
124+
}
125+
126+
const int row = blockDim.x*blockIdx.x + threadIdx.x;
127+
128+
if (i0 >= n_dims) {
129+
const int i = row*ne0 + i0;
130+
131+
dst[i + 0] = x[i + 0];
132+
dst[i + 1] = x[i + 1];
133+
134+
return;
135+
}
136+
137+
const int i = row*ne0 + i0/2;
138+
const int i2 = row/p_delta_rows;
139+
140+
int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
141+
int sec_w = sections.v[1] + sections.v[0];
142+
int sector = (i0 / 2) % sect_dims;
143+
144+
float theta_base = 0.0;
145+
if (sector < sections.v[0]) {
146+
theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
147+
}
148+
else if (sector >= sections.v[0] && sector < sec_w) {
149+
theta_base = pos[i2 + ne2 * 1]*powf(theta_scale, i0/2.0f);
150+
}
151+
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
152+
theta_base = pos[i2 + ne2 * 2]*powf(theta_scale, i0/2.0f);
153+
}
154+
else if (sector >= sec_w + sections.v[2]) {
155+
theta_base = pos[i2 + ne2 * 3]*powf(theta_scale, i0/2.0f);
156+
}
157+
158+
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
159+
160+
float cos_theta;
161+
float sin_theta;
162+
163+
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
164+
165+
const float x0 = x[i + 0];
166+
const float x1 = x[i + n_dims/2];
167+
168+
dst[i + 0] = x0*cos_theta - x1*sin_theta;
169+
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
170+
}
171+
172+
template<typename T, bool has_ff>
173+
static __global__ void rope_vision(
174+
const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
175+
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
176+
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
177+
178+
if (i0 >= ne0) {
179+
return;
180+
}
181+
182+
const int row = blockDim.x*blockIdx.x + threadIdx.x;
183+
184+
const int i = row*ne0 + i0/2;
185+
const int i2 = row/p_delta_rows; // i2-th tokens
186+
187+
int sect_dims = sections.v[0] + sections.v[1];
188+
int sec_w = sections.v[1] + sections.v[0];
189+
int sector = (i0 / 2) % sect_dims;
190+
191+
float theta_base = 0.0;
192+
if (sector < sections.v[0]) {
193+
const int p = sector;
194+
theta_base = pos[i2]*powf(theta_scale, p);
195+
}
196+
else if (sector >= sections.v[0] && sector < sec_w) {
197+
const int p = sector - sections.v[0];
198+
theta_base = pos[i2 + ne2]*powf(theta_scale, p);
199+
}
200+
201+
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
202+
203+
float cos_theta;
204+
float sin_theta;
205+
206+
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
207+
208+
const float x0 = x[i + 0];
209+
const float x1 = x[i + n_dims];
210+
211+
dst[i + 0] = x0*cos_theta - x1*sin_theta;
212+
dst[i + n_dims] = x0*sin_theta + x1*cos_theta;
213+
}
214+
111215
template<typename T>
112216
static void rope_norm_cuda(
113217
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
@@ -156,6 +260,56 @@ static void rope_neox_cuda(
156260
}
157261
}
158262

263+
template<typename T>
264+
static void rope_multi_cuda(
265+
const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
266+
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
267+
GGML_ASSERT(ne0 % 2 == 0);
268+
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
269+
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
270+
const dim3 block_nums(nr, n_blocks_x, 1);
271+
272+
const float theta_scale = powf(freq_base, -2.0f/n_dims);
273+
274+
if (freq_factors == nullptr) {
275+
rope_multi<T, false><<<block_nums, block_dims, 0, stream>>>(
276+
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
277+
theta_scale, freq_factors, sections
278+
);
279+
} else {
280+
rope_multi<T, true><<<block_nums, block_dims, 0, stream>>>(
281+
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
282+
theta_scale, freq_factors, sections
283+
);
284+
}
285+
}
286+
287+
template<typename T>
288+
static void rope_vision_cuda(
289+
const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
290+
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
291+
GGML_ASSERT(ne0 % 2 == 0);
292+
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
293+
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
294+
const dim3 block_nums(nr, n_blocks_x, 1);
295+
// break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
296+
// where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
297+
298+
const float theta_scale = powf(freq_base, -2.0f/n_dims);
299+
300+
if (freq_factors == nullptr) {
301+
rope_vision<T, false><<<block_nums, block_dims, 0, stream>>>(
302+
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
303+
theta_scale, freq_factors, sections
304+
);
305+
} else {
306+
rope_vision<T, true><<<block_nums, block_dims, 0, stream>>>(
307+
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
308+
theta_scale, freq_factors, sections
309+
);
310+
}
311+
}
312+
159313
static void rope_norm_cuda_f16(
160314
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
161315
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
@@ -185,6 +339,38 @@ static void rope_neox_cuda_f32(
185339
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
186340
}
187341

342+
static void rope_multi_cuda_f16(
343+
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
344+
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
345+
) {
346+
347+
rope_multi_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
348+
}
349+
350+
static void rope_multi_cuda_f32(
351+
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
352+
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
353+
) {
354+
355+
rope_multi_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
356+
}
357+
358+
static void rope_vision_cuda_f16(
359+
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
360+
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
361+
) {
362+
363+
rope_vision_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
364+
}
365+
366+
static void rope_vision_cuda_f32(
367+
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
368+
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
369+
) {
370+
371+
rope_vision_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
372+
}
373+
188374
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
189375
const ggml_tensor * src0 = dst->src[0];
190376
const ggml_tensor * src1 = dst->src[1];
@@ -201,15 +387,17 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
201387
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
202388
GGML_ASSERT(src0->type == dst->type);
203389

204-
const int64_t ne00 = src0->ne[0];
205-
const int64_t ne01 = src0->ne[1];
390+
const int64_t ne00 = src0->ne[0]; // head dims
391+
const int64_t ne01 = src0->ne[1]; // num heads
392+
const int64_t ne02 = src0->ne[2]; // num heads
206393
const int64_t nr = ggml_nrows(src0);
207394

208395
//const int n_past = ((int32_t *) dst->op_params)[0];
209396
const int n_dims = ((int32_t *) dst->op_params)[1];
210397
const int mode = ((int32_t *) dst->op_params)[2];
211398
//const int n_ctx = ((int32_t *) dst->op_params)[3];
212399
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
400+
mrope_sections sections;
213401

214402
// RoPE alteration for extended context
215403
float freq_base;
@@ -225,8 +413,19 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
225413
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
226414
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
227415
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
416+
memcpy(&sections.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
228417

229418
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
419+
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
420+
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
421+
422+
if (is_mrope) {
423+
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
424+
}
425+
426+
if (is_vision) {
427+
GGML_ASSERT(n_dims == ne00/2);
428+
}
230429

231430
const int32_t * pos = (const int32_t *) src1_d;
232431

@@ -253,6 +452,34 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
253452
} else {
254453
GGML_ABORT("fatal error");
255454
}
455+
} else if (is_mrope && !is_vision) {
456+
if (src0->type == GGML_TYPE_F32) {
457+
rope_multi_cuda_f32(
458+
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
459+
attn_factor, corr_dims, freq_factors, sections, stream
460+
);
461+
} else if (src0->type == GGML_TYPE_F16) {
462+
rope_multi_cuda_f16(
463+
(const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
464+
attn_factor, corr_dims, freq_factors, sections, stream
465+
);
466+
} else {
467+
GGML_ABORT("fatal error");
468+
}
469+
} else if (is_vision) {
470+
if (src0->type == GGML_TYPE_F32) {
471+
rope_vision_cuda_f32(
472+
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
473+
attn_factor, corr_dims, freq_factors, sections, stream
474+
);
475+
} else if (src0->type == GGML_TYPE_F16) {
476+
rope_vision_cuda_f16(
477+
(const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
478+
attn_factor, corr_dims, freq_factors, sections, stream
479+
);
480+
} else {
481+
GGML_ABORT("fatal error");
482+
}
256483
} else {
257484
if (src0->type == GGML_TYPE_F32) {
258485
rope_norm_cuda_f32(

ggml/src/ggml-kompute/ggml-kompute.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1419,8 +1419,18 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
14191419
case GGML_OP_SOFT_MAX:
14201420
case GGML_OP_RMS_NORM:
14211421
case GGML_OP_NORM:
1422-
case GGML_OP_ROPE:
14231422
return true;
1423+
case GGML_OP_ROPE:
1424+
{
1425+
const int mode = ((const int32_t *) op->op_params)[2];
1426+
if (mode & GGML_ROPE_TYPE_MROPE) {
1427+
return false;
1428+
}
1429+
if (mode & GGML_ROPE_TYPE_VISION) {
1430+
return false;
1431+
}
1432+
return true;
1433+
}
14241434
case GGML_OP_DUP:
14251435
case GGML_OP_CPY:
14261436
case GGML_OP_CONT:

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,8 +1125,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
11251125
return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
11261126
case GGML_OP_ARGMAX:
11271127
case GGML_OP_NORM:
1128-
case GGML_OP_ROPE:
11291128
return true;
1129+
case GGML_OP_ROPE:
1130+
{
1131+
const int mode = ((const int32_t *) op->op_params)[2];
1132+
if (mode & GGML_ROPE_TYPE_MROPE) {
1133+
return false;
1134+
}
1135+
if (mode & GGML_ROPE_TYPE_VISION) {
1136+
return false;
1137+
}
1138+
return true;
1139+
}
11301140
case GGML_OP_IM2COL:
11311141
return op->src[0]->type == GGML_TYPE_F16;
11321142
case GGML_OP_POOL_1D:
@@ -3026,7 +3036,9 @@ static void ggml_metal_encode_node(
30263036
} break;
30273037
case GGML_OP_ROPE:
30283038
{
3029-
GGML_ASSERT(ne10 == ne02);
3039+
// make sure we have one or more position id(ne10) per token(ne02)
3040+
GGML_ASSERT(ne10 % ne02 == 0);
3041+
GGML_ASSERT(ne10 >= ne02);
30303042

30313043
const int nth = MIN(1024, ne00);
30323044

0 commit comments

Comments
 (0)