-
Notifications
You must be signed in to change notification settings - Fork 169
Efficient Eagle3 training with eagle KV cache and flex attention #350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
WalkthroughIntroduces optional max_length propagation through Eagle data collators and call sites, enables flex attention in EagleModule/HF model with block-mask-based TTT handling, refactors Eagle forward to loop over three TTT steps with cached masks and eagle_cache, simplifies loss to classification-only, updates main caller, and removes a heavy unit test. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Trainer
participant Main
participant DataModule as Eagle Data Module
participant Model as EagleModel
participant FlexAttn as Flex Attention
rect rgba(230,240,255,0.6)
note right of Main: Configure training
Trainer->>Main: Build training args (training_seq_len)
Main->>DataModule: make_eagle_supervised_data_module(max_length=training_seq_len)
DataModule-->>Main: Collators with fixed max_length
end
rect rgba(240,255,240,0.6)
note right of Model: Forward with TTT loop
Main->>Model: forward(inputs, eagle_cache=None)
loop ttt_step in [0,1,2]
Model->>Model: _get_ttt_attention_mask(seq_len, step)
Model->>FlexAttn: apply BlockMask (cached/compiled)
Model->>Model: _eagle_forward(..., eagle_cache)
Model->>Model: compute classification_loss, accuracy
Model-->>Model: update eagle_cache
end
Model-->>Main: logits, losses, train_accs
end
rect rgba(255,245,230,0.6)
note over Model: Loss = classification only
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
8938a09
to
6759b45
Compare
6759b45
to
f5835f9
Compare
What does base model only mean in the figure? |
Can you also do correctness check on online training? |
e4a0b13
to
3415d15
Compare
d364c25
to
3415d15
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (6)
modelopt/torch/speculative/plugins/transformers.py (6)
39-39
: Flex Attention import: ensure environment compatibility.torch.nn.attention.flex_attention is version‑gated; add a clear error if unavailable.
Apply:
-from torch.nn.attention.flex_attention import BlockMask, create_block_mask +try: + from torch.nn.attention.flex_attention import BlockMask, create_block_mask +except Exception as e: + raise ImportError( + "Flex Attention is required for EAGLE TTT. Please upgrade PyTorch to a version that " + "provides torch.nn.attention.flex_attention." + ) from e
187-189
: Attn backend set in two places (flex here, sdpa elsewhere). Unify to avoid confusion.HFEagleModel.modify later forces sdpa on self.eagle_config; pick one source of truth.
Apply either:
- Remove/override the sdpa assignment in HFEagleModel.modify, or
- Gate the choice behind a single config flag.
Example (modify HFEagleModel.modify outside this hunk):
- self.eagle_config._attn_implementation = "sdpa" + self.eagle_config._attn_implementation = "flex_attention" # keep consistent with EagleModule
454-460
: Hardcoded num_ttt_steps=3 and eager compile rely on training_seq_len; make configurable and validate.Generalize steps and fail early if training_seq_len is missing.
Apply:
- self.num_ttt_steps = 3 # NOTE: (hg) hardcoded for now. Might add to config later. - # compile and cach flex attention masks - self.cached_attn_blk_masks = [ - self._compile_ttt_block_mask(self.eagle_config.training_seq_len, i) - for i in range(self.num_ttt_steps) - ] + self.num_ttt_steps = getattr(self.eagle_config, "num_ttt_steps", 3) + if not hasattr(self.eagle_config, "training_seq_len") or self.eagle_config.training_seq_len is None: + raise ValueError("eagle_config.training_seq_len must be set when using flex attention TTT.") + # compile and cache flex attention masks + self.cached_attn_blk_masks = [ + self._compile_ttt_block_mask(self.eagle_config.training_seq_len, i) + for i in range(self.num_ttt_steps) + ]
538-572
: Mask compiler: add minimal docstrings and boundary assertions.Helps future maintenance and guards Q/KV bounds for each step.
Apply:
def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask: - """Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention.""" + """Compile per‑step TTT BlockMask for flex attention. + Assumes fixed training seq length and KV cache layout per step.""" + assert seq_length > 0, "seq_length must be positive"
720-721
: Lazily recompile masks if the actual sequence length drifts from configured training_seq_len.Avoids hard failure when users change max_length without rebuilding the model.
Apply:
- b, seq_length, h = base_model_hidden_states.shape + b, seq_length, h = base_model_hidden_states.shape + if getattr(self.eagle_config, "training_seq_len", None) != seq_length: + # Recompile masks on‑the‑fly + self.cached_attn_blk_masks = [ + self._compile_ttt_block_mask(seq_length, i) for i in range(self.num_ttt_steps) + ] + self.eagle_config.training_seq_len = int(seq_length)
785-803
: Loss masking alignment looks correct. Normalize mask dtype to avoid type‑promotion surprises.Apply:
- classification_loss, acc = self._eagle_loss( + classification_loss, acc = self._eagle_loss( # base model predict +1 tok, while eagle predict +2 # so we shift base model outputs compared to eagle outputs base_model_logits[:, 1:], eagle_logits[:, :-1], # additionally, we mask the first n tok of eagle outputs at nth TTT step - torch.cat( + torch.cat( ( torch.zeros( - b, 1 + ttt_step, dtype=loss_mask.dtype, device=loss_mask.device + b, 1 + ttt_step, dtype=loss_mask.dtype, device=loss_mask.device ), loss_mask[:, 2 + ttt_step :], ), dim=1, ), )And inside
_eagle_loss
(outside this hunk) cast the mask once:- loss_mask = loss_mask[:, :, None] + loss_mask = loss_mask.to(eagle_logits.dtype)[:, :, None]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/speculative_decoding/eagle_utils.py
(5 hunks)examples/speculative_decoding/main.py
(2 hunks)modelopt/torch/speculative/plugins/transformers.py
(11 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/speculative_decoding/eagle_utils.py (1)
modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
modelopt/torch/speculative/plugins/transformers.py (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
_eagle_forward
(1156-1179)
examples/speculative_decoding/main.py (1)
examples/speculative_decoding/eagle_utils.py (1)
make_eagle_supervised_data_module
(238-315)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
🔇 Additional comments (7)
examples/speculative_decoding/eagle_utils.py (3)
239-243
: API extension looks good and is backward‑compatible.Optional max_length propagated cleanly.
319-321
: Constructor OK. Consider documenting that sequences longer than max_length will be truncated (see below fix).
364-366
: OK to forward max_length to offline collator.Matches online path behavior.
modelopt/torch/speculative/plugins/transformers.py (3)
710-714
: Good: normalize legacy caches.Prevents attribute errors downstream.
771-784
: OK: reuse inputs_embeds across TTT passes. Verify intended semantics.If the first layer concatenates normalized embeds with shifted prenorm states, keeping inputs_embeds fixed is correct; otherwise, recompute per pass.
821-822
: train_acc API change (list) — check downstream consumers.Found only this assignment in repo: modelopt/torch/speculative/plugins/transformers.py:819-823 — train_acc = train_accs. Search for any callers/metrics/logging that assume a scalar and either aggregate the list here or update callers to accept a list.
examples/speculative_decoding/main.py (1)
232-234
: Approve: max_length is passed and used for Eagle runsConfirmed: main.py passes training_seq_len into make_eagle_supervised_data_module; examples/speculative_decoding/eagle_utils.py instantiates DataCollatorForOffline(max_length=...) and DataCollatorWithPadding(max_length=...) (DataCollatorForOffline forwards max_length). No other Eagle collator bypass found.
Locations: examples/speculative_decoding/main.py:232-234; examples/speculative_decoding/eagle_utils.py:301,309,363
max_hs_length = ( | ||
self.max_length | ||
if self.max_length is not None | ||
else max(item["base_model_hidden_states"].shape[0] for item in features) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apply the same truncation guard for hidden states in offline mode.
Mirror the online clamp to avoid negative padding on preprocessed HS.
Apply:
- max_hs_length = (
+ max_hs_length = (
self.max_length
if self.max_length is not None
else max(item["base_model_hidden_states"].shape[0] for item in features)
)
+ # Clamp oversize hidden states
+ for item in features:
+ if item["base_model_hidden_states"].shape[0] > max_hs_length:
+ item["base_model_hidden_states"] = item["base_model_hidden_states"][:max_hs_length]
+ item["aux_hidden_states"] = item["aux_hidden_states"][:max_hs_length]
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
max_hs_length = ( | |
self.max_length | |
if self.max_length is not None | |
else max(item["base_model_hidden_states"].shape[0] for item in features) | |
) | |
max_hs_length = ( | |
self.max_length | |
if self.max_length is not None | |
else max(item["base_model_hidden_states"].shape[0] for item in features) | |
) | |
# Clamp oversize hidden states | |
for item in features: | |
if item["base_model_hidden_states"].shape[0] > max_hs_length: | |
item["base_model_hidden_states"] = item["base_model_hidden_states"][:max_hs_length] | |
item["aux_hidden_states"] = item["aux_hidden_states"][:max_hs_length] |
🤖 Prompt for AI Agents
In examples/speculative_decoding/eagle_utils.py around lines 373 to 377, the
offline path sets max_hs_length to self.max_length (if not None) which can yield
negative padding when preprocessed hidden states are longer; compute the actual
maximum hidden-state length from features and ensure the final max_hs_length is
at least that value by replacing the assignment with a guarded clamp (e.g.,
compute computed_max = max(item["base_model_hidden_states"].shape[0] for item in
features) and then set max_hs_length = max(self.max_length or 0, computed_max)),
so offline mirrors the online clamp and avoids negative pad lengths.
# pass in the seq length for flex attention mask compilation | ||
"training_seq_len": training_args.training_seq_len, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good: plumbs fixed seq length into Eagle config for flex masks. Add a hard guard.
Protect users from silent shape mismatches by validating against base model limits here.
Apply:
config["eagle_architecture_config"].update(
{
"hidden_size": model.config.hidden_size,
"vocab_size": model.config.vocab_size,
# we also overwrite max_pos_embedding for deployment compatibility
"max_position_embeddings": model.config.max_position_embeddings,
"draft_vocab_size": custom_config["draft_vocab_size"]
if eagle_args.eagle_config and "draft_vocab_size" in custom_config
else model.config.vocab_size,
# pass in the seq length for flex attention mask compilation
- "training_seq_len": training_args.training_seq_len,
+ "training_seq_len": int(training_args.training_seq_len),
}
)
+ # Sanity: training_seq_len must not exceed base model's max positions
+ if training_args.training_seq_len > model.config.max_position_embeddings:
+ raise ValueError(
+ f"training_seq_len ({training_args.training_seq_len}) exceeds "
+ f"base model max_position_embeddings ({model.config.max_position_embeddings})."
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# pass in the seq length for flex attention mask compilation | |
"training_seq_len": training_args.training_seq_len, | |
} | |
config["eagle_architecture_config"].update( | |
{ | |
"hidden_size": model.config.hidden_size, | |
"vocab_size": model.config.vocab_size, | |
# we also overwrite max_pos_embedding for deployment compatibility | |
"max_position_embeddings": model.config.max_position_embeddings, | |
"draft_vocab_size": custom_config["draft_vocab_size"] | |
if eagle_args.eagle_config and "draft_vocab_size" in custom_config | |
else model.config.vocab_size, | |
# pass in the seq length for flex attention mask compilation | |
"training_seq_len": int(training_args.training_seq_len), | |
} | |
) | |
# Sanity: training_seq_len must not exceed base model's max positions | |
if training_args.training_seq_len > model.config.max_position_embeddings: | |
raise ValueError( | |
f"training_seq_len ({training_args.training_seq_len}) exceeds " | |
f"base model max_position_embeddings ({model.config.max_position_embeddings})." | |
) |
🤖 Prompt for AI Agents
In examples/speculative_decoding/main.py around lines 206 to 208, the code
passes training_args.training_seq_len into the Eagle config for flex attention
masks but lacks a hard guard against exceeding the base model's maximum context
length; add a validation right before building the config that reads the base
model's max sequence length (e.g., model.config.max_position_embeddings or
model.config.max_length depending on model type), verify
training_args.training_seq_len is not None and <= that max, and if it is greater
(or missing) raise a clear exception (ValueError) describing the mismatch and
advising the user to reduce training_seq_len or use a larger model so silent
shape mismatches are prevented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/torch/speculative/plugins/transformers.py (2)
731-737
: Mixing 4‑D attn masks with flex attention can break; use BlockMask for the first pass too.Step‑0 currently builds a 4‑D mask via _prepare_decoder_attention_mask while the subsequent TTT passes use BlockMask. Under flex_attention, unify to BlockMask for consistency and perf.
- eagle_input_ids, attention_mask_0, position_ids = self._get_eagle_module_inputs( + eagle_input_ids, _attention_mask_unused, position_ids = self._get_eagle_module_inputs( input_ids, eagle_input_hidden_states, attention_mask, position_ids, eagle_cache, ) @@ - _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward( + # Use the precompiled BlockMask for step 0 + blk_masks = _get_blk_masks(seq_length) + _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward( eagle_input_hidden_states, inputs_embeds, - attention_mask_0, + blk_masks[0], position_ids, position_embeddings, eagle_cache, )Also applies to: 743-750
503-507
: Avoid hardcoding token id 0 when shifting input_ids. Use pad/eos.Using 0 risks feeding BOS for models where pad_token_id ≠ 0. Prefer pad_token_id or eos as fallback.
- zeropadding = torch.zeros( - input_ids.shape[0], 1, dtype=input_ids.dtype, device=input_ids.device - ) - eagle_input_ids = torch.cat((input_ids[:, 1:], zeropadding), dim=1) + pad_id = getattr(self.model.config, "pad_token_id", None) + if pad_id is None: + pad_id = getattr(self.model.config, "eos_token_id", 0) + pad_col = torch.full((input_ids.shape[0], 1), pad_id, dtype=input_ids.dtype, device=input_ids.device) + eagle_input_ids = torch.cat((input_ids[:, 1:], pad_col), dim=1)
🧹 Nitpick comments (5)
examples/speculative_decoding/main.py (1)
232-234
: Pass max_length to the Eagle data module — necessary for fixed-length training.This ensures batches are padded/truncated to the mask-compile length. One ask: guard against accidental mismatch.
Consider asserting consistency at runtime to fail fast:
@@ - data_module = make_eagle_supervised_data_module( - tokenizer, data_args, use_offline_training, max_length=training_args.training_seq_len - ) + data_module = make_eagle_supervised_data_module( + tokenizer, data_args, use_offline_training, max_length=training_args.training_seq_len + ) + assert training_args.dataloader_drop_last, \ + "Set TrainingArguments.dataloader_drop_last=True to keep fixed batch shapes for flex attention."examples/speculative_decoding/eagle_utils.py (2)
319-321
: Small safety: clamp rather than assume max_length ≥ per-sample length.Defensive guard avoids negative-size paddings if a caller passes a too-small max_length.
@@ class DataCollatorWithPadding: - def __init__(self, max_length=None): - self.max_length = max_length + def __init__(self, max_length=None): + self.max_length = max_length @@ - max_length = ( + max_length = ( self.max_length if self.max_length is not None else max(item["input_ids"].shape[0] for item in features) ) + # Clamp to the longest in batch to avoid negative paddings if misconfigured + if self.max_length is not None: + max_length = max(max_length, max(item["input_ids"].shape[0] for item in features))Also applies to: 334-338
364-366
: Mirror clamp for hidden-state padding in offline mode.Avoids negative paddings if hidden_state length > max_length due to stale .pt files.
@@ class DataCollatorForOffline(DataCollatorWithPadding): - max_hs_length = ( + max_hs_length = ( self.max_length if self.max_length is not None else max(item["base_model_hidden_states"].shape[0] for item in features) ) + if self.max_length is not None: + max_hs_length = max( + max_hs_length, max(item["base_model_hidden_states"].shape[0] for item in features) + )Also applies to: 373-377
modelopt/torch/speculative/plugins/transformers.py (2)
538-572
: Document and validate mask semantics; add basic bounds checks.The mask lambdas are dense; at minimum assert shapes and add comments clarifying which KV block each equality term targets. Also sanity check KV_LEN multiples.
Add inline comments like:
- def msk(b, h, q_idx, kv_idx): - # attention mask of shape [seq_len, 3* seq_len] for TTT step 1 + def msk(b, h, q_idx, kv_idx): + # Step 1: allow causal prefix (<= q-2), and cross links to positions + # (q + L - 1) and (q + 2L) in the concatenated KV stream ([L | L | L])And before returning:
+ assert seq_length > 0, "seq_length must be positive for BlockMask"
712-714
: Safe conversion of eagle_cache only when not None.DynamicCache.from_legacy_cache(None) can allocate an empty cache and mask logic may assume “no cache”. Keep None semantics until the first _eagle_forward.
- if not isinstance(eagle_cache, Cache): - eagle_cache = DynamicCache.from_legacy_cache(eagle_cache) + if eagle_cache is not None and not isinstance(eagle_cache, Cache): + eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/speculative_decoding/eagle_utils.py
(5 hunks)examples/speculative_decoding/main.py
(2 hunks)modelopt/torch/speculative/plugins/transformers.py
(11 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/speculative_decoding/main.py (1)
examples/speculative_decoding/eagle_utils.py (1)
make_eagle_supervised_data_module
(238-315)
modelopt/torch/speculative/plugins/transformers.py (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
_eagle_forward
(1156-1179)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
🔇 Additional comments (4)
examples/speculative_decoding/eagle_utils.py (2)
239-243
: API: Adding max_length to make_eagle_supervised_data_module is the right call.Keeps padding logic and mask compilation in sync.
301-301
: Correct: Collators now honor a fixed max_length.This avoids shape drift across batches and matches flex-attn BlockMask expectations.
Also applies to: 309-309
modelopt/torch/speculative/plugins/transformers.py (1)
821-822
: train_acc returned as a list; confirm logging/serialization.HF Trainer won’t auto-log lists. Either aggregate (mean) or expose step‑wise metrics under distinct keys.
Consider returning a scalar avg to avoid JSON serialization surprises:
- train_acc=train_accs, + train_acc=sum(train_accs) / max(len(train_accs), 1),If you want per‑step, emit in logs within training loop instead.
examples/speculative_decoding/main.py (1)
206-208
: Approve — propagate training_seq_len into eagle_architecture_config (verify remaining tokenizer usage)
- Mask compilation uses eagle_config.training_seq_len (modelopt/torch/speculative/plugins/transformers.py:457).
- examples/speculative_decoding/eagle_utils.py still uses tokenizer.model_max_length for offline-data cropping (examples/speculative_decoding/eagle_utils.py:215); confirm if that should use eagle_config.training_seq_len instead.
import torch | ||
from torch import nn | ||
from torch.nn import CrossEntropyLoss | ||
from torch.nn.attention.flex_attention import BlockMask, create_block_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Version‑gating for flex_attention import missing.
torch.nn.attention.flex_attention is only available in recent PyTorch; importing unguarded will crash older envs. Add a soft fallback to SDPA.
-from torch.nn.attention.flex_attention import BlockMask, create_block_mask
+try:
+ from torch.nn.attention.flex_attention import BlockMask, create_block_mask # PyTorch >= supports flex attention
+ _FLEX_ATTENTION_AVAILABLE = True
+except Exception:
+ BlockMask = None
+ create_block_mask = None
+ _FLEX_ATTENTION_AVAILABLE = False
And down below (see next comments) gate usage accordingly.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
from torch.nn.attention.flex_attention import BlockMask, create_block_mask | |
try: | |
from torch.nn.attention.flex_attention import BlockMask, create_block_mask # PyTorch >= supports flex attention | |
_FLEX_ATTENTION_AVAILABLE = True | |
except Exception: | |
BlockMask = None | |
create_block_mask = None | |
_FLEX_ATTENTION_AVAILABLE = False |
🤖 Prompt for AI Agents
In modelopt/torch/speculative/plugins/transformers.py around line 39, the
unguarded import from torch.nn.attention.flex_attention will crash on older
PyTorch; wrap that import in a try/except ImportError (or AttributeError) and
set a module-level flag like HAS_FLEX_ATTENTION = True/False, and on failure
import or reference the standard SDPA utilities (e.g.,
torch.nn.functional.scaled_dot_product_attention) and provide minimal fallback
implementations or aliases for BlockMask/create_block_mask (or set them to None
and document behavior). Then update all downstream usage to check
HAS_FLEX_ATTENTION before calling flex_attention APIs and branch to the SDPA
fallback implementation when False.
8740824
to
3531467
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/speculative_decoding/eagle_utils.py
(5 hunks)examples/speculative_decoding/main.py
(2 hunks)modelopt/torch/speculative/plugins/transformers.py
(11 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- examples/speculative_decoding/main.py
- examples/speculative_decoding/eagle_utils.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/transformers.py (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
_eagle_forward
(1156-1179)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (12)
modelopt/torch/speculative/plugins/transformers.py (12)
319-319
: Correct use ofpast_key_value
parameter.This matches HF LlamaDecoderLayer’s signature.
620-637
: Eagle cache threading through_eagle_forward
looks correct.
742-750
: Confirm flex path accepts 4D attention_mask for the first pass.With
_attn_implementation="flex_attention"
, ensure HF LlamaAttention supports a 4D float mask here; otherwise, convert the first pass to a BlockMask too.If needed, we can add a simple step-(-1) BlockMask equivalent of the standard causal mask for the first pass.
755-762
: Loss/acc aggregation change LGTM.
821-821
: Returningtrain_acc
list is fine; callers should expect a per-step vector.Ensure downstream logging expects a list; otherwise, summarize (e.g., mean) before returning.
846-846
: Loss return shape/signature change LGTM.
39-39
: Version‑gate flex_attention import and provide a fallback flag.Unguarded import will crash on older PyTorch lacking flex_attention. Add a try/except and a feature flag; downstream usage should branch accordingly.
-from torch.nn.attention.flex_attention import BlockMask, create_block_mask +try: + from torch.nn.attention.flex_attention import BlockMask, create_block_mask # PyTorch >= 2.5 + _FLEX_ATTENTION_AVAILABLE = True +except Exception: + BlockMask = None + create_block_mask = None + _FLEX_ATTENTION_AVAILABLE = False
186-189
: Don’t force flex_attention when unavailable; fall back to SDPA.This hard-sets flex attention and will crash if the import failed. Gate on the feature flag.
- # Use flex attention for efficient TTT - config._attn_implementation = "flex_attention" + # Prefer flex attention for efficient TTT when available + config._attn_implementation = "flex_attention" if _FLEX_ATTENTION_AVAILABLE else "sdpa"
454-460
: Make TTT steps configurable and mask compilation resilient to seq_len mismatches.Hardcoding steps and precompiling a single seq_len can break when batches end up shorter (e.g., last batch) or configs change. Cache masks per observed seq_len and read steps from config with a default.
- self.num_ttt_steps = 3 # NOTE: (hg) hardcoded for now. Might add to config later. - # compile and cach flex attention masks - self.cached_attn_blk_masks = [ - self._compile_ttt_block_mask(self.eagle_config.training_seq_len, i) - for i in range(self.num_ttt_steps) - ] + self.num_ttt_steps = getattr(self.eagle_config, "num_ttt_steps", 3) + # Precompile for configured training_seq_len; keep a runtime cache for other lengths. + self._blk_mask_cache: dict[int, list] = {} + def _compile_set(L: int): + self._blk_mask_cache[L] = [ + self._compile_ttt_block_mask(L, i) for i in range(self.num_ttt_steps) + ] + if _FLEX_ATTENTION_AVAILABLE: + _compile_set(self.eagle_config.training_seq_len)Add this helper in the class (nearby is fine):
def _get_blk_masks(self, L: int): if not _FLEX_ATTENTION_AVAILABLE: return None if L not in self._blk_mask_cache: # Compile on-demand for unexpected lengths (e.g., short last batch) self._blk_mask_cache[L] = [ self._compile_ttt_block_mask(L, i) for i in range(self.num_ttt_steps) ] return self._blk_mask_cache[L]
538-571
: Explain the TTT BlockMask semantics and guard when flex_attention is unavailable.This is non-trivial logic; add a concise explanation and an availability check to fail fast with a clear error when flex isn’t present.
- def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask: - """Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention.""" + def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask: + """Compile symbolic BlockMask for TTT step. + + Semantics: + - Q_LEN = seq_length (current step tokens). + - KV_LEN = past_len + seq_length, where past_len accumulates per TTT step. + - We allow sparse connections that emulate the concat+shift behavior without materializing buffers: + * A limited causal window into the previous KV blocks. + * A small number of aligned "skip" positions into recent blocks to propagate predictions. + """ + if not _FLEX_ATTENTION_AVAILABLE: + raise RuntimeError( + "flex_attention is required for TTT BlockMask; install a compatible PyTorch or set _attn_implementation='sdpa'." + ) if ttt_step == 0: def msk(b, h, q_idx, kv_idx): # symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0 return (kv_idx <= (q_idx - 1)) | (kv_idx == q_idx + seq_length) return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2) elif ttt_step == 1: def msk(b, h, q_idx, kv_idx): # attention mask of shape [seq_len, 3* seq_len] for TTT step 1 return ( (kv_idx <= (q_idx - 2)) | ((kv_idx == q_idx + seq_length - 1) & (kv_idx >= seq_length)) | ((kv_idx == q_idx + 2 * seq_length) & (kv_idx >= seq_length * 2)) ) return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3) elif ttt_step == 2: def msk(b, h, q_idx, kv_idx): # attention mask of shape [seq_len, 4* seq_len] for TTT step 2 return ( (kv_idx <= (q_idx - 3)) | ((kv_idx == q_idx + seq_length - 2) & (kv_idx >= seq_length)) | ((kv_idx == q_idx + 2 * seq_length - 1) & (kv_idx >= seq_length * 2)) | ((kv_idx == q_idx + 3 * seq_length) & (kv_idx >= seq_length * 3)) ) return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4) else: raise ValueError(f"EAGLE TTT step {ttt_step} is not supported")Also, please add a short inline comment mapping kv_idx ranges to blocks (past/current/next) to address the prior reviewer ask.
764-803
: Guard masks by actual seq_length at runtime.Use masks matching the observed sequence length; fallback-compile on-the-fly when it differs from the configured training_seq_len.
- for ttt_step in range(self.num_ttt_steps): + for ttt_step in range(self.num_ttt_steps): eagle_input_hidden_states = torch.cat( ( torch.zeros( (b, 1, h), dtype=eagle_input_hidden_states.dtype, device=eagle_input_hidden_states.device, ), eagle_prenorm_h[:, :-1, :], ), dim=1, ) - attention_mask = self.cached_attn_blk_masks[ttt_step] + # Fetch a BlockMask matching the current seq length (compile on-demand if needed) + seq_length = eagle_input_hidden_states.size(1) + blk_masks = self._get_blk_masks(seq_length) if hasattr(self, "_get_blk_masks") else ( + self.cached_attn_blk_masks + if seq_length == self.eagle_config.training_seq_len + else [self._compile_ttt_block_mask(seq_length, i) for i in range(self.num_ttt_steps)] + ) + attention_mask = blk_masks[ttt_step] _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward( eagle_input_hidden_states, inputs_embeds, attention_mask, position_ids, position_embeddings, eagle_cache, )
454-460
: Confirmed: training_seq_len is propagated and dataloader drop_last enforced for EAGLE training
- training_seq_len (TrainingArguments) is passed to tokenizer.model_max_length and injected into eagle_config during mtsp.convert (examples/speculative_decoding/main.py).
- DataCollatorWithPadding / DataCollatorForOffline pad/truncate to the provided max_length so batches are fixed-length (examples/speculative_decoding/eagle_utils.py).
- Offline samples are truncated to tokenizer.model_max_length when loaded.
- Trainer is created with training_args (dataloader_drop_last default True and launch scripts pass True), so incomplete batches are dropped — precompiled cached_attn_blk_masks (modelopt/torch/speculative/plugins/transformers.py) will match the padded batch length.
if not isinstance(past_key_values, Cache): | ||
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
if not isinstance(eagle_cache, Cache): | ||
eagle_cache = DynamicCache.from_legacy_cache(eagle_cache) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧹 Nitpick
Avoid converting None
eagle_cache via DynamicCache.from_legacy_cache(None)
.
Guard the conversion; some HF versions don’t treat None
as legacy cache.
- if not isinstance(eagle_cache, Cache):
- eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)
+ if eagle_cache is not None and not isinstance(eagle_cache, Cache):
+ eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if not isinstance(past_key_values, Cache): | |
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | |
if not isinstance(eagle_cache, Cache): | |
eagle_cache = DynamicCache.from_legacy_cache(eagle_cache) | |
if not isinstance(past_key_values, Cache): | |
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | |
if eagle_cache is not None and not isinstance(eagle_cache, Cache): | |
eagle_cache = DynamicCache.from_legacy_cache(eagle_cache) |
🤖 Prompt for AI Agents
In modelopt/torch/speculative/plugins/transformers.py around lines 710 to 714,
avoid calling DynamicCache.from_legacy_cache(None) for eagle_cache; update the
guard so you only convert when eagle_cache is not None and not already a Cache
(e.g. if eagle_cache is not None and not isinstance(eagle_cache, Cache):
eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)), leaving None
untouched; apply the same None check for past_key_values if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
examples/speculative_decoding/main.py
(2 hunks)modelopt/torch/speculative/config.py
(1 hunks)modelopt/torch/speculative/eagle/conversion.py
(1 hunks)modelopt/torch/speculative/eagle/eagle_model.py
(2 hunks)modelopt/torch/speculative/plugins/megatron_eagle.py
(2 hunks)modelopt/torch/speculative/plugins/transformers.py
(13 hunks)tests/unit/torch/speculative/plugins/test_hf_speculative.py
(0 hunks)
💤 Files with no reviewable changes (1)
- tests/unit/torch/speculative/plugins/test_hf_speculative.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-05T19:10:36.393Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.393Z
Learning: In the TensorRT-Model-Optimizer codebase, both EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG in config.py already use deepcopied configurations (either directly via deepcopy() or through variables that were already deepcopied), making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
modelopt/torch/speculative/eagle/conversion.py
modelopt/torch/speculative/plugins/transformers.py
🧬 Code graph analysis (3)
modelopt/torch/speculative/config.py (1)
modelopt/torch/opt/config.py (1)
ModeloptField
(50-53)
modelopt/torch/speculative/plugins/transformers.py (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
_eagle_forward
(1158-1181)
examples/speculative_decoding/main.py (1)
examples/speculative_decoding/eagle_utils.py (1)
make_eagle_supervised_data_module
(238-315)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (11)
modelopt/torch/speculative/eagle/conversion.py (1)
50-52
: LGTM: threads training_seq_len into Eagle modify path.
No issues spotted; signature alignment with downstream modify implementations looks correct.modelopt/torch/speculative/eagle/eagle_model.py (1)
38-39
: LGTM: persists eagle_training_seq_len on the model.
Storage and signature change are consistent with conversion and plugins.Also applies to: 49-49
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
724-748
: LGTM: correctly forwards eagle_training_seq_len through modify.
No functional concerns here.examples/speculative_decoding/main.py (2)
234-236
: Ensure collators enforce fixed length (pad/truncate) for flex/TTT.
You pass max_length; verify both online and offline collators pad to exactly training_seq_len.If either collator only “caps” length without right‑padding to exact size, flex masks will misalign. Please confirm DataCollatorWithPadding and DataCollatorForOffline always output sequences of length == max_length.
188-193
: Hard‑guard training_seq_len against base model max positions.
Without a check, flex/TTT masks can mismatch and crash. Validate before updating config.Apply:
- # overwrite config with custom config + # overwrite config with custom config + # Sanity: training_seq_len must not exceed base model's max positions + max_pos = getattr(model.config, "max_position_embeddings", None) + if max_pos is not None and training_args.training_seq_len > max_pos: + raise ValueError( + f"training_seq_len ({training_args.training_seq_len}) exceeds " + f"base model max_position_embeddings ({max_pos})." + ) config.update( { "eagle_offline": use_offline_training, "eagle_training_seq_len": training_args.training_seq_len, } )modelopt/torch/speculative/plugins/transformers.py (6)
540-574
: Generalize masks to arbitrary ttt_step or document limit.
Either parametrize to N steps or state the 3‑step assumption in config/docs.If you want, I can refactor _compile_ttt_block_mask to generate the pattern programmatically for any N.
456-462
: Precompiled masks need seq‑len awareness; add a runtime cache.
Current code precompiles for a single length; any drift causes mismatch. Cache by observed seq_len.Also fix the comment typo (“cach” -> “cache”).
Apply:- self.num_ttt_steps = 3 # NOTE: (hg) hardcoded for now. Might add to config later. - # compile and cach flex attention masks - self.cached_attn_blk_masks = [ - self._compile_ttt_block_mask(eagle_training_seq_len, i) - for i in range(self.num_ttt_steps) - ] + self.num_ttt_steps = getattr(self.eagle_config, "num_ttt_steps", 3) # default 3 + # Precompile for configured training_seq_len; keep a runtime cache for other lengths. + self._blk_mask_cache = { + eagle_training_seq_len: [ + self._compile_ttt_block_mask(eagle_training_seq_len, i) + for i in range(self.num_ttt_steps) + ] + }Add this helper (outside the shown hunk):
def _get_blk_masks(self, seq_len): if seq_len not in self._blk_mask_cache: self._blk_mask_cache[seq_len] = [ self._compile_ttt_block_mask(seq_len, i) for i in range(self.num_ttt_steps) ] return self._blk_mask_cache[seq_len]
714-716
: Avoid converting None with DynamicCache.from_legacy_cache.
Guard None to prevent HF cache utils from erroring.Apply:
- if not isinstance(eagle_cache, Cache): - eagle_cache = DynamicCache.from_legacy_cache(eagle_cache) + if eagle_cache is not None and not isinstance(eagle_cache, Cache): + eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)
766-786
: Fetch TTT masks for the actual seq_len during training‑time‑testing.
Use the runtime cache to avoid mismatches when seq_len varies.Apply:
- for ttt_step in range(self.num_ttt_steps): + for ttt_step in range(self.num_ttt_steps): eagle_input_hidden_states = torch.cat( ( torch.zeros( (b, 1, h), dtype=eagle_input_hidden_states.dtype, device=eagle_input_hidden_states.device, ), eagle_prenorm_h[:, :-1, :], ), dim=1, ) - attention_mask = self.cached_attn_blk_masks[ttt_step] + attention_mask = self._get_blk_masks(seq_length)[ttt_step] _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward( eagle_input_hidden_states, inputs_embeds, attention_mask, position_ids, position_embeddings, eagle_cache, )Also applies to: 787-805
39-39
: Guard flex_attention import; fall back gracefully when unavailable.
Unconditional import will crash on older PyTorch. Add a try/except and a feature flag.Apply:
-from torch.nn.attention.flex_attention import BlockMask, create_block_mask +try: + from torch.nn.attention.flex_attention import BlockMask, create_block_mask # PyTorch w/ flex attention + _FLEX_ATTENTION_AVAILABLE = True +except Exception: + BlockMask = None + create_block_mask = None + _FLEX_ATTENTION_AVAILABLE = False
187-189
: Don’t force flex_attention if it’s not available.
Prefer flex when present; otherwise use SDPA.Apply:
- # Use flex attention for efficient TTT - config._attn_implementation = "flex_attention" + # Prefer flex attention for efficient TTT; fall back to SDPA if unavailable + config._attn_implementation = "flex_attention" if _FLEX_ATTENTION_AVAILABLE else "sdpa"
modelopt/torch/speculative/config.py
Outdated
eagle_training_seq_len: int = ModeloptField( | ||
default=1024, description=("The training sequence length.") | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧹 Nitpick
Validate eagle_training_seq_len is positive (and small integers only).
Add a lower bound to avoid accidental 0/negative values.
Apply:
- eagle_training_seq_len: int = ModeloptField(
- default=1024, description=("The training sequence length.")
- )
+ eagle_training_seq_len: int = ModeloptField(
+ default=1024,
+ description=("The training sequence length."),
+ ge=1,
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
eagle_training_seq_len: int = ModeloptField( | |
default=1024, description=("The training sequence length.") | |
) | |
eagle_training_seq_len: int = ModeloptField( | |
default=1024, | |
description=("The training sequence length."), | |
ge=1, | |
) |
🤖 Prompt for AI Agents
In modelopt/torch/speculative/config.py around lines 98 to 101, the
eagle_training_seq_len field lacks validation allowing 0/negative or
unreasonably large values; add constraints so it's an integer >= 1 and capped to
a reasonable upper bound (e.g. <= 65536 or 16384) to prevent accidental misuse.
Update the ModeloptField/Field call to include numeric constraints (e.g. gt/ ge
and le/lt or equivalent parameters) or add an explicit validator that checks
type int, verifies value >= 1 and <= chosen_max, and raises a clear ValueError
if violated.
138c806
to
c9bf27a
Compare
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
47cddea
to
94cbb2a
Compare
Signed-off-by: h-guo18 <[email protected]>
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #350 +/- ##
=======================================
Coverage 73.46% 73.46%
=======================================
Files 172 172
Lines 17640 17640
=======================================
+ Hits 12959 12960 +1
+ Misses 4681 4680 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/speculative/plugins/transformers.py (1)
899-916
: Mirror the flex BlockMask change in pseudo generation.
pseudo_speculative_generate
also passes a dense 4‑D mask; switch to BlockMask when flex is available to avoid runtime errors.Add after computing
eagle_attention_mask
:if _FLEX_ATTENTION_AVAILABLE: seq_length = eagle_input_hidden_states.size(1) eagle_attention_mask = self._get_ttt_attention_mask(seq_length, -1)
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/transformers.py (1)
454-454
: Optional: make num_ttt_steps configurable (default 3).Read from config to keep flexibility and avoid hardcoding.
Already covered in the cache fix above with
getattr(self.eagle_config, "num_ttt_steps", 3)
.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/speculative_decoding/eagle_utils.py
(5 hunks)examples/speculative_decoding/main.py
(1 hunks)modelopt/torch/speculative/plugins/transformers.py
(11 hunks)tests/unit/torch/speculative/plugins/test_hf_speculative.py
(0 hunks)
💤 Files with no reviewable changes (1)
- tests/unit/torch/speculative/plugins/test_hf_speculative.py
🚧 Files skipped from review as they are similar to previous changes (2)
- examples/speculative_decoding/eagle_utils.py
- examples/speculative_decoding/main.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/transformers.py (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
_eagle_forward
(1156-1179)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
🔇 Additional comments (7)
modelopt/torch/speculative/plugins/transformers.py (7)
314-323
: HF compatibility: confirmpast_key_value
arg name across versions.Recent HF uses
past_key_value
; older ones differ. Verify the target HF version and gate if needed.If you need to support older HF too, consider a kwarg alias:
- layer_outputs = decoder_layer( + layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_value=past_key_values, # alias for newer HF + # past_key_values=past_key_values, # uncomment if targeting older HF that requires this name output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, )
820-826
: LGTM on surfacingtrain_acc
.Returning per‑step train accuracy is useful; leaving it empty during eval is fine.
Confirm consumers tolerate
train_acc=[]
at eval.
850-850
: Loss/accuracy computation looks correct.Soft target CE and masked aggregation are implemented as expected.
39-39
: Version‑gate flex_attention import and provide a clean fallback.Unguarded import will crash on older PyTorch. Gate the import and expose a module‑level flag.
-from torch.nn.attention.flex_attention import BlockMask, create_block_mask +try: + from torch.nn.attention.flex_attention import BlockMask, create_block_mask + _FLEX_ATTENTION_AVAILABLE = True +except Exception: + BlockMask = None + create_block_mask = None + _FLEX_ATTENTION_AVAILABLE = False
187-188
: Don’t force flex_attention; fall back to SDPA when unavailable.This will explode in environments without flex attention.
- # Use flex attention for efficient TTT - config._attn_implementation = "flex_attention" + # Prefer flex attention for efficient TTT when available + config._attn_implementation = ( + "flex_attention" if _FLEX_ATTENTION_AVAILABLE else "sdpa" + )
454-464
: Block mask cache keyed only by ttt_step → wrong when seq_length varies.If seq_length changes (even accidentally), a mismatched mask will be reused. Cache by seq_length as well.
- self.num_ttt_steps = 3 # NOTE: (hg) hardcoded for now. Might add to config later. - self._cached_attn_blk_masks = [] + self.num_ttt_steps = getattr(self.eagle_config, "num_ttt_steps", 3) + # Cache: {seq_len: {ttt_step: BlockMask}} + self._blk_mask_cache = {} - def _get_ttt_attention_mask(self, seq_length, ttt_step): - # compile and cached flex attention masks in first call - if ttt_step >= len(self._cached_attn_blk_masks): - self._cached_attn_blk_masks.append(self._compile_ttt_block_mask(seq_length, ttt_step)) - - # return cached flex attention mask - return self._cached_attn_blk_masks[ttt_step] + def _get_ttt_attention_mask(self, seq_length, ttt_step): + L = int(seq_length) + cache_for_len = self._blk_mask_cache.setdefault(L, {}) + if ttt_step not in cache_for_len: + cache_for_len[ttt_step] = self._compile_ttt_block_mask(L, ttt_step) + return cache_for_len[ttt_step]
716-718
: Guard against convertingNone
eagle_cache to DynamicCache.
DynamicCache.from_legacy_cache(None)
can error; add a None guard.- if not isinstance(eagle_cache, Cache): - eagle_cache = DynamicCache.from_legacy_cache(eagle_cache) + if eagle_cache is not None and not isinstance(eagle_cache, Cache): + eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)
def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask: | ||
"""Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention.""" | ||
if ttt_step == 0: | ||
|
||
# Expand attn_mask | ||
zero_mask = torch.ones_like(attention_mask_0).bool() | ||
mask_2_1 = attention_mask_0.clone().detach() | ||
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] | ||
mask_2_2 = torch.ones_like(attention_mask_0).bool() | ||
for i in range(1, seq_length - 1): | ||
mask_2_2[:, :, i, i] = False | ||
cat_attention_mask = torch.cat( | ||
( | ||
torch.cat((attention_mask_0, zero_mask), dim=-1), | ||
torch.cat((mask_2_1, mask_2_2), dim=-1), | ||
), | ||
dim=-2, | ||
) | ||
cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin) | ||
|
||
# Concat position_ids | ||
cat_position_ids = torch.cat((position_ids_0, position_ids_0), dim=-1) | ||
|
||
elif eagle_generated_hs.shape[1] == seq_length * 2: | ||
cat_input_ids = torch.cat((input_ids_0, input_ids_0, input_ids_0), dim=-1) | ||
cat_eagle_input_hidden_states = torch.cat( | ||
( | ||
eagle_input_hidden_states_0, | ||
torch.zeros( | ||
(b, 1, h), | ||
dtype=eagle_input_hidden_states_0.dtype, | ||
device=eagle_input_hidden_states_0.device, | ||
), | ||
eagle_generated_hs[:, :-1, :], | ||
), | ||
dim=1, | ||
) | ||
zero_mask = torch.ones_like(attention_mask_0).bool() | ||
mask_2_1 = attention_mask_0.clone().detach() | ||
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] | ||
mask_2_2 = torch.ones_like(attention_mask_0).bool() | ||
for i in range(1, seq_length - 1): | ||
mask_2_2[:, :, i, i] = False | ||
|
||
mask_3_1 = mask_2_1.clone().detach() | ||
mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:] | ||
mask_3_2 = mask_2_2.clone().detach() | ||
mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:] | ||
mask_3_2[:, :, 1, 0] = True | ||
mask_3_3 = mask_2_2.clone().detach() | ||
mask_3_3[:, :, 1, 1] = True | ||
cat_attention_mask = torch.cat( | ||
( | ||
torch.cat((attention_mask_0, zero_mask, zero_mask), dim=-1), | ||
torch.cat((mask_2_1, mask_2_2, zero_mask), dim=-1), | ||
torch.cat((mask_3_1, mask_3_2, mask_3_3), dim=-1), | ||
), | ||
dim=-2, | ||
) | ||
def msk(b, h, q_idx, kv_idx): | ||
# symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0 | ||
return (kv_idx <= (q_idx - 1)) | (kv_idx == q_idx + seq_length) | ||
|
||
cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin) | ||
cat_position_ids = torch.cat((position_ids_0, position_ids_0, position_ids_0), dim=-1) | ||
|
||
elif eagle_generated_hs.shape[1] == seq_length * 3: | ||
cat_input_ids = torch.cat((input_ids_0, input_ids_0, input_ids_0, input_ids_0), dim=-1) | ||
cat_eagle_input_hidden_states = torch.cat( | ||
( | ||
eagle_input_hidden_states_0, | ||
torch.zeros( | ||
(b, 1, h), | ||
dtype=eagle_input_hidden_states_0.dtype, | ||
device=eagle_input_hidden_states_0.device, | ||
), | ||
eagle_generated_hs[:, :-1, :], | ||
), | ||
dim=1, | ||
) | ||
zero_mask = torch.ones_like(attention_mask_0).bool() | ||
mask_2_1 = attention_mask_0.clone().detach() | ||
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] | ||
mask_2_2 = torch.ones_like(attention_mask_0).bool() | ||
for i in range(1, seq_length - 1): | ||
mask_2_2[:, :, i, i] = False | ||
|
||
mask_3_1 = mask_2_1.clone().detach() | ||
mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:] | ||
mask_3_2 = mask_2_2.clone().detach() | ||
mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:] | ||
mask_3_2[:, :, 1, 0] = True | ||
mask_3_3 = mask_2_2.clone().detach() | ||
mask_3_3[:, :, 1, 1] = True | ||
|
||
mask_4_1 = mask_3_1.clone().detach() | ||
mask_4_1[:, :, :, :-1] = mask_4_1[:, :, :, 1:] | ||
mask_4_2 = mask_3_2.clone().detach() | ||
mask_4_2[:, :, :, :-1] = mask_4_2[:, :, :, 1:] | ||
mask_4_2[:, :, 2, 0] = True | ||
mask_4_3 = mask_3_3.clone().detach() | ||
mask_4_3[:, :, :, :-1] = mask_4_3[:, :, :, 1:] | ||
mask_4_3[:, :, 2, 1] = True | ||
mask_4_4 = mask_3_3.clone().detach() | ||
mask_4_4[:, :, 2, 2] = True | ||
|
||
cat_attention_mask = torch.cat( | ||
( | ||
torch.cat((attention_mask_0, zero_mask, zero_mask, zero_mask), dim=-1), | ||
torch.cat((mask_2_1, mask_2_2, zero_mask, zero_mask), dim=-1), | ||
torch.cat((mask_3_1, mask_3_2, mask_3_3, zero_mask), dim=-1), | ||
torch.cat((mask_4_1, mask_4_2, mask_4_3, mask_4_4), dim=-1), | ||
), | ||
dim=-2, | ||
) | ||
cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin) | ||
cat_position_ids = torch.cat( | ||
(position_ids_0, position_ids_0, position_ids_0, position_ids_0), dim=-1 | ||
) | ||
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2) | ||
elif ttt_step == 1: | ||
|
||
else: | ||
raise ValueError( | ||
f"EAGLE generated hidden states shape {eagle_generated_hs.shape} is not supported" | ||
) | ||
def msk(b, h, q_idx, kv_idx): | ||
# attention mask of shape [seq_len, 3* seq_len] for TTT step 1 | ||
return ( | ||
(kv_idx <= (q_idx - 2)) | ||
| ((kv_idx == q_idx + seq_length - 1) & (kv_idx >= seq_length)) | ||
| ((kv_idx == q_idx + 2 * seq_length) & (kv_idx >= seq_length * 2)) | ||
) | ||
|
||
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3) | ||
elif ttt_step == 2: | ||
|
||
return cat_eagle_input_hidden_states, cat_input_ids, cat_attention_mask, cat_position_ids | ||
def msk(b, h, q_idx, kv_idx): | ||
# attention mask of shape [seq_len, 4* seq_len] for TTT step 2 | ||
return ( | ||
(kv_idx <= (q_idx - 3)) | ||
| ((kv_idx == q_idx + seq_length - 2) & (kv_idx >= seq_length)) | ||
| ((kv_idx == q_idx + 2 * seq_length - 1) & (kv_idx >= seq_length * 2)) | ||
| ((kv_idx == q_idx + 3 * seq_length) & (kv_idx >= seq_length * 3)) | ||
) | ||
|
||
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4) | ||
else: | ||
raise ValueError(f"EAGLE TTT step {ttt_step} is not supported") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Provide a causal BlockMask for the base pass and document semantics.
The first EAGLE pass uses a dense 4D mask, which won’t be consumed by flex attention. Add a base causal mask variant and reuse it when flex is enabled.
def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask:
"""Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention."""
+ # ttt_step == -1 is the base pass: standard causal self-attention (allow attending to <= current position).
+ if ttt_step == -1:
+ def msk(b, h, q_idx, kv_idx):
+ # [seq_len, seq_len]: standard causal mask (k <= q)
+ return kv_idx <= q_idx
+ return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length)
if ttt_step == 0:
def msk(b, h, q_idx, kv_idx):
# symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0
return (kv_idx <= (q_idx - 1)) | (kv_idx == q_idx + seq_length)
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2)
elif ttt_step == 1:
def msk(b, h, q_idx, kv_idx):
# attention mask of shape [seq_len, 3* seq_len] for TTT step 1
return (
(kv_idx <= (q_idx - 2))
| ((kv_idx == q_idx + seq_length - 1) & (kv_idx >= seq_length))
| ((kv_idx == q_idx + 2 * seq_length) & (kv_idx >= seq_length * 2))
)
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3)
elif ttt_step == 2:
def msk(b, h, q_idx, kv_idx):
# attention mask of shape [seq_len, 4* seq_len] for TTT step 2
return (
(kv_idx <= (q_idx - 3))
| ((kv_idx == q_idx + seq_length - 2) & (kv_idx >= seq_length))
| ((kv_idx == q_idx + 2 * seq_length - 1) & (kv_idx >= seq_length * 2))
| ((kv_idx == q_idx + 3 * seq_length) & (kv_idx >= seq_length * 3))
)
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4)
else:
raise ValueError(f"EAGLE TTT step {ttt_step} is not supported")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask: | |
"""Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention.""" | |
if ttt_step == 0: | |
# Expand attn_mask | |
zero_mask = torch.ones_like(attention_mask_0).bool() | |
mask_2_1 = attention_mask_0.clone().detach() | |
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] | |
mask_2_2 = torch.ones_like(attention_mask_0).bool() | |
for i in range(1, seq_length - 1): | |
mask_2_2[:, :, i, i] = False | |
cat_attention_mask = torch.cat( | |
( | |
torch.cat((attention_mask_0, zero_mask), dim=-1), | |
torch.cat((mask_2_1, mask_2_2), dim=-1), | |
), | |
dim=-2, | |
) | |
cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin) | |
# Concat position_ids | |
cat_position_ids = torch.cat((position_ids_0, position_ids_0), dim=-1) | |
elif eagle_generated_hs.shape[1] == seq_length * 2: | |
cat_input_ids = torch.cat((input_ids_0, input_ids_0, input_ids_0), dim=-1) | |
cat_eagle_input_hidden_states = torch.cat( | |
( | |
eagle_input_hidden_states_0, | |
torch.zeros( | |
(b, 1, h), | |
dtype=eagle_input_hidden_states_0.dtype, | |
device=eagle_input_hidden_states_0.device, | |
), | |
eagle_generated_hs[:, :-1, :], | |
), | |
dim=1, | |
) | |
zero_mask = torch.ones_like(attention_mask_0).bool() | |
mask_2_1 = attention_mask_0.clone().detach() | |
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] | |
mask_2_2 = torch.ones_like(attention_mask_0).bool() | |
for i in range(1, seq_length - 1): | |
mask_2_2[:, :, i, i] = False | |
mask_3_1 = mask_2_1.clone().detach() | |
mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:] | |
mask_3_2 = mask_2_2.clone().detach() | |
mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:] | |
mask_3_2[:, :, 1, 0] = True | |
mask_3_3 = mask_2_2.clone().detach() | |
mask_3_3[:, :, 1, 1] = True | |
cat_attention_mask = torch.cat( | |
( | |
torch.cat((attention_mask_0, zero_mask, zero_mask), dim=-1), | |
torch.cat((mask_2_1, mask_2_2, zero_mask), dim=-1), | |
torch.cat((mask_3_1, mask_3_2, mask_3_3), dim=-1), | |
), | |
dim=-2, | |
) | |
def msk(b, h, q_idx, kv_idx): | |
# symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0 | |
return (kv_idx <= (q_idx - 1)) | (kv_idx == q_idx + seq_length) | |
cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin) | |
cat_position_ids = torch.cat((position_ids_0, position_ids_0, position_ids_0), dim=-1) | |
elif eagle_generated_hs.shape[1] == seq_length * 3: | |
cat_input_ids = torch.cat((input_ids_0, input_ids_0, input_ids_0, input_ids_0), dim=-1) | |
cat_eagle_input_hidden_states = torch.cat( | |
( | |
eagle_input_hidden_states_0, | |
torch.zeros( | |
(b, 1, h), | |
dtype=eagle_input_hidden_states_0.dtype, | |
device=eagle_input_hidden_states_0.device, | |
), | |
eagle_generated_hs[:, :-1, :], | |
), | |
dim=1, | |
) | |
zero_mask = torch.ones_like(attention_mask_0).bool() | |
mask_2_1 = attention_mask_0.clone().detach() | |
mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] | |
mask_2_2 = torch.ones_like(attention_mask_0).bool() | |
for i in range(1, seq_length - 1): | |
mask_2_2[:, :, i, i] = False | |
mask_3_1 = mask_2_1.clone().detach() | |
mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:] | |
mask_3_2 = mask_2_2.clone().detach() | |
mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:] | |
mask_3_2[:, :, 1, 0] = True | |
mask_3_3 = mask_2_2.clone().detach() | |
mask_3_3[:, :, 1, 1] = True | |
mask_4_1 = mask_3_1.clone().detach() | |
mask_4_1[:, :, :, :-1] = mask_4_1[:, :, :, 1:] | |
mask_4_2 = mask_3_2.clone().detach() | |
mask_4_2[:, :, :, :-1] = mask_4_2[:, :, :, 1:] | |
mask_4_2[:, :, 2, 0] = True | |
mask_4_3 = mask_3_3.clone().detach() | |
mask_4_3[:, :, :, :-1] = mask_4_3[:, :, :, 1:] | |
mask_4_3[:, :, 2, 1] = True | |
mask_4_4 = mask_3_3.clone().detach() | |
mask_4_4[:, :, 2, 2] = True | |
cat_attention_mask = torch.cat( | |
( | |
torch.cat((attention_mask_0, zero_mask, zero_mask, zero_mask), dim=-1), | |
torch.cat((mask_2_1, mask_2_2, zero_mask, zero_mask), dim=-1), | |
torch.cat((mask_3_1, mask_3_2, mask_3_3, zero_mask), dim=-1), | |
torch.cat((mask_4_1, mask_4_2, mask_4_3, mask_4_4), dim=-1), | |
), | |
dim=-2, | |
) | |
cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin) | |
cat_position_ids = torch.cat( | |
(position_ids_0, position_ids_0, position_ids_0, position_ids_0), dim=-1 | |
) | |
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2) | |
elif ttt_step == 1: | |
else: | |
raise ValueError( | |
f"EAGLE generated hidden states shape {eagle_generated_hs.shape} is not supported" | |
) | |
def msk(b, h, q_idx, kv_idx): | |
# attention mask of shape [seq_len, 3* seq_len] for TTT step 1 | |
return ( | |
(kv_idx <= (q_idx - 2)) | |
| ((kv_idx == q_idx + seq_length - 1) & (kv_idx >= seq_length)) | |
| ((kv_idx == q_idx + 2 * seq_length) & (kv_idx >= seq_length * 2)) | |
) | |
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3) | |
elif ttt_step == 2: | |
return cat_eagle_input_hidden_states, cat_input_ids, cat_attention_mask, cat_position_ids | |
def msk(b, h, q_idx, kv_idx): | |
# attention mask of shape [seq_len, 4* seq_len] for TTT step 2 | |
return ( | |
(kv_idx <= (q_idx - 3)) | |
| ((kv_idx == q_idx + seq_length - 2) & (kv_idx >= seq_length)) | |
| ((kv_idx == q_idx + 2 * seq_length - 1) & (kv_idx >= seq_length * 2)) | |
| ((kv_idx == q_idx + 3 * seq_length) & (kv_idx >= seq_length * 3)) | |
) | |
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4) | |
else: | |
raise ValueError(f"EAGLE TTT step {ttt_step} is not supported") | |
def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask: | |
"""Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention.""" | |
# ttt_step == -1 is the base pass: standard causal self-attention (allow attending to <= current position). | |
if ttt_step == -1: | |
def msk(b, h, q_idx, kv_idx): | |
# [seq_len, seq_len]: standard causal mask (k <= q) | |
return kv_idx <= q_idx | |
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length) | |
if ttt_step == 0: | |
def msk(b, h, q_idx, kv_idx): | |
# symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0 | |
return (kv_idx <= (q_idx - 1)) | (kv_idx == q_idx + seq_length) | |
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2) | |
elif ttt_step == 1: | |
def msk(b, h, q_idx, kv_idx): | |
# attention mask of shape [seq_len, 3* seq_len] for TTT step 1 | |
return ( | |
(kv_idx <= (q_idx - 2)) | |
| ((kv_idx == q_idx + seq_length - 1) & (kv_idx >= seq_length)) | |
| ((kv_idx == q_idx + 2 * seq_length) & (kv_idx >= seq_length * 2)) | |
) | |
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3) | |
elif ttt_step == 2: | |
def msk(b, h, q_idx, kv_idx): | |
# attention mask of shape [seq_len, 4* seq_len] for TTT step 2 | |
return ( | |
(kv_idx <= (q_idx - 3)) | |
| ((kv_idx == q_idx + seq_length - 2) & (kv_idx >= seq_length)) | |
| ((kv_idx == q_idx + 2 * seq_length - 1) & (kv_idx >= seq_length * 2)) | |
| ((kv_idx == q_idx + 3 * seq_length) & (kv_idx >= seq_length * 3)) | |
) | |
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4) | |
else: | |
raise ValueError(f"EAGLE TTT step {ttt_step} is not supported") |
_, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward( | ||
eagle_input_hidden_states, | ||
inputs_embeds, | ||
attention_mask_0, | ||
position_ids, | ||
position_embeddings, | ||
eagle_cache, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With flex enabled, the first EAGLE pass must use a BlockMask, not a 4‑D dense mask.
Currently attention_mask_0
is 4‑D; HF flex attention expects BlockMask
. Switch to the causal BlockMask
for the base pass when flex is available.
- # Then, we run eagle forward
- _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
+ # Then, we run eagle forward
+ if _FLEX_ATTENTION_AVAILABLE:
+ attention_mask_0 = self._get_ttt_attention_mask(seq_length, -1)
+ _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
eagle_input_hidden_states,
inputs_embeds,
attention_mask_0,
position_ids,
position_embeddings,
eagle_cache,
)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
_, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward( | |
eagle_input_hidden_states, | |
inputs_embeds, | |
attention_mask_0, | |
position_ids, | |
position_embeddings, | |
eagle_cache, | |
) | |
if _FLEX_ATTENTION_AVAILABLE: | |
attention_mask_0 = self._get_ttt_attention_mask(seq_length, -1) | |
_, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward( | |
eagle_input_hidden_states, | |
inputs_embeds, | |
attention_mask_0, | |
position_ids, | |
position_embeddings, | |
eagle_cache, | |
) |
for ttt_step in range(self.num_ttt_steps): | ||
eagle_input_hidden_states = torch.cat( | ||
( | ||
torch.zeros(batch_size, 1, dtype=loss_mask.dtype, device=loss_mask.device), | ||
loss_mask[:, 2:], | ||
torch.zeros( | ||
(b, 1, h), | ||
dtype=eagle_input_hidden_states.dtype, | ||
device=eagle_input_hidden_states.device, | ||
), | ||
eagle_prenorm_h[:, :-1, :], | ||
), | ||
dim=1, | ||
), | ||
) | ||
eagle_loss += ( | ||
regression_loss_coefficient * regression_loss | ||
+ classification_loss_coefficient * classification_loss | ||
) | ||
|
||
# ====Third step of eagle forward==== | ||
eagle_input_hidden_states_2, eagle_input_ids_2, attention_mask_2, position_ids_2 = ( | ||
self._concat_eagle_inputs( | ||
eagle_input_ids, | ||
eagle_input_hidden_states, | ||
attention_mask_0, | ||
position_ids, | ||
eagle_prenorm_h, | ||
) | ||
) | ||
with torch.no_grad(): | ||
inputs_embeds = self.model.embed_tokens(eagle_input_ids_2) | ||
position_embeddings = self.eagle_rotary_emb(eagle_input_hidden_states_2, position_ids_2) | ||
eagle_postnorm_h, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward( | ||
eagle_input_hidden_states_2, | ||
inputs_embeds, | ||
attention_mask_2, | ||
position_ids_2, | ||
position_embeddings, | ||
) | ||
|
||
regression_loss, classification_loss, accuracy_2 = self._eagle_loss( | ||
base_model_hidden_states[:, 1:], | ||
base_model_logits[:, 1:], | ||
eagle_postnorm_h[:, -seq_length:-1, :], | ||
eagle_logits[ | ||
:, | ||
-seq_length:-1, | ||
], | ||
torch.cat( | ||
( | ||
torch.zeros(batch_size, 2, dtype=loss_mask.dtype, device=loss_mask.device), | ||
loss_mask[:, 3:], | ||
), | ||
dim=1, | ||
), | ||
) | ||
eagle_loss += ( | ||
regression_loss_coefficient * regression_loss | ||
+ classification_loss_coefficient * classification_loss | ||
) | ||
|
||
# ====Fourth step of eagle forward==== | ||
eagle_input_hidden_states_3, eagle_input_ids_3, attention_mask_3, position_ids_3 = ( | ||
self._concat_eagle_inputs( | ||
eagle_input_ids, | ||
attention_mask = self._get_ttt_attention_mask(seq_length, ttt_step) | ||
_, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward( | ||
eagle_input_hidden_states, | ||
attention_mask_0, | ||
inputs_embeds, | ||
attention_mask, | ||
position_ids, | ||
eagle_prenorm_h, | ||
position_embeddings, | ||
eagle_cache, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Recompute RoPE positions for TTT passes when cache grows.
Position embeddings should advance by past_len
when you carry over KV from previous passes; reuse of the initial position_embeddings
can skew RoPE distances.
- _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
+ past_len = eagle_cache.get_seq_length() if eagle_cache is not None else 0
+ pos_ids_ttt = position_ids + past_len
+ pos_emb_ttt = self.eagle_rotary_emb(eagle_input_hidden_states, pos_ids_ttt)
+ _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
eagle_input_hidden_states,
inputs_embeds,
- attention_mask,
- position_ids,
- position_embeddings,
+ attention_mask,
+ position_ids,
+ pos_emb_ttt,
eagle_cache,
)
Please sanity‑check numerics on a short batch to confirm equivalence before/after this change.
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "Searching repo for relevant symbols..."
rg -n --hidden --no-ignore -S "def _eagle_forward" || true
rg -n --hidden --no-ignore -S "eagle_rotary_emb" || true
rg -n --hidden --no-ignore -S "get_seq_length" || true
rg -n --hidden --no-ignore -S "position_embeddings" || true
rg -n --hidden --no-ignore -S "position_ids" || true
echo
echo "Showing snippet around the reported location (if file exists):"
file="modelopt/torch/speculative/plugins/transformers.py"
if [ -f "$file" ]; then
echo "File: $file"
sed -n '1,240p' "$file" | sed -n '720,820p'
else
echo "File not found: $file"
fi
echo
echo "Showing definitions/usages for _eagle_forward and eagle_rotary_emb (if found):"
rg -n --hidden --no-ignore -S "def _eagle_forward" || true
rg -n --hidden --no-ignore -S "def eagle_rotary_emb" || true
rg -n --hidden --no-ignore -S "class .*Cache" || true
rg -n --hidden --no-ignore -S "get_seq_length" || true
Length of output: 20763
🏁 Script executed:
#!/bin/bash
set -euo pipefail
file="modelopt/torch/speculative/plugins/transformers.py"
echo "====== Showing file header (1-520) ======"
sed -n '1,520p' "$file" || true
echo
echo "====== Showing lines 600-820 (contains _eagle_forward and TTT loop) ======"
sed -n '600,820p' "$file" || true
echo
echo "====== Showing lines 820-980 (follow-up) ======"
sed -n '820,980p' "$file" || true
echo
echo "====== position_embeddings occurrences in file ======"
rg -n "position_embeddings" "$file" || true
echo
echo "====== eagle_rotary_emb assignment and usages ======"
rg -n "eagle_rotary_emb" "$file" || true
echo
echo "====== Search for def get_seq_length across repo ======"
rg -n --hidden --no-ignore -S "def get_seq_length" || true
echo
echo "====== Search for classes/types that might implement get_seq_length ======"
rg -n --hidden --no-ignore -S "class .*KeyValue|class .*Cache|class .*PastKey" || true
echo
echo "====== Show nearby context where eagle_cache.get_seq_length is used ======"
rg -n --hidden --no-ignore -n "eagle_cache.get_seq_length|past_key_values.get_seq_length" || true
Length of output: 38601
🏁 Script executed:
#!/bin/bash
set -euo pipefail
rg -n "test_speculative_megatron_modules" || true
sed -n '1,360p' tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py || true
Length of output: 13544
Recompute RoPE positions for TTT passes when cache grows.
Position embeddings are computed once and reused in the TTT loop; when eagle_cache grows the rotary positions must be advanced by past_len or RoPE distances become incorrect.
Location: modelopt/torch/speculative/plugins/transformers.py — the TTT loop (for ttt_step in range(self.num_ttt_steps):).
- _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
+ past_len = eagle_cache.get_seq_length() if eagle_cache is not None else 0
+ pos_ids_ttt = position_ids + past_len
+ pos_emb_ttt = self.eagle_rotary_emb(eagle_input_hidden_states, pos_ids_ttt)
+ _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
eagle_input_hidden_states,
inputs_embeds,
- attention_mask,
- position_ids,
- position_embeddings,
+ attention_mask,
+ position_ids,
+ pos_emb_ttt,
eagle_cache,
)
Please sanity-check numerics on a short batch to confirm equivalence before/after this change.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
for ttt_step in range(self.num_ttt_steps): | |
eagle_input_hidden_states = torch.cat( | |
( | |
torch.zeros(batch_size, 1, dtype=loss_mask.dtype, device=loss_mask.device), | |
loss_mask[:, 2:], | |
torch.zeros( | |
(b, 1, h), | |
dtype=eagle_input_hidden_states.dtype, | |
device=eagle_input_hidden_states.device, | |
), | |
eagle_prenorm_h[:, :-1, :], | |
), | |
dim=1, | |
), | |
) | |
eagle_loss += ( | |
regression_loss_coefficient * regression_loss | |
+ classification_loss_coefficient * classification_loss | |
) | |
# ====Third step of eagle forward==== | |
eagle_input_hidden_states_2, eagle_input_ids_2, attention_mask_2, position_ids_2 = ( | |
self._concat_eagle_inputs( | |
eagle_input_ids, | |
eagle_input_hidden_states, | |
attention_mask_0, | |
position_ids, | |
eagle_prenorm_h, | |
) | |
) | |
with torch.no_grad(): | |
inputs_embeds = self.model.embed_tokens(eagle_input_ids_2) | |
position_embeddings = self.eagle_rotary_emb(eagle_input_hidden_states_2, position_ids_2) | |
eagle_postnorm_h, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward( | |
eagle_input_hidden_states_2, | |
inputs_embeds, | |
attention_mask_2, | |
position_ids_2, | |
position_embeddings, | |
) | |
regression_loss, classification_loss, accuracy_2 = self._eagle_loss( | |
base_model_hidden_states[:, 1:], | |
base_model_logits[:, 1:], | |
eagle_postnorm_h[:, -seq_length:-1, :], | |
eagle_logits[ | |
:, | |
-seq_length:-1, | |
], | |
torch.cat( | |
( | |
torch.zeros(batch_size, 2, dtype=loss_mask.dtype, device=loss_mask.device), | |
loss_mask[:, 3:], | |
), | |
dim=1, | |
), | |
) | |
eagle_loss += ( | |
regression_loss_coefficient * regression_loss | |
+ classification_loss_coefficient * classification_loss | |
) | |
# ====Fourth step of eagle forward==== | |
eagle_input_hidden_states_3, eagle_input_ids_3, attention_mask_3, position_ids_3 = ( | |
self._concat_eagle_inputs( | |
eagle_input_ids, | |
attention_mask = self._get_ttt_attention_mask(seq_length, ttt_step) | |
_, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward( | |
eagle_input_hidden_states, | |
attention_mask_0, | |
inputs_embeds, | |
attention_mask, | |
position_ids, | |
eagle_prenorm_h, | |
position_embeddings, | |
eagle_cache, | |
) | |
for ttt_step in range(self.num_ttt_steps): | |
eagle_input_hidden_states = torch.cat( | |
( | |
torch.zeros( | |
(b, 1, h), | |
dtype=eagle_input_hidden_states.dtype, | |
device=eagle_input_hidden_states.device, | |
), | |
eagle_prenorm_h[:, :-1, :], | |
), | |
dim=1, | |
) | |
attention_mask = self._get_ttt_attention_mask(seq_length, ttt_step) | |
past_len = eagle_cache.get_seq_length() if eagle_cache is not None else 0 | |
pos_ids_ttt = position_ids + past_len | |
pos_emb_ttt = self.eagle_rotary_emb(eagle_input_hidden_states, pos_ids_ttt) | |
_, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward( | |
eagle_input_hidden_states, | |
inputs_embeds, | |
attention_mask, | |
position_ids, | |
pos_emb_ttt, | |
eagle_cache, | |
) |
🤖 Prompt for AI Agents
In modelopt/torch/speculative/plugins/transformers.py around lines 768–788, the
TTT loop reuses position embeddings across iterations even though eagle_cache
grows; recompute/advance RoPE positions each pass by computing the current past
length from eagle_cache (e.g., infer seq length already cached from eagle_cache
structure) and adding that past_len to the base position_ids (or regenerating
position_ids offset by past_len) before recreating position_embeddings, then
pass the updated position_ids/position_embeddings into _eagle_forward; after
implementing this, run a small-batch numeric sanity check to confirm
equivalence.
/ok to test 71c657e |
Signed-off-by: h-guo18 <[email protected]> Signed-off-by: Ye Yu <[email protected]>
What does this PR do?
Type of change: new feature
Overview:
This PR improves eagle3 training efficiency with cross attention and flex-attention. Main changes includes:
torch.sdpa
for attention. This PR change totorch.flex_attention
to use its sparse computation and symbolic TTT masks. This gives us additional speedup but requires fixed training length.See the effect of two above optimizations in the plots below.
Usage
API unchanged, except that we need additionally pass in
training_seq_len
to the model before training starts. See change in `main.py.Testing
Efficiency Test
On
Llama3.2-1B
we are seeing 2x + speedup:And 1.5x memory saving:

Correctness Test
Test setting: daring-anteater, tinyllama.

We tried both online and offline and got equivalent ending AR before and after optimization. After optimization it's actually slightly higher. This could be due to the minor numeric error introduced by the padded matmul in previous attention.
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Refactor
Tests