Skip to content

Commit 128455f

Browse files
author
Alexander
committed
fixed bug, where mpx would not work with distributed training in jax
1 parent d2ac686 commit 128455f

4 files changed

+69
-3
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
\begin{Verbatim}[commandchars=\\\{\},codes={\catcode`\$=3\catcode`\^=7\catcode`\_=8\relax}]
2+
\PYG{k}{class} \PYG{n+nc}{MultiHeadAttentionBlock}\PYG{p}{(}\PYG{n}{eqx}\PYG{o}{.}\PYG{n}{Module}\PYG{p}{):}
3+
\PYG{n}{dense\PYGZus{}qs}\PYG{p}{:} \PYG{n}{eqx}\PYG{o}{.}\PYG{n}{nn}\PYG{o}{.}\PYG{n}{Linear}
4+
\PYG{n}{dense\PYGZus{}ks}\PYG{p}{:} \PYG{n}{eqx}\PYG{o}{.}\PYG{n}{nn}\PYG{o}{.}\PYG{n}{Linear}
5+
\PYG{n}{dense\PYGZus{}vs}\PYG{p}{:} \PYG{n}{eqx}\PYG{o}{.}\PYG{n}{nn}\PYG{o}{.}\PYG{n}{Linear}
6+
\PYG{n}{dense\PYGZus{}o}\PYG{p}{:} \PYG{n}{eqx}\PYG{o}{.}\PYG{n}{nn}\PYG{o}{.}\PYG{n}{Linear}
7+
\PYG{n}{num\PYGZus{}heads}\PYG{p}{:} \PYG{n+nb}{int}
8+
\PYG{n}{layer\PYGZus{}norm}\PYG{p}{:} \PYG{n}{eqx}\PYG{o}{.}\PYG{n}{nn}\PYG{o}{.}\PYG{n}{LayerNorm}
9+
10+
\PYG{k}{def} \PYG{n+nf+fm}{\PYGZus{}\PYGZus{}init\PYGZus{}\PYGZus{}}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{,} \PYG{n}{feature\PYGZus{}dim}\PYG{p}{,} \PYG{n}{num\PYGZus{}heads}\PYG{p}{,} \PYG{n}{key}\PYG{p}{):}
11+
\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{num\PYGZus{}heads} \PYG{o}{=} \PYG{n}{num\PYGZus{}heads}
12+
\PYG{n}{key}\PYG{p}{,} \PYG{n}{subkey} \PYG{o}{=} \PYG{n}{jax}\PYG{o}{.}\PYG{n}{random}\PYG{o}{.}\PYG{n}{split}\PYG{p}{(}\PYG{n}{key}\PYG{p}{)}
13+
\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{dense\PYGZus{}qs} \PYG{o}{=} \PYG{n}{eqx}\PYG{o}{.}\PYG{n}{nn}\PYG{o}{.}\PYG{n}{Linear}\PYG{p}{(}
14+
\PYG{n}{feature\PYGZus{}dim}\PYG{p}{,} \PYG{n}{feature\PYGZus{}dim}\PYG{p}{,} \PYG{n}{key}\PYG{o}{=}\PYG{n}{subkey}\PYG{p}{)}
15+
\PYG{c+c1}{\PYGZsh{} same for dense\PYGZus{}ks, dense\PYGZus{}vs, dense\PYGZus{}o}
16+
17+
\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{layer\PYGZus{}norm} \PYG{o}{=} \PYG{n}{eqx}\PYG{o}{.}\PYG{n}{nn}\PYG{o}{.}\PYG{n}{LayerNorm}\PYG{p}{(}\PYG{n}{feature\PYGZus{}dim}\PYG{p}{)}
18+
19+
\PYG{k}{def} \PYG{n+nf}{attention}\PYG{p}{(}\PYG{n}{q}\PYG{p}{,} \PYG{n}{k}\PYG{p}{,} \PYG{n}{v}\PYG{p}{):}
20+
\PYG{n}{attention\PYGZus{}scores} \PYG{o}{=} \PYG{n}{q} \PYG{o}{@} \PYG{n}{k}\PYG{o}{.}\PYG{n}{T} \PYG{o}{/} \PYG{n}{jnp}\PYG{o}{.}\PYG{n}{sqrt}\PYG{p}{(}\PYG{n}{q}\PYG{o}{.}\PYG{n}{shape}\PYG{p}{[}\PYG{o}{\PYGZhy{}}\PYG{l+m+mi}{1}\PYG{p}{])}
21+
\PYG{n}{attention\PYGZus{}scores} \PYG{o}{=} \PYG{n}{mpx}\PYG{o}{.}\PYG{n}{force\PYGZus{}full\PYGZus{}precision}\PYG{p}{(}
22+
\PYG{n}{jax}\PYG{o}{.}\PYG{n}{nn}\PYG{o}{.}\PYG{n}{softmax}\PYG{p}{,} \PYG{n}{attention\PYGZus{}scores}\PYG{o}{.}\PYG{n}{dtype}\PYG{p}{)(}\PYG{n}{attention\PYGZus{}scores}\PYG{p}{,} \PYG{n}{axis}\PYG{o}{=\PYGZhy{}}\PYG{l+m+mi}{1}\PYG{p}{)}
23+
\PYG{k}{return} \PYG{n}{attention\PYGZus{}scores} \PYG{o}{@} \PYG{n}{v}
24+
25+
\PYG{k}{def} \PYG{n+nf+fm}{\PYGZus{}\PYGZus{}call\PYGZus{}\PYGZus{}}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{,} \PYG{n}{inputs}\PYG{p}{):}
26+
\PYG{n}{inputs\PYGZus{}after\PYGZus{}layernorm} \PYG{o}{=} \PYG{n}{jax}\PYG{o}{.}\PYG{n}{vmap}\PYG{p}{(}\PYG{n}{mpx}\PYG{o}{.}\PYG{n}{force\PYGZus{}full\PYGZus{}precision}\PYG{p}{(}
27+
\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{layer\PYGZus{}norm}\PYG{p}{,} \PYG{n}{inputs}\PYG{o}{.}\PYG{n}{dtype}\PYG{p}{))(}\PYG{n}{inputs}\PYG{p}{)}
28+
\PYG{n}{qs} \PYG{o}{=} \PYG{n}{jax}\PYG{o}{.}\PYG{n}{vmap}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{dense\PYGZus{}qs}\PYG{p}{)(}\PYG{n}{inputs\PYGZus{}after\PYGZus{}layernorm}\PYG{p}{)}
29+
\PYG{n}{qs} \PYG{o}{=} \PYG{n}{es}\PYG{o}{.}\PYG{n}{jax\PYGZus{}einshape}\PYG{p}{(}\PYG{l+s+s2}{\PYGZdq{}n(hf)\PYGZhy{}\PYGZgt{}hnf\PYGZdq{}}\PYG{p}{,} \PYG{n}{qs}\PYG{p}{,} \PYG{n}{h}\PYG{o}{=}\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{num\PYGZus{}heads}\PYG{p}{)}
30+
\PYG{c+c1}{\PYGZsh{} same for ks and vs...}
31+
32+
\PYG{n}{outputs} \PYG{o}{=} \PYG{n}{jax}\PYG{o}{.}\PYG{n}{vmap}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{attention}\PYG{p}{,} \PYG{n}{in\PYGZus{}axes}\PYG{o}{=}\PYG{p}{(}\PYG{l+m+mi}{0}\PYG{p}{,} \PYG{l+m+mi}{0}\PYG{p}{,} \PYG{l+m+mi}{0}\PYG{p}{))(}\PYG{n}{qs}\PYG{p}{,} \PYG{n}{ks}\PYG{p}{,} \PYG{n}{vs}\PYG{p}{)}
33+
\PYG{n}{outputs} \PYG{o}{=} \PYG{n}{es}\PYG{o}{.}\PYG{n}{jax\PYGZus{}einshape}\PYG{p}{(}\PYG{l+s+s2}{\PYGZdq{}hnf\PYGZhy{}\PYGZgt{}n(hf)\PYGZdq{}}\PYG{p}{,} \PYG{n}{outputs}\PYG{p}{)}
34+
\PYG{n}{outputs} \PYG{o}{=} \PYG{n}{jax}\PYG{o}{.}\PYG{n}{vmap}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{dense\PYGZus{}o}\PYG{p}{)(}\PYG{n}{outputs}\PYG{p}{)}
35+
\PYG{n}{outputs} \PYG{o}{+=} \PYG{n}{inputs}
36+
37+
\PYG{k}{return} \PYG{n}{outputs}
38+
\end{Verbatim}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
\begin{Verbatim}[commandchars=\\\{\},codes={\catcode`\$=3\catcode`\^=7\catcode`\_=8\relax}]
2+
\PYG{n}{grads} \PYG{o}{=} \PYG{n}{eqx}\PYG{o}{.}\PYG{n}{filter\PYGZus{}grad}\PYG{p}{(}\PYG{n}{loss}\PYG{p}{)(}\PYG{n}{model}\PYG{p}{,} \PYG{n}{batch}\PYG{p}{)}
3+
\PYG{n}{updates}\PYG{p}{,} \PYG{n}{optimizer\PYGZus{}state} \PYG{o}{=} \PYG{n}{optimizer}\PYG{o}{.}\PYG{n}{update}\PYG{p}{(}
4+
\PYG{n}{grads}\PYG{p}{,} \PYG{n}{optimizer\PYGZus{}state}\PYG{p}{,} \PYG{n}{eqx}\PYG{o}{.}\PYG{n}{filter}\PYG{p}{(}\PYG{n}{model}\PYG{p}{,} \PYG{n}{eqx}\PYG{o}{.}\PYG{n}{is\PYGZus{}array}\PYG{p}{))}
5+
\PYG{n}{model} \PYG{o}{=} \PYG{n}{eqx}\PYG{o}{.}\PYG{n}{apply\PYGZus{}updates}\PYG{p}{(}\PYG{n}{model}\PYG{p}{,} \PYG{n}{updates}\PYG{p}{)}
6+
\end{Verbatim}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
\begin{Verbatim}[commandchars=\\\{\},codes={\catcode`\$=3\catcode`\^=7\catcode`\_=8\relax}]
2+
\PYG{n}{loss\PYGZus{}scaling}\PYG{p}{,} \PYG{n}{grads\PYGZus{}finite}\PYG{p}{,} \PYG{n}{grads} \PYG{o}{=} \PYG{n}{mpx}\PYG{o}{.}\PYG{n}{filter\PYGZus{}grad}\PYG{p}{(}\PYG{n}{loss}\PYG{p}{,} \PYG{n}{loss\PYGZus{}scaling}\PYG{p}{)(}
3+
\PYG{n}{model}\PYG{p}{,} \PYG{n}{batch}\PYG{p}{)}
4+
\PYG{n}{model}\PYG{p}{,} \PYG{n}{optimizer\PYGZus{}state} \PYG{o}{=} \PYG{n}{mpx}\PYG{o}{.}\PYG{n}{optimizer\PYGZus{}update}\PYG{p}{(}
5+
\PYG{n}{model}\PYG{p}{,} \PYG{n}{optimizer}\PYG{p}{,} \PYG{n}{optimizer\PYGZus{}state}\PYG{p}{,} \PYG{n}{grads}\PYG{p}{,}\PYG{n}{grads\PYGZus{}finite}\PYG{p}{)}
6+
\end{Verbatim}

mpx/_dtypes.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,24 @@
11
import jax.numpy as jnp
22

3-
HALF_PRECISION_DATATYPE = jnp.float16
4-
FLOAT16_MAX = jnp.ones([], dtype=jnp.float32) * (2 - 2**(-10)) * 2**15
5-
BFLOAT16_MAX = jnp.array([((2**8 - 1) * 2**(120))], dtype=jnp.float32)[0]
3+
import sys
4+
import types
5+
6+
# We do to avoid that jax is directly called when importing this module.
7+
# This is to ensure that mpx works with distributed training.
8+
class _MaxConstantsLazyInit(types.ModuleType):
9+
@property
10+
def HALF_PRECISION_DATATYPE(self):
11+
return jnp.float16
12+
13+
@property
14+
def FLOAT16_MAX(self):
15+
return jnp.ones([], dtype=jnp.float32) * (2 - 2**(-10)) * 2**15
16+
17+
@property
18+
def BFLOAT16_MAX(self):
19+
return jnp.array([((2**8 - 1) * 2**(120))], dtype=jnp.float32)[0]
20+
21+
sys.modules[__name__].__class__ = _MaxConstantsLazyInit
622

723
def set_half_precision_datatype(datatype):
824
"""

0 commit comments

Comments
 (0)