Skip to content

vjp3 + array refs fails when some of the gradients are masked #32468

@qGentry

Description

@qGentry

Description

Hi, I keep playing around with the vjp3 + array refs for zero-memory-overhead gradient accumulation. I recently reported the problem about scan + reverse:
#32411 (comment)
I brought it to JAX 0.7.2 via monkey patching and now full transformer training is working properly. Now I moved to more advanced training setups, like training extra heads on top of the transformer while freezing the backbone (via optax). So in forward pass, we're doing full model forward pass, get output embeddings and then feed it to smaller model to predict main model's distribution. But masking some keys in model state's optimizer, we ensure that during the backward pass we only train heads and not the backbone model so jax optimizes away entire main model backward pass.

For the backbone models, we're not using scan to iterate over layers. But for the heads, we're using scan.
We got the following error:

 File "/papyrax/papyrax/training/trainer.py", line 128, in _loop_body
    grad_fn.with_refs(grads_accum)(jnp.ones_like(loss))
ValueError: Invalid shape for `addupdate`. Ref shape: (14336, 4096). Expected shape: (14336, 4096). Value shape: (4096, 14336). Transforms: ().

Error happens here in backward_pass3 function:

            for x, ct in zip(primals, cts_out):
              if isinstance(x, GradAccum):
                x.accum(ct) <- here
primals
[JitTracer<bfloat16[8,8192,4096]>, <jax._src.interpreters.ad.RefAccum object at 0x7fa7e0342bd0>]
cts_out
(None, JitTracer<bfloat16[4096,14336]>)

My guess would be that at some points training state tree is being flattened in a list and then some entires are filtered out while gradient update expect original flat tree.

It is also interesting that it looks like I'm getting this error during tracing, even before before lowering and compiling.

I tried reproducing this error via simple script but no luck yet.

System info (python version, jaxlib version, accelerator, etc.)

root@computeinstance-e00p9411jfafcfq06n:/papyrax# python3 -c 'import jax; jax.print_environment_info()'
jax:    0.7.2
jaxlib: 0.7.2
numpy:  2.1.3
python: 3.11.13 (main, Jun  4 2025, 08:57:30) [GCC 13.3.0]
device info: NVIDIA H100 80GB HBM3-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='computeinstance-e00p9411jfafcfq06n', release='5.15.0-152-generic', version='#162-Ubuntu SMP Wed Jul 23 09:48:42 UTC 2025', machine='x86_64')
XLA_PYTHON_CLIENT_MEM_FRACTION=.85
JAX_COMPILATION_CACHE_DIR=/mnt/llm/cache/jax_compilation_cache

$ nvidia-smi
Tue Oct  7 15:18:00 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.163.01             Driver Version: 550.163.01     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:8D:00.0 Off |                    0 |
| N/A   36C    P0            120W /  700W |     558MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00000000:91:00.0 Off |                    0 |
| N/A   32C    P0            122W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  |   00000000:95:00.0 Off |                    0 |
| N/A   31C    P0            120W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  |   00000000:99:00.0 Off |                    0 |
| N/A   36C    P0            123W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  |   00000000:AB:00.0 Off |                    0 |
| N/A   37C    P0            123W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  |   00000000:AF:00.0 Off |                    0 |
| N/A   32C    P0            120W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  |   00000000:B3:00.0 Off |                    0 |
| N/A   31C    P0            121W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  |   00000000:B7:00.0 Off |                    0 |
| N/A   37C    P0            124W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions