Skip to content

Activation checkpointing crashes Qwen3 MoE DTensor policy with CheckpointError #1203

@kwanUm

Description

@kwanUm

When the DTensor policy is run with activation_checkpointing: true on a Qwen3 MoE model, training fails with torch.utils.checkpoint.CheckpointError (different number of tensors saved vs recomputed). Disabling checkpointing or swapping to non-deterministic-free attention kernels resolves the crash.
Steps to Reproduce

Enable activation checkpointing via policy.dtensor_cfg.activation_checkpointing: true. (third_party/NeMo-RL/nemo_rl/models/dtensor/parallelize.py wraps each decoder layer’s submodules with checkpoint_wrapper, line parallelize.py:660.)
Launch src/flywheel/example_training_bixbench.py (GRPO) against that config.
After a few microbatches, training aborts while backpropagating inside DTensorPolicyWorkerV2.train.
Observed Behavior

torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: A different number of tensors was saved during the original forward and recomputation.
Number of tensors saved during forward: 2038
Number of tensors saved during recomputation: 1638
Stack trace ends at third_party/NeMo-RL/nemo_rl/models/policy/dtensor_policy_worker_v2.py:787 (loss.backward).

Analysis
The checkpoint wrapper is applied to layers[i].mlp, but in Qwen3 MoE that module is the sparse MoE gate (transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock, see modeling_qwen3_moe.py:211-260). It routes tokens with softmax + topk, then executes only the activated experts. During the recomputed forward pass, minor numeric differences (likely from fused Flash Attention kernels under bf16 + TP) flip which experts are selected, the loop executes a different count of expert MLPs, and autograd detects a mismatch in saved tensors.

Expected Behavior
Checkpoint recomputation should hit the same autograd graph and succeed; activation checkpointing should not break GRPO training on Qwen3 MoE.

Request
Could NeMo-RL skip wrapping the MoE router in parallelize.py or otherwise ensure determinism (e.g., use a deterministic attention path, or add a determinism_check="none" gate with logging)? Even a config hook to opt out of checkpointing specific submodules would help.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions