Skip to content

Commit 8a4a856

Browse files
authored
Add LLaDA 8b Diffusion model (#14771)
* Add support for Llada-8b: diffusion model * Add README * Fix README and convert_hf_to_gguf * convert_hf_to_gguf.py: address review comments * Make everything in a single example * Remove model-specific sampling * Remove unused argmax * Remove braced initializers, improve README.md a bit * Add diffusion specific gguf params in set_vocab, remove setting rope_theta and rms_norm_eps * Remove adding the mask token * Move add_add_bos_token to set_vocab * use add_bool in gguf_writer.py
1 parent 11490b3 commit 8a4a856

File tree

12 files changed

+857
-311
lines changed

12 files changed

+857
-311
lines changed

common/arg.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3438,34 +3438,51 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
34383438
}
34393439
).set_examples({LLAMA_EXAMPLE_SERVER}));
34403440

3441-
// diffusion parameters
34423441
add_opt(common_arg(
34433442
{ "--diffusion-steps" }, "N",
34443443
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
34453444
[](common_params & params, int value) { params.diffusion.steps = value; }
34463445
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3446+
add_opt(common_arg(
3447+
{ "--diffusion-visual" },
3448+
string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
3449+
params.diffusion.visual_mode ? "true" : "false"),
3450+
[](common_params & params) { params.diffusion.visual_mode = true; }
3451+
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3452+
34473453
add_opt(common_arg(
34483454
{ "--diffusion-eps" }, "F",
34493455
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
34503456
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
34513457
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
34523458
add_opt(common_arg(
34533459
{ "--diffusion-algorithm" }, "N",
3454-
string_format("diffusion algorithm: 0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY (default: %d)",
3460+
string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)",
34553461
params.diffusion.algorithm),
34563462
[](common_params & params, int value) { params.diffusion.algorithm = value; }
34573463
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
34583464
add_opt(common_arg(
34593465
{ "--diffusion-alg-temp" }, "F",
3460-
string_format("algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
3466+
string_format("dream algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
34613467
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
34623468
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3469+
34633470
add_opt(common_arg(
3464-
{ "--diffusion-visual" },
3465-
string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
3466-
params.diffusion.visual_mode ? "true" : "false"),
3467-
[](common_params & params) { params.diffusion.visual_mode = true; }
3471+
{ "--diffusion-block-length" }, "N",
3472+
string_format("llada block length for generation (default: %d)", params.diffusion.block_length),
3473+
[](common_params & params, int value) { params.diffusion.block_length = value; }
3474+
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3475+
add_opt(common_arg(
3476+
{ "--diffusion-cfg-scale" }, "F",
3477+
string_format("llada classifier-free guidance scale (default: %.3f)", (double) params.diffusion.cfg_scale),
3478+
[](common_params & params, const std::string & value) { params.diffusion.cfg_scale = std::stof(value); }
3479+
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3480+
add_opt(common_arg(
3481+
{ "--diffusion-add-gumbel-noise" }, "F",
3482+
string_format("add gumbel noise to the logits if temp > 0.0 (default: %s)", params.diffusion.add_gumbel_noise ? "true" : "false"),
3483+
[](common_params & params, const std::string & value) { params.diffusion.add_gumbel_noise = std::stof(value); }
34683484
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
34693485

3486+
34703487
return ctx_arg;
34713488
}

common/common.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,11 +220,17 @@ struct common_params_vocoder {
220220
};
221221

222222
struct common_params_diffusion {
223-
int32_t steps = 64; // number of diffusion steps
224-
float eps = 1e-3f; // epsilon for timesteps
225-
int32_t algorithm = 0; // diffusion algorithm (0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY)
226-
float alg_temp = 0.0f; // algorithm temperature
227-
bool visual_mode = false; // show progressive diffusion on screen
223+
int32_t steps = 128;
224+
bool visual_mode = false;
225+
226+
float eps = 0; // epsilon for timesteps
227+
int32_t block_length = 32; // block length for generation
228+
229+
int32_t algorithm = 4; // default algorithm: low-confidence
230+
float alg_temp = 0.0f; // algorithm temperature
231+
232+
float cfg_scale = 0; // classifier-free guidance scale
233+
bool add_gumbel_noise = false; // add gumbel noise to the logits if temp > 0.0
228234
};
229235

230236
enum common_reasoning_format {

convert_hf_to_gguf.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2904,6 +2904,107 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
29042904
yield from super().modify_tensors(data_torch, name, bid)
29052905

29062906

2907+
@ModelBase.register("LLaDAModelLM")
2908+
class LLaDAModel(TextModel):
2909+
model_arch = gguf.MODEL_ARCH.LLADA
2910+
undo_permute = True
2911+
2912+
def get_vocab_base(self) -> tuple[list[str], list[int], str]:
2913+
tokens: list[str] = []
2914+
toktypes: list[int] = []
2915+
2916+
from transformers import AutoTokenizer
2917+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
2918+
2919+
vocab_dict = tokenizer.get_vocab()
2920+
vocab_size = self.hparams.get("vocab_size", len(vocab_dict))
2921+
assert max(vocab_dict.values()) < vocab_size
2922+
2923+
tokpre = self.get_vocab_base_pre(tokenizer)
2924+
2925+
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab_dict.items()}
2926+
added_vocab = tokenizer.get_added_vocab()
2927+
2928+
for i in range(vocab_size):
2929+
if i not in reverse_vocab:
2930+
tokens.append(f"[PAD{i}]")
2931+
toktypes.append(gguf.TokenType.UNUSED)
2932+
elif reverse_vocab[i] in added_vocab:
2933+
tokens.append(reverse_vocab[i])
2934+
# Check if it's a special token - treat special tokens as CONTROL tokens
2935+
if hasattr(tokenizer, 'added_tokens_decoder') and i in tokenizer.added_tokens_decoder:
2936+
if tokenizer.added_tokens_decoder[i].special:
2937+
toktypes.append(gguf.TokenType.CONTROL)
2938+
else:
2939+
toktypes.append(gguf.TokenType.USER_DEFINED)
2940+
else:
2941+
# Fallback: treat all added vocab as control tokens for special tokens like <|im_start|>
2942+
toktypes.append(gguf.TokenType.CONTROL)
2943+
else:
2944+
tokens.append(reverse_vocab[i])
2945+
toktypes.append(gguf.TokenType.NORMAL)
2946+
2947+
return tokens, toktypes, tokpre
2948+
2949+
def set_vocab(self):
2950+
self._set_vocab_gpt2()
2951+
2952+
# LLaDA specific parameters
2953+
self.gguf_writer.add_add_bos_token(True)
2954+
2955+
def set_gguf_parameters(self):
2956+
super().set_gguf_parameters()
2957+
self._try_set_pooling_type()
2958+
2959+
# Add parameters similar to LlamaModel
2960+
hparams = self.hparams
2961+
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
2962+
2963+
if (rope_dim := hparams.get("head_dim")) is None:
2964+
n_heads = hparams.get("num_attention_heads", hparams.get("n_heads"))
2965+
rope_dim = hparams.get("hidden_size", hparams.get("d_model")) // n_heads
2966+
self.gguf_writer.add_rope_dimension_count(rope_dim)
2967+
2968+
# Set context length for LLaDA
2969+
context_length = self.hparams.get("max_sequence_length", 4096)
2970+
self.gguf_writer.add_context_length(context_length)
2971+
2972+
# Set embedding length (dimension size)
2973+
embedding_length = self.hparams.get("d_model", 4096)
2974+
self.gguf_writer.add_embedding_length(embedding_length)
2975+
2976+
# Set feed forward length (MLP hidden size)
2977+
feed_forward_length = self.hparams.get("mlp_hidden_size", 12288)
2978+
self.gguf_writer.add_feed_forward_length(feed_forward_length)
2979+
2980+
# LLaDA models use non-causal attention for diffusion, similar to Dream
2981+
self.gguf_writer.add_causal_attention(False)
2982+
2983+
# LLaDA models don't shift their logits
2984+
self.gguf_writer.add_diffusion_shift_logits(False)
2985+
2986+
@staticmethod
2987+
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
2988+
if n_head_kv is not None and n_head != n_head_kv:
2989+
n_head = n_head_kv
2990+
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
2991+
.swapaxes(1, 2)
2992+
.reshape(weights.shape))
2993+
2994+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2995+
n_head = self.hparams.get("num_attention_heads", self.hparams.get("n_heads"))
2996+
n_kv_head = self.hparams.get("num_key_value_heads", self.hparams.get("n_kv_heads"))
2997+
2998+
if self.undo_permute:
2999+
if name.endswith(("q_proj.weight", "q_proj.bias")):
3000+
data_torch = LLaDAModel.permute(data_torch, n_head, n_head)
3001+
if name.endswith(("k_proj.weight", "k_proj.bias")):
3002+
data_torch = LLaDAModel.permute(data_torch, n_head, n_kv_head)
3003+
3004+
# LLaDA model tensors should be mapped directly since it's the base model
3005+
yield from super().modify_tensors(data_torch, name, bid)
3006+
3007+
29073008
@ModelBase.register("Ernie4_5_ForCausalLM")
29083009
class Ernie4_5Model(TextModel):
29093010
model_arch = gguf.MODEL_ARCH.ERNIE4_5

examples/diffusion/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Diffusion Text Generation
2+
3+
This directory contains implementations for Diffusion LLMs (DLLMs)
4+
5+
More Info:
6+
- https://github.com/ggml-org/llama.cpp/pull/14644
7+
- https://github.com/ggml-org/llama.cpp/pull/14771
8+
9+
10+
Example of using Dream architechture: `llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual`
11+
12+
Example of using LLaDA architechture: `llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual`
13+

0 commit comments

Comments
 (0)