Skip to content

Feature/gated dn#73

Open
francescoshox wants to merge 4 commits intomainfrom
feature/gated-dn
Open

Feature/gated dn#73
francescoshox wants to merge 4 commits intomainfrom
feature/gated-dn

Conversation

@francescoshox
Copy link
Collaborator

@francescoshox francescoshox commented Mar 9, 2026

Description

Type of Change

  • 🐛 Bug fix
  • ✨ New feature
  • 💥 Breaking change
  • 📚 Documentation
  • Other

Changes Made

  • Added GatedDeltaNet

@gemini-code-assist
Copy link

Important

Installation incomplete: to start using Gemini Code Assist, please ask the organization owner(s) to visit the Gemini Code Assist Admin Console and sign the Terms of Services.

Copy link
Collaborator

@phnazari phnazari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you iterate on my comments? Once there is another GDN (Gated DeltaNet) version, I will review the logic more closely

from discretax.utils.config_mixin import Resolvable


class _RMSNorm(eqx.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which is this private? can we create a folder called "modules" and put RMSNorm in there? Similar to this: https://github.com/fla-org/flash-linear-attention/tree/main/fla/modules

"""Minimal RMSNorm module.

This lightweight implementation is local to the GatedDeltaNet block to
reproduce RMS pre-normalization.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this comment

*args: Additional positional arguments (ignored).
**kwargs: Extra kwargs forwarded to mixer resolvers.
"""
del args
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird no?

"""
block_keys = jr.split(key, len(self.blocks))
for block, block_key in zip(self.blocks, block_keys):
x, state = block(x, state, key=block_key)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Miscellaneous
*.DS_Store
data_dir/
extra/
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we maybe call this dev/? Sorry for the nitpick haha, but then we all have the same structure. also I think this is somehow also more common

return x + jnp.log(-jnp.expm1(-x))


def _l2norm(x: Array, eps: float = 1e-6) -> Array:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also this: can we make this a module? (see my comment on RMSNorm in the GDN block)

def _gated_delta_recurrent_step(
h: Array, x: tuple[Array, Array, Array, Array, Array]
) -> tuple[Array, Array]:
"""Single recurrent update for gated delta rule.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you ask claude to write the math of this function here in the comment? Easier to review

_, n_heads, k_dim = q.shape
v_dim = v.shape[-1]
h0 = jnp.zeros((n_heads, k_dim, v_dim), dtype=q.dtype)
_, out = jax.lax.scan(_gated_delta_recurrent_step, h0, (q, k, v, g, beta))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

man this is a purely sequential implementation of gated deltanet

scale = jnp.sqrt(float(self.head_k_dim))
q = q / scale

if self.mode == "chunk":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

look here, it is calling the same recursion regardless. One of the should be chunked

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/naive.py

look, there is a naive recurrent and a naive chunked version there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants