-
Notifications
You must be signed in to change notification settings - Fork 307
Description
When the DTensor policy is run with activation_checkpointing: true on a Qwen3 MoE model, training fails with torch.utils.checkpoint.CheckpointError (different number of tensors saved vs recomputed). Disabling checkpointing or swapping to non-deterministic-free attention kernels resolves the crash.
Steps to Reproduce
Enable activation checkpointing via policy.dtensor_cfg.activation_checkpointing: true. (third_party/NeMo-RL/nemo_rl/models/dtensor/parallelize.py wraps each decoder layer’s submodules with checkpoint_wrapper, line parallelize.py:660.)
Launch src/flywheel/example_training_bixbench.py (GRPO) against that config.
After a few microbatches, training aborts while backpropagating inside DTensorPolicyWorkerV2.train.
Observed Behavior
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: A different number of tensors was saved during the original forward and recomputation.
Number of tensors saved during forward: 2038
Number of tensors saved during recomputation: 1638
Stack trace ends at third_party/NeMo-RL/nemo_rl/models/policy/dtensor_policy_worker_v2.py:787 (loss.backward).
Analysis
The checkpoint wrapper is applied to layers[i].mlp, but in Qwen3 MoE that module is the sparse MoE gate (transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock, see modeling_qwen3_moe.py:211-260). It routes tokens with softmax + topk, then executes only the activated experts. During the recomputed forward pass, minor numeric differences (likely from fused Flash Attention kernels under bf16 + TP) flip which experts are selected, the loop executes a different count of expert MLPs, and autograd detects a mismatch in saved tensors.
Expected Behavior
Checkpoint recomputation should hit the same autograd graph and succeed; activation checkpointing should not break GRPO training on Qwen3 MoE.
Request
Could NeMo-RL skip wrapping the MoE router in parallelize.py or otherwise ensure determinism (e.g., use a deterministic attention path, or add a determinism_check="none" gate with logging)? Even a config hook to opt out of checkpointing specific submodules would help.