|
| 1 | +\begin{Verbatim}[commandchars=\\\{\},codes={\catcode`\$=3\catcode`\^=7\catcode`\_=8\relax}] |
| 2 | +\PYG{k}{class} \PYG{n+nc}{MultiHeadAttentionBlock}\PYG{p}{(}\PYG{n}{eqx}\PYG{o}{.}\PYG{n}{Module}\PYG{p}{):} |
| 3 | + \PYG{n}{dense\PYGZus{}qs}\PYG{p}{:} \PYG{n}{eqx}\PYG{o}{.}\PYG{n}{nn}\PYG{o}{.}\PYG{n}{Linear} |
| 4 | + \PYG{n}{dense\PYGZus{}ks}\PYG{p}{:} \PYG{n}{eqx}\PYG{o}{.}\PYG{n}{nn}\PYG{o}{.}\PYG{n}{Linear} |
| 5 | + \PYG{n}{dense\PYGZus{}vs}\PYG{p}{:} \PYG{n}{eqx}\PYG{o}{.}\PYG{n}{nn}\PYG{o}{.}\PYG{n}{Linear} |
| 6 | + \PYG{n}{dense\PYGZus{}o}\PYG{p}{:} \PYG{n}{eqx}\PYG{o}{.}\PYG{n}{nn}\PYG{o}{.}\PYG{n}{Linear} |
| 7 | + \PYG{n}{num\PYGZus{}heads}\PYG{p}{:} \PYG{n+nb}{int} |
| 8 | + \PYG{n}{layer\PYGZus{}norm}\PYG{p}{:} \PYG{n}{eqx}\PYG{o}{.}\PYG{n}{nn}\PYG{o}{.}\PYG{n}{LayerNorm} |
| 9 | + |
| 10 | + \PYG{k}{def} \PYG{n+nf+fm}{\PYGZus{}\PYGZus{}init\PYGZus{}\PYGZus{}}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{,} \PYG{n}{feature\PYGZus{}dim}\PYG{p}{,} \PYG{n}{num\PYGZus{}heads}\PYG{p}{,} \PYG{n}{key}\PYG{p}{):} |
| 11 | + \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{num\PYGZus{}heads} \PYG{o}{=} \PYG{n}{num\PYGZus{}heads} |
| 12 | + \PYG{n}{key}\PYG{p}{,} \PYG{n}{subkey} \PYG{o}{=} \PYG{n}{jax}\PYG{o}{.}\PYG{n}{random}\PYG{o}{.}\PYG{n}{split}\PYG{p}{(}\PYG{n}{key}\PYG{p}{)} |
| 13 | + \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{dense\PYGZus{}qs} \PYG{o}{=} \PYG{n}{eqx}\PYG{o}{.}\PYG{n}{nn}\PYG{o}{.}\PYG{n}{Linear}\PYG{p}{(} |
| 14 | + \PYG{n}{feature\PYGZus{}dim}\PYG{p}{,} \PYG{n}{feature\PYGZus{}dim}\PYG{p}{,} \PYG{n}{key}\PYG{o}{=}\PYG{n}{subkey}\PYG{p}{)} |
| 15 | + \PYG{c+c1}{\PYGZsh{} same for dense\PYGZus{}ks, dense\PYGZus{}vs, dense\PYGZus{}o} |
| 16 | + |
| 17 | + \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{layer\PYGZus{}norm} \PYG{o}{=} \PYG{n}{eqx}\PYG{o}{.}\PYG{n}{nn}\PYG{o}{.}\PYG{n}{LayerNorm}\PYG{p}{(}\PYG{n}{feature\PYGZus{}dim}\PYG{p}{)} |
| 18 | + |
| 19 | + \PYG{k}{def} \PYG{n+nf}{attention}\PYG{p}{(}\PYG{n}{q}\PYG{p}{,} \PYG{n}{k}\PYG{p}{,} \PYG{n}{v}\PYG{p}{):} |
| 20 | + \PYG{n}{attention\PYGZus{}scores} \PYG{o}{=} \PYG{n}{q} \PYG{o}{@} \PYG{n}{k}\PYG{o}{.}\PYG{n}{T} \PYG{o}{/} \PYG{n}{jnp}\PYG{o}{.}\PYG{n}{sqrt}\PYG{p}{(}\PYG{n}{q}\PYG{o}{.}\PYG{n}{shape}\PYG{p}{[}\PYG{o}{\PYGZhy{}}\PYG{l+m+mi}{1}\PYG{p}{])} |
| 21 | + \PYG{n}{attention\PYGZus{}scores} \PYG{o}{=} \PYG{n}{mpx}\PYG{o}{.}\PYG{n}{force\PYGZus{}full\PYGZus{}precision}\PYG{p}{(} |
| 22 | + \PYG{n}{jax}\PYG{o}{.}\PYG{n}{nn}\PYG{o}{.}\PYG{n}{softmax}\PYG{p}{,} \PYG{n}{attention\PYGZus{}scores}\PYG{o}{.}\PYG{n}{dtype}\PYG{p}{)(}\PYG{n}{attention\PYGZus{}scores}\PYG{p}{,} \PYG{n}{axis}\PYG{o}{=\PYGZhy{}}\PYG{l+m+mi}{1}\PYG{p}{)} |
| 23 | + \PYG{k}{return} \PYG{n}{attention\PYGZus{}scores} \PYG{o}{@} \PYG{n}{v} |
| 24 | + |
| 25 | + \PYG{k}{def} \PYG{n+nf+fm}{\PYGZus{}\PYGZus{}call\PYGZus{}\PYGZus{}}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{,} \PYG{n}{inputs}\PYG{p}{):} |
| 26 | + \PYG{n}{inputs\PYGZus{}after\PYGZus{}layernorm} \PYG{o}{=} \PYG{n}{jax}\PYG{o}{.}\PYG{n}{vmap}\PYG{p}{(}\PYG{n}{mpx}\PYG{o}{.}\PYG{n}{force\PYGZus{}full\PYGZus{}precision}\PYG{p}{(} |
| 27 | + \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{layer\PYGZus{}norm}\PYG{p}{,} \PYG{n}{inputs}\PYG{o}{.}\PYG{n}{dtype}\PYG{p}{))(}\PYG{n}{inputs}\PYG{p}{)} |
| 28 | + \PYG{n}{qs} \PYG{o}{=} \PYG{n}{jax}\PYG{o}{.}\PYG{n}{vmap}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{dense\PYGZus{}qs}\PYG{p}{)(}\PYG{n}{inputs\PYGZus{}after\PYGZus{}layernorm}\PYG{p}{)} |
| 29 | + \PYG{n}{qs} \PYG{o}{=} \PYG{n}{es}\PYG{o}{.}\PYG{n}{jax\PYGZus{}einshape}\PYG{p}{(}\PYG{l+s+s2}{\PYGZdq{}n(hf)\PYGZhy{}\PYGZgt{}hnf\PYGZdq{}}\PYG{p}{,} \PYG{n}{qs}\PYG{p}{,} \PYG{n}{h}\PYG{o}{=}\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{num\PYGZus{}heads}\PYG{p}{)} |
| 30 | + \PYG{c+c1}{\PYGZsh{} same for ks and vs...} |
| 31 | + |
| 32 | + \PYG{n}{outputs} \PYG{o}{=} \PYG{n}{jax}\PYG{o}{.}\PYG{n}{vmap}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{attention}\PYG{p}{,} \PYG{n}{in\PYGZus{}axes}\PYG{o}{=}\PYG{p}{(}\PYG{l+m+mi}{0}\PYG{p}{,} \PYG{l+m+mi}{0}\PYG{p}{,} \PYG{l+m+mi}{0}\PYG{p}{))(}\PYG{n}{qs}\PYG{p}{,} \PYG{n}{ks}\PYG{p}{,} \PYG{n}{vs}\PYG{p}{)} |
| 33 | + \PYG{n}{outputs} \PYG{o}{=} \PYG{n}{es}\PYG{o}{.}\PYG{n}{jax\PYGZus{}einshape}\PYG{p}{(}\PYG{l+s+s2}{\PYGZdq{}hnf\PYGZhy{}\PYGZgt{}n(hf)\PYGZdq{}}\PYG{p}{,} \PYG{n}{outputs}\PYG{p}{)} |
| 34 | + \PYG{n}{outputs} \PYG{o}{=} \PYG{n}{jax}\PYG{o}{.}\PYG{n}{vmap}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{dense\PYGZus{}o}\PYG{p}{)(}\PYG{n}{outputs}\PYG{p}{)} |
| 35 | + \PYG{n}{outputs} \PYG{o}{+=} \PYG{n}{inputs} |
| 36 | + |
| 37 | + \PYG{k}{return} \PYG{n}{outputs} |
| 38 | +\end{Verbatim} |
0 commit comments