Skip to content

Commit 25963a8

Browse files
authored
add GroveMoE support
1 parent a094f38 commit 25963a8

File tree

20 files changed

+570
-12
lines changed

20 files changed

+570
-12
lines changed

common/arg.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2415,7 +2415,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
24152415
{"--cpu-moe", "-cmoe"},
24162416
"keep all Mixture of Experts (MoE) weights in the CPU",
24172417
[](common_params & params) {
2418-
params.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()});
2418+
params.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_(ch|)exps", ggml_backend_cpu_buffer_type()});
24192419
}
24202420
).set_env("LLAMA_ARG_CPU_MOE"));
24212421
add_opt(common_arg(
@@ -2428,7 +2428,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
24282428
for (int i = 0; i < value; ++i) {
24292429
// keep strings alive and avoid leaking memory by storing them in a static vector
24302430
static std::list<std::string> buft_overrides;
2431-
buft_overrides.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i));
2431+
buft_overrides.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_(ch|)exps", i));
24322432
params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), ggml_backend_cpu_buffer_type()});
24332433
}
24342434
}

convert_hf_to_gguf.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7654,6 +7654,121 @@ def prepare_tensors(self):
76547654
raise ValueError(f"Unprocessed experts: {experts}")
76557655

76567656

7657+
@ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM")
7658+
class GroveMoeModel(TextModel):
7659+
model_arch = gguf.MODEL_ARCH.GROVEMOE
7660+
7661+
def set_gguf_parameters(self):
7662+
super().set_gguf_parameters()
7663+
if (n_experts := self.hparams.get("num_experts")) is not None:
7664+
self.gguf_writer.add_expert_count(n_experts)
7665+
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
7666+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
7667+
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
7668+
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L299
7669+
self.gguf_writer.add_expert_chunk_feed_forward_length(self.hparams.get("head_dim") or 128)
7670+
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L298
7671+
self.gguf_writer.add_experts_per_group(2)
7672+
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L376
7673+
self.gguf_writer.add_expert_group_scale(0.05)
7674+
# YaRN is not enabled by default
7675+
# To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts
7676+
rope_scaling = self.hparams.get("rope_scaling") or {}
7677+
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
7678+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
7679+
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
7680+
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
7681+
7682+
_experts: list[dict[str, Tensor]] | None = None
7683+
_chunk_experts: list[dict[str, Tensor]] | None = None
7684+
7685+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
7686+
if name.endswith(".expert_bias"):
7687+
# FIXME?: Unused https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L303
7688+
return []
7689+
7690+
# process the experts separately
7691+
if name.find("chunk_experts") != -1:
7692+
n_experts = self.hparams["num_experts"] // 2 # see add_experts_per_group
7693+
assert bid is not None
7694+
7695+
if self._chunk_experts is None:
7696+
self._chunk_experts = [{} for _ in range(self.block_count)]
7697+
7698+
self._chunk_experts[bid][name] = data_torch
7699+
7700+
if len(self._chunk_experts[bid]) >= n_experts * 3:
7701+
tensors: list[tuple[str, Tensor]] = []
7702+
7703+
# merge the experts into a single 3d tensor
7704+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
7705+
datas: list[Tensor] = []
7706+
7707+
for xid in range(n_experts):
7708+
ename = f"model.layers.{bid}.mlp.chunk_experts.{xid}.{w_name}.weight"
7709+
datas.append(self._chunk_experts[bid][ename])
7710+
del self._chunk_experts[bid][ename]
7711+
7712+
data_torch = torch.stack(datas, dim=0)
7713+
7714+
merged_name = f"model.layers.{bid}.mlp.chunk_experts.{w_name}.weight"
7715+
7716+
new_name = self.map_tensor_name(merged_name)
7717+
7718+
tensors.append((new_name, data_torch))
7719+
return tensors
7720+
else:
7721+
return []
7722+
elif name.find("experts") != -1:
7723+
n_experts = self.hparams["num_experts"]
7724+
assert bid is not None
7725+
7726+
if self._experts is None:
7727+
self._experts = [{} for _ in range(self.block_count)]
7728+
7729+
self._experts[bid][name] = data_torch
7730+
7731+
if len(self._experts[bid]) >= n_experts * 3:
7732+
tensors: list[tuple[str, Tensor]] = []
7733+
7734+
# merge the experts into a single 3d tensor
7735+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
7736+
datas: list[Tensor] = []
7737+
7738+
for xid in range(n_experts):
7739+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
7740+
datas.append(self._experts[bid][ename])
7741+
del self._experts[bid][ename]
7742+
7743+
data_torch = torch.stack(datas, dim=0)
7744+
7745+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
7746+
7747+
new_name = self.map_tensor_name(merged_name)
7748+
7749+
tensors.append((new_name, data_torch))
7750+
return tensors
7751+
else:
7752+
return []
7753+
7754+
return [(self.map_tensor_name(name), data_torch)]
7755+
7756+
def prepare_tensors(self):
7757+
super().prepare_tensors()
7758+
7759+
if self._chunk_experts is not None:
7760+
# flatten `list[dict[str, Tensor]]` into `list[str]`
7761+
chunk_experts = [k for d in self._chunk_experts for k in d.keys()]
7762+
if len(chunk_experts) > 0:
7763+
raise ValueError(f"Unprocessed adjugate experts: {chunk_experts}")
7764+
7765+
if self._experts is not None:
7766+
# flatten `list[dict[str, Tensor]]` into `list[str]`
7767+
experts = [k for d in self._experts for k in d.keys()]
7768+
if len(experts) > 0:
7769+
raise ValueError(f"Unprocessed experts: {experts}")
7770+
7771+
76577772
@ModelBase.register("ChameleonForConditionalGeneration")
76587773
@ModelBase.register("ChameleonForCausalLM") # obsolete
76597774
class ChameleonModel(TextModel):

