Skip to content

Commit 4902eeb

Browse files
wp4032CISC
andauthored
models : Added support for RND1 Diffusion Language Model (ggml-org#17433)
* Converted RND1 model to GGUF weights * RND1 llama.cpp support v1 * RND1 llama.cpp support v2 non causal bug * RND1 llama.cpp support v3 doccumentation * RND1 llama.cpp support v4 clean code * linting issues * RND1 pr fixes v1 * RND1 pr fixes v2 Co-authored-by: Sigbjørn Skjæret <[email protected]> * Diffusion documentation edits --------- Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 923ae3c commit 4902eeb

File tree

9 files changed

+257
-3
lines changed

9 files changed

+257
-3
lines changed

convert_hf_to_gguf.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4183,6 +4183,21 @@ def set_vocab(self):
41834183
super().set_vocab()
41844184

41854185

4186+
@ModelBase.register("RND1")
4187+
class RND1Model(Qwen2MoeModel):
4188+
model_arch = gguf.MODEL_ARCH.RND1
4189+
4190+
def set_gguf_parameters(self):
4191+
super().set_gguf_parameters()
4192+
4193+
# RND1 specific parameters
4194+
# RND1 uses bidirectional attention
4195+
self.gguf_writer.add_causal_attention(False)
4196+
4197+
if (mask_token_id := self.hparams.get("mask_token_id")) is not None:
4198+
self.gguf_writer.add_mask_token_id(mask_token_id)
4199+
4200+
41864201
@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
41874202
class Qwen3VLVisionModel(MmprojModel):
41884203
def __init__(self, *args, **kwargs):

examples/diffusion/README.md

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,54 @@ More Info:
66
- https://github.com/ggml-org/llama.cpp/pull/14644
77
- https://github.com/ggml-org/llama.cpp/pull/14771
88

9+
## Parameters
10+
The diffusion CLI supports various parameters to control the generation process:
911

10-
Example of using Dream architechture: `llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual`
12+
### Core Diffusion Parameters
13+
- `--diffusion-steps`: Number of diffusion steps (default: 256)
14+
- `--diffusion-algorithm`: Algorithm for token selection
15+
- `0`: ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006.
16+
- `1`: ENTROPY_BASED - Entropy-based selection
17+
- `2`: MARGIN_BASED - Margin-based selection
18+
- `3`: RANDOM - Random selection
19+
- `4`: CONFIDENCE_BASED - Confidence-based selection (default)
20+
- More documentation here https://github.com/DreamLM/Dream
21+
- `--diffusion-visual`: Enable live visualization during generation
1122

12-
Example of using LLaDA architechture: `llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual`
23+
### Scheduling Parameters
24+
Choose one of the following scheduling methods:
1325

26+
**Timestep-based scheduling:**
27+
- `--diffusion-eps`: Epsilon value for timestep scheduling (e.g., 0.001)
28+
29+
**Block-based scheduling:**
30+
- `--diffusion-block-length`: Block size for block-based scheduling (e.g., 32)
31+
32+
### Sampling Parameters
33+
- `--temp`: Temperature for sampling (0.0 = greedy/deterministic, higher = more random)
34+
- `--top-k`: Top-k filtering for sampling
35+
- `--top-p`: Top-p (nucleus) filtering for sampling
36+
- `--seed`: Random seed for reproducibility
37+
38+
### Model Parameters
39+
- `-m`: Path to the GGUF model file
40+
- `-p`: Input prompt text
41+
- `-ub`: Maximum sequence length (ubatch size)
42+
- `-c`: Context size
43+
- `-b`: Batch size
44+
45+
### Examples
46+
#### Dream architechture:
47+
```
48+
llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual
49+
```
50+
51+
#### LLaDA architechture:
52+
```
53+
llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual
54+
```
55+
56+
#### RND1 architecture:
57+
```
58+
llama-diffusion-cli -m RND1-Base-0910.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-algorithm 1 --diffusion-steps 256 --diffusion-visual --temp 0.5 --diffusion-eps 0.001
59+
```

gguf-py/gguf/constants.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ class MODEL_ARCH(IntEnum):
427427
APERTUS = auto()
428428
COGVLM = auto()
429429
MINIMAXM2 = auto()
430+
RND1 = auto()
430431
PANGU_EMBED = auto()
431432

432433

@@ -797,6 +798,7 @@ class MODEL_TENSOR(IntEnum):
797798
MODEL_ARCH.APERTUS: "apertus",
798799
MODEL_ARCH.MINIMAXM2: "minimax-m2",
799800
MODEL_ARCH.COGVLM: "cogvlm",
801+
MODEL_ARCH.RND1: "rnd1",
800802
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
801803
}
802804

