Skip to content

Commit c2deea2

Browse files
Merge pull request #33 from ggml-org/master
Sync from upstream.
2 parents 321574b + 2016f07 commit c2deea2

File tree

10 files changed

+711
-296
lines changed

10 files changed

+711
-296
lines changed

convert_hf_to_gguf.py

Lines changed: 363 additions & 271 deletions
Large diffs are not rendered by default.

convert_lora_to_gguf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import gguf
2525

2626
# reuse model definitions from convert_hf_to_gguf.py
27-
from convert_hf_to_gguf import LazyTorchTensor, Model
27+
from convert_hf_to_gguf import LazyTorchTensor, ModelBase
2828

2929
logger = logging.getLogger("lora-to-gguf")
3030

@@ -340,11 +340,11 @@ def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
340340
sys.exit(1)
341341
else:
342342
logger.info(f"Loading base model: {dir_base_model.name}")
343-
hparams = Model.load_hparams(dir_base_model)
343+
hparams = ModelBase.load_hparams(dir_base_model)
344344

345345
with torch.inference_mode():
346346
try:
347-
model_class = Model.from_model_architecture(hparams["architectures"][0])
347+
model_class = ModelBase.from_model_architecture(hparams["architectures"][0])
348348
except NotImplementedError:
349349
logger.error(f"Model {hparams['architectures'][0]} is not supported")
350350
sys.exit(1)

examples/llava/clip-impl.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
// tensor name constants
5151
//
5252

53-
#define TN_TOKEN_EMBD "%s.token_embd.weight"
5453
#define TN_POS_EMBD "%s.position_embd.weight"
5554
#define TN_CLASS_EMBD "v.class_embd"
5655
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
@@ -66,8 +65,6 @@
6665
#define TN_LN_2 "%s.blk.%d.ln2.%s"
6766
#define TN_LN_PRE "%s.pre_ln.%s"
6867
#define TN_LN_POST "%s.post_ln.%s"
69-
#define TN_TEXT_PROJ "text_projection.weight"
70-
#define TN_VIS_PROJ "visual_projection.weight"
7168
#define TN_LLAVA_PROJ "mm.%d.%s"
7269
#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s"
7370
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"

examples/llava/clip.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@ struct clip_image_size {
3030
int height;
3131
};
3232

33+
struct clip_image_f32;
3334
struct clip_image_u8_batch;
3435
struct clip_image_f32_batch;
3536

3637
struct clip_context_params {
3738
bool use_gpu;
38-
ggml_log_level verbosity;
39+
enum ggml_log_level verbosity;
3940
};
4041

4142
// deprecated, use clip_init
@@ -84,7 +85,7 @@ CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
8485
CLIP_API size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size()
8586
CLIP_API size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx
8687
CLIP_API size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
87-
CLIP_API clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
88+
CLIP_API struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
8889

