Skip to content

Conversation

@zacharydenton
Copy link
Contributor

This adjusts the normalize logic to match pytorch (i.e. always take variance + epsilon instead of ignoring it when variance < epsilon) to fix #621.

@seanmor5
Copy link
Contributor

Thanks! IIRC this was based off of an old reference implementation from Jax that I can't seem to find anymore. The closest I can see now is jax.nn.standardize, and AFAICT this is pretty much equivalent: https://github.com/jax-ml/jax/blob/9af721622fa57a5740730669692c0896bde6e50e/jax/_src/nn/functions.py#L652

We rely on Nx.variance which uses the more accurate variance impl and cannot ever be negative, so we don't need to clip

@seanmor5
Copy link
Contributor

@zacharydenton Can you run mix format ?

@seanmor5 seanmor5 merged commit d846c97 into elixir-nx:main Dec 18, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

normalize logic differs from pytorch

2 participants