Skip to content

Commit 7b720cb

Browse files
committed
added pypi configs
1 parent ff8e2b9 commit 7b720cb

File tree

6 files changed

+137
-28
lines changed

6 files changed

+137
-28
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
## Introduction
44

55
This repository offers a tool for training JAX models using mixed precision, called **mpx**. It builds upon [JMP](https://github.com/google-deepmind/jmp)—another mixed precision library for JAX—but extends its capabilities.
6-
I discovered that JMP does not support arbitrary PyTrees and is particularly incompatible with models developed using [Equinox](https://docs.kidger.site/equinox/). To overcome these limitations, I created mpx, which leverages Equinox's flexibility to work with any PyTree.
6+
JMP does not support arbitrary PyTrees and is particularly incompatible with models developed using [Equinox](https://docs.kidger.site/equinox/). mpx overcomes these limitations, by leveraging Equinox's flexibility to work with any PyTree.
77

88
## Basics of Mixed Precision Training
99

doc/paper/main.tex

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22
\usepackage[left=2cm,right=2cm]{geometry}
33

44
\usepackage[utf8]{inputenc}
5-
\usepackage{hyperref}
5+
\usepackage[hidelinks]{hyperref}
66
\usepackage{amsmath}
77
\usepackage{amssymb}
88
\usepackage{graphicx}
99
\usepackage{listings}
1010
\usepackage{xcolor}
1111
\usepackage{enumitem}
1212

13-
\title{MPX: Mixed Precision Training for JAX}
13+
\newcommand{\mpx}{\textsc{MPX}}
14+
15+
\title{\mpx{}: Mixed Precision Training for JAX}
1416
\author{}
1517
\date{}
1618

@@ -19,7 +21,7 @@
1921
\maketitle
2022

2123
\section{Introduction}
22-
This paper presents \textbf{mpx}, a tool for training JAX models using mixed precision. The library extends the capabilities of JMP (JAX Mixed Precision) \cite{jmp}, addressing its limitations in handling arbitrary PyTrees and compatibility with models developed using Equinox \cite{equinox}. By leveraging Equinox's flexibility, mpx provides a solution that works with any PyTree structure.
24+
This paper presents \mpx{}, a tool for training JAX models using mixed precision. The library extends the capabilities of JMP (JAX Mixed Precision) \cite{jmp}, addressing its limitations in handling arbitrary PyTrees and compatibility with models developed using Equinox \cite{kidger2021equinox}. By leveraging Equinox's flexibility, \mpx{} provides a solution that works with any PyTree structure.
2325

2426
\section{Basics of Mixed Precision Training}
2527
This section summarizes the original Mixed Precision method from NVIDIA's Automatic Mixed Precision \cite{nvidia_amp} and the paper by Micikevicius et al. \cite{mixed_precision_paper}.
@@ -46,15 +48,15 @@ \subsection{Loss Scaling}
4648
\end{itemize}
4749

4850
\section{Implementation Details}
49-
mpx is to provides transformations that allow users to transform their existing training pipeline into mixed precision.
51+
\mpx{} is to provides transformations that allow users to transform their existing training pipeline into mixed precision.
5052
For this, it provides several functions that allow to cast
5153

52-
The mpx library provides essential transformations for mixed precision training while maintaining JAX's low-level approach. Key components include:
54+
The \mpx{} library provides essential transformations for mixed precision training while maintaining JAX's low-level approach. Key components include:
5355

5456
\begin{enumerate}
55-
\item \textbf{Transformations to Cast PyTrees}: mpx features the following functions to cast arbitrary PyTrees: \texttt{cast\_tree(tree, dtype)} \texttt{cast\_to\_half\_precision(x)}, \texttt{cast\_to\_half\_precision(x)}, \texttt{cast\_to\_float16(x)}, \texttt{cast\_to\_bfloat16(x)}, \texttt{cast\_to\_float32(x)}. All these functions cast all leaves of the input that are JAX arrays and of type float to the corresponding float datatype. All other leaves, including arrays that are of non-float types, like int32, remain unchanged.
56-
\item \textbf{Transformations to Cast Functions}: mpx contains a transformation \texttt{cast\_function(func, dtype, return\_dtype=None)} for functions. This transformation returns a function that casts all its inputs to the desired input datatype (using \texttt{cast\_tree(tree, dtype)}), calls the function and then casts the outputs of the function. Moreover, mpx contains \texttt{force\_full\_precision(func, return\_dtype)}, which forces a function to perform its computations with full precision. This is important as some operations, such as sum, mean or softmax, are sensitive to overflows when calculated in float16.
57-
\item \textbf{Transformations to Cast Gradients}: mpx contains the Equinox equivalents \texttt{filter\_grad(func, scaling, has\_aux=False, use\_mixed\_precision=True)} and \texttt{filter\_value\_and\_grad(func, scaling, has\_aux=False, use\_mixed\_precision=True)} that calculate the gradient of a function using mixed precision with loss scaling (as described above). Additional to calculating the gradient, the functions also perform the automatic adaption of the loss scaling value.
57+
\item \textbf{Transformations to Cast PyTrees}: \mpx{} features the following functions to cast arbitrary PyTrees: \texttt{cast\_tree(tree, dtype)} \texttt{cast\_to\_half\_precision(x)}, \texttt{cast\_to\_half\_precision(x)}, \texttt{cast\_to\_float16(x)}, \texttt{cast\_to\_bfloat16(x)}, \texttt{cast\_to\_float32(x)}. All these functions cast all leaves of the input that are JAX arrays and of type float to the corresponding float datatype. All other leaves, including arrays that are of non-float types, like int32, remain unchanged.
58+
\item \textbf{Transformations to Cast Functions}: \mpx{} contains a transformation \texttt{cast\_function(func, dtype, return\_dtype=None)} for functions. This transformation returns a function that casts all its inputs to the desired input datatype (using \texttt{cast\_tree(tree, dtype)}), calls the function and then casts the outputs of the function. Moreover, \mpx{} contains \texttt{force\_full\_precision(func, return\_dtype)}, which forces a function to perform its computations with full precision. This is important as some operations, such as sum, mean or softmax, are sensitive to overflows when calculated in float16.
59+
\item \textbf{Transformations to Cast Gradients}: \mpx{} contains the Equinox equivalents \texttt{filter\_grad(func, scaling, has\_aux=False, use\_mixed\_precision=True)} and \texttt{filter\_value\_and\_grad(func, scaling, has\_aux=False, use\_mixed\_precision=True)} that calculate the gradient of a function using mixed precision with loss scaling (as described above). Additional to calculating the gradient, the functions also perform the automatic adaption of the loss scaling value.
5860
These drop-in replacements allow users to reuse their existing Equinox training pipelines without major changes to their structure (cf. Section todo).
5961
\end{enumerate}
6062

@@ -68,35 +70,42 @@ \subsection{Automatic Loss Scaling Implementation}
6870
\end{itemize}
6971

7072
\subsection{Optimizer}
71-
mpx works with all optax optimizers. However, as explained above, one might need to skip optimizer updates if gradients became infinite due to loss scaling.
73+
\mpx{} works with all optax optimizers. However, as explained above, one might need to skip optimizer updates if gradients became infinite due to loss scaling.
7274
The \texttt{optimizer\_update(model, optimizer, optimizer\_state, grads, grads\_finite)} function handles model updates based on gradient finiteness.
73-
This means, instead of calling \texttt{optimizer.update}, followed by \texttt{eqx.apply\_updates} as done in regular Equinox training pipelines, one just have to call \texttt{mpx.optimizer\_update}.
75+
This means, instead of calling \texttt{optimizer.update}, followed by \texttt{eqx.apply\_updates} as done in regular Equinox training pipelines, one just have to call \texttt{\mpx{}.optimizer\_update}.
7476

7577
\section{Example}
76-
I t
78+
Here, we provide an example and show which parts in a training pipeline need to be changed for mixed precision training.
79+
80+
\section{Model Implementation}
81+
For the largest part, the implementation of the model must not be changed.
82+
As \mpx{} works with arbitrary PyTrees, every Toolbox that defines their model/parameters as PyTrees, like Flax~\cite{flax2020github} or Equinox~\cite{kidger2021equinox} can be used in conjunction with \mpx{}.
7783

7884
\section{Acknowledgements}
7985
We express our gratitude to Patrick Kidger for Equinox and Google DeepMind for JMP, which served as the foundation for this implementation.
8086

8187
The authors acknowledge the computing time provided by the NHR Center NHR4CES at RWTH Aachen University (project number p0021919), funded by the Federal Ministry of Education and Research, and participating state governments through the GWK resolutions for national high performance computing at universities.
8288

83-
\begin{thebibliography}{9}
84-
\bibitem{jmp}
85-
JMP: JAX Mixed Precision
86-
\newblock \url{https://github.com/google-deepmind/jmp}
89+
\bibliographystyle{plain}
90+
\bibliography{references}
91+
92+
% \begin{thebibliography}{9}
93+
% \bibitem{jmp}
94+
% JMP: JAX Mixed Precision
95+
% \newblock \url{https://github.com/google-deepmind/jmp}
8796

88-
\bibitem{equinox}
89-
Equinox: Neural Networks in JAX
90-
\newblock \url{https://docs.kidger.site/equinox/}
97+
% \bibitem{equinox}
98+
% Equinox: Neural Networks in JAX
99+
% \newblock \url{https://docs.kidger.site/equinox/}
91100

92-
\bibitem{nvidia_amp}
93-
NVIDIA Automatic Mixed Precision
94-
\newblock \url{https://developer.nvidia.com/automatic-mixed-precision}
101+
% \bibitem{nvidia_amp}
102+
% NVIDIA Automatic Mixed Precision
103+
% \newblock \url{https://developer.nvidia.com/automatic-mixed-precision}
95104

96-
\bibitem{mixed_precision_paper}
97-
P. Micikevicius et al.
98-
\newblock ``Mixed Precision Training''
99-
\newblock arXiv:1710.03740, 2017
100-
\end{thebibliography}
105+
% \bibitem{mixed_precision_paper}
106+
% P. Micikevicius et al.
107+
% \newblock ``Mixed Precision Training''
108+
% \newblock arXiv:1710.03740, 2017
109+
% \end{thebibliography}
101110

102111
\end{document}

doc/paper/references.bib

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
@misc{jmp,
2+
title = {JMP: JAX Mixed Precision},
3+
howpublished = {\url{https://github.com/google-deepmind/jmp}},
4+
note = {Accessed: 2024-06-09}
5+
}
6+
@article{mixed_precision_paper,
7+
title={Mixed precision training},
8+
author={Micikevicius, Paulius and Narang, Sharan and Alben, Jonah and Diamos, Gregory and Elsen, Erich and Garcia, David and Ginsburg, Boris and Houston, Michael and Kuchaiev, Oleksii and Venkatesh, Ganesh and others},
9+
journal={arXiv preprint arXiv:1710.03740},
10+
year={2017}
11+
}
12+
13+
@software{flax2020github,
14+
author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
15+
title = {{F}lax: A neural network library and ecosystem for {JAX}},
16+
url = {http://github.com/google/flax},
17+
version = {0.10.6},
18+
year = {2024},
19+
}
20+
21+
22+
@article{kidger2021equinox,
23+
author={Patrick Kidger and Cristian Garcia},
24+
title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and filtered transformations},
25+
year={2021},
26+
journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
27+
}

mpx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Mixed Precision for JAX - A library for mixed precision training in JAX
33
"""
44

5+
__version__ = "0.1.2"
6+
57
from .cast import (
68
cast_tree,
79
cast_to_float32,

mpx/loss_scaling.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,29 @@ def wrapper(*_args, **_kwargs):
7373

7474

7575
class DynamicLossScaling(eqx.Module):
76-
"""Basic structure taken from jmp."""
76+
"""
77+
Implements dynamic loss scaling for mixed precision training in JAX.
78+
The basic structure is taken from jmp.
79+
This class automatically adjusts the loss scaling factor during training to prevent
80+
numerical underflow/overflow when using reduced precision (e.g., float16). The scaling
81+
factor is increased periodically if gradients are finite, and decreased if non-finite
82+
gradients are detected, within specified bounds.
83+
Attributes:
84+
loss_scaling (jnp.ndarray): Current loss scaling factor.
85+
min_loss_scaling (jnp.ndarray): Minimum allowed loss scaling factor.
86+
counter (jnp.ndarray): Counter for tracking update periods.
87+
factor (int): Multiplicative factor for adjusting loss scaling.
88+
period (int): Number of steps between potential increases of loss scaling.
89+
Methods:
90+
scale(tree):
91+
Scales all leaves of a pytree by the current loss scaling factor.
92+
unscale(tree):
93+
Unscales all leaves of a pytree by the inverse of the current loss scaling factor,
94+
casting the result to float32.
95+
adjust(grads_finite: jnp.ndarray) -> 'DynamicLossScaling':
96+
Returns a new DynamicLossScaling instance with updated loss scaling and counter,
97+
depending on whether the gradients are finite.
98+
"""
7799
loss_scaling: jnp.ndarray
78100
min_loss_scaling: jnp.ndarray
79101
counter: jnp.ndarray

pyproject.toml

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
[build-system]
2+
build-backend = "hatchling.build"
3+
requires = ["hatchling"]
4+
5+
6+
[project]
7+
name = "mixed-precision-for-JAX"
8+
dynamic = ["version"]
9+
dependencies = [
10+
"equinox",
11+
"optax",
12+
"jax>=0.4.38",
13+
"jaxtyping>=0.2.20",
14+
"typing_extensions>=4.5.0",
15+
"wadler_lindig>=0.1.0"
16+
]
17+
requires-python = ">=3.10"
18+
authors = [
19+
{name = "Alexander Graefe", email = "alexander.graefe@dsme.rwth-aachen.de"},
20+
]
21+
maintainers = [
22+
{name = "Alexander Graefe", email = "alexander.graefe@dsme.rwth-aachen.de"},
23+
]
24+
description = "A toolbox for mixed precision training via JAX."
25+
readme = "README.md"
26+
license = "MIT"
27+
license-files = ["LICENCSE"]
28+
keywords = ["JAX", "Neural Network", "Mixed Precision"]
29+
classifiers = [
30+
"Programming Language :: Python",
31+
"Development Status :: 3 - Alpha",
32+
"Intended Audience :: Education",
33+
"Intended Audience :: Developers",
34+
"Intended Audience :: Information Technology",
35+
"Intended Audience :: Science/Research",
36+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
37+
"Topic :: Scientific/Engineering :: Information Analysis",
38+
"Topic :: Scientific/Engineering :: Mathematics"
39+
]
40+
41+
[tool.hatch.build]
42+
include = ["mpx/*"]
43+
44+
[tool.hatch.version]
45+
path = "mpx/__init__.py"
46+
47+
[project.urls]
48+
Repository = "https://github.com/AlexGraefe/mixed_precision_for_JAX"
49+
"Bug Tracker" = "https://github.com/AlexGraefe/mixed_precision_for_JAX/issues"

0 commit comments

Comments
 (0)