8990
/**
9091
* Build image from pixels decoded by other libraries instead of stb_image.h for better performance.

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
481481
GGML_METAL_KERNEL_TYPE_SQRT,
482482
GGML_METAL_KERNEL_TYPE_SIN,
483483
GGML_METAL_KERNEL_TYPE_COS,
484+
GGML_METAL_KERNEL_TYPE_NEG,
484485
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
485486
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
486487
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -1159,6 +1160,7 @@ @implementation GGMLMetalClass
11591160
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
11601161
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
11611162
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
1163+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
11621164
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
11631165
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
11641166
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
@@ -1320,6 +1322,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
13201322
case GGML_UNARY_OP_GELU_QUICK:
13211323
case GGML_UNARY_OP_SILU:
13221324
case GGML_UNARY_OP_ELU:
1325+
case GGML_UNARY_OP_NEG:
13231326
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
13241327
default:
13251328
return false;
@@ -2010,6 +2013,18 @@ static void ggml_metal_encode_node(
20102013

20112014
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
20122015
} break;
2016+
case GGML_UNARY_OP_NEG:
2017+
{
2018+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;
2019+
2020+
[encoder setComputePipelineState:pipeline];
2021+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2022+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2023+
2024+
const int64_t n = ggml_nelements(dst);
2025+
2026+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2027+
} break;
20132028
default:
20142029
{
20152030
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -949,6 +949,13 @@ kernel void kernel_cos(
949949
dst[tpig] = cos(src0[tpig]);
950950
}
951951

952+
kernel void kernel_neg(
953+
device const float * src0,
954+
device float * dst,
955+
uint tpig[[thread_position_in_grid]]) {
956+
dst[tpig] = -src0[tpig];
957+
}
958+
952959
kernel void kernel_sum_rows(
953960
device const float * src0,
954961
device float * dst,

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2397,7 +2397,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
23972397

23982398
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
23992399
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2400-
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2400+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
24012401
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
24022402
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
24032403

@@ -6006,6 +6006,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
60066006
case GGML_OP_REPEAT:
60076007
case GGML_OP_REPEAT_BACK:
60086008
case GGML_OP_ROPE:
6009+
case GGML_OP_RMS_NORM:
60096010
return true;
60106011
default:
60116012
return false;
@@ -6216,7 +6217,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
62166217

62176218
switch (op) {
62186219
case GGML_OP_NORM:
6219-
case GGML_OP_RMS_NORM:
62206220
case GGML_OP_RMS_NORM_BACK:
62216221
case GGML_OP_L2_NORM:
62226222
case GGML_OP_SOFT_MAX:
@@ -6233,6 +6233,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
62336233
elements = { nr, 1, 1 };
62346234
}
62356235
} break;
6236+
case GGML_OP_RMS_NORM:
6237+
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
6238+
break;
6239+
62366240
case GGML_OP_SUM:
62376241
// We use GGML_OP_SUM_ROWS with 1 row.
62386242
elements = { 1, 1, 1 };
@@ -6883,7 +6887,17 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
68836887

68846888
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
68856889
float * op_params = (float *)dst->op_params;
6886-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6890+
const uint32_t src0_type_size = ggml_type_size(src0->type);
6891+
const uint32_t dst_type_size = ggml_type_size(dst->type);
6892+
6893+
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
6894+
(uint32_t)ggml_nelements(src0),
6895+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
6896+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
6897+
0,
6898+
op_params[0], 0.0f,
6899+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
6900+
}, dryrun);
68876901
}
68886902

68896903
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9388,10 +9402,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
93889402
case GGML_OP_VIEW:
93899403
case GGML_OP_PERMUTE:
93909404
case GGML_OP_TRANSPOSE:
9405+
case GGML_OP_RMS_NORM:
93919406
return true;
93929407
case GGML_OP_NORM:
93939408
case GGML_OP_GROUP_NORM:
9394-
case GGML_OP_RMS_NORM:
93959409
case GGML_OP_L2_NORM:
93969410
return ggml_is_contiguous(op->src[0]);
93979411
case GGML_OP_ADD:
Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,36 @@
11
#version 450
22

3-
#include "generic_head.comp"
3+
#include "generic_unary_head.comp"
44
#include "types.comp"
55

66
#extension GL_EXT_control_flow_attributes : enable
77
#define BLOCK_SIZE 512
88

99
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
1010

11-
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
12-
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
13-
1411
shared FLOAT_TYPE sum[BLOCK_SIZE];
1512

1613
void main() {
17-
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
18-
const uint tid = gl_LocalInvocationID.x;
14+
const uint ncols = p.ne00;
15+
const uint nrows = gl_NumWorkGroups.x;
16+
const uint nchannels = gl_NumWorkGroups.y;
17+
18+
const uint row = gl_WorkGroupID.x;
19+
const uint channel = gl_WorkGroupID.y;
20+
const uint samp = gl_WorkGroupID.z;
21+
const uint tid = gl_LocalInvocationID.x;
22+
23+
const uint stride_row = p.nb01;
24+
const uint stride_channel = p.nb02;
25+
const uint stride_sample = p.nb03;
26+
27+
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
28+
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
1929

2030
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
2131

22-
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
23-
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
32+
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
33+
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
2434
sum[tid] += xi * xi;
2535
}
2636

@@ -33,10 +43,10 @@ void main() {
3343
barrier();
3444
}
3545

36-
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX);
46+
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
3747
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
3848

39-
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
40-
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
49+
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
50+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
4151
}
4252
}

0 commit comments

Comments
 (0)