You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+1-1Lines changed: 1 addition & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -3,7 +3,7 @@
3
3
## Introduction
4
4
5
5
This repository offers a tool for training JAX models using mixed precision, called **mpx**. It builds upon [JMP](https://github.com/google-deepmind/jmp)—another mixed precision library for JAX—but extends its capabilities.
6
-
I discovered that JMP does not support arbitrary PyTrees and is particularly incompatible with models developed using [Equinox](https://docs.kidger.site/equinox/). To overcome these limitations, I created mpx, which leverages Equinox's flexibility to work with any PyTree.
6
+
JMP does not support arbitrary PyTrees and is particularly incompatible with models developed using [Equinox](https://docs.kidger.site/equinox/). mpx overcomes these limitations, by leveraging Equinox's flexibility to work with any PyTree.
Copy file name to clipboardExpand all lines: doc/paper/main.tex
+35-26Lines changed: 35 additions & 26 deletions
Original file line number
Diff line number
Diff line change
@@ -2,15 +2,17 @@
2
2
\usepackage[left=2cm,right=2cm]{geometry}
3
3
4
4
\usepackage[utf8]{inputenc}
5
-
\usepackage{hyperref}
5
+
\usepackage[hidelinks]{hyperref}
6
6
\usepackage{amsmath}
7
7
\usepackage{amssymb}
8
8
\usepackage{graphicx}
9
9
\usepackage{listings}
10
10
\usepackage{xcolor}
11
11
\usepackage{enumitem}
12
12
13
-
\title{MPX: Mixed Precision Training for JAX}
13
+
\newcommand{\mpx}{\textsc{MPX}}
14
+
15
+
\title{\mpx{}: Mixed Precision Training for JAX}
14
16
\author{}
15
17
\date{}
16
18
@@ -19,7 +21,7 @@
19
21
\maketitle
20
22
21
23
\section{Introduction}
22
-
This paper presents \textbf{mpx}, a tool for training JAX models using mixed precision. The library extends the capabilities of JMP (JAX Mixed Precision) \cite{jmp}, addressing its limitations in handling arbitrary PyTrees and compatibility with models developed using Equinox \cite{equinox}. By leveraging Equinox's flexibility, mpx provides a solution that works with any PyTree structure.
24
+
This paper presents \mpx{}, a tool for training JAX models using mixed precision. The library extends the capabilities of JMP (JAX Mixed Precision) \cite{jmp}, addressing its limitations in handling arbitrary PyTrees and compatibility with models developed using Equinox \cite{kidger2021equinox}. By leveraging Equinox's flexibility, \mpx{} provides a solution that works with any PyTree structure.
23
25
24
26
\section{Basics of Mixed Precision Training}
25
27
This section summarizes the original Mixed Precision method from NVIDIA's Automatic Mixed Precision \cite{nvidia_amp} and the paper by Micikevicius et al. \cite{mixed_precision_paper}.
@@ -46,15 +48,15 @@ \subsection{Loss Scaling}
46
48
\end{itemize}
47
49
48
50
\section{Implementation Details}
49
-
mpx is to provides transformations that allow users to transform their existing training pipeline into mixed precision.
51
+
\mpx{} is to provides transformations that allow users to transform their existing training pipeline into mixed precision.
50
52
For this, it provides several functions that allow to cast
51
53
52
-
The mpx library provides essential transformations for mixed precision training while maintaining JAX's low-level approach. Key components include:
54
+
The \mpx{} library provides essential transformations for mixed precision training while maintaining JAX's low-level approach. Key components include:
53
55
54
56
\begin{enumerate}
55
-
\item\textbf{Transformations to Cast PyTrees}: mpx features the following functions to cast arbitrary PyTrees: \texttt{cast\_tree(tree, dtype)} \texttt{cast\_to\_half\_precision(x)}, \texttt{cast\_to\_half\_precision(x)}, \texttt{cast\_to\_float16(x)}, \texttt{cast\_to\_bfloat16(x)}, \texttt{cast\_to\_float32(x)}. All these functions cast all leaves of the input that are JAX arrays and of type float to the corresponding float datatype. All other leaves, including arrays that are of non-float types, like int32, remain unchanged.
56
-
\item\textbf{Transformations to Cast Functions}: mpx contains a transformation \texttt{cast\_function(func, dtype, return\_dtype=None)} for functions. This transformation returns a function that casts all its inputs to the desired input datatype (using \texttt{cast\_tree(tree, dtype)}), calls the function and then casts the outputs of the function. Moreover, mpx contains \texttt{force\_full\_precision(func, return\_dtype)}, which forces a function to perform its computations with full precision. This is important as some operations, such as sum, mean or softmax, are sensitive to overflows when calculated in float16.
57
-
\item\textbf{Transformations to Cast Gradients}: mpx contains the Equinox equivalents \texttt{filter\_grad(func, scaling, has\_aux=False, use\_mixed\_precision=True)} and \texttt{filter\_value\_and\_grad(func, scaling, has\_aux=False, use\_mixed\_precision=True)} that calculate the gradient of a function using mixed precision with loss scaling (as described above). Additional to calculating the gradient, the functions also perform the automatic adaption of the loss scaling value.
57
+
\item\textbf{Transformations to Cast PyTrees}: \mpx{} features the following functions to cast arbitrary PyTrees: \texttt{cast\_tree(tree, dtype)} \texttt{cast\_to\_half\_precision(x)}, \texttt{cast\_to\_half\_precision(x)}, \texttt{cast\_to\_float16(x)}, \texttt{cast\_to\_bfloat16(x)}, \texttt{cast\_to\_float32(x)}. All these functions cast all leaves of the input that are JAX arrays and of type float to the corresponding float datatype. All other leaves, including arrays that are of non-float types, like int32, remain unchanged.
58
+
\item\textbf{Transformations to Cast Functions}: \mpx{} contains a transformation \texttt{cast\_function(func, dtype, return\_dtype=None)} for functions. This transformation returns a function that casts all its inputs to the desired input datatype (using \texttt{cast\_tree(tree, dtype)}), calls the function and then casts the outputs of the function. Moreover, \mpx{} contains \texttt{force\_full\_precision(func, return\_dtype)}, which forces a function to perform its computations with full precision. This is important as some operations, such as sum, mean or softmax, are sensitive to overflows when calculated in float16.
59
+
\item\textbf{Transformations to Cast Gradients}: \mpx{} contains the Equinox equivalents \texttt{filter\_grad(func, scaling, has\_aux=False, use\_mixed\_precision=True)} and \texttt{filter\_value\_and\_grad(func, scaling, has\_aux=False, use\_mixed\_precision=True)} that calculate the gradient of a function using mixed precision with loss scaling (as described above). Additional to calculating the gradient, the functions also perform the automatic adaption of the loss scaling value.
58
60
These drop-in replacements allow users to reuse their existing Equinox training pipelines without major changes to their structure (cf. Section todo).
59
61
\end{enumerate}
60
62
@@ -68,35 +70,42 @@ \subsection{Automatic Loss Scaling Implementation}
68
70
\end{itemize}
69
71
70
72
\subsection{Optimizer}
71
-
mpx works with all optax optimizers. However, as explained above, one might need to skip optimizer updates if gradients became infinite due to loss scaling.
73
+
\mpx{} works with all optax optimizers. However, as explained above, one might need to skip optimizer updates if gradients became infinite due to loss scaling.
72
74
The \texttt{optimizer\_update(model, optimizer, optimizer\_state, grads, grads\_finite)} function handles model updates based on gradient finiteness.
73
-
This means, instead of calling \texttt{optimizer.update}, followed by \texttt{eqx.apply\_updates} as done in regular Equinox training pipelines, one just have to call \texttt{mpx.optimizer\_update}.
75
+
This means, instead of calling \texttt{optimizer.update}, followed by \texttt{eqx.apply\_updates} as done in regular Equinox training pipelines, one just have to call \texttt{\mpx{}.optimizer\_update}.
74
76
75
77
\section{Example}
76
-
I t
78
+
Here, we provide an example and show which parts in a training pipeline need to be changed for mixed precision training.
79
+
80
+
\section{Model Implementation}
81
+
For the largest part, the implementation of the model must not be changed.
82
+
As \mpx{} works with arbitrary PyTrees, every Toolbox that defines their model/parameters as PyTrees, like Flax~\cite{flax2020github} or Equinox~\cite{kidger2021equinox} can be used in conjunction with \mpx{}.
77
83
78
84
\section{Acknowledgements}
79
85
We express our gratitude to Patrick Kidger for Equinox and Google DeepMind for JMP, which served as the foundation for this implementation.
80
86
81
87
The authors acknowledge the computing time provided by the NHR Center NHR4CES at RWTH Aachen University (project number p0021919), funded by the Federal Ministry of Education and Research, and participating state governments through the GWK resolutions for national high performance computing at universities.
author={Micikevicius, Paulius and Narang, Sharan and Alben, Jonah and Diamos, Gregory and Elsen, Erich and Garcia, David and Ginsburg, Boris and Houston, Michael and Kuchaiev, Oleksii and Venkatesh, Ganesh and others},
9
+
journal={arXiv preprint arXiv:1710.03740},
10
+
year={2017}
11
+
}
12
+
13
+
@software{flax2020github,
14
+
author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
15
+
title = {{F}lax: A neural network library and ecosystem for {JAX}},
16
+
url = {http://github.com/google/flax},
17
+
version = {0.10.6},
18
+
year = {2024},
19
+
}
20
+
21
+
22
+
@article{kidger2021equinox,
23
+
author={Patrick Kidger and Cristian Garcia},
24
+
title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and filtered transformations},
25
+
year={2021},
26
+
journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
0 commit comments