@@ -2991,6 +2993,23 @@ class MODEL_TENSOR(IntEnum):
29912993
MODEL_TENSOR.VISEXP_UP,
29922994
MODEL_TENSOR.VISEXP_DOWN,
29932995
],
2996+
MODEL_ARCH.RND1: [
2997+
MODEL_TENSOR.TOKEN_EMBD,
2998+
MODEL_TENSOR.OUTPUT_NORM,
2999+
MODEL_TENSOR.OUTPUT,
3000+
MODEL_TENSOR.ATTN_NORM,
3001+
MODEL_TENSOR.ATTN_Q,
3002+
MODEL_TENSOR.ATTN_Q_NORM,
3003+
MODEL_TENSOR.ATTN_K,
3004+
MODEL_TENSOR.ATTN_K_NORM,
3005+
MODEL_TENSOR.ATTN_V,
3006+
MODEL_TENSOR.ATTN_OUT,
3007+
MODEL_TENSOR.FFN_NORM,
3008+
MODEL_TENSOR.FFN_GATE_INP,
3009+
MODEL_TENSOR.FFN_GATE_EXP,
3010+
MODEL_TENSOR.FFN_DOWN_EXP,
3011+
MODEL_TENSOR.FFN_UP_EXP,
3012+
],
29943013
MODEL_ARCH.PANGU_EMBED: [
29953014
MODEL_TENSOR.TOKEN_EMBD,
29963015
MODEL_TENSOR.OUTPUT_NORM,

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ add_library(llama
115115
models/qwen3vl-moe.cpp
116116
models/qwen3moe.cpp
117117
models/refact.cpp
118+
models/rnd1.cpp
118119
models/rwkv6-base.cpp
119120
models/rwkv6.cpp
120121
models/rwkv6qwen2.cpp

src/llama-arch.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
108108
{ LLM_ARCH_APERTUS, "apertus" },
109109
{ LLM_ARCH_MINIMAX_M2, "minimax-m2" },
110110
{ LLM_ARCH_COGVLM, "cogvlm" },
111+
{ LLM_ARCH_RND1, "rnd1" },
111112
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
112113
{ LLM_ARCH_UNKNOWN, "(unknown)" },
113114
};
@@ -2446,6 +2447,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
24462447
{ LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" },
24472448
},
24482449
},
2450+
{
2451+
LLM_ARCH_RND1,
2452+
{
2453+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
2454+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
2455+
{ LLM_TENSOR_OUTPUT, "output" },
2456+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
2457+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
2458+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
2459+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
2460+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
2461+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
2462+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
2463+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
2464+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
2465+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
2466+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
2467+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
2468+
},
2469+
},
24492470
{
24502471
LLM_ARCH_UNKNOWN,
24512472
{
@@ -2722,6 +2743,7 @@ bool llm_arch_is_diffusion(const llm_arch & arch) {
27222743
case LLM_ARCH_DREAM:
27232744
case LLM_ARCH_LLADA:
27242745
case LLM_ARCH_LLADA_MOE:
2746+
case LLM_ARCH_RND1:
27252747
return true;
27262748
default:
27272749
return false;

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ enum llm_arch {
112112
LLM_ARCH_APERTUS,
113113
LLM_ARCH_MINIMAX_M2,
114114
LLM_ARCH_COGVLM,
115+
LLM_ARCH_RND1,
115116
LLM_ARCH_PANGU_EMBED,
116117
LLM_ARCH_UNKNOWN,
117118
};

src/llama-model.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,18 @@ void llama_model::load_hparams(llama_model_loader & ml) {
10361036
default: type = LLM_TYPE_UNKNOWN;
10371037
}
10381038
} break;
1039+
case LLM_ARCH_RND1:
1040+
{
1041+
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
1042+
1043+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1044+
switch (hparams.n_layer) {
1045+
case 48: type = LLM_TYPE_30B_A3B; break;
1046+
default: type = LLM_TYPE_UNKNOWN;
1047+
}
1048+
// Set non-causal attention for diffusion models
1049+
hparams.causal_attn = false;
1050+
} break;
10391051
case LLM_ARCH_QWEN2MOE:
10401052
{
10411053
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
@@ -3402,6 +3414,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
34023414
} break;
34033415
case LLM_ARCH_QWEN3MOE:
34043416
case LLM_ARCH_QWEN3VLMOE:
3417+
case LLM_ARCH_RND1:
34053418
{
34063419
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
34073420

@@ -6720,7 +6733,7 @@ void llama_model::print_info() const {
67206733
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
67216734
}
67226735

6723-
if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE) {
6736+
if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) {
67246737
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
67256738
}
67266739

@@ -6882,6 +6895,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
68826895
case LLM_ARCH_DREAM:
68836896
case LLM_ARCH_LLADA:
68846897
case LLM_ARCH_LLADA_MOE:
6898+
case LLM_ARCH_RND1:
68856899
{
68866900
res = nullptr;
68876901
} break;
@@ -7075,6 +7089,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
70757089
llm = std::make_unique<llm_build_llada_moe>(*this, params);
70767090
}
70777091
break;
7092+
case LLM_ARCH_RND1:
7093+
{
7094+
llm = std::make_unique<llm_build_rnd1>(*this, params);
7095+
}
7096+
break;
70787097
case LLM_ARCH_QWEN2VL:
70797098
{
70807099
llm = std::make_unique<llm_build_qwen2vl>(*this, params);
@@ -7595,6 +7614,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
75957614
case LLM_ARCH_QWEN3:
75967615
case LLM_ARCH_QWEN3MOE:
75977616
case LLM_ARCH_LLADA_MOE:
7617+
case LLM_ARCH_RND1:
75987618
case LLM_ARCH_OLMO2:
75997619
case LLM_ARCH_OLMOE:
76007620
case LLM_ARCH_PHI2:

src/models/models.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,10 @@ struct llm_build_refact : public llm_graph_context {
431431
llm_build_refact(const llama_model & model, const llm_graph_params & params);
432432
};
433433

434+
struct llm_build_rnd1 : public llm_graph_context {
435+
llm_build_rnd1(const llama_model & model, const llm_graph_params & params);
436+
};
437+
434438
struct llm_build_rwkv6 : public llm_build_rwkv6_base {
435439
llm_build_rwkv6(const llama_model & model, const llm_graph_params & params);
436440
};

src/models/rnd1.cpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#include "models.h"
2+
3+
// RND1 is a Qwen3Moe AR model converted to diffusion model.
4+
llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
5+
const int64_t n_embd_head = hparams.n_embd_head_v;
6+
7+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
8+
GGML_ASSERT(n_embd_head == hparams.n_rot);
9+
10+
ggml_tensor * cur;
11+
ggml_tensor * inpL;
12+
13+
inpL = build_inp_embd(model.tok_embd);
14+
15+
// inp_pos - contains the positions
16+
ggml_tensor * inp_pos = build_inp_pos();
17+
18+
// Non-causal attention for diffusion
19+
auto * inp_attn = build_attn_inp_no_cache();
20+
21+
ggml_tensor * inp_out_ids = build_inp_out_ids();
22+
23+
for (int il = 0; il < n_layer; ++il) {
24+
ggml_tensor * inpSA = inpL;
25+
26+
// norm
27+
cur = build_norm(inpL,
28+
model.layers[il].attn_norm, NULL,
29+
LLM_NORM_RMS, il);
30+
cb(cur, "attn_norm", il);
31+
32+
// self_attention
33+
{
34+
// compute Q and K and RoPE them
35+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
36+
cb(Qcur, "Qcur", il);
37+
38+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
39+
cb(Kcur, "Kcur", il);
40+
41+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
42+
cb(Vcur, "Vcur", il);
43+
44+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
45+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
46+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
47+
48+
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
49+
cb(Qcur, "Qcur_normed", il);
50+
51+
Qcur = ggml_rope_ext(
52+
ctx0, Qcur, inp_pos, nullptr,
53+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
54+
ext_factor, attn_factor, beta_fast, beta_slow
55+
);
56+
57+
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
58+
cb(Kcur, "Kcur_normed", il);
59+
60+
Kcur = ggml_rope_ext(
61+
ctx0, Kcur, inp_pos, nullptr,
62+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
63+
ext_factor, attn_factor, beta_fast, beta_slow
64+
);
65+
66+
cb(Qcur, "Qcur", il);
67+
cb(Kcur, "Kcur", il);
68+
cb(Vcur, "Vcur", il);
69+
70+
cur = build_attn(inp_attn,
71+
model.layers[il].wo, model.layers[il].bo,
72+
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
73+
}
74+
if (il == n_layer - 1 && inp_out_ids) {
75+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
76+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
77+
}
78+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
79+
cb(ffn_inp, "ffn_inp", il);
80+
81+
// MoE branch
82+
cur = build_norm(ffn_inp,
83+
model.layers[il].ffn_norm, NULL,
84+
LLM_NORM_RMS, il);
85+
cb(cur, "ffn_norm", il);
86+
87+
ggml_tensor * moe_out =
88+
build_moe_ffn(cur,
89+
model.layers[il].ffn_gate_inp,
90+
model.layers[il].ffn_up_exps,
91+
model.layers[il].ffn_gate_exps,
92+
model.layers[il].ffn_down_exps,
93+
nullptr,
94+
n_expert, n_expert_used,
95+
LLM_FFN_SILU, true,
96+
false, 0.0,
97+
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
98+
il);
99+
cb(moe_out, "ffn_moe_out", il);
100+
cur = moe_out;
101+
102+
cur = ggml_add(ctx0, cur, ffn_inp);
103+
104+
cur = build_cvec(cur, il);
105+
cb(cur, "l_out", il);
106+
107+
// input for next layer
108+
inpL = cur;
109+
}
110+
cur = inpL;
111+
112+
cur = build_norm(cur,
113+
model.output_norm, NULL,
114+
LLM_NORM_RMS, -1);
115+
116+
cb(cur, "result_norm", -1);
117+
res->t_embd = cur;
118+
119+
// lm_head
120+
cur = build_lora_mm(model.output, cur);
121+
122+
cb(cur, "result_output", -1);
123+
res->t_logits = cur;
124+
125+
ggml_build_forward_expand(gf, cur);
126+
}

0 commit comments

Comments
 (0)