Skip to content

Commit dd6d36b

Browse files
authored
Merge branch 'master' into smoldocling-support
2 parents bd137ff + 4a5686d commit dd6d36b

File tree

24 files changed

+760
-460
lines changed

24 files changed

+760
-460
lines changed

convert_hf_to_gguf.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4974,6 +4974,123 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
49744974
yield (new_name, data_torch)
49754975

49764976

4977+
@ModelBase.register("JambaForCausalLM")
4978+
class JambaModel(TextModel):
4979+
model_arch = gguf.MODEL_ARCH.JAMBA
4980+
4981+
def get_vocab_base_pre(self, tokenizer) -> str:
4982+
del tokenizer # unused
4983+
4984+
return "gpt-2"
4985+
4986+
def set_vocab(self):
4987+
if (self.dir_model / "tokenizer.model").is_file():
4988+
# Using Jamba's tokenizer.json causes errors on model load
4989+
# (something about "byte not found in vocab"),
4990+
# but there's a working tokenizer.model
4991+
self._set_vocab_sentencepiece()
4992+
else:
4993+
# Some Jamba models only have a tokenizer.json, which works.
4994+
self._set_vocab_gpt2()
4995+
4996+
def set_gguf_parameters(self):
4997+
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
4998+
d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4
4999+
d_inner = self.hparams["mamba_expand"] * d_model
5000+
d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16
5001+
# ceiling division
5002+
# ref: https://stackoverflow.com/a/17511341/22827863
5003+
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
5004+
dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16)
5005+
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6
5006+
n_kv_head = self.hparams["num_key_value_heads"]
5007+
attn_offset = self.hparams["attn_layer_offset"]
5008+
attn_period = self.hparams["attn_layer_period"]
5009+
n_kv_vec = [0 for _ in range(attn_offset)] + [
5010+
n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count)
5011+
]
5012+
5013+
self.gguf_writer.add_block_count(self.block_count)
5014+
self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"]))
5015+
self.gguf_writer.add_embedding_length(d_model)
5016+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
5017+
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
5018+
self.gguf_writer.add_head_count_kv(n_kv_vec)
5019+
self.gguf_writer.add_ssm_conv_kernel(d_conv)
5020+
self.gguf_writer.add_ssm_inner_size(d_inner)
5021+
self.gguf_writer.add_ssm_state_size(d_state)
5022+
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
5023+
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
5024+
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
5025+
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
5026+
self.gguf_writer.add_file_type(self.ftype)
5027+
5028+
_experts: list[dict[str, Tensor]] | None = None
5029+
5030+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
5031+
5032+
# Mini-Jamba
5033+
name = name.replace(".moe.", ".feed_forward.")
5034+
if bid is not None:
5035+
moe_offset = self.hparams["expert_layer_offset"]
5036+
moe_period = self.hparams["expert_layer_period"]
5037+
5038+
if not (bid >= moe_offset and (bid - moe_offset) % moe_period == 0):
5039+
name = name.replace(".experts.0.", ".")
5040+
5041+
# process the experts separately
5042+
if ".feed_forward.experts." in name:
5043+
n_experts = self.hparams["num_experts"]
5044+
5045+
assert bid is not None
5046+
5047+
if self._experts is None:
5048+
self._experts = [{} for _ in range(self.block_count)]
5049+
5050+
self._experts[bid][name] = data_torch
5051+
5052+
if len(self._experts[bid]) >= n_experts * 3:
5053+
5054+
# merge the experts into a single 3d tensor
5055+
for wid in ["down_proj", "gate_proj", "up_proj"]:
5056+
datas: list[Tensor] = []
5057+
5058+
for xid in range(n_experts):
5059+
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight"
5060+
datas.append(self._experts[bid][ename])
5061+
del self._experts[bid][ename]
5062+
5063+
data_torch = torch.stack(datas, dim=0)
5064+
5065+
# using the same merged name as qwen2moe
5066+
merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight"
5067+
5068+
new_name = self.map_tensor_name(merged_name)
5069+
5070+
yield new_name, data_torch
5071+
return
5072+
5073+
new_name = self.map_tensor_name(name)
5074+
5075+
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
5076+
data_torch = data_torch.squeeze()
5077+
5078+
if name.endswith(".A_log"):
5079+
logger.debug("A_log --> A ==> " + new_name)
5080+
data_torch = -torch.exp(data_torch)
5081+
5082+
yield (new_name, data_torch)
5083+
5084+
def prepare_tensors(self):
5085+
super().prepare_tensors()
5086+
5087+
if self._experts is not None:
5088+
# flatten `list[dict[str, Tensor]]` into `list[str]`
5089+
experts = [k for d in self._experts for k in d.keys()]
5090+
if len(experts) > 0:
5091+
raise ValueError(f"Unprocessed experts: {experts}")
5092+
5093+
49775094
@ModelBase.register("CohereForCausalLM")
49785095
class CommandR2Model(TextModel):
49795096
model_arch = gguf.MODEL_ARCH.COMMAND_R

