Skip to content

Commit e24c9df

Browse files
committed
Remove OP_DELTA_NET, fix flake8 and editorchecker because why not
1 parent 6e3abeb commit e24c9df

File tree

5 files changed

+17
-179
lines changed

5 files changed

+17
-179
lines changed

convert_hf_to_gguf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3748,6 +3748,7 @@ def set_vocab(self):
37483748

37493749
super().set_vocab()
37503750

3751+
37513752
@ModelBase.register("Qwen3NextForCausalLM")
37523753
class Qwen3NextModel(Qwen3MoeModel):
37533754
model_arch = gguf.MODEL_ARCH.QWEN3NEXT

ggml/include/ggml.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -539,8 +539,7 @@ extern "C" {
539539
GGML_OP_RWKV_WKV6,
540540
GGML_OP_GATED_LINEAR_ATTN,
541541
GGML_OP_RWKV_WKV7,
542-
GGML_OP_DELTA_NET,
543-
542+
544543
GGML_OP_UNARY,
545544

546545
GGML_OP_MAP_CUSTOM1,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 0 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,172 +1656,6 @@ static void ggml_compute_forward_mul_mat_id(
16561656
}
16571657
}
16581658

1659-
// ggml_compute_forward_delta_net
1660-
1661-
static void ggml_compute_forward_delta_net(
1662-
const struct ggml_compute_params * params,
1663-
struct ggml_tensor * dst) {
1664-
1665-
const struct ggml_tensor * src0 = dst->src[0]; // query
1666-
const struct ggml_tensor * src1 = dst->src[1]; // key
1667-
const struct ggml_tensor * src2 = dst->src[2]; // value
1668-
const struct ggml_tensor * src3 = dst->src[3]; // gate
1669-
const struct ggml_tensor * src4 = dst->src[4]; // beta
1670-
const struct ggml_tensor * src5 = dst->src[5]; // state
1671-
1672-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
1673-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
1674-
GGML_ASSERT(src2->type == GGML_TYPE_F32);
1675-
GGML_ASSERT(src3->type == GGML_TYPE_F32);
1676-
GGML_ASSERT(src4->type == GGML_TYPE_F32);
1677-
GGML_ASSERT(src5->type == GGML_TYPE_F32);
1678-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
1679-
1680-
GGML_TENSOR_TERNARY_OP_LOCALS;
1681-
GGML_TENSOR_LOCALS(int64_t, ne3, src3, ne);
1682-
GGML_TENSOR_LOCALS(size_t, nb3, src3, nb);
1683-
GGML_TENSOR_LOCALS(int64_t, ne4, src4, ne);
1684-
GGML_TENSOR_LOCALS(size_t, nb4, src4, nb);
1685-
GGML_TENSOR_LOCALS(int64_t, ne5, src5, ne);
1686-
GGML_TENSOR_LOCALS(size_t, nb5, src5, nb);
1687-
1688-
const int ith = params->ith;
1689-
const int nth = params->nth;
1690-
1691-
const int64_t S = src0->ne[0]; // head dimension
1692-
const int64_t H = src0->ne[1]; // number of heads
1693-
const int64_t n_tokens = src0->ne[2];
1694-
const int64_t n_seqs = src0->ne[3];
1695-
1696-
GGML_ASSERT(ne00 == S && ne01 == H && ne02 == n_tokens && ne03 == n_seqs);
1697-
GGML_ASSERT(ne10 == S && ne11 == H && ne12 == n_tokens && ne13 == n_seqs);
1698-
GGML_ASSERT(ne20 == S && ne21 == H && ne22 == n_tokens && ne23 == n_seqs);
1699-
GGML_ASSERT(ne30 == S && ne31 == H && ne32 == n_tokens && ne33 == n_seqs);
1700-
GGML_ASSERT(ne40 == H && ne41 == n_tokens && ne42 == n_seqs && ne43 == 1);
1701-
GGML_ASSERT(ne50 == S && ne51 == S && ne52 == H && ne53 == n_seqs);
1702-
1703-
// Get operation parameters
1704-
bool use_qk_l2norm = ggml_get_op_params_i32(dst, 1) != 0;
1705-
float scale;
1706-
memcpy(&scale, ((int32_t*)dst->op_params) + 4, sizeof(float));
1707-
1708-
GGML_ASSERT(ne0 == S * H);
1709-
GGML_ASSERT(ne1 == n_tokens + S * n_seqs);
1710-
1711-
// Parallelize over sequences and heads
1712-
const int64_t n_total = n_seqs * H;
1713-
const int64_t n_per_thread = (n_total + nth - 1) / nth;
1714-
const int64_t n_start = ith * n_per_thread;
1715-
const int64_t n_end = MIN(n_start + n_per_thread, n_total);
1716-
1717-
for (int64_t n = n_start; n < n_end; ++n) {
1718-
const int64_t seq_idx = n / H;
1719-
const int64_t head_idx = n % H;
1720-
1721-
// Get pointers to current sequence and head
1722-
float * q_ptr = (float *)((char *)src0->data + seq_idx * nb03 + head_idx * nb01);
1723-
float * k_ptr = (float *)((char *)src1->data + seq_idx * nb13 + head_idx * nb11);
1724-
float * v_ptr = (float *)((char *)src2->data + seq_idx * nb23 + head_idx * nb21);
1725-
float * g_ptr = (float *)((char *)src3->data + seq_idx * nb33 + head_idx * nb31);
1726-
float * beta_ptr = (float *)((char *)src4->data + seq_idx * nb43);
1727-
float * state_ptr = (float *)((char *)src5->data + seq_idx * nb53 + head_idx * nb51);
1728-
1729-
float * out_ptr = (float *)((char *)dst->data + n * ne0 * sizeof(float));
1730-
float * new_state_ptr = out_ptr + n_tokens * S;
1731-
1732-
// Apply L2 normalization if requested
1733-
if (use_qk_l2norm) {
1734-
// Normalize query and key
1735-
for (int64_t t = 0; t < n_tokens; ++t) {
1736-
float q_sum = 0.0f, k_sum = 0.0f;
1737-
for (int64_t s = 0; s < S; ++s) {
1738-
float q_val = q_ptr[t * nb02 / sizeof(float) + s];
1739-
float k_val = k_ptr[t * nb12 / sizeof(float) + s];
1740-
q_sum += q_val * q_val;
1741-
k_sum += k_val * k_val;
1742-
}
1743-
float q_norm = sqrtf(q_sum + 1e-6f);
1744-
float k_norm = sqrtf(k_sum + 1e-6f);
1745-
1746-
for (int64_t s = 0; s < S; ++s) {
1747-
q_ptr[t * nb02 / sizeof(float) + s] /= q_norm;
1748-
k_ptr[t * nb12 / sizeof(float) + s] /= k_norm;
1749-
}
1750-
}
1751-
}
1752-
1753-
// Apply scaling to query
1754-
for (int64_t i = 0; i < n_tokens * S; ++i) {
1755-
q_ptr[i] *= scale;
1756-
}
1757-
1758-
// Apply sigmoid to beta
1759-
float * beta_sigmoid = (float *)alloca(n_tokens * sizeof(float));
1760-
for (int64_t t = 0; t < n_tokens; ++t) {
1761-
beta_sigmoid[t] = 1.0f / (1.0f + expf(-beta_ptr[t * nb42 / sizeof(float)]));
1762-
}
1763-
1764-
// Complete implementation of gated delta rule
1765-
// Based on torch_recurrent_gated_delta_rule from the reference implementation
1766-
1767-
// Process each token sequentially for recurrent computation
1768-
for (int64_t t = 0; t < n_tokens; ++t) {
1769-
// Get pointers to current token data
1770-
float * q_t = q_ptr + t * (nb02 / sizeof(float));
1771-
float * k_t = k_ptr + t * (nb12 / sizeof(float));
1772-
float * v_t = v_ptr + t * (nb22 / sizeof(float));
1773-
float * g_t = g_ptr + t * (nb32 / sizeof(float));
1774-
1775-
// Apply exponential to gate and multiply by beta
1776-
float g_exp = expf(g_t[0]); // g is per-head, not per-dimension
1777-
float beta_t = beta_sigmoid[t];
1778-
1779-
// Update recurrent state: state = state * g_exp
1780-
for (int64_t i = 0; i < S * S; ++i) {
1781-
state_ptr[i] *= g_exp;
1782-
}
1783-
1784-
// Compute kv_mem = (state * k_t^T).sum(dim=-1)
1785-
// This is a matrix-vector multiplication: state[S×S] @ k_t[S]
1786-
float kv_mem[S];
1787-
for (int64_t i = 0; i < S; ++i) {
1788-
kv_mem[i] = 0.0f;
1789-
for (int64_t j = 0; j < S; ++j) {
1790-
kv_mem[i] += state_ptr[i * S + j] * k_t[j];
1791-
}
1792-
}
1793-
1794-
// Compute delta = (v_t - kv_mem) * beta_t
1795-
float delta[S];
1796-
for (int64_t i = 0; i < S; ++i) {
1797-
delta[i] = (v_t[i] - kv_mem[i]) * beta_t;
1798-
}
1799-
1800-
// Update state: state = state + k_t * delta^T
1801-
// This is an outer product: k_t[S] ⊗ delta[S]
1802-
for (int64_t i = 0; i < S; ++i) {
1803-
for (int64_t j = 0; j < S; ++j) {
1804-
state_ptr[i * S + j] += k_t[i] * delta[j];
1805-
}
1806-
}
1807-
1808-
// Compute output: out = (state * q_t^T).sum(dim=-1)
1809-
// This is a matrix-vector multiplication: state[S×S] @ q_t[S]
1810-
float * out_t = out_ptr + t * S;
1811-
for (int64_t i = 0; i < S; ++i) {
1812-
out_t[i] = 0.0f;
1813-
for (int64_t j = 0; j < S; ++j) {
1814-
out_t[i] += state_ptr[i * S + j] * q_t[j];
1815-
}
1816-
}
1817-
}
1818-
1819-
// Copy final state to new_state
1820-
memcpy(new_state_ptr, state_ptr, S * S * sizeof(float));
1821-
}
1822-
}
1823-
1824-
18251659
/////////////////////////////////
18261660

