Skip to content

Commit 9574ff1

Browse files
cathalobrienpre-commit-ci[bot]floriankrbanaprietonem
authored
feat(training): performance docs (#696)
## Description Adds a documentation page which explains some strategies to improve throughput and memory of your model ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md) <!-- readthedocs-preview anemoi-training start --> ---- 📚 Documentation preview 📚: https://anemoi-training--696.org.readthedocs.build/en/696/ <!-- readthedocs-preview anemoi-training end --> <!-- readthedocs-preview anemoi-graphs start --> ---- 📚 Documentation preview 📚: https://anemoi-graphs--696.org.readthedocs.build/en/696/ <!-- readthedocs-preview anemoi-graphs end --> <!-- readthedocs-preview anemoi-models start --> ---- 📚 Documentation preview 📚: https://anemoi-models--696.org.readthedocs.build/en/696/ <!-- readthedocs-preview anemoi-models end --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Florian Pinault <[email protected]> Co-authored-by: Ana Prieto Nemesio <[email protected]>
1 parent cce8763 commit 9574ff1

File tree

5 files changed

+334
-0
lines changed

5 files changed

+334
-0
lines changed
150 KB
Loading
235 KB
Loading
461 KB
Loading

training/docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ This package provides the *Anemoi* training functionality.
4444
user-guide/benchmarking
4545
user-guide/distributed
4646
user-guide/download-era5-o96
47+
user-guide/performance-optimisation
4748

4849
.. toctree::
4950
:maxdepth: 1
Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
.. _usage-performance:
2+
3+
.. role:: bash(code)
4+
:language: bash
5+
6+
##########################
7+
Performance Optimisation
8+
##########################
9+
10+
Is your model running too slowly? Or does your model not fit in your
11+
devices memory? Anemoi contains numerous settings to tweak the
12+
throughput and memory behaviour of your model.
13+
14+
This guide will introduce you to what you can change your models
15+
performance. It is structured as a flowchart you can follow when
16+
debugging performance issues with your models.
17+
18+
.. image:: ../images/performance-guide/performance-flowchart.png
19+
20+
.. note::
21+
22+
This guide assumes a batch size of 1 e.g. 1 single PyTorch
23+
DistributedDataParallel (DDP) model instance. It is recommended to
24+
follow this guide to work out the optimal performance settings for a
25+
single model instance. Then the total number of model instances can
26+
be scaled up via DDP. The optimal settings and runtime should not
27+
change.
28+
29+
********
30+
Memory
31+
********
32+
33+
Memory issues typically appear as a "CUDA Out Of Memory" error. These
34+
typically occur in the first few iterations of your model.
35+
36+
.. code::
37+
38+
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 8.60 GiB. GPU 0 has a total capacity of 39.56 GiB of which 4.79 GiB is free.
39+
40+
If Out Of Memory errors occur much later on in your run, this could
41+
indicate a memory leak. The `memory profiler`_ can be used to identify a
42+
memory leak.
43+
44+
Reduce Memory Fragmentation
45+
===========================
46+
47+
The first step to getting past an out-of-memory error is to reduce
48+
memory fragmentation. Over the course of a run, blocks of GPU memory are
49+
allocated and freed many times. This can lead to relatively small gaps
50+
occurring between allocated blocks of memory. These gaps taken
51+
altogether, might be sufficient to store a large tensor, but since they
52+
are fragmented they cannot be used. Instead a CUDA out-of-memory error
53+
is raised.
54+
55+
The easiest way to tell if your memory is fragmented is to read the CUDA
56+
out-of-memory error.
57+
58+
.. code::
59+
60+
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.15 GiB. GPU 0 has a total capacity of 39.56 GiB of which 6.80 GiB is free. Including non-PyTorch memory, this process has 36.61 GiB memory in use. Of the allocated memory 31.66 GiB is allocated by PyTorch, and 4.11 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.
61+
62+
The error message states there was an error allocating a 3GiB tensor,
63+
but that 4GiB of memory is reserved but not allocated. This is memory
64+
which is unusable due to memory fragmentation.
65+
66+
To resolve memory fragmentation the following environment variable can
67+
be set
68+
69+
.. code:: bash
70+
71+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
72+
73+
If you are launching jobs via SLURM, this line can be put at the top of
74+
your SLURM batch script. This environment variable works for most GPUs
75+
tested (Nvidia A100, H100, GH200 and AMD MI250x). It is currently not
76+
supported on AMD MI300A (due to the unified physical memory) and CPUs
77+
(which do not use the CUDA caching memory allocator).
78+
79+
For a technical explanation of the CUDA Caching memory allocator, and
80+
how memory fragmentation occurs, you can read `this blog post`_.
81+
82+
.. _this blog post: https://zdevito.github.io/2022/08/04/cuda-caching-allocator.html
83+
84+
Chunking
85+
========
86+
87+
Memory usage in anemoi varies greatly across a run. Memory usage
88+
typically peaks during the encoder and decoder phases, as the model must
89+
iterate over many edge connections to compute the mapping between source
90+
and latent grids.
91+
92+
The image below shows memory usage in a single iteration (forward and
93+
backward pass). The 4 large peaks represent, in order: fwd-encoder,
94+
fwd-decoder, bwd-decoder, bwd-encoder.
95+
96+
.. image:: ../images/performance-guide/mem-snapshot-1-mapper-chunk.png
97+
98+
Peak memory usage in the mappers can be greatly reduced by computing the
99+
mappers sequentially in smaller chunks.
100+
101+
.. image:: ../images/performance-guide/mem-snapshot-4-mapper-chunks.png
102+
103+
In the example above the number of mapper chunks has been increased from
104+
1 to 4. Subsequentially the peak memory usage has decreased from ~22GB
105+
to ~9GB.
106+
107+
Chunking defaults to 4 for the encoder and decoder and 2 for the
108+
processor (lower resolution). Chunking can be controlled using the
109+
following config parameter
110+
111+
.. code:: bash
112+
113+
model.encoder.num_chunks=... model.processor.num_chunks=... model.decoder.num_chunks=...
114+
115+
The number of chunks should always be a power of two.
116+
117+
Processor chunks can be increased to a maximum of the number of layers
118+
in a model, by default 16. The number of processor chunks has no impact
119+
on performance.
120+
121+
There is no hard limit on how much the mappers can be chunked. However
122+
there is typically a small (~10%) performance penalty from 16 chunks and
123+
beyond. Additionally the memory savings of higher chunk counts begin to
124+
drop off. Therefore it is recommended to chunk between 4 and 16 in the
125+
mappers.
126+
127+
It is often possible to determine from reading the CUDA out-of-memory
128+
stacktrace in which component of the model the OOM occurred. This can
129+
inform you about which num_chunks parameter to change.
130+
131+
.. note::
132+
133+
Chunking the mappers requires ``model.encoder/decoder.shard_strategy:
134+
'edges'``. This is the default since anemoi-models v0.6.0.
135+
136+
.. note::
137+
138+
At inference time the chunking can be changed using the following
139+
environment variables:
140+
141+
.. code:: bash
142+
143+
export ANEMOI_INFERENCE_NUM_CHUNKS_MAPPER=... ANEMOI_INFERENCE_NUM_CHUNKS_PROCESSOR=...
144+
145+
Shard model over more GPUs
146+
==========================
147+
148+
If your model instance still does not fit in memory, you can shard the
149+
model over multiple GPUs.
150+
151+
.. code::
152+
153+
hardware:
154+
num_gpus_per_model: 2
155+
156+
This will reduce memory usage by sharding the input batch and model
157+
channels across GPUs.
158+
159+
The number of GPUs per model should be a power of two and is limited by
160+
the number of heads in your model, by default 16.
161+
162+
Sharding a model over multiple GPUs can also increase performance, as
163+
the compute workload is divided over more GPUs. However sharding a model
164+
over too many GPUs can lead to decreased performance from increased
165+
collective communication operations required to keep the GPUs in sync.
166+
Additionally, model sharding increases the total number of GPUs
167+
required, which can lead to longer queue times on shared HPC systems.
168+
169+
The GPUs within a node typically are connected via a faster interconnect
170+
than GPUs across nodes. For this reason model sharding typically
171+
performs less efficiently once a model instance is sharded across
172+
multiple nodes.
173+
174+
Memory Profiling
175+
================
176+
177+
For further insights on the memory usage of your model, you can use the
178+
`memory profiler`_.
179+
180+
.. _memory profiler: https://anemoi.readthedocs.io/projects/training/en/latest/user-guide/benchmarking.html#memorysnapshotrecorder
181+
182+
*************
183+
Performance
184+
*************
185+
186+
Optimise the dataloading
187+
========================
188+
189+
One reason for slow performance could be that your CPU and filesystem
190+
cannot load input data fast enough to keep up with your GPU. This
191+
results in your GPU stalling at the start of an iteration while it waits
192+
for the CPU to provide the next input batch.
193+
194+
By default, each GPU will spawn 8 workers. Each worker will load data in
195+
parallel. You should try to increase this number until you run out of
196+
CPU memory. A CPU out of memory error looks like:
197+
198+
.. code::
199+
200+
slurmstepd: error: Detected 4 oom_kill events in StepId=39701120.0. Some of the step tasks have been OOM Killed.
201+
202+
Below are some other settings which impact dataloader performance, and
203+
their recommended settings
204+
205+
.. code::
206+
207+
#training/config/dataloader/native_grid.yaml
208+
209+
# prefetch_factor > 1 only seems to increase memory required by dataloader processes without giving a speedup.
210+
prefetch_factor: 1
211+
# Reduce the time needed to transfer data from CPU to GPU by copying the input batch into a pinned memory buffer on the CPU.
212+
pin_memory: True
213+
214+
#dataloaders read in parallel.
215+
#Only impactful if hardware.num_gpus_per_model > 1
216+
read_group_size: ${hardware.num_gpus_per_model}
217+
218+
.. note::
219+
220+
Dataloader workers run on the CPU and require CPU cores and memory.
221+
If you are running on slurm you should ensure you have allocated the
222+
maximum number of CPU cores and memory required. For example on a
223+
node with 4 GPUs and 128 CPU cores
224+
225+
.. code:: bash
226+
227+
#SBATCH --ntasks-per-node=4
228+
#SBATCH --gpus-per-node=4
229+
#SBATCH --cpus-per-task=32 # 128 cores / 4 tasks = 32 cores per task
230+
#SBATCH --mem=0 # '0' is shorthand to request a CPU nodes entire memory
231+
232+
.. note::
233+
234+
Longer rollout increases the CPU memory required by the dataloaders.
235+
It can be beneficial to break rollout runs into multiple runs (e.g.
236+
rollout 1->6 and rollout 7->12) and tune the number of workers for
237+
both runs accordingly.
238+
239+
Change attention backend
240+
========================
241+
242+
The processor is a large component of the overall runtime. Both the
243+
GraphTransformer and Transformer processors support multiple backends
244+
which have different performance characteristics.
245+
246+
For the Transformer processor, the 'flash attention' backend is the
247+
fastest. Flash attention can be selected in the config like so:
248+
249+
.. code::
250+
251+
model.processor.attention_implementation: 'flash_attention'
252+
253+
Flash attention is currently available on Nvidia and AMD GPUs only. On
254+
Nvidia GPUs, there are multiple versions of the flash attention library
255+
(2, 3 and 4) corresponding to different hardware generations (Ampere,
256+
Hopper and Blackwell) which take advantage of hardware-specific features
257+
for further speedups.
258+
259+
Flash attention is not the default as it must be compiled from source.
260+
261+
For the GraphTransformer processor, the 'triton' backend is the fastest.
262+
To use the 'triton' backend set the following config option:
263+
264+
.. code::
265+
266+
model.processor.graph_attention_backend: "triton"
267+
268+
Triton is the default backend when using the GraphTransformer processor.
269+
However it requires the 'triton' library to be installed. On AMD systems
270+
the library is called 'pytorch-triton-rocm'. Triton is not officially
271+
supported on CPUs.
272+
273+
Compiling
274+
=========
275+
276+
PyTorch can improve performance by compiling PyTorch code into Triton
277+
code at runtime.
278+
279+
Anemoi supports compilation via the 'models.compile' keyword, which
280+
takes a list of modules to be compiled.
281+
282+
.. code::
283+
284+
#training/config/models/graphtransformer.yaml
285+
compile:
286+
- module: anemoi.models.layers.conv.GraphTransformerConv
287+
- module: anemoi.models.layers.normalization.ConditionalLayerNorm
288+
options:
289+
dynamic: False
290+
291+
For information on how to compile see the `compilation documentation`_
292+
for anemoi.
293+
294+
The following modules have been found to give a speedup from
295+
compilation:
296+
297+
- anemoi.models.layers.conv.GraphTransformerConv (when not using the
298+
triton backend)
299+
- anemoi.models.layers.normalization.ConditionalLayerNorm (when using
300+
the ensemble model)
301+
- torch.nn.LayerNorm
302+
303+
Compiling can also decrease the peak memory required by fusing multiple
304+
functions into a single one which reduces the intermediate activations
305+
that must be stored.
306+
307+
Not all modules are able to be compiled, and some compilation errors can
308+
be difficult to debug.
309+
310+
.. note::
311+
312+
Compiling the triton backend of the GraphTransformer will not have an
313+
effect, since it is already in triton.
314+
315+
.. note::
316+
317+
The triton backend currently uses more memory than the compiled pyg
318+
due to the need to store edges in an intermediate CSC form during the
319+
forward pass. If memory is a limiting factor it might be worthwhile
320+
to switch to the compiled pyg attention backend, once other fixes
321+
such as chunking are exhausted.
322+
323+
.. _compilation documentation: https://anemoi.readthedocs.io/projects/training/en/latest/user-guide/models.html#compilation
324+
325+
Performance Profiling
326+
=====================
327+
328+
For further insights into your runtime performance, you can take the
329+
traces produced by the `pytorch profiler`_ and upload them to perfetto_.
330+
331+
.. _perfetto: https://ui.perfetto.dev/
332+
333+
.. _pytorch profiler: https://anemoi.readthedocs.io/projects/training/en/latest/user-guide/benchmarking.html#memory-profiler

0 commit comments

Comments
 (0)