ggml/include/ggml.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1297,6 +1297,19 @@ extern "C" {
12971297
struct ggml_tensor * a,
12981298
float s);
12991299

1300+
// x = s * a + b
1301+
GGML_API struct ggml_tensor * ggml_scale_bias(
1302+
struct ggml_context * ctx,
1303+
struct ggml_tensor * a,
1304+
float s,
1305+
float b);
1306+
1307+
GGML_API struct ggml_tensor * ggml_scale_bias_inplace(
1308+
struct ggml_context * ctx,
1309+
struct ggml_tensor * a,
1310+
float s,
1311+
float b);
1312+
13001313
// b -> view(a,offset,nb1,nb2,3), return modified a
13011314
GGML_API struct ggml_tensor * ggml_set(
13021315
struct ggml_context * ctx,

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2188,7 +2188,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
21882188
case GGML_OP_MUL:
21892189
case GGML_OP_DIV:
21902190
case GGML_OP_RMS_NORM:
2191-
case GGML_OP_SCALE:
21922191
case GGML_OP_SQR:
21932192
case GGML_OP_SQRT:
21942193
case GGML_OP_CLAMP:
@@ -2210,6 +2209,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
22102209
case GGML_OP_PAD_REFLECT_1D:
22112210
case GGML_OP_COUNT_EQUAL:
22122211
return true;
2212+
case GGML_OP_SCALE:
2213+
float bias;
2214+
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
2215+
return bias == 0.0f; // TODO: support bias != 0.0f
22132216
case GGML_OP_SOFT_MAX:
22142217
// TODO: support broadcast
22152218
// ref: https://github.com/ggml-org/llama.cpp/pull/14435

ggml/src/ggml-cpu/ops.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4643,9 +4643,11 @@ static void ggml_compute_forward_scale_f32(
46434643
GGML_ASSERT(ggml_is_contiguous(dst));
46444644
GGML_ASSERT(ggml_are_same_shape(src0, dst));
46454645

4646-
// scale factor
4647-
float v;
4648-
memcpy(&v, dst->op_params, sizeof(float));
4646+
float s; // scale factor
4647+
float b; // bias
4648+
4649+
memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
4650+
memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
46494651

46504652
const int ith = params->ith;
46514653
const int nth = params->nth;
@@ -4664,12 +4666,22 @@ static void ggml_compute_forward_scale_f32(
46644666

46654667
const size_t nb1 = dst->nb[1];
46664668

4667-
for (int i1 = ir0; i1 < ir1; i1++) {
4668-
if (dst->data != src0->data) {
4669-
// src0 is same shape as dst => same indices
4670-
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4669+
if (b == 0.0f) {
4670+
for (int i1 = ir0; i1 < ir1; i1++) {
4671+
if (dst->data != src0->data) {
4672+
// src0 is same shape as dst => same indices
4673+
// TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
4674+
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4675+
}
4676+
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
4677+
}
4678+
} else {
4679+
for (int i1 = ir0; i1 < ir1; i1++) {
4680+
ggml_vec_mad1_f32(nc,
4681+
(float *) ((char *) dst->data + i1*nb1),
4682+
(float *) ((char *) src0->data + i1*nb1),
4683+
s, b);
46714684
}
4672-
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
46734685
}
46744686
}
46754687

ggml/src/ggml-cpu/vec.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,45 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int
351351
#endif
352352
}
353353

354+
inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
355+
#if defined(GGML_USE_ACCELERATE)
356+
vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
357+
#elif defined(GGML_SIMD)
358+
#if defined(__ARM_FEATURE_SVE)
359+
// scalar ; TODO: Write SVE code
360+
for (int i = 0; i < n; ++i) {
361+
y[i] = x[i]*s + b;
362+
}
363+
#else
364+
const int np = (n & ~(GGML_F32_STEP - 1));
365+
366+
GGML_F32_VEC vs = GGML_F32_VEC_SET1(s);
367+
GGML_F32_VEC vb = GGML_F32_VEC_SET1(b);
368+
369+
GGML_F32_VEC ay[GGML_F32_ARR];
370+
371+
for (int i = 0; i < np; i += GGML_F32_STEP) {
372+
for (int j = 0; j < GGML_F32_ARR; j++) {
373+
ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
374+
ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb);
375+
376+
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
377+
}
378+
}
379+
380+
// leftovers
381+
for (int i = np; i < n; ++i) {
382+
y[i] = x[i]*s + b;
383+
}
384+
#endif
385+
#else
386+
// scalar
387+
for (int i = 0; i < n; ++i) {
388+
y[i] = x[i]*s + b;
389+
}
390+
#endif
391+
}
392+
354393
//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
355394
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
356395
#if defined(GGML_USE_ACCELERATE)

ggml/src/ggml-cuda/scale.cu

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
#include "scale.cuh"
22

3-
static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
3+
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) {
44
const int i = blockDim.x*blockIdx.x + threadIdx.x;
55

66
if (i >= k) {
77
return;
88
}
99

10-
dst[i] = scale * x[i];
10+
dst[i] = scale * x[i] + bias;
1111
}
1212

