Skip to content

selective ac#2055

Open
faresobeid wants to merge 16 commits intomainfrom
selective-ac-new
Open

selective ac#2055
faresobeid wants to merge 16 commits intomainfrom
selective-ac-new

Conversation

@faresobeid
Copy link
Contributor

@faresobeid faresobeid commented Mar 20, 2026

Can use with
[model.ac]
freq = 1
mode = "selective"
targets = ["norm", "attention_sdpa", "routed_experts"]


Note

Medium Risk
Changes core model execution and memory behavior by introducing selective activation checkpointing and refactoring attention/MoE internals, which could affect training correctness/perf on specific custom model layers.

Overview
Adds selective activation checkpointing via new model.ac.mode (full/selective) and model.ac.targets, including validation (non-empty targets; requires model.impl='custom') and benchmark-script support for the new CLI flags.

Implements selective checkpointing by patching specific subcomponent methods (norm, _attention_core, _mla_up_proj, _run_routed_experts) with non-reentrant checkpoints, updating apply_ac to mix selective + full-block fallback per layer and to error on unsupported targets.

Refactors attention implementations to expose _attention_core for SDPA/Flash paths (AFMoE, Qwen3.5-MoE, and shared attention layers), extracts GLM MoE DSA sparse MLA attention into a new mla_attn.py, and splits MoE routed expert compute into _run_routed_experts; also resets MoE tokens_per_expert buffers after model setup to avoid stale runtime stats.

Written by Cursor Bugbot for commit c4a87ca. This will update automatically on new commits. Configure here.

Copy link
Member

@samsja samsja left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure how I feel about the run_with_optional_checkpoint pattern, wondering if we can do the same with hook instead, defo make the code harder to understand, how does torchtitan do this ?

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Autofix Details

Bugbot Autofix prepared a fix for the issue found in the latest run.

  • ✅ Fixed: Missing changelog for config schema changes
    • Added a CHANGELOG.md entry documenting the new model.ac.mode and model.ac.targets fields.

Create PR

Or push these changes by commenting:

@cursor push f77b587704
Preview (f77b587704)
diff --git a/CHANGELOG.md b/CHANGELOG.md
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,6 +2,8 @@
 
 Documenting changes which affect configuration usage patterns (added/moved/removed/renamed fields, notable logic changes).
 
+- **`model.ac.mode`** and **`model.ac.targets`**: Added selective activation checkpointing controls. `mode` selects `"full"` (whole blocks) vs `"selective"` (subcomponents on supported custom decoder layers). When `"selective"`, `targets` chooses from `["norm", "attention_sdpa", "mla_up_proj", "routed_experts"]`. Defaults: `mode="full"`, `targets=["norm"]`. (2026-03-20)
+
 - **`orchestrator.advantage.length_weighted_mean`**: Removed. The default advantage now always uses the plain per-problem mean baseline unless `orchestrator.advantage.length_shaping_alpha` is set. Existing configs must delete this field. (2026-03-19)
 - **`orchestrator.advantage.length_shaping_alpha`**: Added Group Relative Reward Rescaling coefficient to the default advantage config. When set, applies length-based GR3 shaping during advantage computation and requires `orchestrator.buffer.online_difficulty_filtering = true` (default: `None`) (2026-03-18)
 - **`prime_monitor.log_extras.sample_ratio`**: Added ratio-based rollout sampling (0.0–1.0, default: None). When set, caps the number of rollouts logged per step to `len(rollouts) * sample_ratio`. `None` preserves current behavior (log all rollouts). Interacts with existing `interval` gate which still runs first. (2026-03-12)

This Bugbot Autofix run was free. To enable autofix for future PRs, go to the Cursor dashboard.

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

There are 2 total unresolved issues (including 1 from previous review).

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

faresobeid and others added 2 commits March 22, 2026 06:01
Signed-off-by: faresobeid <111092724+faresobeid@users.noreply.github.com>
Copy link
Collaborator

@S1ro1 S1ro1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just nits on styling, else lgtm

return dq, dkv, None, None


class LayerNorm(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't live in this file imo, create a new layers/norms.py maybe?

return indices.view(1, total_tokens, 1, index_topk)


class GlmMoeDsaAttention(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this being in this file either, or atleast it being called GlmMoeDsaAttention if it has to live here.



class GlmMoeDsaAttention(nn.Module):
def __init__(self, config: GlmMoeDsaConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it also depends on GlmMoeDsaConfig which introduces ugly circular pattern

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants