Skip to content

Commit a13b618

Browse files
committed
Document cudaMallocAsync as an experimental feature.
1 parent e707ede commit a13b618

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

docs/gpu_memory_allocation.rst

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,31 @@ Common causes of OOM failures
7070
memory. Note however, that the algorithm is basic and you can often get better
7171
trade-off between compute and memory by disabling the automatic remat pass and doing
7272
it manually with `the jax.remat API <https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html>`_
73+
74+
75+
Experimental features
76+
---------------------
77+
78+
Features here are experimental and must be tried with caution.
79+
80+
``TF_GPU_ALLOCATOR=cuda_malloc_async``
81+
This replace XLA's own BFC memory allocator with `cudaMallocAsync
82+
<https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html>`_.
83+
This will remove the big fixed pre-allocation and use a memory pool that grows.
84+
The expected benefit is no need to set `XLA_PYTHON_CLIENT_MEM_FRACTION`.
85+
86+
The risk are:
87+
88+
- that memory fragmentation is different, so if you are close to the
89+
limit, the exact OOM case due to fragmentation will be different.
90+
- The allocation time won't be all paid at the start, but be incurred
91+
when the memory pool need to be increased. So you could
92+
experience less speed stability at the start and for benchmarks
93+
it will be even more important to ignore the first few iterations.
94+
95+
The risks can be mitigated by pre-allocating a signigicant chunk and
96+
still get the benefit of having a growing memory pool. This can be
97+
done with `TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC=N`. If N is `-1`
98+
it will preallocate the same as what was allocatedy by
99+
default. Otherwise, it is the size in bytes that you want to
100+
preallocate.

0 commit comments

Comments
 (0)