Skip to content

Commit 4923a7b

Browse files
committed
Sync: Add LLaDA 8b Diffusion model
1 parent e390ab3 commit 4923a7b

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

llama_cpp/llama_cpp.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,6 +1668,14 @@ def llama_model_is_recurrent(model: llama_model_p, /) -> bool:
16681668
...
16691669

16701670

1671+
# // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
1672+
# LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
1673+
@ctypes_function("llama_model_is_diffusion", [llama_model_p_ctypes], ctypes.c_bool)
1674+
def llama_model_is_diffusion(model: llama_model_p, /) -> bool:
1675+
"""Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)"""
1676+
...
1677+
1678+
16711679
# // Returns 0 on success
16721680
# LLAMA_API uint32_t llama_model_quantize(
16731681
# const char * fname_inp,
@@ -2619,6 +2627,7 @@ def llama_synchronize(ctx: llama_context_p, /):
26192627
# // in the order they have appeared in the batch.
26202628
# // Rows: number of tokens for which llama_batch.logits[i] != 0
26212629
# // Cols: n_vocab
2630+
# // TODO: deprecate in favor of llama_get_logits_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
26222631
# LLAMA_API float * llama_get_logits(struct llama_context * ctx);
26232632
@ctypes_function(
26242633
"llama_get_logits", [llama_context_p_ctypes], ctypes.POINTER(ctypes.c_float)
@@ -2659,6 +2668,7 @@ def llama_get_logits_ith(
26592668
# // in the order they have appeared in the batch.
26602669
# // shape: [n_outputs*n_embd]
26612670
# // Otherwise, returns NULL.
2671+
# // TODO: deprecate in favor of llama_get_embeddings_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
26622672
# LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
26632673
@ctypes_function(
26642674
"llama_get_embeddings", [llama_context_p_ctypes], ctypes.POINTER(ctypes.c_float)

0 commit comments

Comments
 (0)