-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Description
Description
Hi, I keep playing around with the vjp3 + array refs for zero-memory-overhead gradient accumulation. I recently reported the problem about scan + reverse:
#32411 (comment)
I brought it to JAX 0.7.2 via monkey patching and now full transformer training is working properly. Now I moved to more advanced training setups, like training extra heads on top of the transformer while freezing the backbone (via optax). So in forward pass, we're doing full model forward pass, get output embeddings and then feed it to smaller model to predict main model's distribution. But masking some keys in model state's optimizer, we ensure that during the backward pass we only train heads and not the backbone model so jax optimizes away entire main model backward pass.
For the backbone models, we're not using scan to iterate over layers. But for the heads, we're using scan.
We got the following error:
File "/papyrax/papyrax/training/trainer.py", line 128, in _loop_body
grad_fn.with_refs(grads_accum)(jnp.ones_like(loss))
ValueError: Invalid shape for `addupdate`. Ref shape: (14336, 4096). Expected shape: (14336, 4096). Value shape: (4096, 14336). Transforms: ().
Error happens here in backward_pass3 function:
for x, ct in zip(primals, cts_out):
if isinstance(x, GradAccum):
x.accum(ct) <- here
primals
[JitTracer<bfloat16[8,8192,4096]>, <jax._src.interpreters.ad.RefAccum object at 0x7fa7e0342bd0>]
cts_out
(None, JitTracer<bfloat16[4096,14336]>)
My guess would be that at some points training state tree is being flattened in a list and then some entires are filtered out while gradient update expect original flat tree.
It is also interesting that it looks like I'm getting this error during tracing, even before before lowering and compiling.
I tried reproducing this error via simple script but no luck yet.
System info (python version, jaxlib version, accelerator, etc.)
root@computeinstance-e00p9411jfafcfq06n:/papyrax# python3 -c 'import jax; jax.print_environment_info()'
jax: 0.7.2
jaxlib: 0.7.2
numpy: 2.1.3
python: 3.11.13 (main, Jun 4 2025, 08:57:30) [GCC 13.3.0]
device info: NVIDIA H100 80GB HBM3-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='computeinstance-e00p9411jfafcfq06n', release='5.15.0-152-generic', version='#162-Ubuntu SMP Wed Jul 23 09:48:42 UTC 2025', machine='x86_64')
XLA_PYTHON_CLIENT_MEM_FRACTION=.85
JAX_COMPILATION_CACHE_DIR=/mnt/llm/cache/jax_compilation_cache
$ nvidia-smi
Tue Oct 7 15:18:00 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.163.01 Driver Version: 550.163.01 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 |
| N/A 36C P0 120W / 700W | 558MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 |
| N/A 32C P0 122W / 700W | 538MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 |
| N/A 31C P0 120W / 700W | 538MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 |
| N/A 36C P0 123W / 700W | 538MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 |
| N/A 37C P0 123W / 700W | 538MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 |
| N/A 32C P0 120W / 700W | 538MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 |
| N/A 31C P0 121W / 700W | 538MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 |
| N/A 37C P0 124W / 700W | 538MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+