Skip to content
16 changes: 13 additions & 3 deletions benchmarks/scripts/run_single_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,15 @@ class BenchmarkConfig(BaseConfig):

seq_len: Annotated[int, Field(ge=1, description="Sequence length")] = 512

ac: Annotated[Literal["Recompute", "Offload", "None"], Field(description="Activation checkpointing type")] = (
"Recompute"
)
ac: Annotated[
Literal["Recompute", "Selective", "Offload"] | None,
Field(description="Activation checkpointing type"),
] = "Recompute"

selective_targets: Annotated[
list[Literal["norm", "attention_sdpa", "mla_up_proj", "routed_experts"]] | None,
Field(description="Selective activation checkpoint targets when ac=Selective"),
] = None

attention: Annotated[
Literal["sdpa", "flash_attention_2", "flash_attention_3", "flash_attention_4"],
Expand Down Expand Up @@ -141,6 +147,10 @@ def build_command(config: BenchmarkConfig) -> list[str]:
# Add activation checkpointing if enabled
if config.ac == "Recompute":
cmd.append("--model.ac")
elif config.ac == "Selective":
cmd.extend(["--model.ac", "--model.ac.mode", "selective"])
if config.selective_targets:
cmd.extend(["--model.ac.targets", json.dumps(config.selective_targets)])
elif config.ac == "Offload":
cmd.append("--model.ac-offloading")

Expand Down
28 changes: 27 additions & 1 deletion src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
# -- Shared trainer configs (used by both SFT and RL trainers) --

AttnImplementation: TypeAlias = Literal["sdpa", "flash_attention_2", "flash_attention_3", "fa4"]
ActivationCheckpointTarget: TypeAlias = Literal[
"norm",
"attention_sdpa",
"mla_up_proj",
"routed_experts",
]

# User-facing name -> internal name. Users set `flash_attention_4` in configs,
# which gets rewritten to `fa4` before pydantic validation.
Expand All @@ -28,14 +34,34 @@
class ActivationCheckpointConfig(BaseConfig):
"""Configures activation checkpointing."""

mode: Annotated[
Literal["full", "selective"],
Field(
description="Whether to checkpoint whole transformer blocks (`full`) or selected subcomponents inside supported custom decoder layers (`selective`).",
),
] = "full"

freq: Annotated[
int,
Field(
ge=1,
description="Applies activation checkpointing to every `freq` layers. Defaults to 1, which will is full activation checkpointing.",
description="Applies activation checkpointing to every `freq` layers. Defaults to 1.",
),
] = 1

targets: Annotated[
list[ActivationCheckpointTarget],
Field(
description="Selective checkpoint targets. `norm` checkpoints decoder RMSNorm stages, MLA latent RMSNorm stages, and folds in QK norm plus RoPE when available. `attention_sdpa` checkpoints the attention-kernel stage regardless of backend. `mla_up_proj` checkpoints MLA Q/KV up-projection work where supported, and `routed_experts` checkpoints routed expert compute in MoE layers.",
),
] = ["norm"]

@model_validator(mode="after")
def validate_selective_targets(self):
if self.mode == "selective" and not self.targets:
raise ValueError("Selective activation checkpointing requires at least one target.")
return self


class ActivationOffloadingConfig(BaseConfig):
"""Configures the activation offloading."""
Expand Down
52 changes: 50 additions & 2 deletions src/prime_rl/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
get_custom_vlm_cls,
supports_custom_impl,
)
from prime_rl.trainer.models.layers.checkpointing import get_supported_targets, set_selective_activation_checkpointing
from prime_rl.trainer.models.layers.lm_head import inject_prime_lm_head
from prime_rl.trainer.models.layers.moe import MoE
from prime_rl.trainer.parallel_dims import ParallelDims
Expand Down Expand Up @@ -332,6 +333,7 @@ def get_model(
assert model.lm_head.weight.dtype == dtype, (
f"LM head dtype wasnt loaded correctly {model.lm_head.weight.dtype} != {dtype}"
)
_reset_runtime_moe_buffers(model)
return model


Expand Down Expand Up @@ -638,12 +640,51 @@ def reshard_module(model: nn.Module):


def apply_ac(model: nn.Module, ac_config: ActivationCheckpointConfig):
logger = get_logger()
language_model = get_language_model(model)
selective_layers = 0
full_layers = 0
fallback_layer_types: set[str] = set()
model_supported_targets: set[str] = set()

for layer_id, (layer_name, transformer_block) in enumerate(language_model.layers.named_children()):
if layer_id % ac_config.freq == 0:
if layer_id % ac_config.freq != 0:
continue

if ac_config.mode == "selective" and getattr(
transformer_block, "supports_selective_activation_checkpointing", False
):
model_supported_targets.update(get_supported_targets(transformer_block))
set_selective_activation_checkpointing(transformer_block, ac_config.targets)
selective_layers += 1
else:
if ac_config.mode == "selective":
fallback_layer_types.add(type(transformer_block).__name__)
transformer_block = checkpoint_wrapper(transformer_block, preserve_rng_state=False)
full_layers += 1

language_model.layers.register_module(layer_name, transformer_block)
get_logger().info(f"Applied activation checkpointing (freq={ac_config.freq})")

if ac_config.mode == "selective":
unsupported_targets = frozenset(ac_config.targets) - model_supported_targets
if unsupported_targets:
raise ValueError(
f"Selective activation checkpoint targets {sorted(unsupported_targets)} are not supported "
f"by the selected model layers. Supported targets across the model: {sorted(model_supported_targets)}"
)
if fallback_layer_types:
logger.warning(
"Selective activation checkpointing is not supported for layer types "
f"{sorted(fallback_layer_types)}; falling back to full checkpointing for those layers."
)
logger.info(
"Applied selective activation checkpointing "
f"(freq={ac_config.freq}, targets={ac_config.targets}, selective_layers={selective_layers}, "
f"full_fallback_layers={full_layers})"
)
return

logger.info(f"Applied activation checkpointing (freq={ac_config.freq})")


def apply_compile(model: nn.Module, compile_config: CompileConfig):
Expand Down Expand Up @@ -675,6 +716,12 @@ def _move_buffers_to_cuda(model: nn.Module, config: ModelConfig) -> None:
buffer.data = buffer.data.to("cuda")


def _reset_runtime_moe_buffers(model: nn.Module) -> None:
for module in model.modules():
if isinstance(module, MoE) and module.tokens_per_expert.device.type != "meta":
module.tokens_per_expert.zero_()


def _validate_flash_attn_4_installed() -> None:
"""Validate that flash-attn-cute is installed and not overwritten by flash-attn.

Expand Down Expand Up @@ -797,6 +844,7 @@ def setup_model(
else:
load_dcp_from_hf(model, config, parallel_dims)

_reset_runtime_moe_buffers(model)
return model


Expand Down
21 changes: 18 additions & 3 deletions src/prime_rl/trainer/models/afmoe/modeling_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from transformers.utils import TransformersKwargs

from prime_rl.trainer.models.base import PreTrainedModelPrimeRL
from prime_rl.trainer.models.layers.checkpointing import run_with_optional_checkpoint, should_checkpoint
from prime_rl.trainer.models.layers.lm_head import PrimeLmOutput
from prime_rl.trainer.models.layers.mlp import MLP, MLPConfig
from prime_rl.trainer.models.layers.moe import MoE, MoEArgs
Expand Down Expand Up @@ -287,6 +288,8 @@ def _get_afmoe_attention(config: AfmoeConfig, layer_idx: int) -> nn.Module:


class AfmoeDecoderLayer(GradientCheckpointingLayer):
supports_selective_activation_checkpointing = True

def __init__(self, config: AfmoeConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
Expand Down Expand Up @@ -333,9 +336,14 @@ def forward(
max_seqlen: int | None = None,
routed_experts: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
checkpoint_attn_norm = should_checkpoint(self, "attn_norm")
checkpoint_ffn_norm = should_checkpoint(self, "ffn_norm")
checkpoint_routed_experts = should_checkpoint(self, "routed_experts")

residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
hidden_states = run_with_optional_checkpoint(checkpoint_attn_norm, self.input_layernorm, hidden_states)

hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
Expand All @@ -347,8 +355,15 @@ def forward(
hidden_states = residual + hidden_states

residual = hidden_states
hidden_states = self.pre_mlp_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states, routed_experts=routed_experts)
hidden_states = run_with_optional_checkpoint(checkpoint_ffn_norm, self.pre_mlp_layernorm, hidden_states)
if isinstance(self.mlp, MoE):
hidden_states = self.mlp(
hidden_states,
routed_experts=routed_experts,
checkpoint_routed_experts=checkpoint_routed_experts,
)
else:
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_mlp_layernorm(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
Expand Down
42 changes: 37 additions & 5 deletions src/prime_rl/trainer/models/glm4_moe/modeling_glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
convert_tt_to_hf_moe,
)
from prime_rl.trainer.models.layers.attn import ATTN_IMPL2CLASS, AttentionConfig
from prime_rl.trainer.models.layers.checkpointing import (
ATTENTION_SELECTIVE_AC_TARGETS,
DEFAULT_SELECTIVE_AC_TARGETS,
MOE_SELECTIVE_AC_TARGETS,
run_with_optional_checkpoint,
should_checkpoint,
)
from prime_rl.trainer.models.layers.lm_head import PrimeLmOutput
from prime_rl.trainer.models.layers.mlp import MLP, MLPConfig
from prime_rl.trainer.models.layers.moe import MoE, MoEArgs
Expand All @@ -42,6 +49,8 @@


class Glm4MoeDecoderLayer(GradientCheckpointingLayer):
supports_selective_activation_checkpointing = True

def __init__(self, config: Glm4MoeConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
Expand Down Expand Up @@ -83,6 +92,13 @@ def __init__(self, config: Glm4MoeConfig, layer_idx: int):
self.input_layernorm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps))
self.post_attention_layernorm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps))

@property
def supported_selective_activation_checkpoint_targets(self) -> frozenset[str]:
targets = DEFAULT_SELECTIVE_AC_TARGETS | ATTENTION_SELECTIVE_AC_TARGETS
if isinstance(self.mlp, MoE):
return targets | MOE_SELECTIVE_AC_TARGETS
return targets

@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
Expand All @@ -92,21 +108,37 @@ def forward(
max_seqlen: Optional[int] = None,
routed_experts: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
checkpoint_attn_norm = should_checkpoint(self, "attn_norm")
checkpoint_ffn_norm = should_checkpoint(self, "ffn_norm")
checkpoint_qk_norm_rope = should_checkpoint(self, "qk_norm_rope")
checkpoint_attention_sdpa = should_checkpoint(self, "attention_sdpa")
checkpoint_routed_experts = should_checkpoint(self, "routed_experts")

residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = run_with_optional_checkpoint(checkpoint_attn_norm, self.input_layernorm, hidden_states)

# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
hidden_states, _ = self.self_attn.forward_selective(
hidden_states,
position_embeddings=position_embeddings,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
checkpoint_qk_norm_rope=checkpoint_qk_norm_rope,
checkpoint_attention_sdpa=checkpoint_attention_sdpa,
)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states, routed_experts=routed_experts)
hidden_states = run_with_optional_checkpoint(checkpoint_ffn_norm, self.post_attention_layernorm, hidden_states)
if isinstance(self.mlp, MoE):
hidden_states = self.mlp(
hidden_states,
routed_experts=routed_experts,
checkpoint_routed_experts=checkpoint_routed_experts,
)
else:
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states

Expand Down
Loading
Loading