ggml/include/ggml.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,16 @@ extern "C" {
916916
struct ggml_tensor * a,
917917
struct ggml_tensor * b);
918918

919+
GGML_API struct ggml_tensor * ggml_div_scalar_i32(
920+
struct ggml_context * ctx,
921+
struct ggml_tensor * a,
922+
int32_t b);
923+
924+
GGML_API struct ggml_tensor * ggml_div_scalar_left_i32(
925+
struct ggml_context * ctx,
926+
int32_t a,
927+
struct ggml_tensor * b);
928+
919929
GGML_API struct ggml_tensor * ggml_sqr(
920930
struct ggml_context * ctx,
921931
struct ggml_tensor * a);

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2471,7 +2471,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
24712471
case GGML_OP_ADD1:
24722472
case GGML_OP_SUB:
24732473
case GGML_OP_MUL:
2474-
case GGML_OP_DIV:
24752474
case GGML_OP_RMS_NORM:
24762475
case GGML_OP_SQR:
24772476
case GGML_OP_SQRT:
@@ -2494,6 +2493,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
24942493
case GGML_OP_PAD_REFLECT_1D:
24952494
case GGML_OP_COUNT_EQUAL:
24962495
return true;
2496+
case GGML_OP_DIV:
2497+
{
2498+
struct ggml_tensor * a = op->src[0];
2499+
struct ggml_tensor * b = op->src[1];
2500+
return a && b && a->type != GGML_TYPE_I32 && b->type != GGML_TYPE_I32;
2501+
} break;
24972502
case GGML_OP_SCALE:
24982503
float bias;
24992504
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));

ggml/src/ggml-cpu/binary-ops.cpp

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,50 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
115115
}
116116
}
117117