13-
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
13+
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) {
1414
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
15-
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
15+
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, k);
1616
}
1717

1818
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -25,7 +25,9 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
2525
GGML_ASSERT( dst->type == GGML_TYPE_F32);
2626

2727
float scale;
28-
memcpy(&scale, dst->op_params, sizeof(float));
28+
float bias;
29+
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
30+
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
2931

30-
scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
32+
scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream);
3133
}

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2256,7 +2256,9 @@ static bool ggml_metal_encode_node(
22562256
GGML_ASSERT(ggml_is_contiguous(src0));
22572257

22582258
float scale;
2259-
memcpy(&scale, dst->op_params, sizeof(scale));
2259+
float bias;
2260+
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float));
2261+
memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float));
22602262

22612263
int64_t n = ggml_nelements(dst);
22622264

@@ -2273,6 +2275,7 @@ static bool ggml_metal_encode_node(
22732275
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
22742276
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
22752277
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
2278+
[encoder setBytes:&bias length:sizeof(bias) atIndex:3];
22762279

22772280
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
22782281
} break;

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,16 +1014,18 @@ kernel void kernel_scale(
10141014
device const float * src0,
10151015
device float * dst,
10161016
constant float & scale,
1017+
constant float & bias,
10171018
uint tpig[[thread_position_in_grid]]) {
1018-
dst[tpig] = src0[tpig] * scale;
1019+
dst[tpig] = src0[tpig] * scale + bias;
10191020
}
10201021

10211022
kernel void kernel_scale_4(
10221023
device const float4 * src0,
10231024
device float4 * dst,
10241025
constant float & scale,
1026+
constant float & bias,
10251027
uint tpig[[thread_position_in_grid]]) {
1026-
dst[tpig] = src0[tpig] * scale;
1028+
dst[tpig] = src0[tpig] * scale + bias;
10271029
}
10281030

10291031
kernel void kernel_clamp(

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5587,7 +5587,9 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
55875587
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
55885588

55895589
float scale;
5590-
memcpy(&scale, dst->op_params, sizeof(scale));
5590+
float bias;
5591+
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(float));
5592+
memcpy(&bias, ((int32_t *) dst->op_params) + 1, sizeof(float));
55915593

55925594
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
55935595
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
@@ -5602,6 +5604,7 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
56025604
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
56035605
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
56045606
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale));
5607+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &bias));
56055608

56065609
int n = ggml_nelements(dst)/4;
56075610

ggml/src/ggml-opencl/kernels/scale.cl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@ kernel void kernel_scale(
88
ulong offset0,
99
global float4 * dst,
1010
ulong offsetd,
11-
float scale
11+
float scale,
12+
float bias
1213
) {
1314
src0 = (global float4*)((global char*)src0 + offset0);
1415
dst = (global float4*)((global char*)dst + offsetd);
15-
dst[get_global_id(0)] = src0[get_global_id(0)] * scale;
16+
dst[get_global_id(0)] = src0[get_global_id(0)] * scale + bias;
1617
}

0 commit comments

Comments
 (0)