@@ -1656,172 +1656,6 @@ static void ggml_compute_forward_mul_mat_id(
16561656 }
16571657}
16581658
1659- // ggml_compute_forward_delta_net
1660-
1661- static void ggml_compute_forward_delta_net (
1662- const struct ggml_compute_params * params ,
1663- struct ggml_tensor * dst ) {
1664-
1665- const struct ggml_tensor * src0 = dst -> src [0 ]; // query
1666- const struct ggml_tensor * src1 = dst -> src [1 ]; // key
1667- const struct ggml_tensor * src2 = dst -> src [2 ]; // value
1668- const struct ggml_tensor * src3 = dst -> src [3 ]; // gate
1669- const struct ggml_tensor * src4 = dst -> src [4 ]; // beta
1670- const struct ggml_tensor * src5 = dst -> src [5 ]; // state
1671-
1672- GGML_ASSERT (src0 -> type == GGML_TYPE_F32 );
1673- GGML_ASSERT (src1 -> type == GGML_TYPE_F32 );
1674- GGML_ASSERT (src2 -> type == GGML_TYPE_F32 );
1675- GGML_ASSERT (src3 -> type == GGML_TYPE_F32 );
1676- GGML_ASSERT (src4 -> type == GGML_TYPE_F32 );
1677- GGML_ASSERT (src5 -> type == GGML_TYPE_F32 );
1678- GGML_ASSERT (dst -> type == GGML_TYPE_F32 );
1679-
1680- GGML_TENSOR_TERNARY_OP_LOCALS ;
1681- GGML_TENSOR_LOCALS (int64_t , ne3 , src3 , ne );
1682- GGML_TENSOR_LOCALS (size_t , nb3 , src3 , nb );
1683- GGML_TENSOR_LOCALS (int64_t , ne4 , src4 , ne );
1684- GGML_TENSOR_LOCALS (size_t , nb4 , src4 , nb );
1685- GGML_TENSOR_LOCALS (int64_t , ne5 , src5 , ne );
1686- GGML_TENSOR_LOCALS (size_t , nb5 , src5 , nb );
1687-
1688- const int ith = params -> ith ;
1689- const int nth = params -> nth ;
1690-
1691- const int64_t S = src0 -> ne [0 ]; // head dimension
1692- const int64_t H = src0 -> ne [1 ]; // number of heads
1693- const int64_t n_tokens = src0 -> ne [2 ];
1694- const int64_t n_seqs = src0 -> ne [3 ];
1695-
1696- GGML_ASSERT (ne00 == S && ne01 == H && ne02 == n_tokens && ne03 == n_seqs );
1697- GGML_ASSERT (ne10 == S && ne11 == H && ne12 == n_tokens && ne13 == n_seqs );
1698- GGML_ASSERT (ne20 == S && ne21 == H && ne22 == n_tokens && ne23 == n_seqs );
1699- GGML_ASSERT (ne30 == S && ne31 == H && ne32 == n_tokens && ne33 == n_seqs );
1700- GGML_ASSERT (ne40 == H && ne41 == n_tokens && ne42 == n_seqs && ne43 == 1 );
1701- GGML_ASSERT (ne50 == S && ne51 == S && ne52 == H && ne53 == n_seqs );
1702-
1703- // Get operation parameters
1704- bool use_qk_l2norm = ggml_get_op_params_i32 (dst , 1 ) != 0 ;
1705- float scale ;
1706- memcpy (& scale , ((int32_t * )dst -> op_params ) + 4 , sizeof (float ));
1707-
1708- GGML_ASSERT (ne0 == S * H );
1709- GGML_ASSERT (ne1 == n_tokens + S * n_seqs );
1710-
1711- // Parallelize over sequences and heads
1712- const int64_t n_total = n_seqs * H ;
1713- const int64_t n_per_thread = (n_total + nth - 1 ) / nth ;
1714- const int64_t n_start = ith * n_per_thread ;
1715- const int64_t n_end = MIN (n_start + n_per_thread , n_total );
1716-
1717- for (int64_t n = n_start ; n < n_end ; ++ n ) {
1718- const int64_t seq_idx = n / H ;
1719- const int64_t head_idx = n % H ;
1720-
1721- // Get pointers to current sequence and head
1722- float * q_ptr = (float * )((char * )src0 -> data + seq_idx * nb03 + head_idx * nb01 );
1723- float * k_ptr = (float * )((char * )src1 -> data + seq_idx * nb13 + head_idx * nb11 );
1724- float * v_ptr = (float * )((char * )src2 -> data + seq_idx * nb23 + head_idx * nb21 );
1725- float * g_ptr = (float * )((char * )src3 -> data + seq_idx * nb33 + head_idx * nb31 );
1726- float * beta_ptr = (float * )((char * )src4 -> data + seq_idx * nb43 );
1727- float * state_ptr = (float * )((char * )src5 -> data + seq_idx * nb53 + head_idx * nb51 );
1728-
1729- float * out_ptr = (float * )((char * )dst -> data + n * ne0 * sizeof (float ));
1730- float * new_state_ptr = out_ptr + n_tokens * S ;
1731-
1732- // Apply L2 normalization if requested
1733- if (use_qk_l2norm ) {
1734- // Normalize query and key
1735- for (int64_t t = 0 ; t < n_tokens ; ++ t ) {
1736- float q_sum = 0.0f , k_sum = 0.0f ;
1737- for (int64_t s = 0 ; s < S ; ++ s ) {
1738- float q_val = q_ptr [t * nb02 / sizeof (float ) + s ];
1739- float k_val = k_ptr [t * nb12 / sizeof (float ) + s ];
1740- q_sum += q_val * q_val ;
1741- k_sum += k_val * k_val ;
1742- }
1743- float q_norm = sqrtf (q_sum + 1e-6f );
1744- float k_norm = sqrtf (k_sum + 1e-6f );
1745-
1746- for (int64_t s = 0 ; s < S ; ++ s ) {
1747- q_ptr [t * nb02 / sizeof (float ) + s ] /= q_norm ;
1748- k_ptr [t * nb12 / sizeof (float ) + s ] /= k_norm ;
1749- }
1750- }
1751- }
1752-
1753- // Apply scaling to query
1754- for (int64_t i = 0 ; i < n_tokens * S ; ++ i ) {
1755- q_ptr [i ] *= scale ;
1756- }
1757-
1758- // Apply sigmoid to beta
1759- float * beta_sigmoid = (float * )alloca (n_tokens * sizeof (float ));
1760- for (int64_t t = 0 ; t < n_tokens ; ++ t ) {
1761- beta_sigmoid [t ] = 1.0f / (1.0f + expf (- beta_ptr [t * nb42 / sizeof (float )]));
1762- }
1763-
1764- // Complete implementation of gated delta rule
1765- // Based on torch_recurrent_gated_delta_rule from the reference implementation
1766-
1767- // Process each token sequentially for recurrent computation
1768- for (int64_t t = 0 ; t < n_tokens ; ++ t ) {
1769- // Get pointers to current token data
1770- float * q_t = q_ptr + t * (nb02 / sizeof (float ));
1771- float * k_t = k_ptr + t * (nb12 / sizeof (float ));
1772- float * v_t = v_ptr + t * (nb22 / sizeof (float ));
1773- float * g_t = g_ptr + t * (nb32 / sizeof (float ));
1774-
1775- // Apply exponential to gate and multiply by beta
1776- float g_exp = expf (g_t [0 ]); // g is per-head, not per-dimension
1777- float beta_t = beta_sigmoid [t ];
1778-
1779- // Update recurrent state: state = state * g_exp
1780- for (int64_t i = 0 ; i < S * S ; ++ i ) {
1781- state_ptr [i ] *= g_exp ;
1782- }
1783-
1784- // Compute kv_mem = (state * k_t^T).sum(dim=-1)
1785- // This is a matrix-vector multiplication: state[S×S] @ k_t[S]
1786- float kv_mem [S ];
1787- for (int64_t i = 0 ; i < S ; ++ i ) {
1788- kv_mem [i ] = 0.0f ;
1789- for (int64_t j = 0 ; j < S ; ++ j ) {
1790- kv_mem [i ] += state_ptr [i * S + j ] * k_t [j ];
1791- }
1792- }
1793-
1794- // Compute delta = (v_t - kv_mem) * beta_t
1795- float delta [S ];
1796- for (int64_t i = 0 ; i < S ; ++ i ) {
1797- delta [i ] = (v_t [i ] - kv_mem [i ]) * beta_t ;
1798- }
1799-
1800- // Update state: state = state + k_t * delta^T
1801- // This is an outer product: k_t[S] ⊗ delta[S]
1802- for (int64_t i = 0 ; i < S ; ++ i ) {
1803- for (int64_t j = 0 ; j < S ; ++ j ) {
1804- state_ptr [i * S + j ] += k_t [i ] * delta [j ];
1805- }
1806- }
1807-
1808- // Compute output: out = (state * q_t^T).sum(dim=-1)
1809- // This is a matrix-vector multiplication: state[S×S] @ q_t[S]
1810- float * out_t = out_ptr + t * S ;
1811- for (int64_t i = 0 ; i < S ; ++ i ) {
1812- out_t [i ] = 0.0f ;
1813- for (int64_t j = 0 ; j < S ; ++ j ) {
1814- out_t [i ] += state_ptr [i * S + j ] * q_t [j ];
1815- }
1816- }
1817- }
1818-
1819- // Copy final state to new_state
1820- memcpy (new_state_ptr , state_ptr , S * S * sizeof (float ));
1821- }
1822- }
1823-
1824-
18251659/////////////////////////////////
18261660
18271661static void ggml_compute_forward (struct ggml_compute_params * params , struct ggml_tensor * tensor ) {
@@ -2164,10 +1998,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
21641998 {
21651999 ggml_compute_forward_rwkv_wkv7 (params , tensor );
21662000 } break ;
2167- case GGML_OP_DELTA_NET :
2168- {
2169- ggml_compute_forward_delta_net (params , tensor );
2170- } break ;
21712001 case GGML_OP_MAP_CUSTOM1 :
21722002 {
21732003 ggml_compute_forward_map_custom1 (params , tensor );
@@ -2461,7 +2291,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
24612291 case GGML_OP_RWKV_WKV6 :
24622292 case GGML_OP_GATED_LINEAR_ATTN :
24632293 case GGML_OP_RWKV_WKV7 :
2464- case GGML_OP_DELTA_NET :
24652294 {
24662295 n_tasks = n_threads ;
24672296 } break ;
0 commit comments