18271661
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -2164,10 +1998,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
21641998
{
21651999
ggml_compute_forward_rwkv_wkv7(params, tensor);
21662000
} break;
2167-
case GGML_OP_DELTA_NET:
2168-
{
2169-
ggml_compute_forward_delta_net(params, tensor);
2170-
} break;
21712001
case GGML_OP_MAP_CUSTOM1:
21722002
{
21732003
ggml_compute_forward_map_custom1(params, tensor);
@@ -2461,7 +2291,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
24612291
case GGML_OP_RWKV_WKV6:
24622292
case GGML_OP_GATED_LINEAR_ATTN:
24632293
case GGML_OP_RWKV_WKV7:
2464-
case GGML_OP_DELTA_NET:
24652294
{
24662295
n_tasks = n_threads;
24672296
} break;

ggml/src/ggml.c

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,7 +1002,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
10021002
"RWKV_WKV6",
10031003
"GATED_LINEAR_ATTN",
10041004
"RWKV_WKV7",
1005-
"DELTA_NET",
10061005

10071006
"UNARY",
10081007

@@ -1020,7 +1019,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
10201019
"GLU",
10211020
};
10221021

1023-
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");
1022+
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
10241023

10251024
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10261025
"none",
@@ -1107,7 +1106,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
11071106
"rwkv_wkv6(k, v, r, tf, td, s)",
11081107
"gated_linear_attn(k, v, q, gate, s)",
11091108
"rwkv_wkv7(r, w, k, v, a, b, s)",
1110-
"delta_net(k, v, q, g, conv_w, conv_b, beta, state, chunk_size, use_qk_l2norm, scale)",
11111109

11121110
"unary(x)",
11131111

@@ -1125,7 +1123,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
11251123
"glu(x)",
11261124
};
11271125

1128-
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");
1126+
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
11291127

11301128
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
11311129

src/llama-model.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18958,6 +18958,7 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
1895818958
ggml_tensor * inpL;
1895918959

1896018960
inpL = build_inp_embd(model.tok_embd);
18961+
cb(inpL, "model.embed_tokens", -1);
1896118962

1896218963
auto * inp = build_inp_mem_hybrid();
1896318964

@@ -19259,21 +19260,25 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
1925919260
ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads,
1926019261
n_tokens, n_seqs, split_sizes_qkvz[0] * sizeof(float),
1926119262
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], 0));
19263+
cb(query, "q", il);
1926219264

1926319265
ggml_tensor * key =
1926419266
ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, n_seqs,
1926519267
split_sizes_qkvz[1] * sizeof(float), mixed_qkvz_reshaped->nb[1],
1926619268
mixed_qkvz_reshaped->nb[2], split_sizes_qkvz[0] * sizeof(float)));
19269+
cb(query, "k", il);
1926719270

1926819271
ggml_tensor * value =
1926919272
ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tokens, n_seqs,
1927019273
split_sizes_qkvz[2] * sizeof(float), mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2],
1927119274
(split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float));
19275+
cb(query, "v", il);
1927219276

1927319277
ggml_tensor * z =
1927419278
ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tokens, n_seqs,
1927519279
split_sizes_qkvz[3] * sizeof(float), mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2],
1927619280
(split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
19281+
cb(query, "z", il);
1927719282

1927819283
// Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
1927919284
ggml_tensor * value_reshaped =
@@ -19293,10 +19298,12 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
1929319298
ggml_tensor * b =
1929419299
ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_tokens, n_seqs,
1929519300
split_sizes_ba[0] * sizeof(float), mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], 0);
19301+
cb(query, "b", il);
1929619302

1929719303
ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_tokens, n_seqs,
1929819304
split_sizes_ba[1] * sizeof(float), mixed_ba_reshaped->nb[1],
1929919305
mixed_ba_reshaped->nb[2], split_sizes_ba[0] * sizeof(float));
19306+
cb(query, "a", il);
1930019307

1930119308
// Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
1930219309
ggml_tensor * beta = ggml_reshape_3d(ctx0, ggml_cont(ctx0, b), num_v_heads, n_tokens, n_seqs);
@@ -19305,16 +19312,21 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
1930519312
GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));
1930619313

1930719314
ggml_tensor * alpha_softplus = softplus(alpha, model.layers[il].ssm_dt);
19315+
cb(alpha_softplus, "a_softplus", il);
1930819316
ggml_tensor * A_log_exp = ggml_exp(ctx0, model.layers[il].ssm_a); // A_log.exp()
19317+
cb(A_log_exp, "a_logexp", il);
1930919318
ggml_tensor * gate_scaled = ggml_mul(ctx0, alpha_softplus, A_log_exp); // A_log.exp() * softplus
19319+
cb(gate_scaled, "gate_scaled", il);
1931019320
ggml_tensor * gate = ggml_scale(ctx0, gate_scaled, -1.0f); // - (A_log.exp() * softplus)
19321+
cb(gate, "gate", il);
1931119322

1931219323
// Get convolution states from cache
1931319324
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
1931419325
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
1931519326

1931619327
// Build the convolution states tensor
1931719328
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
19329+
cb(conv_states, "conv_states", il);
1931819330

1931919331
// Calculate convolution kernel size
1932019332
const int64_t conv_kernel_size = model.layers[il].ssm_conv1d->ne[0];
@@ -19396,7 +19408,6 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
1939619408
ggml_tensor * target_gate = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_dim, n_heads, n_tokens, n_seqs);
1939719409
ggml_tensor * gate_broadcast = ggml_reshape_4d(ctx0, gate, 1, n_heads, n_tokens, n_seqs);
1939819410
gate = ggml_repeat(ctx0, gate_broadcast, target_gate);
19399-
cb(gate, "gate", il);
1940019411

1940119412
// Call the new ggml_delta_net function with the corrected flow
1940219413
ggml_tensor * output = ggml_delta_net(k_conv, v_conv, q_conv, gate, beta, state_broadcast, true, 1.0f, il);
@@ -20190,6 +20201,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
2019020201
case LLM_ARCH_ARCEE:
2019120202
case LLM_ARCH_ERNIE4_5:
2019220203
case LLM_ARCH_ERNIE4_5_MOE:
20204+
case LLM_ARCH_QWEN3NEXT:
2019320205
return LLAMA_ROPE_TYPE_NORM;
2019420206

2019520207
// the pairs of head values are offset by n_rot/2
@@ -20209,7 +20221,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
2020920221
case LLM_ARCH_QWEN2MOE:
2021020222
case LLM_ARCH_QWEN3:
2021120223
case LLM_ARCH_QWEN3MOE:
20212-
case LLM_ARCH_QWEN3NEXT:
2021320224
case LLM_ARCH_LLADA_MOE:
2021420225
case LLM_ARCH_OLMO2:
2021520226
case LLM_ARCH_OLMOE:

0 commit comments

Comments
 (0)