118+
static void apply_scalar_div_op(const ggml_compute_params * params, ggml_tensor * dst) {
119+
const ggml_tensor * src0 = dst->src[0];
120+
const ggml_tensor * src1 = dst->src[1];
121+
const ggml_tensor * src = src0 ? src0 : src1;
122+
const int32_t scalar = ggml_get_op_params_i32(dst, 0);
123+
124+
GGML_ASSERT(ggml_are_same_shape(src, dst));
125+
126+
GGML_TENSOR_BINARY_OP_LOCALS
127+
128+
const auto [ir0, ir1] = get_thread_range(params, src);
129+
130+
for (int64_t ir = ir0; ir < ir1; ++ir) {
131+
const int64_t i03 = ir/(ne02*ne01);
132+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
133+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
134+
135+
int32_t * dst_ptr = (int32_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
136+
const int32_t * src_ptr = (const int32_t *) ((const char *) src->data + i03*nb03 + i02*nb02 + i01*nb01);
137+
138+
for (int i = 0; i < ne00; i++) {
139+
dst_ptr[i] = src0 ? src_ptr[i] / scalar : scalar / src_ptr[i];
140+
}
141+
}
142+
}
143+
118144
// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
119145
template <float (*op)(float, float)>
120146
static void binary_op(const ggml_compute_params * params, ggml_tensor * dst) {
121147
const ggml_tensor * src0 = dst->src[0];
122148
const ggml_tensor * src1 = dst->src[1];
123149

124-
/* */ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
150+
/* */ if (!src0 || !src1) { // scalar
151+
if (dst->type == GGML_TYPE_I32) {
152+
if constexpr (op == op_div) {
153+
apply_scalar_div_op(params, dst);
154+
} else {
155+
GGML_ABORT("%s: unsupported op\n", __func__);
156+
}
157+
} else {
158+
GGML_ABORT("%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
159+
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
160+
}
161+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
125162
apply_binary_op<op, float, float, float>(params, dst);
126163
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
127164
apply_binary_op<op, ggml_fp16_t, ggml_fp16_t, ggml_fp16_t>(params, dst);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3434,7 +3434,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
34343434
case GGML_OP_ADD1:
34353435
case GGML_OP_SUB:
34363436
case GGML_OP_MUL:
3437-
case GGML_OP_DIV:
34383437
case GGML_OP_SCALE:
34393438
case GGML_OP_SQR:
34403439
case GGML_OP_SQRT:
@@ -3443,6 +3442,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
34433442
case GGML_OP_CLAMP:
34443443
case GGML_OP_LOG:
34453444
return true;
3445+
case GGML_OP_DIV:
3446+
{
3447+
struct ggml_tensor * a = op->src[0];
3448+
struct ggml_tensor * b = op->src[1];
3449+
return a && b && a->type != GGML_TYPE_I32 && b->type != GGML_TYPE_I32;
3450+
} break;
34463451
case GGML_OP_SSM_SCAN: {
34473452
if (op->src[3]->ne[0] == 1) {
34483453
// Mamba2

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1814,9 +1814,14 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
18141814
case GGML_OP_ADD:
18151815
case GGML_OP_SUB:
18161816
case GGML_OP_MUL:
1817-
case GGML_OP_DIV:
18181817
case GGML_OP_ADD_ID:
18191818
return op->src[0]->type == GGML_TYPE_F32;
1819+
case GGML_OP_DIV:
1820+
{
1821+
struct ggml_tensor * a = op->src[0];
1822+
struct ggml_tensor * b = op->src[1];
1823+
return a && b && a->type == GGML_TYPE_F32 && b->type == GGML_TYPE_F32;
1824+
} break;
18201825
case GGML_OP_ACC:
18211826
case GGML_OP_REPEAT:
18221827
case GGML_OP_SCALE:

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2608,11 +2608,19 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
26082608
}
26092609
}
26102610
case GGML_OP_MUL:
2611-
case GGML_OP_DIV:
26122611
case GGML_OP_SUB:
26132612
return (op->src[0]->type == op->src[1]->type) &&
26142613
(op->src[0]->type == op->type) &&
26152614
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
2615+
case GGML_OP_DIV:
2616+
{
2617+
struct ggml_tensor * a = op->src[0];
2618+
struct ggml_tensor * b = op->src[1];
2619+
return (a && b) &&
2620+
(a->type == b->type) &&
2621+
(a->type == op->type) &&
2622+
(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16);
2623+
} break;
26162624
case GGML_OP_ADD_ID:
26172625
return op->src[0]->type == GGML_TYPE_F32;
26182626
case GGML_OP_UNARY:

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4349,9 +4349,14 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
43494349
case GGML_OP_ADD1:
43504350
case GGML_OP_SUB:
43514351
case GGML_OP_MUL:
4352-
case GGML_OP_DIV:
43534352
case GGML_OP_REPEAT:
43544353
return true;
4354+
case GGML_OP_DIV:
4355+
{
4356+
struct ggml_tensor * a = op->src[0];
4357+
struct ggml_tensor * b = op->src[1];
4358+
return a && b && a->type != GGML_TYPE_I32 && b->type != GGML_TYPE_I32;
4359+
} break;
43554360
case GGML_OP_SQR:
43564361
case GGML_OP_SQRT:
43574362
case GGML_OP_SIN:

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11397,10 +11397,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1139711397
case GGML_OP_ADD:
1139811398
case GGML_OP_SUB:
1139911399
case GGML_OP_MUL:
11400-
case GGML_OP_DIV:
1140111400
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1140211401
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
1140311402
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
11403+
case GGML_OP_DIV:
11404+
{
11405+
struct ggml_tensor * a = op->src[0];
11406+
struct ggml_tensor * b = op->src[1];
11407+
return (a && b) &&
11408+
(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16) &&
11409+
(b->type == GGML_TYPE_F32 || b->type == GGML_TYPE_F16) &&
11410+
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
11411+
} break;
1140411412
case GGML_OP_ADD_ID:
1140511413
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&
1140611414
op->type == GGML_TYPE_F32;

0 commit comments

Comments
 (0)