Conversation
|
Important Installation incomplete: to start using Gemini Code Assist, please ask the organization owner(s) to visit the Gemini Code Assist Admin Console and sign the Terms of Services. |
phnazari
left a comment
There was a problem hiding this comment.
Can you iterate on my comments? Once there is another GDN (Gated DeltaNet) version, I will review the logic more closely
| from discretax.utils.config_mixin import Resolvable | ||
|
|
||
|
|
||
| class _RMSNorm(eqx.Module): |
There was a problem hiding this comment.
which is this private? can we create a folder called "modules" and put RMSNorm in there? Similar to this: https://github.com/fla-org/flash-linear-attention/tree/main/fla/modules
| """Minimal RMSNorm module. | ||
|
|
||
| This lightweight implementation is local to the GatedDeltaNet block to | ||
| reproduce RMS pre-normalization. |
| *args: Additional positional arguments (ignored). | ||
| **kwargs: Extra kwargs forwarded to mixer resolvers. | ||
| """ | ||
| del args |
| """ | ||
| block_keys = jr.split(key, len(self.blocks)) | ||
| for block, block_key in zip(self.blocks, block_keys): | ||
| x, state = block(x, state, key=block_key) |
There was a problem hiding this comment.
the original block has another norm: https://github.com/fla-org/flash-linear-attention/blob/7cbe461b10bde60180c32c5f9381a04d41887514/fla/models/gated_deltanet/modeling_gated_deltanet.py#L252
Is this applied down-stream?
| # Miscellaneous | ||
| *.DS_Store | ||
| data_dir/ | ||
| extra/ |
There was a problem hiding this comment.
can we maybe call this dev/? Sorry for the nitpick haha, but then we all have the same structure. also I think this is somehow also more common
| return x + jnp.log(-jnp.expm1(-x)) | ||
|
|
||
|
|
||
| def _l2norm(x: Array, eps: float = 1e-6) -> Array: |
There was a problem hiding this comment.
also this: can we make this a module? (see my comment on RMSNorm in the GDN block)
| def _gated_delta_recurrent_step( | ||
| h: Array, x: tuple[Array, Array, Array, Array, Array] | ||
| ) -> tuple[Array, Array]: | ||
| """Single recurrent update for gated delta rule. |
There was a problem hiding this comment.
can you ask claude to write the math of this function here in the comment? Easier to review
| _, n_heads, k_dim = q.shape | ||
| v_dim = v.shape[-1] | ||
| h0 = jnp.zeros((n_heads, k_dim, v_dim), dtype=q.dtype) | ||
| _, out = jax.lax.scan(_gated_delta_recurrent_step, h0, (q, k, v, g, beta)) |
There was a problem hiding this comment.
man this is a purely sequential implementation of gated deltanet
| scale = jnp.sqrt(float(self.head_k_dim)) | ||
| q = q / scale | ||
|
|
||
| if self.mode == "chunk": |
There was a problem hiding this comment.
look here, it is calling the same recursion regardless. One of the should be chunked
There was a problem hiding this comment.
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/naive.py
look, there is a naive recurrent and a naive chunked version there.
Description
Type of Change
Changes Made
GatedDeltaNet