Skip to content
Open
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

Documenting changes which affect configuration usage patterns (added/moved/removed/renamed fields, notable logic changes).

- **`model.ac.mode`** and **`model.ac.targets`**: Added selective activation checkpointing configuration. `model.ac.mode` accepts `full` (default) or `selective`. When `selective`, `model.ac.targets` selects subcomponents to checkpoint. Supported public targets are currently `norm`, `attention_sdpa`, `mla_up_proj`, and `routed_experts`; runtime validation remains the source of truth. `model.ac.targets` defaults to `["norm"]`, and selective mode requires at least one target. (2026-03-20)
- **`model.optimization_dtype` / `model.reduce_dtype` (VLM models)**: Added validation that VLM models must use `optimization_dtype='bfloat16'` and `reduce_dtype='bfloat16'` to match vLLM inference. Previously valid configs with `float32` (the default) are now rejected for VLM model names. Set both fields to `"bfloat16"` when training VLMs. (2026-03-21)
- **`orchestrator.advantage.length_weighted_mean`**: Removed. The default advantage now always uses the plain per-problem mean baseline unless `orchestrator.advantage.length_shaping_alpha` is set. Existing configs must delete this field. (2026-03-19)
- **`orchestrator.advantage.length_shaping_alpha`**: Added Group Relative Reward Rescaling coefficient to the default advantage config. When set, applies length-based GR3 shaping during advantage computation and requires `orchestrator.buffer.online_difficulty_filtering = true` (default: `None`) (2026-03-18)
Expand Down
16 changes: 13 additions & 3 deletions benchmarks/scripts/run_single_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,15 @@ class BenchmarkConfig(BaseConfig):

seq_len: Annotated[int, Field(ge=1, description="Sequence length")] = 512

ac: Annotated[Literal["Recompute", "Offload", "None"], Field(description="Activation checkpointing type")] = (
"Recompute"
)
ac: Annotated[
Literal["Recompute", "Selective", "Offload"] | None,
Field(description="Activation checkpointing type"),
] = "Recompute"

selective_targets: Annotated[
list[str] | None,
Field(description="Selective activation checkpoint targets when ac=Selective"),
] = None

attention: Annotated[
Literal["sdpa", "flash_attention_2", "flash_attention_3", "flash_attention_4"],
Expand Down Expand Up @@ -141,6 +147,10 @@ def build_command(config: BenchmarkConfig) -> list[str]:
# Add activation checkpointing if enabled
if config.ac == "Recompute":
cmd.append("--model.ac")
elif config.ac == "Selective":
cmd.extend(["--model.ac", "--model.ac.mode", "selective"])
if config.selective_targets:
cmd.extend(["--model.ac.targets", json.dumps(config.selective_targets)])
elif config.ac == "Offload":
cmd.append("--model.ac-offloading")

Expand Down
28 changes: 27 additions & 1 deletion src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,34 @@
class ActivationCheckpointConfig(BaseConfig):
"""Configures activation checkpointing."""

mode: Annotated[
Literal["full", "selective"],
Field(
description="Whether to checkpoint whole transformer blocks (`full`) or selected subcomponents inside supported custom decoder layers (`selective`).",
),
] = "full"

freq: Annotated[
int,
Field(
ge=1,
description="Applies activation checkpointing to every `freq` layers. Defaults to 1, which will is full activation checkpointing.",
description="Applies activation checkpointing to every `freq` layers. Defaults to 1.",
),
] = 1

targets: Annotated[
list[str],
Field(
description="Selective checkpoint targets. `norm` checkpoints every norm module executed inside selected layers, including decoder, attention, MLA, and other model-specific norm blocks. `attention_sdpa` checkpoints the attention-kernel stage regardless of backend. `mla_up_proj` checkpoints MLA Q/KV up-projection work where supported, and `routed_experts` checkpoints routed expert compute in MoE layers.",
),
] = ["norm"]

@model_validator(mode="after")
def validate_selective_targets(self):
if self.mode == "selective" and not self.targets:
raise ValueError("Selective activation checkpointing requires at least one target.")
return self


class ActivationOffloadingConfig(BaseConfig):
"""Configures the activation offloading."""
Expand Down Expand Up @@ -325,6 +345,12 @@ def ac_offloading_requires_ac(self):
self.ac = ActivationCheckpointConfig()
return self

@model_validator(mode="after")
def selective_ac_only_with_custom_impl(self):
if self.ac is not None and self.ac.mode == "selective" and self.impl != "custom":
raise ValueError("Selective activation checkpointing requires model.impl='custom'")
return self

@model_validator(mode="after")
def cpu_offload_mutual_exclusion(self):
if self.fsdp_cpu_offload and self.optim_cpu_offload:
Expand Down
53 changes: 51 additions & 2 deletions src/prime_rl/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
get_custom_vlm_cls,
supports_custom_impl,
)
from prime_rl.trainer.models.layers.checkpointing import (
get_supported_targets,
set_selective_activation_checkpointing,
supports_selective_activation_checkpointing,
)
from prime_rl.trainer.models.layers.lm_head import inject_prime_lm_head
from prime_rl.trainer.models.layers.moe import MoE
from prime_rl.trainer.parallel_dims import ParallelDims
Expand Down Expand Up @@ -638,12 +643,49 @@ def reshard_module(model: nn.Module):


def apply_ac(model: nn.Module, ac_config: ActivationCheckpointConfig):
logger = get_logger()
language_model = get_language_model(model)
selective_layers = 0
full_layers = 0
fallback_layer_types: set[str] = set()
model_supported_targets: set[str] = set()

for layer_id, (layer_name, transformer_block) in enumerate(language_model.layers.named_children()):
if layer_id % ac_config.freq == 0:
if layer_id % ac_config.freq != 0:
continue

if ac_config.mode == "selective" and supports_selective_activation_checkpointing(transformer_block):
model_supported_targets.update(get_supported_targets(transformer_block))
set_selective_activation_checkpointing(transformer_block, ac_config.targets)
selective_layers += 1
else:
if ac_config.mode == "selective":
fallback_layer_types.add(type(transformer_block).__name__)
transformer_block = checkpoint_wrapper(transformer_block, preserve_rng_state=False)
full_layers += 1

language_model.layers.register_module(layer_name, transformer_block)
get_logger().info(f"Applied activation checkpointing (freq={ac_config.freq})")

if ac_config.mode == "selective":
unsupported_targets = frozenset(ac_config.targets) - model_supported_targets
if unsupported_targets:
raise ValueError(
f"Selective activation checkpoint targets {sorted(unsupported_targets)} are not supported "
f"by the selected model layers. Supported targets across the model: {sorted(model_supported_targets)}"
)
if fallback_layer_types:
logger.warning(
"Selective activation checkpointing is not supported for layer types "
f"{sorted(fallback_layer_types)}; falling back to full checkpointing for those layers."
)
logger.info(
"Applied selective activation checkpointing "
f"(freq={ac_config.freq}, targets={ac_config.targets}, selective_layers={selective_layers}, "
f"full_fallback_layers={full_layers})"
)
return

logger.info(f"Applied activation checkpointing (freq={ac_config.freq})")


def apply_compile(model: nn.Module, compile_config: CompileConfig):
Expand Down Expand Up @@ -675,6 +717,12 @@ def _move_buffers_to_cuda(model: nn.Module, config: ModelConfig) -> None:
buffer.data = buffer.data.to("cuda")


def _reset_runtime_moe_buffers(model: nn.Module) -> None:
for module in model.modules():
if isinstance(module, MoE) and module.tokens_per_expert.device.type != "meta":
module.tokens_per_expert.zero_()


def _validate_flash_attn_4_installed() -> None:
"""Validate that flash-attn-cute is installed and not overwritten by flash-attn.

Expand Down Expand Up @@ -797,6 +845,7 @@ def setup_model(
else:
load_dcp_from_hf(model, config, parallel_dims)

_reset_runtime_moe_buffers(model)
return model


Expand Down
45 changes: 38 additions & 7 deletions src/prime_rl/trainer/models/afmoe/modeling_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,24 @@ def _finalize_output(
class AfmoeSDPAAttention(AfmoeAttentionBase):
"""AFMoE attention using PyTorch's scaled_dot_product_attention."""

def _attention_core(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
) -> torch.Tensor:
return F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=dropout_p,
is_causal=attention_mask is None,
scale=self.scaling,
)

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -159,14 +177,12 @@ def forward(
)

dropout_p = self.attention_dropout if self.training else 0.0
attn_output = F.scaled_dot_product_attention(
attn_output = self._attention_core(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
attention_mask=attention_mask,
dropout_p=dropout_p,
is_causal=attention_mask is None, # Use causal if no explicit mask
scale=self.scaling,
)

return self._finalize_output(attn_output, gate_states, input_shape)
Expand Down Expand Up @@ -202,6 +218,16 @@ def _compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out = out[0]
return out

def _attention_core(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
cu_seqlens: torch.LongTensor | None = None,
max_seqlen: int | None = None,
) -> torch.Tensor:
return self._compute_attention(query_states[0], key_states[0], value_states[0], cu_seqlens, max_seqlen)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -229,9 +255,14 @@ def forward(
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)

out = self._compute_attention(query_states[0], key_states[0], value_states[0], cu_seqlens, max_seqlen)

attn_output = out.contiguous().view(*input_shape, -1)
attn_output = self._attention_core(
query_states,
key_states,
value_states,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
attn_output = attn_output.contiguous().view(*input_shape, -1)
attn_output = attn_output * torch.sigmoid(gate_states)
attn_output = self.o_proj(attn_output)
return attn_output, None
Expand Down
Loading
Loading