Skip to content

EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models#3527

Open
winglian wants to merge 19 commits intomainfrom
ebft
Open

EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models#3527
winglian wants to merge 19 commits intomainfrom
ebft

Conversation

@winglian
Copy link
Collaborator

@winglian winglian commented Mar 21, 2026

Summary by CodeRabbit

  • New Features

    • Added Energy-Based Fine-Tuning (EBFT) with structured (vLLM rollouts) and strided (block-parallel) modes, feature-matching rewards, and multi-turn/reasoning-aware dataset transforms.
    • New dataset transforms and example configs for Llama/Qwen/Opencode to run EBFT and LoRA experiments.
    • Faster GPU ops via fused kernels for similarity/diversity/reinforce computations.
  • Chores

    • Improved vLLM serving: eager-mode option, async prefetch/weight-sync enhancements, and HTTP weight-update endpoint.
  • Documentation

    • Comprehensive EBFT guide and example configs added.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 21, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 61fd4dc2-09bc-41da-99d7-2c45ac0a0f76

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds Energy-Based Fine-Tuning (EBFT) to axolotl: new trainers (structured/async/strided), feature-matching reward logic, Triton kernels, dataset transforms/prompt strategies, config/schema validation, vLLM serving/weight-sync extensions, example configs and documentation.

Changes

Cohort / File(s) Summary
Documentation
examples/ebft/README.md
New comprehensive README documenting EBFT modes, workflows, configs, metrics, and citation.
Example dataset transforms
examples/ebft/ebft_opencode.py, examples/ebft/ebft_pretrain.py, examples/ebft/ebft_strided_structured.py, src/axolotl/prompt_strategies/ebft/...
Multiple transform entrypoints for OpenCodeInstruct, pretrain, strided structured and chat/multiturn/reasoning variants; normalize outputs and set remove_columns.
Example configs
examples/ebft/llama-*.yaml, examples/ebft/qwen*.yaml
Seven new example YAMLs showing structured/strided EBFT runs (vLLM and non-vLLM), LoRA settings, runtime flags, and dataset selections.
Trainer dispatch & args
src/axolotl/core/trainers/ebft/__init__.py, src/axolotl/core/trainers/ebft/args.py, src/axolotl/core/builders/rl.py, src/axolotl/core/trainers/__init__.py
EBFTStrategy added to dispatch trainer and training-args classes by mode; dataclasses for structured/async/strided EBFT; HFRLTrainerBuilder wired to EBFT; trainers exported.
Structured EBFT trainer & mixin
src/axolotl/core/trainers/ebft/trainer.py
EBFTMixin implementing frozen feature network, feature-matching reward function (alignment/diversity/CFM), optional sequential rollouts, and new AxolotlEBFTTrainer / AxolotlAsyncEBFTTrainer.
Strided EBFT trainer
src/axolotl/core/trainers/ebft/strided.py
Large strided block-parallel trainer: strided rollout generation, flex/eager attention mask builders, feature extraction, per-block rewards, advantage computation and loss assembly.
Rewards & utilities
src/axolotl/core/trainers/ebft/rewards.py
Hidden-state extraction, pooling/embed methods (last_token/mean/concat/completion_mean), alignment/diversity reward helpers, batched SVD whitening with fallbacks.
Triton fused kernels
src/axolotl/core/trainers/ebft/kernels.py
Four Triton JIT kernels + Python wrappers for fused log-softmax gather, REINFORCE loss, cosine similarity, and diversity penalty.
Prompt strategies
src/axolotl/prompt_strategies/ebft/*
New EBFT prompt strategy package and modules for chat multiturn, opencode, reasoning, strided chat/structured transforms.
vLLM serving & weight sync
src/axolotl/cli/vllm_serve.py, src/axolotl/scripts/vllm_serve_lora.py, src/axolotl/scripts/vllm_worker_ext.py
Added enforce_eager passthrough, HTTP weight-update endpoint and worker HTTP load handlers, middleware logging, and name-mapping/fallbacks for vLLM parameter sets.
Async GRPO / trainer sync changes
src/axolotl/core/trainers/grpo/async_trainer.py, src/axolotl/monkeypatch/trainer/trl_vllm.py
Multi-process rollout broadcast, controlled rank-only generation, refined LoRA sync selection, HTTP-based param transport path for non-communicator setups, and timeouts.
Data loading & schemas
src/axolotl/common/datasets.py, src/axolotl/utils/data/rl.py, src/axolotl/utils/schemas/*
EBFT routing in dataset loader, long-sequence handling, normalize remove_columns; added EBFTConfig, RLType.EBFT, TRL generation/chat kwargs, vLLM enforce_eager, and EBFT validation mixins.
Training integration
src/axolotl/train.py
Treats EBFT like GRPO for reference-model setup and SequenceParallelContextManager gather semantics.
Misc & tooling
docker/Dockerfile-cloud-uv, src/axolotl/utils/callbacks/*, src/axolotl/integrations/diffusion/callbacks.py
Small runtime/dev updates: virtualenv activation in bashrc; type-ignore annotations for WandB usages.

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

enhancement, review requested

Suggested reviewers

  • NanoCode012
  • djsaunde
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 62.61% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly describes the main feature: Energy-Based Fine-Tuning (EBFT) integrating a feature-matching approach into axolotl, which aligns with the comprehensive changeset documenting EBFT training modes, utilities, and configurations.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch ebft

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can use Trivy to scan for security misconfigurations and secrets in Infrastructure as Code files.

Add a .trivyignore file to your project to customize which findings Trivy reports.

@winglian winglian added the scheduled_release This PR is slated for the upcoming release label Mar 21, 2026
@github-actions
Copy link
Contributor

github-actions bot commented Mar 21, 2026

📖 Documentation Preview: https://69bfd6d3215bcc9737afcd61--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit 567d07e

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 19

🧹 Nitpick comments (15)
src/axolotl/prompt_strategies/ebft/ebft_reasoning.py (2)

242-261: Think-tag masking logic may miss multi-token tags.

The masking logic assumes <think> and </think> each tokenize to a single token. If the tokenizer splits these tags into multiple tokens (e.g., <, think, >), think_open_id will be unk_token_id and masking will be silently skipped.

Consider adding a warning or fallback to text-based span detection when single-token IDs are unavailable.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py` around lines 242 - 261,
The current mask_thinking block assumes "<think>" and "</think>" map to single
token IDs (think_open_id/think_close_id) and silently skips masking when they
equal tokenizer.unk_token_id; update it to detect and handle multi-token tags:
if think_open_id or think_close_id == tokenizer.unk_token_id, fall back to
scanning the decoded substring (e.g.,
tokenizer.decode(input_ids[scan_start:end]) or joining
tokenizer.convert_ids_to_tokens(input_ids[scan_start:end])) to locate "<think>"
and "</think>" text spans and then map those spans back to token index ranges to
set labels[i] = -100, and/or emit a warning when single-token IDs are
unavailable; keep the existing behavior when single-token IDs are present.
Ensure you modify the mask_thinking branch and use input_ids, labels, start,
end, tokenizer, think_open_id/think_close_id identifiers.

35-40: Remove unused helper function _extract_thinking.

The function is defined but never called anywhere in the codebase. Since it's prefixed with an underscore (indicating private scope), there's no indication of intended external usage.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py` around lines 35 - 40,
The private helper function _extract_thinking is unused; remove its definition
from ebft_reasoning.py (delete the def _extract_thinking(...) block) and ensure
there are no remaining references to _extract_thinking elsewhere; if the re
import is now unused after removal, also remove the import to avoid linter
warnings and run tests/linters to confirm no residual uses.
examples/ebft/ebft_opencode.py (1)

18-20: Consider using remove_columns: "__all__" to avoid schema drift.

The explicit list is brittle if upstream dataset columns change.

♻️ Suggested simplification
-    return transform_fn, {"remove_columns": ["id", "domain", "generation_algorithm",
-                                              "llm_judgement", "unit_tests",
-                                              "tests_execution_status", "average_test_score"]}
+    return transform_fn, {"remove_columns": "__all__"}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/ebft/ebft_opencode.py` around lines 18 - 20, The return value
currently hardcodes a list of columns to drop which is brittle; update the
second element of the returned tuple (the transformer config returned alongside
transform_fn in ebft_opencode.py) to use "remove_columns": "__all__" instead of
the explicit list so the transformer removes all original dataset columns and
avoids schema drift while preserving transformed outputs from transform_fn.
src/axolotl/utils/schemas/config.py (3)

66-70: Consider using Literal type for embed_method validation.

The embed_method field accepts any string but the description lists specific valid values. Using a Literal type would provide compile-time validation and better IDE support.

♻️ Suggested improvement
+from typing import Literal
+
+EmbedMethod = Literal["last_token", "mean_pooling", "concat"]
+
-    embed_method: str = Field(
+    embed_method: EmbedMethod = Field(
         default="last_token",
         json_schema_extra={
             "description": "Embedding method: 'last_token', 'mean_pooling', or 'concat'"
         },
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/schemas/config.py` around lines 66 - 70, The embed_method
Field currently allows any string even though the docstring lists specific valid
values; change its type from str to typing.Literal with the allowed options
(e.g., Literal["last_token","mean_pooling","concat"]) and update the Field
declaration for embed_method in the Config schema so Pydantic/typing enforces
and IDEs surface the valid choices; keep the same json_schema_extra description
but ensure the attribute name embed_method in the relevant class/schema is
updated to use the Literal type.

132-134: Consider using Literal for advantage_estimator validation.

The advantage_estimator has three defined valid values that could benefit from type-level validation.

♻️ Suggested improvement
-    advantage_estimator: str = Field(
+    advantage_estimator: Literal["rloo", "group_norm", "reinforce"] = Field(
         default="rloo",
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/schemas/config.py` around lines 132 - 134, Change the
advantage_estimator field to use typing.Literal for strict validation: update
the type annotation of advantage_estimator to Literal["rloo", "group_norm",
"reinforce"] (keep the default "rloo"), add the Literal import, and leave the
existing Field(json_schema_extra=...) intact so pydantic/JSON schema will
enforce and document the allowed values; locate and modify the
advantage_estimator declaration and add the import near other typing imports.

98-102: Consider using Literal for mode validation.

Similar to embed_method, the mode field has defined valid values that could be enforced with a Literal type.

♻️ Suggested improvement
-    mode: str = Field(
+    mode: Literal["structured", "strided"] = Field(
         default="structured",
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/schemas/config.py` around lines 98 - 102, The mode field
currently declared as mode: str = Field(...) should enforce allowed values using
typing.Literal like embed_method does; update the type annotation for mode to
use Literal["structured", "strided"] (and import Literal) and remove or keep the
Field default/json_schema_extra as needed so Pydantic validates values at type
level; target the mode declaration in src/axolotl/utils/schemas/config.py (near
the existing embed_method pattern) to make this change.
examples/ebft/llama-1b-ebft-opencode.yaml (1)

64-64: Non-zero lora_dropout may disable LoRA kernel optimizations.

Setting lora_dropout: 0.05 typically disables auto-enabled LoRA kernel optimizations (lora_mlp_kernel, lora_qkv_kernel, lora_o_kernel). If you want the kernel speedups, consider setting lora_dropout: 0.0 or explicitly enabling kernels.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/ebft/llama-1b-ebft-opencode.yaml` at line 64, The lora_dropout: 0.05
setting will likely disable auto-enabled LoRA kernel optimizations; change the
lora_dropout value to 0.0 in the configuration or explicitly enable the kernels
(lora_mlp_kernel, lora_qkv_kernel, lora_o_kernel) so the LoRA kernel speedups
remain active—update the lora_dropout entry or add boolean flags for the three
kernel options to ensure optimizations are not disabled.
src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py (1)

42-43: Use defensive message-field access to avoid hard failures on bad rows.

These transforms directly index message keys. A single malformed record can crash datasets.map with KeyError. Prefer dict.get(...) + skip/fallback handling.

Example hardening pattern
-            if msg["role"] == "assistant" and not found_first:
-                first_gt = msg["content"]
+            role = msg.get("role") if isinstance(msg, dict) else None
+            content = msg.get("content", "") if isinstance(msg, dict) else ""
+            if role == "assistant" and not found_first:
+                first_gt = content
                 found_first = True
             elif found_first:
                 remaining.append(msg)
             else:
                 prompt_msgs.append(msg)

Also applies to: 56-57, 83-85, 116-118

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py` around lines 42 -
43, The loop in ebft_chat_multiturn.py is indexing message dicts directly (e.g.,
msg["role"], msg["content"]) which can raise KeyError on malformed rows; update
the code in the functions/blocks that set first_gt/last_gt and append messages
(references: variables msg, first_gt, last_gt, found_first, the
message-processing loop) to use dict.get("role") and dict.get("content") and
skip or continue when required keys are missing or not strings (e.g., if role is
None or content is None: continue), ensuring all places noted in the review
(around the current first_gt/last_gt assignments and the other mentioned blocks)
perform defensive checks before accessing message fields.
src/axolotl/utils/schemas/validation.py (1)

1572-1575: Guard EBFT strided length math with explicit parameter validation.

Line 1574 can throw a raw ZeroDivisionError (or produce invalid block math) for bad configs. Prefer an explicit ValueError with clear guidance.

Proposed hardening
         stride = ebft.get("stride", 8)
         ctx_len = ebft.get("context_length", 8)
-        max_blocks = (seq_len - gen_len - ctx_len) // stride + 1
+        if stride <= 0:
+            raise ValueError("ebft.stride must be > 0 in strided mode")
+        if seq_len <= gen_len + ctx_len:
+            raise ValueError(
+                "sequence_len must be greater than ebft.generate_max_len + ebft.context_length"
+            )
+        max_blocks = (seq_len - gen_len - ctx_len) // stride + 1
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/schemas/validation.py` around lines 1572 - 1575, The
strided-length calculation using ebft.get("stride") and
ebft.get("context_length") can raise ZeroDivisionError or produce invalid math;
validate parameters first: ensure stride is an int > 0 and context_length
(ctx_len) is an int >= 0 and less than (seq_len - gen_len) so the divisor
(seq_len - gen_len - ctx_len) is non-negative; if any check fails raise a
ValueError with a clear message referencing the offending keys ("stride" /
"context_length") and the values (seq_len, gen_len) so callers know how to
correct the EBFT config before computing max_blocks and full_seq.
src/axolotl/core/trainers/ebft/__init__.py (1)

60-166: Consider reducing repetition in kwargs mapping.

The is not None checks are repetitive. A helper function or loop over a mapping could reduce boilerplate, though this is a stylistic preference.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/core/trainers/ebft/__init__.py` around lines 60 - 166, The
mapping in set_training_args_kwargs repeats many "if X is not None:
kwargs[...]=X" blocks for ebft and trl; refactor by introducing a small helper
(e.g., a local function map_if_present or apply_mapping) that accepts a source
object and an iterable of (attr_name, kwarg_key) pairs and sets
kwargs[kwarg_key] = getattr(source, attr_name) when the value is not None (or
truthy where appropriate), then replace the repeated blocks for ebft and trl
with calls to this helper and explicit handling only for special cases like vllm
colocate logic, async_prefetch, vllm_server_host/port and vllm_enable_sleep_mode
inside set_training_args_kwargs.
src/axolotl/core/trainers/ebft/strided.py (1)

314-314: Mutable class attribute detected by Ruff (RUF012).

Class-level mutable defaults can cause unexpected sharing. While _tag_names is unlikely to be mutated, using a tuple is safer.

♻️ Use tuple instead of list
-    _tag_names = ["ebft", "strided", "axolotl"]
+    _tag_names = ("ebft", "strided", "axolotl")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/core/trainers/ebft/strided.py` at line 314, Replace the mutable
class attribute _tag_names (currently a list) with an immutable tuple to avoid
accidental shared-state; locate the _tag_names declaration in the Strided
trainer class in strided.py and change ["ebft", "strided", "axolotl"] to
("ebft", "strided", "axolotl") so the class-level constant is immutable.
examples/ebft/qwen35-4b-ebft-structured-async.yaml (1)

70-72: Complex regex for LoRA targeting — verify it matches intended layers.

The regex pattern targets full-attention layers (3,7,11,15,19,23,27,31) and MLP on all layers. This is a careful design for hybrid attention models, but the pattern complexity makes it easy to miss layers.

Consider adding a verification script or test to confirm the pattern matches exactly the intended modules when applied to the model.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/ebft/qwen35-4b-ebft-structured-async.yaml` around lines 70 - 72, The
lora_target_modules regex is complex and may miss or overmatch intended module
names; add a small verification helper that loads the model's state_dict or
module names and tests each key against lora_target_modules to assert that
exactly the intended layer indices (3,7,11,15,19,23,27,31) for
self_attn.(q|k|v|o)_proj and all mlp.(gate|up|down)_proj are matched; implement
this check as a unit test or a CLI validation (e.g., verify_lora_targets or
validate_lora_regex) that prints unmatched expected modules and any unexpected
matches so you can adjust the regex if it misfires.
src/axolotl/core/trainers/ebft/trainer.py (3)

44-44: Mutable class attribute (RUF012).

Same as strided.py — use tuple for immutable tag list.

♻️ Use tuple instead of list
-    _tag_names = ["trl", "ebft", "axolotl"]
+    _tag_names = ("trl", "ebft", "axolotl")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/core/trainers/ebft/trainer.py` at line 44, The class-level
_tag_names is defined as a mutable list; change it to an immutable tuple to
avoid RUF012. Replace _tag_names = ["trl", "ebft", "axolotl"] with _tag_names =
("trl", "ebft", "axolotl") in trainer.py (same pattern as in strided.py) so the
attribute is immutable at class scope.

364-366: Bare Exception catch is overly broad (BLE001).

Catching all exceptions can mask unexpected errors. Consider catching specific vLLM client exceptions or at minimum re-raising after logging for non-recoverable errors.

♻️ Narrow exception handling
-                    except Exception as e:
-                        LOG.warning(f"Multi-turn rollout generation failed: {e}")
-                        gen_text = ""
+                    except (ConnectionError, TimeoutError, RuntimeError) as e:
+                        LOG.warning(f"Multi-turn rollout generation failed: {e}")
+                        gen_text = ""

Or if the vLLM client has specific exception types, catch those instead.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/core/trainers/ebft/trainer.py` around lines 364 - 366, The
current broad except Exception block around the multi-turn rollout generation
masks unexpected errors; replace it by catching concrete vLLM/client errors
(e.g., vllm.client.exceptions.VLLMError or the client-specific exception types)
and handle them by logging and setting gen_text as before, and add a separate
fallback that logs with LOG.exception(...) and re-raises for any other
unexpected exceptions; import the client exception types at top and update the
try/except in the multi-turn rollout generation block (the one that currently
sets gen_text = "") to use specific except clauses and a final except Exception
to re-raise after logging.

60-70: Mypy errors on super().__init__() call are expected for mixin pattern.

The Mypy errors about "unexpected keyword arguments for __init__ of object" occur because Mypy doesn't see the full MRO at the mixin level. This is a known limitation with mixin patterns. Consider adding a TYPE_CHECKING block with protocol hints or # type: ignore comments if the errors are noisy in CI.

♻️ Silence Mypy for mixin super() call
-        super().__init__(
+        super().__init__(  # type: ignore[call-arg]
             model=model,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/core/trainers/ebft/trainer.py` around lines 60 - 70, Mypy
complains about unexpected keyword args on the super().__init__() call because
of the mixin pattern; either silence it by adding a TYPE_CHECKING block that
defines a minimal Protocol for the base initializer signature (import
TYPE_CHECKING from typing and declare a Protocol with __init__(..., model,
reward_funcs, args, train_dataset, eval_dataset, processing_class, callbacks,
optimizers, peft_config) and use it only under TYPE_CHECKING), or add a scoped
type ignore on the call (e.g., append # type: ignore[arg-type] to the
super().__init__(...) line) so the mixin pattern no longer fails CI; target the
super().__init__ invocation in trainer.py (the call that passes model,
reward_funcs=[self._feature_matching_reward], args, train_dataset, eval_dataset,
processing_class, callbacks, optimizers, peft_config).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/ebft/ebft_pretrain.py`:
- Around line 27-31: The returned dict aliases labels to encoded["input_ids"],
which can create shared references when the tokenizer returns Python lists;
change the assignment so "labels" is a shallow copy of encoded["input_ids"]
(e.g., use list(...) or .copy()) instead of referencing the same object; update
the return in ebft_pretrain.py that currently returns {"input_ids":
encoded["input_ids"], "attention_mask": encoded["attention_mask"], "labels":
encoded["input_ids"]} to set "labels" to a copy of encoded["input_ids"] to avoid
accidental mutation of input_ids.
- Line 17: The variable pad_id assigned from tokenizer.pad_token_id or
tokenizer.eos_token_id is unused; remove the pad_id assignment in
examples/ebft/ebft_pretrain.py (the line creating pad_id) and rely on the
tokenizer's built-in padding (used later with padding="max_length"), ensuring no
other code references pad_id (search for pad_id to confirm) and run tests or
lint to verify no remaining usages.

In `@examples/ebft/ebft_strided_structured.py`:
- Around line 76-77: The file examples/ebft/ebft_strided_structured.py is
missing a trailing newline; open the file and add a single newline character at
the end of the file (ensure it ends with exactly one '\n') so the pre-commit
end-of-file-fixer check passes.

In `@examples/ebft/llama-3b-ebft-strided-fft.yaml`:
- Around line 53-55: The EBFT validator fails because strided EBFT with
gradient_checkpointing enabled is incompatible with torch_compile=true and
requires reentrant checkpointing; update the YAML to set torch_compile: false
(change the torch_compile key) and set
gradient_checkpointing_kwargs.use_reentrant: true (change use_reentrant value)
so the configuration meets the validator's requirements for strided EBFT and
flex-attention checkpointing.

In `@examples/ebft/qwen35-4b-ebft-structured.yaml`:
- Around line 32-33: Replace the non-routable wildcard host value used for the
TRL vllm server with a loopback address: change the vllm_server_host setting
(trl.vllm_server_host) from "0.0.0.0" to "127.0.0.1" so clients connect to a
routable local address; update any corresponding examples or documentation in
the same YAML file to use 127.0.0.1 for vllm_server_host while leaving
vllm_server_port as-is (8000).

In `@examples/ebft/README.md`:
- Around line 187-188: Update the strided-mode performance guidance to stop
recommending "torch_compile: true" because EBFT validation now warns/errors when
torch_compile is enabled with strided mode and gradient checkpointing; replace
that sentence so it either recommends leaving torch_compile disabled for strided
configurations or documents the validation restriction, and keep the explanation
of "flex_attention" behavior and fallback as-is so users know flex_attention is
used when available.

In `@src/axolotl/cli/vllm_serve.py`:
- Around line 82-86: The current precedence uses Python's truthy "or" so an
explicit CLI False (cli_args.get("enforce_eager") == False) is ignored when
cfg.vllm.enforce_eager is truthy; change the logic to prefer an explicitly
provided CLI value by checking presence/None instead of truthiness: if
cli_args.get("enforce_eager") is not None use
bool(cli_args.get("enforce_eager")) else use getattr(cfg.vllm, "enforce_eager",
False). Replace the current assignment to enforce_eager with this pattern and
apply the same change to the other occurrence referenced (the second
enforce_eager assignment at the other location).
- Around line 109-110: The current check calls getattr(cfg.trl,
"vllm_lora_sync", False) and will raise AttributeError if cfg has no trl
attribute; change the guard to safely access trl first (e.g., use getattr(cfg,
"trl", None) and then getattr(..., "vllm_lora_sync", False) or check
hasattr(cfg, "trl") before reading vllm_lora_sync) so that when trl is absent
you fall back to False and still set lora_kwargs["enable_lora"] = False; update
the expression around cfg.trl, "vllm_lora_sync", False to use a safe nested
getattr or an existence check so the code never dereferences a missing trl.

In `@src/axolotl/common/datasets.py`:
- Around line 121-124: Normalize cfg.rl to an RLType before the membership check
so string values like "grpo"/"ebft" are treated the same as RLType.GRPO/EBFT;
update the block that computes total_num_steps (referencing cfg.rl and
total_num_steps) to first map/normalize cfg.rl into an RLType (or compare
against lowercased names) and then check membership against {RLType.GRPO,
RLType.EBFT} so GRPO/EBFT are properly excluded whether cfg.rl is provided as a
string or an enum.

In `@src/axolotl/core/trainers/ebft/kernels.py`:
- Around line 256-261: The kernel divides by (N - 1) and when N == 1 this yields
a divide-by-zero; update the Python wrapper fused_diversity_penalty to guard
against N <= 1 by short-circuiting before launching the kernel: detect the input
size (N), and if N <= 1 return an appropriately shaped tensor of zeros (or the
intended neutral penalty) without calling the kernel; otherwise proceed to call
the existing kernel as before. Ensure the check references
fused_diversity_penalty and the N dimension used to compute (N - 1) so the
kernel never receives N == 1.

In `@src/axolotl/core/trainers/ebft/rewards.py`:
- Around line 205-213: The whitening currently builds W from left singular
vectors U (producing a (B,B) matrix) but must operate in feature space: use the
right singular vectors V and singular values S to build a (D,D) whitening matrix
W = V @ diag(inv_s) @ V.T (where inv_s is computed with whiten_tol as before),
then apply it to features via phi_w = phi_f @ W.T and phi_gt_w = phi_gt_f @ W.T
(ensure dtype casts remain as phi.dtype / phi_gt.dtype); replace usages of U, W
(B,B) with V and the new (D,D) W to correct the transform.

In `@src/axolotl/core/trainers/ebft/strided.py`:
- Line 768: The return uses an undefined outputs when return_outputs=True;
update the function (the block that computes loss when backbone is not None) to
always set a value for outputs (e.g., outputs = None or the actual model
outputs) before the final return, or change the final return to return (loss,
None) when outputs were not produced; ensure references to outputs,
return_outputs, and backbone in this function (the strided trainer method
containing the loss computation) are adjusted so outputs is always defined when
return_outputs is True.

In `@src/axolotl/core/trainers/ebft/trainer.py`:
- Around line 181-190: The zip(prompts, ground_truth) usage in the loop that
builds gt_texts should be made strict to avoid silent length mismatches; update
the zip call inside the loop (the one iterating with for i, (p, gt) in
enumerate(...)) to zip(prompts, ground_truth, strict=True) so mismatched lengths
raise an error, keeping the rest of the logic that uses num_gens,
processing_class.apply_chat_template, and gt_texts unchanged.
- Around line 155-166: The loop in trainer.py that iterates over prompts and
completions uses zip(prompts, completions) without the strict parameter; change
it to zip(prompts, completions, strict=True) to enforce equal lengths and
surface mismatches early, keeping the existing handling of list vs. scalar
prompt/completion values and appending combined strings to gen_texts
(references: prompts, completions, gen_texts, and
processing_class.apply_chat_template).

In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py`:
- Around line 61-64: The code uses `"prompt_msgs_snapshot" in dir()` to detect
whether prompt_msgs_snapshot was assigned; replace this unreliable check by
explicitly tracking assignment: either initialize prompt_msgs_snapshot = None at
the top of each function (transform, transform_split_thinking,
transform_answer_only) and test `if prompt_msgs_snapshot is not None` before
using it, or set a boolean flag (e.g., has_prompt_snapshot = False -> True when
you create the snapshot) and test that flag in the return. Update the three
functions (transform, transform_split_thinking, transform_answer_only) to use
the sentinel/flag instead of the dir() check so the branch is deterministic.

In `@src/axolotl/scripts/vllm_serve_lora.py`:
- Around line 505-526: The http_update_weights handler is using torch
(torch.frombuffer and getattr(torch,...)) but torch is not imported, causing a
NameError; add an import for torch (preferably at module top) so
http_update_weights can reference it. Also address the unused results variable
returned from asyncio.gather: either consume/check results for worker
responses/errors (e.g., inspect the list returned by asyncio.gather) or remove
the assignment and await the gather call solely for synchronization. Target
symbols: http_update_weights, torch.frombuffer, getattr(torch,...), connections,
asyncio.gather, and results.

In `@src/axolotl/train.py`:
- Around line 141-142: The current condition accesses cfg.trl directly which can
raise if cfg has no trl attribute; change it to safely obtain trl first (e.g.
trl = getattr(cfg, 'trl', None)) and then check if trl is not None before
reading beta, e.g. if cfg.rl in {RLType.GRPO, RLType.EBFT} and trl and
getattr(trl, 'beta', 0) == 0: reference_model = False — this ensures accessing
beta is guarded and prevents attribute errors on cfg.trl.

In `@src/axolotl/utils/data/rl.py`:
- Around line 223-229: The removal of columns logic assumes a "train" split and
will break for DatasetDicts without that key; change the DatasetDict branch in
the remove_columns resolution to pick the first available split dynamically
(e.g., use next(iter(dataset)) or list(dataset.keys())[0]) and read its
.column_names instead of dataset["train"].column_names so it works for arbitrary
split names; update the code path that computes ds_columns (the block guarded by
isinstance(dataset, DatasetDict)) to use the first split's column_names and keep
the existing Dataset and fallback behaviors unchanged.

In `@src/axolotl/utils/schemas/validation.py`:
- Line 1579: Replace the Unicode multiplication sign in the EBFT log message
with an ASCII character to satisfy Ruff RUF001: update the f-string containing
"EBFT strided: full_seq_len={full_seq} × n_samples={n_samples} = " to use "x"
(e.g. "full_seq_len={full_seq} x n_samples={n_samples}") or "*" instead; locate
the string that references variables full_seq and n_samples in the validation
code and make this simple character substitution.

---

Nitpick comments:
In `@examples/ebft/ebft_opencode.py`:
- Around line 18-20: The return value currently hardcodes a list of columns to
drop which is brittle; update the second element of the returned tuple (the
transformer config returned alongside transform_fn in ebft_opencode.py) to use
"remove_columns": "__all__" instead of the explicit list so the transformer
removes all original dataset columns and avoids schema drift while preserving
transformed outputs from transform_fn.

In `@examples/ebft/llama-1b-ebft-opencode.yaml`:
- Line 64: The lora_dropout: 0.05 setting will likely disable auto-enabled LoRA
kernel optimizations; change the lora_dropout value to 0.0 in the configuration
or explicitly enable the kernels (lora_mlp_kernel, lora_qkv_kernel,
lora_o_kernel) so the LoRA kernel speedups remain active—update the lora_dropout
entry or add boolean flags for the three kernel options to ensure optimizations
are not disabled.

In `@examples/ebft/qwen35-4b-ebft-structured-async.yaml`:
- Around line 70-72: The lora_target_modules regex is complex and may miss or
overmatch intended module names; add a small verification helper that loads the
model's state_dict or module names and tests each key against
lora_target_modules to assert that exactly the intended layer indices
(3,7,11,15,19,23,27,31) for self_attn.(q|k|v|o)_proj and all
mlp.(gate|up|down)_proj are matched; implement this check as a unit test or a
CLI validation (e.g., verify_lora_targets or validate_lora_regex) that prints
unmatched expected modules and any unexpected matches so you can adjust the
regex if it misfires.

In `@src/axolotl/core/trainers/ebft/__init__.py`:
- Around line 60-166: The mapping in set_training_args_kwargs repeats many "if X
is not None: kwargs[...]=X" blocks for ebft and trl; refactor by introducing a
small helper (e.g., a local function map_if_present or apply_mapping) that
accepts a source object and an iterable of (attr_name, kwarg_key) pairs and sets
kwargs[kwarg_key] = getattr(source, attr_name) when the value is not None (or
truthy where appropriate), then replace the repeated blocks for ebft and trl
with calls to this helper and explicit handling only for special cases like vllm
colocate logic, async_prefetch, vllm_server_host/port and vllm_enable_sleep_mode
inside set_training_args_kwargs.

In `@src/axolotl/core/trainers/ebft/strided.py`:
- Line 314: Replace the mutable class attribute _tag_names (currently a list)
with an immutable tuple to avoid accidental shared-state; locate the _tag_names
declaration in the Strided trainer class in strided.py and change ["ebft",
"strided", "axolotl"] to ("ebft", "strided", "axolotl") so the class-level
constant is immutable.

In `@src/axolotl/core/trainers/ebft/trainer.py`:
- Line 44: The class-level _tag_names is defined as a mutable list; change it to
an immutable tuple to avoid RUF012. Replace _tag_names = ["trl", "ebft",
"axolotl"] with _tag_names = ("trl", "ebft", "axolotl") in trainer.py (same
pattern as in strided.py) so the attribute is immutable at class scope.
- Around line 364-366: The current broad except Exception block around the
multi-turn rollout generation masks unexpected errors; replace it by catching
concrete vLLM/client errors (e.g., vllm.client.exceptions.VLLMError or the
client-specific exception types) and handle them by logging and setting gen_text
as before, and add a separate fallback that logs with LOG.exception(...) and
re-raises for any other unexpected exceptions; import the client exception types
at top and update the try/except in the multi-turn rollout generation block (the
one that currently sets gen_text = "") to use specific except clauses and a
final except Exception to re-raise after logging.
- Around line 60-70: Mypy complains about unexpected keyword args on the
super().__init__() call because of the mixin pattern; either silence it by
adding a TYPE_CHECKING block that defines a minimal Protocol for the base
initializer signature (import TYPE_CHECKING from typing and declare a Protocol
with __init__(..., model, reward_funcs, args, train_dataset, eval_dataset,
processing_class, callbacks, optimizers, peft_config) and use it only under
TYPE_CHECKING), or add a scoped type ignore on the call (e.g., append # type:
ignore[arg-type] to the super().__init__(...) line) so the mixin pattern no
longer fails CI; target the super().__init__ invocation in trainer.py (the call
that passes model, reward_funcs=[self._feature_matching_reward], args,
train_dataset, eval_dataset, processing_class, callbacks, optimizers,
peft_config).

In `@src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py`:
- Around line 42-43: The loop in ebft_chat_multiturn.py is indexing message
dicts directly (e.g., msg["role"], msg["content"]) which can raise KeyError on
malformed rows; update the code in the functions/blocks that set
first_gt/last_gt and append messages (references: variables msg, first_gt,
last_gt, found_first, the message-processing loop) to use dict.get("role") and
dict.get("content") and skip or continue when required keys are missing or not
strings (e.g., if role is None or content is None: continue), ensuring all
places noted in the review (around the current first_gt/last_gt assignments and
the other mentioned blocks) perform defensive checks before accessing message
fields.

In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py`:
- Around line 242-261: The current mask_thinking block assumes "<think>" and
"</think>" map to single token IDs (think_open_id/think_close_id) and silently
skips masking when they equal tokenizer.unk_token_id; update it to detect and
handle multi-token tags: if think_open_id or think_close_id ==
tokenizer.unk_token_id, fall back to scanning the decoded substring (e.g.,
tokenizer.decode(input_ids[scan_start:end]) or joining
tokenizer.convert_ids_to_tokens(input_ids[scan_start:end])) to locate "<think>"
and "</think>" text spans and then map those spans back to token index ranges to
set labels[i] = -100, and/or emit a warning when single-token IDs are
unavailable; keep the existing behavior when single-token IDs are present.
Ensure you modify the mask_thinking branch and use input_ids, labels, start,
end, tokenizer, think_open_id/think_close_id identifiers.
- Around line 35-40: The private helper function _extract_thinking is unused;
remove its definition from ebft_reasoning.py (delete the def
_extract_thinking(...) block) and ensure there are no remaining references to
_extract_thinking elsewhere; if the re import is now unused after removal, also
remove the import to avoid linter warnings and run tests/linters to confirm no
residual uses.

In `@src/axolotl/utils/schemas/config.py`:
- Around line 66-70: The embed_method Field currently allows any string even
though the docstring lists specific valid values; change its type from str to
typing.Literal with the allowed options (e.g.,
Literal["last_token","mean_pooling","concat"]) and update the Field declaration
for embed_method in the Config schema so Pydantic/typing enforces and IDEs
surface the valid choices; keep the same json_schema_extra description but
ensure the attribute name embed_method in the relevant class/schema is updated
to use the Literal type.
- Around line 132-134: Change the advantage_estimator field to use
typing.Literal for strict validation: update the type annotation of
advantage_estimator to Literal["rloo", "group_norm", "reinforce"] (keep the
default "rloo"), add the Literal import, and leave the existing
Field(json_schema_extra=...) intact so pydantic/JSON schema will enforce and
document the allowed values; locate and modify the advantage_estimator
declaration and add the import near other typing imports.
- Around line 98-102: The mode field currently declared as mode: str =
Field(...) should enforce allowed values using typing.Literal like embed_method
does; update the type annotation for mode to use Literal["structured",
"strided"] (and import Literal) and remove or keep the Field
default/json_schema_extra as needed so Pydantic validates values at type level;
target the mode declaration in src/axolotl/utils/schemas/config.py (near the
existing embed_method pattern) to make this change.

In `@src/axolotl/utils/schemas/validation.py`:
- Around line 1572-1575: The strided-length calculation using ebft.get("stride")
and ebft.get("context_length") can raise ZeroDivisionError or produce invalid
math; validate parameters first: ensure stride is an int > 0 and context_length
(ctx_len) is an int >= 0 and less than (seq_len - gen_len) so the divisor
(seq_len - gen_len - ctx_len) is non-negative; if any check fails raise a
ValueError with a clear message referencing the offending keys ("stride" /
"context_length") and the values (seq_len, gen_len) so callers know how to
correct the EBFT config before computing max_blocks and full_seq.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: d773aa21-0e02-40e5-853d-1c6fb87e4935

📥 Commits

Reviewing files that changed from the base of the PR and between b0294b3 and 5f978b6.

📒 Files selected for processing (40)
  • examples/ebft/README.md
  • examples/ebft/ebft_opencode.py
  • examples/ebft/ebft_pretrain.py
  • examples/ebft/ebft_strided_structured.py
  • examples/ebft/llama-1b-ebft-opencode-novllm.yaml
  • examples/ebft/llama-1b-ebft-opencode.yaml
  • examples/ebft/llama-1b-ebft-strided-structured.yaml
  • examples/ebft/llama-1b-ebft-strided.yaml
  • examples/ebft/llama-3b-ebft-strided-fft.yaml
  • examples/ebft/llama-8b-ebft-strided-fft.yaml
  • examples/ebft/qwen35-4b-ebft-structured-async.yaml
  • examples/ebft/qwen35-4b-ebft-structured.yaml
  • examples/ebft/qwen35-9b-ebft-structured.yaml
  • src/axolotl/cli/vllm_serve.py
  • src/axolotl/common/datasets.py
  • src/axolotl/core/builders/rl.py
  • src/axolotl/core/trainers/__init__.py
  • src/axolotl/core/trainers/ebft/__init__.py
  • src/axolotl/core/trainers/ebft/args.py
  • src/axolotl/core/trainers/ebft/kernels.py
  • src/axolotl/core/trainers/ebft/rewards.py
  • src/axolotl/core/trainers/ebft/strided.py
  • src/axolotl/core/trainers/ebft/trainer.py
  • src/axolotl/core/trainers/grpo/async_trainer.py
  • src/axolotl/monkeypatch/trainer/trl_vllm.py
  • src/axolotl/prompt_strategies/ebft/__init__.py
  • src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py
  • src/axolotl/prompt_strategies/ebft/ebft_opencode.py
  • src/axolotl/prompt_strategies/ebft/ebft_reasoning.py
  • src/axolotl/prompt_strategies/ebft/ebft_strided_chat.py
  • src/axolotl/prompt_strategies/ebft/ebft_strided_structured.py
  • src/axolotl/scripts/vllm_serve_lora.py
  • src/axolotl/scripts/vllm_worker_ext.py
  • src/axolotl/train.py
  • src/axolotl/utils/data/rl.py
  • src/axolotl/utils/schemas/config.py
  • src/axolotl/utils/schemas/enums.py
  • src/axolotl/utils/schemas/trl.py
  • src/axolotl/utils/schemas/validation.py
  • src/axolotl/utils/schemas/vllm.py

Comment on lines +53 to +55
torch_compile: true
gradient_checkpointing_kwargs:
use_reentrant: false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

This example currently fails EBFT validation at startup.

With strided EBFT + gradient_checkpointing: true, Line 53 (torch_compile: true) is rejected by the new validator. Also, Line 55 (use_reentrant: false) conflicts with the documented flex-attention checkpointing requirement.

Proposed fix
 gradient_checkpointing: true
-torch_compile: true
+# Keep torch_compile disabled in EBFT strided mode with gradient checkpointing
 gradient_checkpointing_kwargs:
-  use_reentrant: false
+  use_reentrant: true
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
torch_compile: true
gradient_checkpointing_kwargs:
use_reentrant: false
# Keep torch_compile disabled in EBFT strided mode with gradient checkpointing
gradient_checkpointing_kwargs:
use_reentrant: true
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/ebft/llama-3b-ebft-strided-fft.yaml` around lines 53 - 55, The EBFT
validator fails because strided EBFT with gradient_checkpointing enabled is
incompatible with torch_compile=true and requires reentrant checkpointing;
update the YAML to set torch_compile: false (change the torch_compile key) and
set gradient_checkpointing_kwargs.use_reentrant: true (change use_reentrant
value) so the configuration meets the validator's requirements for strided EBFT
and flex-attention checkpointing.

Comment on lines +32 to +33
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use a routable client host for trl.vllm_server_host in this example.

At Line 32, 0.0.0.0 can be problematic as a connect target; 127.0.0.1 is safer for local training-to-server calls.

💡 Proposed fix
-  vllm_server_host: 0.0.0.0
+  vllm_server_host: 127.0.0.1
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
vllm_server_host: 127.0.0.1
vllm_server_port: 8000
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/ebft/qwen35-4b-ebft-structured.yaml` around lines 32 - 33, Replace
the non-routable wildcard host value used for the TRL vllm server with a
loopback address: change the vllm_server_host setting (trl.vllm_server_host)
from "0.0.0.0" to "127.0.0.1" so clients connect to a routable local address;
update any corresponding examples or documentation in the same YAML file to use
127.0.0.1 for vllm_server_host while leaving vllm_server_port as-is (8000).

@winglian
Copy link
Collaborator Author

@coderabbitai review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 22, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/axolotl/utils/callbacks/generation.py (1)

28-60: ⚠️ Potential issue | 🔴 Critical

Unindent the sample generation block to fix unreachable code.

Lines 31–60 are incorrectly indented inside the if not getattr(cfg, "generate_samples", False): block. This means the generation code is unreachable regardless of the config value: when generate_samples is false, the return executes; when true, the entire if-block is skipped.

Dedent lines 31–60 by one level (4 spaces) so they execute when generate_samples is true.

Suggested fix
        if not getattr(cfg, "generate_samples", False):
            return

-            dataloader = None
-            try:
-                if getattr(self.trainer, "eval_dataset", None) is not None:
-                    dataloader = self.trainer.get_eval_dataloader()
-                    LOG.info(
-                        f"Using eval dataloader for generation at step {state.global_step}"
-                    )
-            except Exception as e:
-                LOG.warning(f"Could not get eval dataloader: {e}")
-                dataloader = None
-
-            if dataloader is None:
-                dataloader = self.trainer.get_train_dataloader()
-                LOG.info(
-                    f"Using train dataloader for generation at step {state.global_step}"
-                )
-
-            samples = generate_samples(
-                model=self.trainer.model,
-                tokenizer=self.trainer.processing_class,
-                dataloader=dataloader,
-                num_generation_samples=getattr(cfg, "num_generation_samples", 3),
-                max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50),
-                temperature=getattr(cfg, "generation_temperature", 0.7),
-                top_p=getattr(cfg, "generation_top_p", None),
-                top_k=getattr(cfg, "generation_top_k", None),
-                do_sample=getattr(cfg, "generation_do_sample", True),
-                prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5),
-            )
-            self._log_samples(samples, state.global_step)
+        dataloader = None
+        try:
+            if getattr(self.trainer, "eval_dataset", None) is not None:
+                dataloader = self.trainer.get_eval_dataloader()
+                LOG.info(
+                    f"Using eval dataloader for generation at step {state.global_step}"
+                )
+        except Exception as e:
+            LOG.warning(f"Could not get eval dataloader: {e}")
+            dataloader = None
+
+        if dataloader is None:
+            dataloader = self.trainer.get_train_dataloader()
+            LOG.info(
+                f"Using train dataloader for generation at step {state.global_step}"
+            )
+
+        samples = generate_samples(
+            model=self.trainer.model,
+            tokenizer=self.trainer.processing_class,
+            dataloader=dataloader,
+            num_generation_samples=getattr(cfg, "num_generation_samples", 3),
+            max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50),
+            temperature=getattr(cfg, "generation_temperature", 0.7),
+            top_p=getattr(cfg, "generation_top_p", None),
+            top_k=getattr(cfg, "generation_top_k", None),
+            do_sample=getattr(cfg, "generation_do_sample", True),
+            prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5),
+        )
+        self._log_samples(samples, state.global_step)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/callbacks/generation.py` around lines 28 - 60, The
sample-generation block is incorrectly indented under the early-return for
getattr(cfg, "generate_samples", False), making it unreachable; dedent the
entire block that calls dataloader selection and generate_samples (the
try/except fetching self.trainer.get_eval_dataloader(), the fallback to
self.trainer.get_train_dataloader(), the call to generate_samples with
model=self.trainer.model and tokenizer=self.trainer.processing_class, and the
subsequent self._log_samples(samples, state.global_step)) so it runs only when
generate_samples is True (i.e., move that block out of the if that contains the
return).
src/axolotl/core/trainers/grpo/async_trainer.py (1)

645-670: ⚠️ Potential issue | 🟠 Major

Restore _init_vllm after this trainer finishes initialization.

This mutates VLLMGeneration._init_vllm at class scope and never puts the original method back. After one async/HTTP-only trainer is constructed, later trainers in the same process will also skip communicator init even when they need the stock behavior.

Suggested fix
-        if _skip_nccl:
+        restore_init_vllm = None
+        if _skip_nccl:
             from trl.generation.vllm_generation import VLLMGeneration

             _orig_init_vllm = VLLMGeneration._init_vllm
+            restore_init_vllm = _orig_init_vllm
             ...
             VLLMGeneration._init_vllm = _init_vllm_no_communicator

-        super().__init__(*args, **kwargs)
+        try:
+            super().__init__(*args, **kwargs)
+        finally:
+            if restore_init_vllm is not None:
+                VLLMGeneration._init_vllm = restore_init_vllm
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/core/trainers/grpo/async_trainer.py` around lines 645 - 670, The
patch permanently replaces VLLMGeneration._init_vllm when _skip_nccl is true,
causing later trainers to inherit the no-communicator behavior; instead, save
_orig_init_vllm, assign VLLMGeneration._init_vllm = _init_vllm_no_communicator
only for the duration of this trainer's initialization and restore the original
in a finally/cleanup block (or use a context manager) so that the original
_init_vllm is reinstated whether initialization succeeds or raises; reference
VLLMGeneration._init_vllm, _orig_init_vllm, and _init_vllm_no_communicator and
ensure restoration happens after the trainer finishes initialization.
♻️ Duplicate comments (2)
src/axolotl/prompt_strategies/ebft/ebft_reasoning.py (1)

148-151: ⚠️ Potential issue | 🟠 Major

Replace the remaining dir() guards with an explicit sentinel.

This is the same unresolved issue from the previous round: transform_split_thinking() and transform_answer_only() still use if "prompt_msgs_snapshot" in dir() to decide whether a local was assigned. That keeps the branch dependent on interpreter locals instead of explicit state. Initialize prompt_msgs_snapshot = None before the loop and test is not None in both functions.

Also applies to: 172-175

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py` around lines 148 - 151,
The code uses if "prompt_msgs_snapshot" in dir() to detect whether
prompt_msgs_snapshot was set; instead, initialize prompt_msgs_snapshot = None
before the loop and change both guards in transform_split_thinking and
transform_answer_only to explicit checks (prompt_msgs_snapshot is not None) so
the branch depends on explicit state; update all occurrences (including the
similar checks around lines 172-175) to use the sentinel instead of dir() and
ensure the functions return prompt_msgs_snapshot when not None and fall back to
split_messages[:-1] otherwise.
src/axolotl/core/trainers/ebft/rewards.py (1)

208-236: ⚠️ Potential issue | 🔴 Critical

Fix whitening to operate in feature space before enabling it anywhere.

This still builds W from U, so W is (B, B) instead of (D, D). Besides the math bug from the earlier review, EBFTMixin._feature_matching_reward() calls this with phi.shape == (num_generations, D) and phi_gt.shape == (1, D), so W @ phi_gt_f will shape-mismatch as soon as whitening is enabled with more than one generation.

Suggested fix
-        U, S, _ = torch.linalg.svd(phi_f.unsqueeze(0), full_matrices=False)
+        _, S, Vh = torch.linalg.svd(phi_f, full_matrices=False)
...
-    U, S = U.squeeze(0), S.squeeze(0)
-
     # Safe inverse of singular values
     s_max = S.max()
     inv_s = torch.where(S > whiten_tol * s_max, 1.0 / (S + 1e-12), torch.zeros_like(S))

-    # FIXME
-    # W = U @ diag(inv_S) @ U^T
-    W = (U * inv_s.unsqueeze(0)) @ U.T  # (B, B)
-    phi_w = (W @ phi_f).to(phi.dtype)
-    phi_gt_w = (W @ phi_gt_f).to(phi_gt.dtype)
+    V = Vh.transpose(-2, -1)
+    W = (V * inv_s.unsqueeze(0)) @ Vh  # (D, D)
+    phi_w = (phi_f @ W.T).to(phi.dtype)
+    phi_gt_w = (phi_gt_f @ W.T).to(phi_gt.dtype)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/core/trainers/ebft/rewards.py` around lines 208 - 236, The
whitening builds W in sample space (B,B) because SVD was taken on phi_f (shape
(B,D)), causing a shape mismatch when multiplying with phi_gt_f; fix by
performing SVD in feature space so W is (D,D): compute SVD on phi_f.T (or
equivalently compute eigendecomposition of phi_f.T @ phi_f) to produce U with
shape (D,D), form inv_s from S and build W = U @ diag(inv_s) @ U.T (use
whiten_tol and small eps as before), then apply W @ phi_f.T (or transpose inputs
appropriately) to get phi_w and phi_gt_w in the feature dimension; update the
code paths around U, S, inv_s, W, phi_f, phi_gt_f and ensure
EBFTMixin._feature_matching_reward() (which calls this) receives
correctly-shaped outputs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/core/trainers/ebft/rewards.py`:
- Around line 92-120: Pooling assumes left-aligned tokens; change pooling to use
attention_mask-derived token positions per sample: in last_token use
attention_mask to compute per-sample last valid index
(torch.where(attention_mask.bool()) grouped by batch or
attention_mask.sum(dim=1)-1) and index hidden_states accordingly instead of raw
indices; in completion_mean build comp_mask by computing valid token positions
per sample from prompt_lengths and attention_mask (i.e., find positions >=
prompt_lengths AND attention_mask==1) before mean-pooling; in concat, for each
sample compute the list of valid token indices from attention_mask (or
prompt-aware valid positions), pick quartile positions relative to that
per-sample valid-length (e.g., floor((valid_len-1)*[0.25,0.5,0.75])) and gather
hidden_states at those indices before concatenation so padding/left-padding is
never sampled. Ensure all indexing handles batched gather safely and uses
hidden_states device/dtypes.

In `@src/axolotl/core/trainers/ebft/trainer.py`:
- Around line 395-478: _sequential_rollout currently uses
self.vllm_generation.client and returns only concatenated assistant text; change
it to use self.vllm_generation.vllm_client (replace vllm_client =
self.vllm_generation.client with vllm_client = self.vllm_generation.vllm_client)
and instead of appending full_gen_text to extended_completions append the full
conversation representation (the conv list or a fully rendered conversation that
includes both user and assistant turns) so downstream reward code sees prompt +
interleaved user/assistant messages; ensure the no-remaining-turns branch
returns a full-message list consistent with the new format (e.g., original
prompt_msgs + assistant first turn) and keep vllm_client.chat call and decoding
logic (result, gen_ids, self.processing_class.decode) unchanged.

In `@src/axolotl/core/trainers/grpo/async_trainer.py`:
- Around line 897-902: The code only syncs weights when the computed mod_path
exists in lora_info, which drops trainable parameters stored under
modules_to_save.default.* (e.g., lm_head, embed_tokens). Update the conditional
around vllm_name/mod_path so it also accepts entries that were prefixed by
"modules_to_save.default.": after computing mod_path from vllm_name (and after
calling fix_name with extra_prefixes), check both mod_path and
"modules_to_save.default."+mod_path (or the original un-fixed mod_path) against
lora_info, and only continue if neither is present; this ensures
modules_to_save.default.* parameters are included in the sync.

In `@src/axolotl/monkeypatch/trainer/trl_vllm.py`:
- Around line 78-97: The current fallback loop uses MAX_PARAMS_PER_REQUEST to
split by parameter count which can still produce huge base64 JSON bodies; change
the batching in the function that builds payloads (where MAX_PARAMS_PER_REQUEST,
chunk, payload, and the self.session.post call are used) to instead accumulate
parameters until a MAX_BYTES_PER_REQUEST threshold (e.g., ~10MB) would be
exceeded, then send that batch; additionally, detect individual tensors whose
serialized byte size exceeds MAX_BYTES_PER_REQUEST and split them into smaller
slices (preserving name and adding slice metadata such as a shard index or
byte/row range) so each slice is serialized, base64-encoded and included as
separate payload entries, and ensure the shape/dtype fields reflect the slice so
the server can reassemble; replace the fixed-count loop with this byte-aware
accumulator before calling self.session.post.
- Around line 58-63: The POST calls that sync weights (the session.post to
f"{self.base_url}/batch_update_named_params/" and the other session.post to
f"{self.base_url}/http_update_weights/") currently have no timeout and can block
forever; update both calls to pass an explicit timeout (e.g., timeout=30) to
session.post, and keep the existing status_code check/Exception behavior; also
import and optionally catch requests.exceptions.Timeout around the calls in the
same function (or let it propagate) so timeouts surface deterministically.

In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py`:
- Around line 225-260: The loop in ebft_reasoning.py currently marks every
assistant turn as trainable (labels[]) which makes prompt_length be derived from
an earlier assistant turn; instead detect and record only the final assistant
turn span: when iterating messages use the existing tokenizer logic to compute
start/end but do not set labels for each assistant turn immediately—instead
store the last assistant span (final_start, final_end), then after the loop set
labels[i]=input_ids[i] only for i in range(final_start, min(final_end,
len(labels))) and compute prompt_length from final_start (before applying any
mask_thinking_ce modifications); apply the same change to the other occurrence
noted around the 285-290 region so only the final assistant span is treated as
the structured completion.
- Line 205: The code treats tokenizer.pad_token_id as missing when it equals 0
by using a falsy fallback expression; change the logic around pad_id in
ebft_reasoning.py so you only fall back to tokenizer.eos_token_id when
pad_token_id is actually None (or not set), e.g. replace the
`tokenizer.pad_token_id or tokenizer.eos_token_id` pattern with an explicit None
check (use tokenizer.pad_token_id if it is not None, otherwise
tokenizer.eos_token_id) where pad_id is assigned so PAD=0 remains respected.

In `@src/axolotl/utils/schemas/config.py`:
- Around line 245-250: Add a pydantic root validator on the same config model
that defines ebft (the model containing the field "ebft: EBFTConfig | None =
Field(...)") to enforce that when rl is set to EBFT (check data.get("rl")
against the RL enum or the string "ebft"), the incoming data contains a non-null
"ebft" entry; if missing or None, raise a ValueError with a clear message (e.g.,
"ebft config is required when rl is EBFT") so parsing fails early instead of
letting downstream validators in validation.py silently bypass the requirement.
- Around line 67-157: Tighten validation on the EBFT config fields by adding
explicit enum/range constraints: restrict embed_method to the allowed values
(e.g., "last_token","mean_pooling","concat") and mode to
("structured","strided") and advantage_estimator to
("rloo","group_norm","reinforce") (use Literal or an Enum for
embed_method/mode/advantage_estimator), enforce top_p between 0.0 and 1.0
(le=1.0, ge=0.0), ensure temperature is non-negative (ge=0.0), and make integer
fields like stride, context_length, generate_max_len, n_samples_per_prompt, and
min_completion_prefix positive or non-negative as appropriate (ge=0 or ge=1).
Update the Field declarations for embed_method, mode, advantage_estimator,
top_p, temperature, stride, context_length, generate_max_len,
n_samples_per_prompt, and min_completion_prefix to include these constraints so
invalid values fail fast during model validation.

---

Outside diff comments:
In `@src/axolotl/core/trainers/grpo/async_trainer.py`:
- Around line 645-670: The patch permanently replaces VLLMGeneration._init_vllm
when _skip_nccl is true, causing later trainers to inherit the no-communicator
behavior; instead, save _orig_init_vllm, assign VLLMGeneration._init_vllm =
_init_vllm_no_communicator only for the duration of this trainer's
initialization and restore the original in a finally/cleanup block (or use a
context manager) so that the original _init_vllm is reinstated whether
initialization succeeds or raises; reference VLLMGeneration._init_vllm,
_orig_init_vllm, and _init_vllm_no_communicator and ensure restoration happens
after the trainer finishes initialization.

In `@src/axolotl/utils/callbacks/generation.py`:
- Around line 28-60: The sample-generation block is incorrectly indented under
the early-return for getattr(cfg, "generate_samples", False), making it
unreachable; dedent the entire block that calls dataloader selection and
generate_samples (the try/except fetching self.trainer.get_eval_dataloader(),
the fallback to self.trainer.get_train_dataloader(), the call to
generate_samples with model=self.trainer.model and
tokenizer=self.trainer.processing_class, and the subsequent
self._log_samples(samples, state.global_step)) so it runs only when
generate_samples is True (i.e., move that block out of the if that contains the
return).

---

Duplicate comments:
In `@src/axolotl/core/trainers/ebft/rewards.py`:
- Around line 208-236: The whitening builds W in sample space (B,B) because SVD
was taken on phi_f (shape (B,D)), causing a shape mismatch when multiplying with
phi_gt_f; fix by performing SVD in feature space so W is (D,D): compute SVD on
phi_f.T (or equivalently compute eigendecomposition of phi_f.T @ phi_f) to
produce U with shape (D,D), form inv_s from S and build W = U @ diag(inv_s) @
U.T (use whiten_tol and small eps as before), then apply W @ phi_f.T (or
transpose inputs appropriately) to get phi_w and phi_gt_w in the feature
dimension; update the code paths around U, S, inv_s, W, phi_f, phi_gt_f and
ensure EBFTMixin._feature_matching_reward() (which calls this) receives
correctly-shaped outputs.

In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py`:
- Around line 148-151: The code uses if "prompt_msgs_snapshot" in dir() to
detect whether prompt_msgs_snapshot was set; instead, initialize
prompt_msgs_snapshot = None before the loop and change both guards in
transform_split_thinking and transform_answer_only to explicit checks
(prompt_msgs_snapshot is not None) so the branch depends on explicit state;
update all occurrences (including the similar checks around lines 172-175) to
use the sentinel instead of dir() and ensure the functions return
prompt_msgs_snapshot when not None and fall back to split_messages[:-1]
otherwise.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 3bc4db9a-27a7-4043-a5b9-dba87868b9a1

📥 Commits

Reviewing files that changed from the base of the PR and between 5f978b6 and 5134d7c.

📒 Files selected for processing (44)
  • docker/Dockerfile-cloud-uv
  • examples/ebft/README.md
  • examples/ebft/ebft_opencode.py
  • examples/ebft/ebft_pretrain.py
  • examples/ebft/ebft_strided_structured.py
  • examples/ebft/llama-1b-ebft-opencode-novllm.yaml
  • examples/ebft/llama-1b-ebft-opencode.yaml
  • examples/ebft/llama-1b-ebft-strided-structured.yaml
  • examples/ebft/llama-1b-ebft-strided.yaml
  • examples/ebft/llama-3b-ebft-strided-fft.yaml
  • examples/ebft/llama-8b-ebft-strided-fft.yaml
  • examples/ebft/qwen35-4b-ebft-structured-async.yaml
  • examples/ebft/qwen35-4b-ebft-structured.yaml
  • examples/ebft/qwen35-9b-ebft-structured.yaml
  • src/axolotl/cli/vllm_serve.py
  • src/axolotl/common/datasets.py
  • src/axolotl/core/builders/rl.py
  • src/axolotl/core/trainers/__init__.py
  • src/axolotl/core/trainers/ebft/__init__.py
  • src/axolotl/core/trainers/ebft/args.py
  • src/axolotl/core/trainers/ebft/kernels.py
  • src/axolotl/core/trainers/ebft/rewards.py
  • src/axolotl/core/trainers/ebft/strided.py
  • src/axolotl/core/trainers/ebft/trainer.py
  • src/axolotl/core/trainers/grpo/async_trainer.py
  • src/axolotl/integrations/diffusion/callbacks.py
  • src/axolotl/monkeypatch/trainer/trl_vllm.py
  • src/axolotl/prompt_strategies/ebft/__init__.py
  • src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py
  • src/axolotl/prompt_strategies/ebft/ebft_opencode.py
  • src/axolotl/prompt_strategies/ebft/ebft_reasoning.py
  • src/axolotl/prompt_strategies/ebft/ebft_strided_chat.py
  • src/axolotl/prompt_strategies/ebft/ebft_strided_structured.py
  • src/axolotl/scripts/vllm_serve_lora.py
  • src/axolotl/scripts/vllm_worker_ext.py
  • src/axolotl/train.py
  • src/axolotl/utils/callbacks/__init__.py
  • src/axolotl/utils/callbacks/generation.py
  • src/axolotl/utils/data/rl.py
  • src/axolotl/utils/schemas/config.py
  • src/axolotl/utils/schemas/enums.py
  • src/axolotl/utils/schemas/trl.py
  • src/axolotl/utils/schemas/validation.py
  • src/axolotl/utils/schemas/vllm.py
✅ Files skipped from review due to trivial changes (21)
  • docker/Dockerfile-cloud-uv
  • src/axolotl/integrations/diffusion/callbacks.py
  • src/axolotl/utils/schemas/enums.py
  • src/axolotl/prompt_strategies/ebft/init.py
  • examples/ebft/ebft_pretrain.py
  • examples/ebft/ebft_opencode.py
  • src/axolotl/utils/callbacks/init.py
  • examples/ebft/qwen35-4b-ebft-structured.yaml
  • examples/ebft/qwen35-9b-ebft-structured.yaml
  • examples/ebft/llama-1b-ebft-opencode.yaml
  • src/axolotl/prompt_strategies/ebft/ebft_strided_structured.py
  • examples/ebft/llama-1b-ebft-strided.yaml
  • examples/ebft/README.md
  • src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py
  • examples/ebft/llama-1b-ebft-strided-structured.yaml
  • src/axolotl/cli/vllm_serve.py
  • examples/ebft/qwen35-4b-ebft-structured-async.yaml
  • examples/ebft/llama-3b-ebft-strided-fft.yaml
  • examples/ebft/llama-8b-ebft-strided-fft.yaml
  • examples/ebft/llama-1b-ebft-opencode-novllm.yaml
  • src/axolotl/core/trainers/ebft/kernels.py
🚧 Files skipped from review as they are similar to previous changes (15)
  • src/axolotl/common/datasets.py
  • src/axolotl/utils/schemas/vllm.py
  • src/axolotl/core/trainers/init.py
  • src/axolotl/utils/data/rl.py
  • src/axolotl/utils/schemas/trl.py
  • src/axolotl/core/builders/rl.py
  • examples/ebft/ebft_strided_structured.py
  • src/axolotl/scripts/vllm_serve_lora.py
  • src/axolotl/prompt_strategies/ebft/ebft_opencode.py
  • src/axolotl/prompt_strategies/ebft/ebft_strided_chat.py
  • src/axolotl/train.py
  • src/axolotl/scripts/vllm_worker_ext.py
  • src/axolotl/utils/schemas/validation.py
  • src/axolotl/core/trainers/ebft/init.py
  • src/axolotl/core/trainers/ebft/args.py

Comment on lines +92 to +120
if method == "last_token":
if attention_mask is not None:
# Find last non-padding position per sample
last_idx = attention_mask.sum(dim=1).long() - 1 # (B,)
return hidden_states[torch.arange(hidden_states.shape[0]), last_idx]
return hidden_states[:, -1, :]

if method == "mean_pooling":
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).float() # (B, S, 1)
return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
return hidden_states.mean(dim=1)

if method == "completion_mean":
# Mean pool over completion tokens only (exclude prompt)
if prompt_lengths is None:
raise ValueError("completion_mean requires prompt_lengths")
B, S, _ = hidden_states.shape
positions = torch.arange(S, device=hidden_states.device).unsqueeze(0) # (1, S)
comp_mask = positions >= prompt_lengths.unsqueeze(1) # (B, S)
if attention_mask is not None:
comp_mask = comp_mask & attention_mask.bool()
mask = comp_mask.unsqueeze(-1).float() # (B, S, 1)
return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)

if method == "concat":
seq_len = hidden_states.shape[1]
positions = [seq_len // 4, seq_len // 2, 3 * seq_len // 4]
return torch.cat([hidden_states[:, p, :] for p in positions], dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Make pooling respect padded sequence layouts.

last_token and completion_mean currently assume tokens start at index 0, and concat samples fixed quartiles from the padded width. Since _feature_matching_reward() tokenizes with padding=True, left-padded batches will pool from padding/prompt positions, and shorter rows can feed pad tokens into concat. Derive positions from attention_mask per sample instead of raw tensor indices.

🧰 Tools
🪛 Ruff (0.15.6)

[warning] 109-109: Unpacked variable B is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/core/trainers/ebft/rewards.py` around lines 92 - 120, Pooling
assumes left-aligned tokens; change pooling to use attention_mask-derived token
positions per sample: in last_token use attention_mask to compute per-sample
last valid index (torch.where(attention_mask.bool()) grouped by batch or
attention_mask.sum(dim=1)-1) and index hidden_states accordingly instead of raw
indices; in completion_mean build comp_mask by computing valid token positions
per sample from prompt_lengths and attention_mask (i.e., find positions >=
prompt_lengths AND attention_mask==1) before mean-pooling; in concat, for each
sample compute the list of valid token indices from attention_mask (or
prompt-aware valid positions), pick quartile positions relative to that
per-sample valid-length (e.g., floor((valid_len-1)*[0.25,0.5,0.75])) and gather
hidden_states at those indices before concatenation so padding/left-padding is
never sampled. Ensure all indexing handles batched gather safely and uses
hidden_states device/dtypes.

Comment on lines +395 to +478
def _sequential_rollout(
self,
prompts: list,
first_completions: list,
remaining_turns: list,
num_gens: int,
) -> list:
"""
Extend single-turn completions into multi-turn conversations.

For each prompt group, takes the first generated assistant turn and
sequentially generates subsequent assistant turns by calling vLLM,
building up a full multi-turn conversation.

Args:
prompts: List of prompt message lists (repeated num_gens times)
first_completions: List of generated first-turn completions
remaining_turns: List of remaining turn pairs after first assistant turn.
Each element is a list of dicts: [{"role": "user", "content": "..."},
{"role": "assistant", "content": "...GT..."}]
num_gens: Number of generations per prompt

Returns:
Extended completions incorporating all generated turns
"""
vllm_client = self.vllm_generation.client
max_tokens = getattr(self.args, "max_completion_length", 256)
temperature = getattr(self.args, "temperature", 0.7)
gen_kwargs = getattr(self.args, "generation_kwargs", None) or {}

extended_completions = []

for idx in range(len(prompts)):
prompt_msgs = prompts[idx] if isinstance(prompts[idx], list) else []
first_comp = first_completions[idx]

# Extract first completion text
if isinstance(first_comp, list):
first_text = first_comp[0].get("content", "") if first_comp else ""
else:
first_text = first_comp

# Get remaining turns for this prompt (same for all num_gens copies)
prompt_idx = idx // num_gens
turns = (
remaining_turns[prompt_idx] if prompt_idx < len(remaining_turns) else []
)

if not turns:
# No remaining turns — just use the first completion
extended_completions.append(first_comp)
continue

# Build conversation with generated first turn
conv = list(prompt_msgs) + [{"role": "assistant", "content": first_text}]
full_gen_text = first_text

# Generate subsequent turns
for turn in turns:
if turn["role"] == "user":
conv.append(turn)
elif turn["role"] == "assistant":
# Generate this assistant turn via vLLM
try:
result = vllm_client.chat(
messages=[conv],
n=1,
max_tokens=max_tokens,
temperature=temperature,
generation_kwargs=gen_kwargs,
)
# Extract generated text
gen_ids = result.get("completion_ids", [[]])[0]
gen_text = self.processing_class.decode(
gen_ids, skip_special_tokens=True
)
except Exception as e:
LOG.warning(f"Multi-turn rollout generation failed: {e}")
gen_text = ""

conv.append({"role": "assistant", "content": gen_text})
full_gen_text += gen_text

extended_completions.append(full_gen_text)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, let's examine the file structure and understand the broader context
cd src/axolotl/core/trainers/ebft && wc -l trainer.py

Repository: axolotl-ai-cloud/axolotl

Length of output: 82


🏁 Script executed:

# Check the imports and class definition to understand vllm_client usage patterns
head -50 src/axolotl/core/trainers/ebft/trainer.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 1922


🏁 Script executed:

# Search for all vllm_client usages in the file
rg "vllm_client" src/axolotl/core/trainers/ebft/trainer.py -B 2 -A 2

Repository: axolotl-ai-cloud/axolotl

Length of output: 544


🏁 Script executed:

# Find where _sequential_rollout is called and how its return value is used
rg "_sequential_rollout" src/axolotl/core/trainers/ebft/trainer.py -B 3 -A 5

Repository: axolotl-ai-cloud/axolotl

Length of output: 1164


🏁 Script executed:

# Look for reward model input construction - search for patterns around completion/prompt concatenation
rg "prompt_text|comp_text|reward" src/axolotl/core/trainers/ebft/trainer.py -B 2 -A 2

Repository: axolotl-ai-cloud/axolotl

Length of output: 5202


🏁 Script executed:

# Look at the full _feature_matching_reward function to see how returned completions are used
sed -n '196,320p' src/axolotl/core/trainers/ebft/trainer.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 5731


🏁 Script executed:

# Check if there are other uses of vllm_client in the generate_and_score_completions or calling code
rg "self.vllm_generation" src/axolotl/core/trainers/ebft/trainer.py -B 2 -A 2

Repository: axolotl-ai-cloud/axolotl

Length of output: 1056


🏁 Script executed:

# Search for how the completions variable flows after _sequential_rollout returns
sed -n '285,330p' src/axolotl/core/trainers/ebft/trainer.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 1937


🏁 Script executed:

# Look more carefully at the reward computation to see if multi-turn context should be included
sed -n '305,365p' src/axolotl/core/trainers/ebft/trainer.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 2824


🏁 Script executed:

# Verify if there's a different vllm_client pattern elsewhere in the GRPO trainer or this PR context
rg "vllm_client\s*=" src/axolotl/core/trainers/ -B 3 -A 3 --max-count=10

Repository: axolotl-ai-cloud/axolotl

Length of output: 2694


🏁 Script executed:

# Check if _sequential_rollout should return full message list vs string by looking at ground_truth structure
rg "ground_truth" src/axolotl/core/trainers/ebft/trainer.py -B 2 -A 2 | head -50

Repository: axolotl-ai-cloud/axolotl

Length of output: 1743


Return only the full conversation with user turns and use vllm_client.vllm_client instead of vllm_client.client.

_sequential_rollout() has two issues:

  1. Uses self.vllm_generation.client, but the async trainer and rest of the codebase use self.vllm_generation.vllm_client.

  2. Returns only concatenated assistant text (full_gen_text), dropping user turns from the conversation. Although conv is built with interleaved user/assistant messages, user turns are never accumulated into the return value. This means the reward model sees prompt + [asst₁ + asst₂ + ...] instead of the actual generated conversation structure with interspersed user turns. Return the full message list (or fully rendered conversation including user turns) so reward computation aligns with the actual multi-turn generation.

🧰 Tools
🪛 Ruff (0.15.6)

[warning] 449-449: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)


[warning] 471-471: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/core/trainers/ebft/trainer.py` around lines 395 - 478,
_sequential_rollout currently uses self.vllm_generation.client and returns only
concatenated assistant text; change it to use self.vllm_generation.vllm_client
(replace vllm_client = self.vllm_generation.client with vllm_client =
self.vllm_generation.vllm_client) and instead of appending full_gen_text to
extended_completions append the full conversation representation (the conv list
or a fully rendered conversation that includes both user and assistant turns) so
downstream reward code sees prompt + interleaved user/assistant messages; ensure
the no-remaining-turns branch returns a full-message list consistent with the
new format (e.g., original prompt_msgs + assistant first turn) and keep
vllm_client.chat call and decoding logic (result, gen_ids,
self.processing_class.decode) unchanged.

Comment on lines 897 to +902
vllm_name = fix_name(vllm_name, extra_prefixes=["modules_to_save.default."])

# Only sync weights that have LoRA adapters
mod_path = vllm_name[: -len(".weight")]
if mod_path not in lora_info:
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

modules_to_save weights are still dropped from no-merge sync.

The new filter only syncs weights whose mod_path exists in lora_info. That excludes trainable modules_to_save.default.* parameters such as lm_head or embed_tokens, so server-side generation can drift from training whenever PEFT is configured with modules_to_save.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/core/trainers/grpo/async_trainer.py` around lines 897 - 902, The
code only syncs weights when the computed mod_path exists in lora_info, which
drops trainable parameters stored under modules_to_save.default.* (e.g.,
lm_head, embed_tokens). Update the conditional around vllm_name/mod_path so it
also accepts entries that were prefixed by "modules_to_save.default.": after
computing mod_path from vllm_name (and after calling fix_name with
extra_prefixes), check both mod_path and "modules_to_save.default."+mod_path (or
the original un-fixed mod_path) against lora_info, and only continue if neither
is present; this ensures modules_to_save.default.* parameters are included in
the sync.

Comment on lines +58 to +63
url = f"{self.base_url}/batch_update_named_params/"
response = self.session.post(url, json={"params": param_metadata})
if response.status_code != 200:
raise Exception(
f"Request failed: {response.status_code}, {response.text}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add an explicit timeout to both weight-sync POSTs.

These requests currently inherit Requests' infinite timeout. If the vLLM server wedges during metadata sync or /http_update_weights/, training can block forever on the request thread.

Suggested fix
-            response = self.session.post(url, json={"params": param_metadata})
+            response = self.session.post(
+                url,
+                json={"params": param_metadata},
+                timeout=getattr(self, "connection_timeout", None),
+            )
...
-            response = self.session.post(url, json={"params": payload})
+            response = self.session.post(
+                url,
+                json={"params": payload},
+                timeout=getattr(self, "connection_timeout", None),
+            )

Also applies to: 96-101

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/monkeypatch/trainer/trl_vllm.py` around lines 58 - 63, The POST
calls that sync weights (the session.post to
f"{self.base_url}/batch_update_named_params/" and the other session.post to
f"{self.base_url}/http_update_weights/") currently have no timeout and can block
forever; update both calls to pass an explicit timeout (e.g., timeout=30) to
session.post, and keep the existing status_code check/Exception behavior; also
import and optionally catch requests.exceptions.Timeout around the calls in the
same function (or let it propagate) so timeouts surface deterministically.

Comment on lines +78 to +97
MAX_PARAMS_PER_REQUEST = 32 # avoid huge HTTP payloads
for i in range(0, len(params), MAX_PARAMS_PER_REQUEST):
chunk = params[i : i + MAX_PARAMS_PER_REQUEST]
payload = []
for name, weights in chunk:
w_cpu = weights.contiguous().cpu()
# NumPy doesn't support bfloat16; cast to float32 for serialization
if w_cpu.dtype == torch.bfloat16:
w_cpu = w_cpu.float()
raw = w_cpu.numpy().tobytes()
payload.append(
{
"name": name,
"dtype": str(w_cpu.dtype),
"shape": list(weights.shape),
"data": base64.b64encode(raw).decode("ascii"),
}
)
url = f"{self.base_url}/http_update_weights/"
response = self.session.post(url, json={"params": payload})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Chunk the HTTP fallback by bytes, not just parameter count.

MAX_PARAMS_PER_REQUEST = 32 does not actually cap request size. A single embedding or projection matrix can still turn into a hundreds-of-MB base64 JSON body, which is likely to trip body limits or spike memory on both client and server. Bound chunks by serialized byte size, and split oversized tensors when needed.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/monkeypatch/trainer/trl_vllm.py` around lines 78 - 97, The
current fallback loop uses MAX_PARAMS_PER_REQUEST to split by parameter count
which can still produce huge base64 JSON bodies; change the batching in the
function that builds payloads (where MAX_PARAMS_PER_REQUEST, chunk, payload, and
the self.session.post call are used) to instead accumulate parameters until a
MAX_BYTES_PER_REQUEST threshold (e.g., ~10MB) would be exceeded, then send that
batch; additionally, detect individual tensors whose serialized byte size
exceeds MAX_BYTES_PER_REQUEST and split them into smaller slices (preserving
name and adding slice metadata such as a shard index or byte/row range) so each
slice is serialized, base64-encoded and included as separate payload entries,
and ensure the shape/dtype fields reflect the slice so the server can
reassemble; replace the fixed-count loop with this byte-aware accumulator before
calling self.session.post.

return {"prompt": m["content"]}
return {"prompt": str(messages)}

pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't treat pad_token_id == 0 as “missing”.

tokenizer.pad_token_id or tokenizer.eos_token_id will fall back to EOS whenever the pad id is 0, which is a valid PAD token for several tokenizers. That silently changes the padded tail from PAD to EOS.

Suggested fix
-        pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
+        pad_id = (
+            tokenizer.pad_token_id
+            if tokenizer.pad_token_id is not None
+            else tokenizer.eos_token_id
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
pad_id = (
tokenizer.pad_token_id
if tokenizer.pad_token_id is not None
else tokenizer.eos_token_id
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py` at line 205, The code
treats tokenizer.pad_token_id as missing when it equals 0 by using a falsy
fallback expression; change the logic around pad_id in ebft_reasoning.py so you
only fall back to tokenizer.eos_token_id when pad_token_id is actually None (or
not set), e.g. replace the `tokenizer.pad_token_id or tokenizer.eos_token_id`
pattern with an explicit None check (use tokenizer.pad_token_id if it is not
None, otherwise tokenizer.eos_token_id) where pad_id is assigned so PAD=0
remains respected.

Comment on lines +225 to +260
# Find assistant turn boundaries using incremental tokenization
prefix_messages = []
for msg in messages:
if msg["role"] == "assistant":
prefix_text = tokenizer.apply_chat_template(
prefix_messages,
tokenize=False,
add_generation_prompt=True,
)
prefix_ids = tokenizer(
prefix_text,
truncation=True,
max_length=seq_len,
add_special_tokens=False,
return_tensors=None,
)["input_ids"]
start = len(prefix_ids)

prefix_messages.append(msg)
with_turn_text = tokenizer.apply_chat_template(
prefix_messages,
tokenize=False,
add_generation_prompt=False,
)
with_turn_ids = tokenizer(
with_turn_text,
truncation=True,
max_length=seq_len,
add_special_tokens=False,
return_tensors=None,
)["input_ids"]
end = len(with_turn_ids)

# Mark assistant tokens as trainable
for i in range(start, min(end, len(labels))):
labels[i] = input_ids[i]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Only treat the final assistant span as the structured completion.

This loop labels every assistant turn and then derives prompt_length from the first non--100 label. On multi-turn chats, earlier assistant replies become part of the “completion” span, so the strided trainer can place anchors and CE loss inside prompt history instead of only the final target answer. If mask_thinking_ce=True, the derived boundary can move even further into the answer. Track the final assistant span explicitly and compute prompt_length from that true start before optional masking.

Also applies to: 285-290

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py` around lines 225 - 260,
The loop in ebft_reasoning.py currently marks every assistant turn as trainable
(labels[]) which makes prompt_length be derived from an earlier assistant turn;
instead detect and record only the final assistant turn span: when iterating
messages use the existing tokenizer logic to compute start/end but do not set
labels for each assistant turn immediately—instead store the last assistant span
(final_start, final_end), then after the loop set labels[i]=input_ids[i] only
for i in range(final_start, min(final_end, len(labels))) and compute
prompt_length from final_start (before applying any mask_thinking_ce
modifications); apply the same change to the other occurrence noted around the
285-290 region so only the final assistant span is treated as the structured
completion.

Comment on lines +67 to +157
embed_method: str = Field(
default="last_token",
json_schema_extra={
"description": "Embedding method: 'last_token', 'mean_pooling', or 'concat'"
},
)
use_whitening: bool = Field(
default=False,
json_schema_extra={"description": "Apply SVD whitening to feature embeddings"},
)
alignment_coef: float = Field(
default=1.0,
json_schema_extra={
"description": "Coefficient for alignment reward (cosine similarity with ground truth)"
},
)
diversity_coef: float = Field(
default=1.0,
json_schema_extra={
"description": "Coefficient for diversity penalty (pairwise similarity between samples)"
},
)
ce_coef: float = Field(
default=0.0,
json_schema_extra={
"description": "Cross-entropy loss coefficient on ground-truth tokens"
},
)
adaptive_max_tokens: bool = Field(
default=True,
json_schema_extra={
"description": "Set per-batch max_tokens based on ground-truth length"
},
)
gt_length_multiplier: float = Field(
default=1.5,
json_schema_extra={
"description": "Multiplier for ground-truth token count when computing adaptive max_tokens"
},
)

# Strided mode fields (for unstructured text)
mode: str = Field(
default="structured",
json_schema_extra={
"description": "EBFT mode: 'structured' (QA with vLLM) or 'strided' (unstructured text)"
},
)
stride: int = Field(
default=8,
json_schema_extra={"description": "Stride between anchor points (tokens)"},
)
context_length: int = Field(
default=8,
json_schema_extra={"description": "Context window size per block"},
)
generate_max_len: int = Field(
default=8,
json_schema_extra={"description": "Tokens to generate per block"},
)
n_samples_per_prompt: int = Field(
default=4,
json_schema_extra={"description": "Independent rollouts per document"},
)
temperature: float = Field(
default=0.6,
json_schema_extra={
"description": "Sampling temperature for strided generation"
},
)
top_p: float = Field(
default=1.0,
json_schema_extra={"description": "Top-p nucleus sampling threshold"},
)
rl_coef: float = Field(
default=1.0,
json_schema_extra={"description": "RL policy gradient loss coefficient"},
)
advantage_estimator: str = Field(
default="rloo",
json_schema_extra={
"description": "Advantage estimator: 'rloo', 'group_norm', 'reinforce'"
},
)
min_completion_prefix: int = Field(
default=0,
json_schema_extra={
"description": "Minimum tokens into completion before placing anchors. "
"Skips anchors too close to the prompt boundary where features are dominated by prompt context."
},
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Tighten EBFT field validation to fail fast on invalid values.

Several EBFT fields (e.g., embed_method, mode, advantage_estimator, top_p, temperature, stride) currently accept invalid inputs at schema level, which can push failures to runtime. Please add enum/range constraints in this schema.

Proposed schema hardening
 class EBFTConfig(BaseModel):
@@
-    embed_method: str = Field(
+    embed_method: Literal["last_token", "mean_pooling", "concat"] = Field(
@@
-    gt_length_multiplier: float = Field(
+    gt_length_multiplier: float = Field(
         default=1.5,
+        gt=0.0,
@@
-    mode: str = Field(
+    mode: Literal["structured", "strided", "async"] = Field(
         default="structured",
@@
-    stride: int = Field(
+    stride: int = Field(
         default=8,
+        ge=1,
@@
-    context_length: int = Field(
+    context_length: int = Field(
         default=8,
+        ge=1,
@@
-    generate_max_len: int = Field(
+    generate_max_len: int = Field(
         default=8,
+        ge=1,
@@
-    n_samples_per_prompt: int = Field(
+    n_samples_per_prompt: int = Field(
         default=4,
+        ge=1,
@@
-    temperature: float = Field(
+    temperature: float = Field(
         default=0.6,
+        gt=0.0,
@@
-    top_p: float = Field(
+    top_p: float = Field(
         default=1.0,
+        ge=0.0,
+        le=1.0,
@@
-    advantage_estimator: str = Field(
+    advantage_estimator: Literal["rloo", "group_norm", "reinforce"] = Field(
         default="rloo",
@@
-    min_completion_prefix: int = Field(
+    min_completion_prefix: int = Field(
         default=0,
+        ge=0,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/schemas/config.py` around lines 67 - 157, Tighten
validation on the EBFT config fields by adding explicit enum/range constraints:
restrict embed_method to the allowed values (e.g.,
"last_token","mean_pooling","concat") and mode to ("structured","strided") and
advantage_estimator to ("rloo","group_norm","reinforce") (use Literal or an Enum
for embed_method/mode/advantage_estimator), enforce top_p between 0.0 and 1.0
(le=1.0, ge=0.0), ensure temperature is non-negative (ge=0.0), and make integer
fields like stride, context_length, generate_max_len, n_samples_per_prompt, and
min_completion_prefix positive or non-negative as appropriate (ge=0 or ge=1).
Update the Field declarations for embed_method, mode, advantage_estimator,
top_p, temperature, stride, context_length, generate_max_len,
n_samples_per_prompt, and min_completion_prefix to include these constraints so
invalid values fail fast during model validation.

Comment on lines +245 to +250
ebft: EBFTConfig | None = Field(
default=None,
json_schema_extra={
"description": "Configuration for Energy-Based Fine-Tuning (EBFT)"
},
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Require ebft config when rl is set to EBFT.

At Line 245, ebft is optional, so rl: ebft can pass parsing with no ebft block. Cross-file validators in src/axolotl/utils/schemas/validation.py use data.get("ebft", {}), which silently bypasses required EBFT checks and shifts failures downstream. Add a schema validator that enforces presence of ebft when rl is ebft.

Proposed guard validator
 class AxolotlInputConfig(
@@
     ebft: EBFTConfig | None = Field(
         default=None,
         json_schema_extra={
             "description": "Configuration for Energy-Based Fine-Tuning (EBFT)"
         },
     )
+
+    `@model_validator`(mode="before")
+    `@classmethod`
+    def check_ebft_block_present_when_rl_ebft(cls, data):
+        rl = data.get("rl")
+        if rl in ("ebft", RLType.EBFT) and not data.get("ebft"):
+            raise ValueError("`ebft` configuration is required when `rl: ebft`.")
+        return data
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/schemas/config.py` around lines 245 - 250, Add a pydantic
root validator on the same config model that defines ebft (the model containing
the field "ebft: EBFTConfig | None = Field(...)") to enforce that when rl is set
to EBFT (check data.get("rl") against the RL enum or the string "ebft"), the
incoming data contains a non-null "ebft" entry; if missing or None, raise a
ValueError with a clear message (e.g., "ebft config is required when rl is
EBFT") so parsing fails early instead of letting downstream validators in
validation.py silently bypass the requirement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

scheduled_release This PR is slated for the upcoming release

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant