Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughAdds Energy-Based Fine-Tuning (EBFT) to axolotl: new trainers (structured/async/strided), feature-matching reward logic, Triton kernels, dataset transforms/prompt strategies, config/schema validation, vLLM serving/weight-sync extensions, example configs and documentation. Changes
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip CodeRabbit can use Trivy to scan for security misconfigurations and secrets in Infrastructure as Code files.Add a .trivyignore file to your project to customize which findings Trivy reports. |
|
📖 Documentation Preview: https://69bfd6d3215bcc9737afcd61--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit 567d07e |
There was a problem hiding this comment.
Actionable comments posted: 19
🧹 Nitpick comments (15)
src/axolotl/prompt_strategies/ebft/ebft_reasoning.py (2)
242-261: Think-tag masking logic may miss multi-token tags.The masking logic assumes
<think>and</think>each tokenize to a single token. If the tokenizer splits these tags into multiple tokens (e.g.,<,think,>),think_open_idwill beunk_token_idand masking will be silently skipped.Consider adding a warning or fallback to text-based span detection when single-token IDs are unavailable.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py` around lines 242 - 261, The current mask_thinking block assumes "<think>" and "</think>" map to single token IDs (think_open_id/think_close_id) and silently skips masking when they equal tokenizer.unk_token_id; update it to detect and handle multi-token tags: if think_open_id or think_close_id == tokenizer.unk_token_id, fall back to scanning the decoded substring (e.g., tokenizer.decode(input_ids[scan_start:end]) or joining tokenizer.convert_ids_to_tokens(input_ids[scan_start:end])) to locate "<think>" and "</think>" text spans and then map those spans back to token index ranges to set labels[i] = -100, and/or emit a warning when single-token IDs are unavailable; keep the existing behavior when single-token IDs are present. Ensure you modify the mask_thinking branch and use input_ids, labels, start, end, tokenizer, think_open_id/think_close_id identifiers.
35-40: Remove unused helper function_extract_thinking.The function is defined but never called anywhere in the codebase. Since it's prefixed with an underscore (indicating private scope), there's no indication of intended external usage.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py` around lines 35 - 40, The private helper function _extract_thinking is unused; remove its definition from ebft_reasoning.py (delete the def _extract_thinking(...) block) and ensure there are no remaining references to _extract_thinking elsewhere; if the re import is now unused after removal, also remove the import to avoid linter warnings and run tests/linters to confirm no residual uses.examples/ebft/ebft_opencode.py (1)
18-20: Consider usingremove_columns: "__all__"to avoid schema drift.The explicit list is brittle if upstream dataset columns change.
♻️ Suggested simplification
- return transform_fn, {"remove_columns": ["id", "domain", "generation_algorithm", - "llm_judgement", "unit_tests", - "tests_execution_status", "average_test_score"]} + return transform_fn, {"remove_columns": "__all__"}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/ebft/ebft_opencode.py` around lines 18 - 20, The return value currently hardcodes a list of columns to drop which is brittle; update the second element of the returned tuple (the transformer config returned alongside transform_fn in ebft_opencode.py) to use "remove_columns": "__all__" instead of the explicit list so the transformer removes all original dataset columns and avoids schema drift while preserving transformed outputs from transform_fn.src/axolotl/utils/schemas/config.py (3)
66-70: Consider usingLiteraltype forembed_methodvalidation.The
embed_methodfield accepts any string but the description lists specific valid values. Using aLiteraltype would provide compile-time validation and better IDE support.♻️ Suggested improvement
+from typing import Literal + +EmbedMethod = Literal["last_token", "mean_pooling", "concat"] + - embed_method: str = Field( + embed_method: EmbedMethod = Field( default="last_token", json_schema_extra={ "description": "Embedding method: 'last_token', 'mean_pooling', or 'concat'" }, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/schemas/config.py` around lines 66 - 70, The embed_method Field currently allows any string even though the docstring lists specific valid values; change its type from str to typing.Literal with the allowed options (e.g., Literal["last_token","mean_pooling","concat"]) and update the Field declaration for embed_method in the Config schema so Pydantic/typing enforces and IDEs surface the valid choices; keep the same json_schema_extra description but ensure the attribute name embed_method in the relevant class/schema is updated to use the Literal type.
132-134: Consider usingLiteralforadvantage_estimatorvalidation.The
advantage_estimatorhas three defined valid values that could benefit from type-level validation.♻️ Suggested improvement
- advantage_estimator: str = Field( + advantage_estimator: Literal["rloo", "group_norm", "reinforce"] = Field( default="rloo",🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/schemas/config.py` around lines 132 - 134, Change the advantage_estimator field to use typing.Literal for strict validation: update the type annotation of advantage_estimator to Literal["rloo", "group_norm", "reinforce"] (keep the default "rloo"), add the Literal import, and leave the existing Field(json_schema_extra=...) intact so pydantic/JSON schema will enforce and document the allowed values; locate and modify the advantage_estimator declaration and add the import near other typing imports.
98-102: Consider usingLiteralformodevalidation.Similar to
embed_method, themodefield has defined valid values that could be enforced with aLiteraltype.♻️ Suggested improvement
- mode: str = Field( + mode: Literal["structured", "strided"] = Field( default="structured",🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/schemas/config.py` around lines 98 - 102, The mode field currently declared as mode: str = Field(...) should enforce allowed values using typing.Literal like embed_method does; update the type annotation for mode to use Literal["structured", "strided"] (and import Literal) and remove or keep the Field default/json_schema_extra as needed so Pydantic validates values at type level; target the mode declaration in src/axolotl/utils/schemas/config.py (near the existing embed_method pattern) to make this change.examples/ebft/llama-1b-ebft-opencode.yaml (1)
64-64: Non-zerolora_dropoutmay disable LoRA kernel optimizations.Setting
lora_dropout: 0.05typically disables auto-enabled LoRA kernel optimizations (lora_mlp_kernel,lora_qkv_kernel,lora_o_kernel). If you want the kernel speedups, consider settinglora_dropout: 0.0or explicitly enabling kernels.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/ebft/llama-1b-ebft-opencode.yaml` at line 64, The lora_dropout: 0.05 setting will likely disable auto-enabled LoRA kernel optimizations; change the lora_dropout value to 0.0 in the configuration or explicitly enable the kernels (lora_mlp_kernel, lora_qkv_kernel, lora_o_kernel) so the LoRA kernel speedups remain active—update the lora_dropout entry or add boolean flags for the three kernel options to ensure optimizations are not disabled.src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py (1)
42-43: Use defensive message-field access to avoid hard failures on bad rows.These transforms directly index message keys. A single malformed record can crash
datasets.mapwithKeyError. Preferdict.get(...)+ skip/fallback handling.Example hardening pattern
- if msg["role"] == "assistant" and not found_first: - first_gt = msg["content"] + role = msg.get("role") if isinstance(msg, dict) else None + content = msg.get("content", "") if isinstance(msg, dict) else "" + if role == "assistant" and not found_first: + first_gt = content found_first = True elif found_first: remaining.append(msg) else: prompt_msgs.append(msg)Also applies to: 56-57, 83-85, 116-118
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py` around lines 42 - 43, The loop in ebft_chat_multiturn.py is indexing message dicts directly (e.g., msg["role"], msg["content"]) which can raise KeyError on malformed rows; update the code in the functions/blocks that set first_gt/last_gt and append messages (references: variables msg, first_gt, last_gt, found_first, the message-processing loop) to use dict.get("role") and dict.get("content") and skip or continue when required keys are missing or not strings (e.g., if role is None or content is None: continue), ensuring all places noted in the review (around the current first_gt/last_gt assignments and the other mentioned blocks) perform defensive checks before accessing message fields.src/axolotl/utils/schemas/validation.py (1)
1572-1575: Guard EBFT strided length math with explicit parameter validation.Line 1574 can throw a raw
ZeroDivisionError(or produce invalid block math) for bad configs. Prefer an explicitValueErrorwith clear guidance.Proposed hardening
stride = ebft.get("stride", 8) ctx_len = ebft.get("context_length", 8) - max_blocks = (seq_len - gen_len - ctx_len) // stride + 1 + if stride <= 0: + raise ValueError("ebft.stride must be > 0 in strided mode") + if seq_len <= gen_len + ctx_len: + raise ValueError( + "sequence_len must be greater than ebft.generate_max_len + ebft.context_length" + ) + max_blocks = (seq_len - gen_len - ctx_len) // stride + 1🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/schemas/validation.py` around lines 1572 - 1575, The strided-length calculation using ebft.get("stride") and ebft.get("context_length") can raise ZeroDivisionError or produce invalid math; validate parameters first: ensure stride is an int > 0 and context_length (ctx_len) is an int >= 0 and less than (seq_len - gen_len) so the divisor (seq_len - gen_len - ctx_len) is non-negative; if any check fails raise a ValueError with a clear message referencing the offending keys ("stride" / "context_length") and the values (seq_len, gen_len) so callers know how to correct the EBFT config before computing max_blocks and full_seq.src/axolotl/core/trainers/ebft/__init__.py (1)
60-166: Consider reducing repetition in kwargs mapping.The
is not Nonechecks are repetitive. A helper function or loop over a mapping could reduce boilerplate, though this is a stylistic preference.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/core/trainers/ebft/__init__.py` around lines 60 - 166, The mapping in set_training_args_kwargs repeats many "if X is not None: kwargs[...]=X" blocks for ebft and trl; refactor by introducing a small helper (e.g., a local function map_if_present or apply_mapping) that accepts a source object and an iterable of (attr_name, kwarg_key) pairs and sets kwargs[kwarg_key] = getattr(source, attr_name) when the value is not None (or truthy where appropriate), then replace the repeated blocks for ebft and trl with calls to this helper and explicit handling only for special cases like vllm colocate logic, async_prefetch, vllm_server_host/port and vllm_enable_sleep_mode inside set_training_args_kwargs.src/axolotl/core/trainers/ebft/strided.py (1)
314-314: Mutable class attribute detected by Ruff (RUF012).Class-level mutable defaults can cause unexpected sharing. While
_tag_namesis unlikely to be mutated, using a tuple is safer.♻️ Use tuple instead of list
- _tag_names = ["ebft", "strided", "axolotl"] + _tag_names = ("ebft", "strided", "axolotl")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/core/trainers/ebft/strided.py` at line 314, Replace the mutable class attribute _tag_names (currently a list) with an immutable tuple to avoid accidental shared-state; locate the _tag_names declaration in the Strided trainer class in strided.py and change ["ebft", "strided", "axolotl"] to ("ebft", "strided", "axolotl") so the class-level constant is immutable.examples/ebft/qwen35-4b-ebft-structured-async.yaml (1)
70-72: Complex regex for LoRA targeting — verify it matches intended layers.The regex pattern targets full-attention layers (3,7,11,15,19,23,27,31) and MLP on all layers. This is a careful design for hybrid attention models, but the pattern complexity makes it easy to miss layers.
Consider adding a verification script or test to confirm the pattern matches exactly the intended modules when applied to the model.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/ebft/qwen35-4b-ebft-structured-async.yaml` around lines 70 - 72, The lora_target_modules regex is complex and may miss or overmatch intended module names; add a small verification helper that loads the model's state_dict or module names and tests each key against lora_target_modules to assert that exactly the intended layer indices (3,7,11,15,19,23,27,31) for self_attn.(q|k|v|o)_proj and all mlp.(gate|up|down)_proj are matched; implement this check as a unit test or a CLI validation (e.g., verify_lora_targets or validate_lora_regex) that prints unmatched expected modules and any unexpected matches so you can adjust the regex if it misfires.src/axolotl/core/trainers/ebft/trainer.py (3)
44-44: Mutable class attribute (RUF012).Same as strided.py — use tuple for immutable tag list.
♻️ Use tuple instead of list
- _tag_names = ["trl", "ebft", "axolotl"] + _tag_names = ("trl", "ebft", "axolotl")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/core/trainers/ebft/trainer.py` at line 44, The class-level _tag_names is defined as a mutable list; change it to an immutable tuple to avoid RUF012. Replace _tag_names = ["trl", "ebft", "axolotl"] with _tag_names = ("trl", "ebft", "axolotl") in trainer.py (same pattern as in strided.py) so the attribute is immutable at class scope.
364-366: BareExceptioncatch is overly broad (BLE001).Catching all exceptions can mask unexpected errors. Consider catching specific vLLM client exceptions or at minimum re-raising after logging for non-recoverable errors.
♻️ Narrow exception handling
- except Exception as e: - LOG.warning(f"Multi-turn rollout generation failed: {e}") - gen_text = "" + except (ConnectionError, TimeoutError, RuntimeError) as e: + LOG.warning(f"Multi-turn rollout generation failed: {e}") + gen_text = ""Or if the vLLM client has specific exception types, catch those instead.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/core/trainers/ebft/trainer.py` around lines 364 - 366, The current broad except Exception block around the multi-turn rollout generation masks unexpected errors; replace it by catching concrete vLLM/client errors (e.g., vllm.client.exceptions.VLLMError or the client-specific exception types) and handle them by logging and setting gen_text as before, and add a separate fallback that logs with LOG.exception(...) and re-raises for any other unexpected exceptions; import the client exception types at top and update the try/except in the multi-turn rollout generation block (the one that currently sets gen_text = "") to use specific except clauses and a final except Exception to re-raise after logging.
60-70: Mypy errors onsuper().__init__()call are expected for mixin pattern.The Mypy errors about "unexpected keyword arguments for
__init__of object" occur because Mypy doesn't see the full MRO at the mixin level. This is a known limitation with mixin patterns. Consider adding aTYPE_CHECKINGblock with protocol hints or# type: ignorecomments if the errors are noisy in CI.♻️ Silence Mypy for mixin super() call
- super().__init__( + super().__init__( # type: ignore[call-arg] model=model,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/core/trainers/ebft/trainer.py` around lines 60 - 70, Mypy complains about unexpected keyword args on the super().__init__() call because of the mixin pattern; either silence it by adding a TYPE_CHECKING block that defines a minimal Protocol for the base initializer signature (import TYPE_CHECKING from typing and declare a Protocol with __init__(..., model, reward_funcs, args, train_dataset, eval_dataset, processing_class, callbacks, optimizers, peft_config) and use it only under TYPE_CHECKING), or add a scoped type ignore on the call (e.g., append # type: ignore[arg-type] to the super().__init__(...) line) so the mixin pattern no longer fails CI; target the super().__init__ invocation in trainer.py (the call that passes model, reward_funcs=[self._feature_matching_reward], args, train_dataset, eval_dataset, processing_class, callbacks, optimizers, peft_config).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/ebft/ebft_pretrain.py`:
- Around line 27-31: The returned dict aliases labels to encoded["input_ids"],
which can create shared references when the tokenizer returns Python lists;
change the assignment so "labels" is a shallow copy of encoded["input_ids"]
(e.g., use list(...) or .copy()) instead of referencing the same object; update
the return in ebft_pretrain.py that currently returns {"input_ids":
encoded["input_ids"], "attention_mask": encoded["attention_mask"], "labels":
encoded["input_ids"]} to set "labels" to a copy of encoded["input_ids"] to avoid
accidental mutation of input_ids.
- Line 17: The variable pad_id assigned from tokenizer.pad_token_id or
tokenizer.eos_token_id is unused; remove the pad_id assignment in
examples/ebft/ebft_pretrain.py (the line creating pad_id) and rely on the
tokenizer's built-in padding (used later with padding="max_length"), ensuring no
other code references pad_id (search for pad_id to confirm) and run tests or
lint to verify no remaining usages.
In `@examples/ebft/ebft_strided_structured.py`:
- Around line 76-77: The file examples/ebft/ebft_strided_structured.py is
missing a trailing newline; open the file and add a single newline character at
the end of the file (ensure it ends with exactly one '\n') so the pre-commit
end-of-file-fixer check passes.
In `@examples/ebft/llama-3b-ebft-strided-fft.yaml`:
- Around line 53-55: The EBFT validator fails because strided EBFT with
gradient_checkpointing enabled is incompatible with torch_compile=true and
requires reentrant checkpointing; update the YAML to set torch_compile: false
(change the torch_compile key) and set
gradient_checkpointing_kwargs.use_reentrant: true (change use_reentrant value)
so the configuration meets the validator's requirements for strided EBFT and
flex-attention checkpointing.
In `@examples/ebft/qwen35-4b-ebft-structured.yaml`:
- Around line 32-33: Replace the non-routable wildcard host value used for the
TRL vllm server with a loopback address: change the vllm_server_host setting
(trl.vllm_server_host) from "0.0.0.0" to "127.0.0.1" so clients connect to a
routable local address; update any corresponding examples or documentation in
the same YAML file to use 127.0.0.1 for vllm_server_host while leaving
vllm_server_port as-is (8000).
In `@examples/ebft/README.md`:
- Around line 187-188: Update the strided-mode performance guidance to stop
recommending "torch_compile: true" because EBFT validation now warns/errors when
torch_compile is enabled with strided mode and gradient checkpointing; replace
that sentence so it either recommends leaving torch_compile disabled for strided
configurations or documents the validation restriction, and keep the explanation
of "flex_attention" behavior and fallback as-is so users know flex_attention is
used when available.
In `@src/axolotl/cli/vllm_serve.py`:
- Around line 82-86: The current precedence uses Python's truthy "or" so an
explicit CLI False (cli_args.get("enforce_eager") == False) is ignored when
cfg.vllm.enforce_eager is truthy; change the logic to prefer an explicitly
provided CLI value by checking presence/None instead of truthiness: if
cli_args.get("enforce_eager") is not None use
bool(cli_args.get("enforce_eager")) else use getattr(cfg.vllm, "enforce_eager",
False). Replace the current assignment to enforce_eager with this pattern and
apply the same change to the other occurrence referenced (the second
enforce_eager assignment at the other location).
- Around line 109-110: The current check calls getattr(cfg.trl,
"vllm_lora_sync", False) and will raise AttributeError if cfg has no trl
attribute; change the guard to safely access trl first (e.g., use getattr(cfg,
"trl", None) and then getattr(..., "vllm_lora_sync", False) or check
hasattr(cfg, "trl") before reading vllm_lora_sync) so that when trl is absent
you fall back to False and still set lora_kwargs["enable_lora"] = False; update
the expression around cfg.trl, "vllm_lora_sync", False to use a safe nested
getattr or an existence check so the code never dereferences a missing trl.
In `@src/axolotl/common/datasets.py`:
- Around line 121-124: Normalize cfg.rl to an RLType before the membership check
so string values like "grpo"/"ebft" are treated the same as RLType.GRPO/EBFT;
update the block that computes total_num_steps (referencing cfg.rl and
total_num_steps) to first map/normalize cfg.rl into an RLType (or compare
against lowercased names) and then check membership against {RLType.GRPO,
RLType.EBFT} so GRPO/EBFT are properly excluded whether cfg.rl is provided as a
string or an enum.
In `@src/axolotl/core/trainers/ebft/kernels.py`:
- Around line 256-261: The kernel divides by (N - 1) and when N == 1 this yields
a divide-by-zero; update the Python wrapper fused_diversity_penalty to guard
against N <= 1 by short-circuiting before launching the kernel: detect the input
size (N), and if N <= 1 return an appropriately shaped tensor of zeros (or the
intended neutral penalty) without calling the kernel; otherwise proceed to call
the existing kernel as before. Ensure the check references
fused_diversity_penalty and the N dimension used to compute (N - 1) so the
kernel never receives N == 1.
In `@src/axolotl/core/trainers/ebft/rewards.py`:
- Around line 205-213: The whitening currently builds W from left singular
vectors U (producing a (B,B) matrix) but must operate in feature space: use the
right singular vectors V and singular values S to build a (D,D) whitening matrix
W = V @ diag(inv_s) @ V.T (where inv_s is computed with whiten_tol as before),
then apply it to features via phi_w = phi_f @ W.T and phi_gt_w = phi_gt_f @ W.T
(ensure dtype casts remain as phi.dtype / phi_gt.dtype); replace usages of U, W
(B,B) with V and the new (D,D) W to correct the transform.
In `@src/axolotl/core/trainers/ebft/strided.py`:
- Line 768: The return uses an undefined outputs when return_outputs=True;
update the function (the block that computes loss when backbone is not None) to
always set a value for outputs (e.g., outputs = None or the actual model
outputs) before the final return, or change the final return to return (loss,
None) when outputs were not produced; ensure references to outputs,
return_outputs, and backbone in this function (the strided trainer method
containing the loss computation) are adjusted so outputs is always defined when
return_outputs is True.
In `@src/axolotl/core/trainers/ebft/trainer.py`:
- Around line 181-190: The zip(prompts, ground_truth) usage in the loop that
builds gt_texts should be made strict to avoid silent length mismatches; update
the zip call inside the loop (the one iterating with for i, (p, gt) in
enumerate(...)) to zip(prompts, ground_truth, strict=True) so mismatched lengths
raise an error, keeping the rest of the logic that uses num_gens,
processing_class.apply_chat_template, and gt_texts unchanged.
- Around line 155-166: The loop in trainer.py that iterates over prompts and
completions uses zip(prompts, completions) without the strict parameter; change
it to zip(prompts, completions, strict=True) to enforce equal lengths and
surface mismatches early, keeping the existing handling of list vs. scalar
prompt/completion values and appending combined strings to gen_texts
(references: prompts, completions, gen_texts, and
processing_class.apply_chat_template).
In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py`:
- Around line 61-64: The code uses `"prompt_msgs_snapshot" in dir()` to detect
whether prompt_msgs_snapshot was assigned; replace this unreliable check by
explicitly tracking assignment: either initialize prompt_msgs_snapshot = None at
the top of each function (transform, transform_split_thinking,
transform_answer_only) and test `if prompt_msgs_snapshot is not None` before
using it, or set a boolean flag (e.g., has_prompt_snapshot = False -> True when
you create the snapshot) and test that flag in the return. Update the three
functions (transform, transform_split_thinking, transform_answer_only) to use
the sentinel/flag instead of the dir() check so the branch is deterministic.
In `@src/axolotl/scripts/vllm_serve_lora.py`:
- Around line 505-526: The http_update_weights handler is using torch
(torch.frombuffer and getattr(torch,...)) but torch is not imported, causing a
NameError; add an import for torch (preferably at module top) so
http_update_weights can reference it. Also address the unused results variable
returned from asyncio.gather: either consume/check results for worker
responses/errors (e.g., inspect the list returned by asyncio.gather) or remove
the assignment and await the gather call solely for synchronization. Target
symbols: http_update_weights, torch.frombuffer, getattr(torch,...), connections,
asyncio.gather, and results.
In `@src/axolotl/train.py`:
- Around line 141-142: The current condition accesses cfg.trl directly which can
raise if cfg has no trl attribute; change it to safely obtain trl first (e.g.
trl = getattr(cfg, 'trl', None)) and then check if trl is not None before
reading beta, e.g. if cfg.rl in {RLType.GRPO, RLType.EBFT} and trl and
getattr(trl, 'beta', 0) == 0: reference_model = False — this ensures accessing
beta is guarded and prevents attribute errors on cfg.trl.
In `@src/axolotl/utils/data/rl.py`:
- Around line 223-229: The removal of columns logic assumes a "train" split and
will break for DatasetDicts without that key; change the DatasetDict branch in
the remove_columns resolution to pick the first available split dynamically
(e.g., use next(iter(dataset)) or list(dataset.keys())[0]) and read its
.column_names instead of dataset["train"].column_names so it works for arbitrary
split names; update the code path that computes ds_columns (the block guarded by
isinstance(dataset, DatasetDict)) to use the first split's column_names and keep
the existing Dataset and fallback behaviors unchanged.
In `@src/axolotl/utils/schemas/validation.py`:
- Line 1579: Replace the Unicode multiplication sign in the EBFT log message
with an ASCII character to satisfy Ruff RUF001: update the f-string containing
"EBFT strided: full_seq_len={full_seq} × n_samples={n_samples} = " to use "x"
(e.g. "full_seq_len={full_seq} x n_samples={n_samples}") or "*" instead; locate
the string that references variables full_seq and n_samples in the validation
code and make this simple character substitution.
---
Nitpick comments:
In `@examples/ebft/ebft_opencode.py`:
- Around line 18-20: The return value currently hardcodes a list of columns to
drop which is brittle; update the second element of the returned tuple (the
transformer config returned alongside transform_fn in ebft_opencode.py) to use
"remove_columns": "__all__" instead of the explicit list so the transformer
removes all original dataset columns and avoids schema drift while preserving
transformed outputs from transform_fn.
In `@examples/ebft/llama-1b-ebft-opencode.yaml`:
- Line 64: The lora_dropout: 0.05 setting will likely disable auto-enabled LoRA
kernel optimizations; change the lora_dropout value to 0.0 in the configuration
or explicitly enable the kernels (lora_mlp_kernel, lora_qkv_kernel,
lora_o_kernel) so the LoRA kernel speedups remain active—update the lora_dropout
entry or add boolean flags for the three kernel options to ensure optimizations
are not disabled.
In `@examples/ebft/qwen35-4b-ebft-structured-async.yaml`:
- Around line 70-72: The lora_target_modules regex is complex and may miss or
overmatch intended module names; add a small verification helper that loads the
model's state_dict or module names and tests each key against
lora_target_modules to assert that exactly the intended layer indices
(3,7,11,15,19,23,27,31) for self_attn.(q|k|v|o)_proj and all
mlp.(gate|up|down)_proj are matched; implement this check as a unit test or a
CLI validation (e.g., verify_lora_targets or validate_lora_regex) that prints
unmatched expected modules and any unexpected matches so you can adjust the
regex if it misfires.
In `@src/axolotl/core/trainers/ebft/__init__.py`:
- Around line 60-166: The mapping in set_training_args_kwargs repeats many "if X
is not None: kwargs[...]=X" blocks for ebft and trl; refactor by introducing a
small helper (e.g., a local function map_if_present or apply_mapping) that
accepts a source object and an iterable of (attr_name, kwarg_key) pairs and sets
kwargs[kwarg_key] = getattr(source, attr_name) when the value is not None (or
truthy where appropriate), then replace the repeated blocks for ebft and trl
with calls to this helper and explicit handling only for special cases like vllm
colocate logic, async_prefetch, vllm_server_host/port and vllm_enable_sleep_mode
inside set_training_args_kwargs.
In `@src/axolotl/core/trainers/ebft/strided.py`:
- Line 314: Replace the mutable class attribute _tag_names (currently a list)
with an immutable tuple to avoid accidental shared-state; locate the _tag_names
declaration in the Strided trainer class in strided.py and change ["ebft",
"strided", "axolotl"] to ("ebft", "strided", "axolotl") so the class-level
constant is immutable.
In `@src/axolotl/core/trainers/ebft/trainer.py`:
- Line 44: The class-level _tag_names is defined as a mutable list; change it to
an immutable tuple to avoid RUF012. Replace _tag_names = ["trl", "ebft",
"axolotl"] with _tag_names = ("trl", "ebft", "axolotl") in trainer.py (same
pattern as in strided.py) so the attribute is immutable at class scope.
- Around line 364-366: The current broad except Exception block around the
multi-turn rollout generation masks unexpected errors; replace it by catching
concrete vLLM/client errors (e.g., vllm.client.exceptions.VLLMError or the
client-specific exception types) and handle them by logging and setting gen_text
as before, and add a separate fallback that logs with LOG.exception(...) and
re-raises for any other unexpected exceptions; import the client exception types
at top and update the try/except in the multi-turn rollout generation block (the
one that currently sets gen_text = "") to use specific except clauses and a
final except Exception to re-raise after logging.
- Around line 60-70: Mypy complains about unexpected keyword args on the
super().__init__() call because of the mixin pattern; either silence it by
adding a TYPE_CHECKING block that defines a minimal Protocol for the base
initializer signature (import TYPE_CHECKING from typing and declare a Protocol
with __init__(..., model, reward_funcs, args, train_dataset, eval_dataset,
processing_class, callbacks, optimizers, peft_config) and use it only under
TYPE_CHECKING), or add a scoped type ignore on the call (e.g., append # type:
ignore[arg-type] to the super().__init__(...) line) so the mixin pattern no
longer fails CI; target the super().__init__ invocation in trainer.py (the call
that passes model, reward_funcs=[self._feature_matching_reward], args,
train_dataset, eval_dataset, processing_class, callbacks, optimizers,
peft_config).
In `@src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py`:
- Around line 42-43: The loop in ebft_chat_multiturn.py is indexing message
dicts directly (e.g., msg["role"], msg["content"]) which can raise KeyError on
malformed rows; update the code in the functions/blocks that set
first_gt/last_gt and append messages (references: variables msg, first_gt,
last_gt, found_first, the message-processing loop) to use dict.get("role") and
dict.get("content") and skip or continue when required keys are missing or not
strings (e.g., if role is None or content is None: continue), ensuring all
places noted in the review (around the current first_gt/last_gt assignments and
the other mentioned blocks) perform defensive checks before accessing message
fields.
In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py`:
- Around line 242-261: The current mask_thinking block assumes "<think>" and
"</think>" map to single token IDs (think_open_id/think_close_id) and silently
skips masking when they equal tokenizer.unk_token_id; update it to detect and
handle multi-token tags: if think_open_id or think_close_id ==
tokenizer.unk_token_id, fall back to scanning the decoded substring (e.g.,
tokenizer.decode(input_ids[scan_start:end]) or joining
tokenizer.convert_ids_to_tokens(input_ids[scan_start:end])) to locate "<think>"
and "</think>" text spans and then map those spans back to token index ranges to
set labels[i] = -100, and/or emit a warning when single-token IDs are
unavailable; keep the existing behavior when single-token IDs are present.
Ensure you modify the mask_thinking branch and use input_ids, labels, start,
end, tokenizer, think_open_id/think_close_id identifiers.
- Around line 35-40: The private helper function _extract_thinking is unused;
remove its definition from ebft_reasoning.py (delete the def
_extract_thinking(...) block) and ensure there are no remaining references to
_extract_thinking elsewhere; if the re import is now unused after removal, also
remove the import to avoid linter warnings and run tests/linters to confirm no
residual uses.
In `@src/axolotl/utils/schemas/config.py`:
- Around line 66-70: The embed_method Field currently allows any string even
though the docstring lists specific valid values; change its type from str to
typing.Literal with the allowed options (e.g.,
Literal["last_token","mean_pooling","concat"]) and update the Field declaration
for embed_method in the Config schema so Pydantic/typing enforces and IDEs
surface the valid choices; keep the same json_schema_extra description but
ensure the attribute name embed_method in the relevant class/schema is updated
to use the Literal type.
- Around line 132-134: Change the advantage_estimator field to use
typing.Literal for strict validation: update the type annotation of
advantage_estimator to Literal["rloo", "group_norm", "reinforce"] (keep the
default "rloo"), add the Literal import, and leave the existing
Field(json_schema_extra=...) intact so pydantic/JSON schema will enforce and
document the allowed values; locate and modify the advantage_estimator
declaration and add the import near other typing imports.
- Around line 98-102: The mode field currently declared as mode: str =
Field(...) should enforce allowed values using typing.Literal like embed_method
does; update the type annotation for mode to use Literal["structured",
"strided"] (and import Literal) and remove or keep the Field
default/json_schema_extra as needed so Pydantic validates values at type level;
target the mode declaration in src/axolotl/utils/schemas/config.py (near the
existing embed_method pattern) to make this change.
In `@src/axolotl/utils/schemas/validation.py`:
- Around line 1572-1575: The strided-length calculation using ebft.get("stride")
and ebft.get("context_length") can raise ZeroDivisionError or produce invalid
math; validate parameters first: ensure stride is an int > 0 and context_length
(ctx_len) is an int >= 0 and less than (seq_len - gen_len) so the divisor
(seq_len - gen_len - ctx_len) is non-negative; if any check fails raise a
ValueError with a clear message referencing the offending keys ("stride" /
"context_length") and the values (seq_len, gen_len) so callers know how to
correct the EBFT config before computing max_blocks and full_seq.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: d773aa21-0e02-40e5-853d-1c6fb87e4935
📒 Files selected for processing (40)
examples/ebft/README.mdexamples/ebft/ebft_opencode.pyexamples/ebft/ebft_pretrain.pyexamples/ebft/ebft_strided_structured.pyexamples/ebft/llama-1b-ebft-opencode-novllm.yamlexamples/ebft/llama-1b-ebft-opencode.yamlexamples/ebft/llama-1b-ebft-strided-structured.yamlexamples/ebft/llama-1b-ebft-strided.yamlexamples/ebft/llama-3b-ebft-strided-fft.yamlexamples/ebft/llama-8b-ebft-strided-fft.yamlexamples/ebft/qwen35-4b-ebft-structured-async.yamlexamples/ebft/qwen35-4b-ebft-structured.yamlexamples/ebft/qwen35-9b-ebft-structured.yamlsrc/axolotl/cli/vllm_serve.pysrc/axolotl/common/datasets.pysrc/axolotl/core/builders/rl.pysrc/axolotl/core/trainers/__init__.pysrc/axolotl/core/trainers/ebft/__init__.pysrc/axolotl/core/trainers/ebft/args.pysrc/axolotl/core/trainers/ebft/kernels.pysrc/axolotl/core/trainers/ebft/rewards.pysrc/axolotl/core/trainers/ebft/strided.pysrc/axolotl/core/trainers/ebft/trainer.pysrc/axolotl/core/trainers/grpo/async_trainer.pysrc/axolotl/monkeypatch/trainer/trl_vllm.pysrc/axolotl/prompt_strategies/ebft/__init__.pysrc/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.pysrc/axolotl/prompt_strategies/ebft/ebft_opencode.pysrc/axolotl/prompt_strategies/ebft/ebft_reasoning.pysrc/axolotl/prompt_strategies/ebft/ebft_strided_chat.pysrc/axolotl/prompt_strategies/ebft/ebft_strided_structured.pysrc/axolotl/scripts/vllm_serve_lora.pysrc/axolotl/scripts/vllm_worker_ext.pysrc/axolotl/train.pysrc/axolotl/utils/data/rl.pysrc/axolotl/utils/schemas/config.pysrc/axolotl/utils/schemas/enums.pysrc/axolotl/utils/schemas/trl.pysrc/axolotl/utils/schemas/validation.pysrc/axolotl/utils/schemas/vllm.py
| torch_compile: true | ||
| gradient_checkpointing_kwargs: | ||
| use_reentrant: false |
There was a problem hiding this comment.
This example currently fails EBFT validation at startup.
With strided EBFT + gradient_checkpointing: true, Line 53 (torch_compile: true) is rejected by the new validator. Also, Line 55 (use_reentrant: false) conflicts with the documented flex-attention checkpointing requirement.
Proposed fix
gradient_checkpointing: true
-torch_compile: true
+# Keep torch_compile disabled in EBFT strided mode with gradient checkpointing
gradient_checkpointing_kwargs:
- use_reentrant: false
+ use_reentrant: true📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| torch_compile: true | |
| gradient_checkpointing_kwargs: | |
| use_reentrant: false | |
| # Keep torch_compile disabled in EBFT strided mode with gradient checkpointing | |
| gradient_checkpointing_kwargs: | |
| use_reentrant: true |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/ebft/llama-3b-ebft-strided-fft.yaml` around lines 53 - 55, The EBFT
validator fails because strided EBFT with gradient_checkpointing enabled is
incompatible with torch_compile=true and requires reentrant checkpointing;
update the YAML to set torch_compile: false (change the torch_compile key) and
set gradient_checkpointing_kwargs.use_reentrant: true (change use_reentrant
value) so the configuration meets the validator's requirements for strided EBFT
and flex-attention checkpointing.
| vllm_server_host: 0.0.0.0 | ||
| vllm_server_port: 8000 |
There was a problem hiding this comment.
Use a routable client host for trl.vllm_server_host in this example.
At Line 32, 0.0.0.0 can be problematic as a connect target; 127.0.0.1 is safer for local training-to-server calls.
💡 Proposed fix
- vllm_server_host: 0.0.0.0
+ vllm_server_host: 127.0.0.1📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| vllm_server_host: 0.0.0.0 | |
| vllm_server_port: 8000 | |
| vllm_server_host: 127.0.0.1 | |
| vllm_server_port: 8000 |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/ebft/qwen35-4b-ebft-structured.yaml` around lines 32 - 33, Replace
the non-routable wildcard host value used for the TRL vllm server with a
loopback address: change the vllm_server_host setting (trl.vllm_server_host)
from "0.0.0.0" to "127.0.0.1" so clients connect to a routable local address;
update any corresponding examples or documentation in the same YAML file to use
127.0.0.1 for vllm_server_host while leaving vllm_server_port as-is (8000).
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 9
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/axolotl/utils/callbacks/generation.py (1)
28-60:⚠️ Potential issue | 🔴 CriticalUnindent the sample generation block to fix unreachable code.
Lines 31–60 are incorrectly indented inside the
if not getattr(cfg, "generate_samples", False):block. This means the generation code is unreachable regardless of the config value: whengenerate_samplesis false, thereturnexecutes; when true, the entire if-block is skipped.Dedent lines 31–60 by one level (4 spaces) so they execute when
generate_samplesis true.Suggested fix
if not getattr(cfg, "generate_samples", False): return - dataloader = None - try: - if getattr(self.trainer, "eval_dataset", None) is not None: - dataloader = self.trainer.get_eval_dataloader() - LOG.info( - f"Using eval dataloader for generation at step {state.global_step}" - ) - except Exception as e: - LOG.warning(f"Could not get eval dataloader: {e}") - dataloader = None - - if dataloader is None: - dataloader = self.trainer.get_train_dataloader() - LOG.info( - f"Using train dataloader for generation at step {state.global_step}" - ) - - samples = generate_samples( - model=self.trainer.model, - tokenizer=self.trainer.processing_class, - dataloader=dataloader, - num_generation_samples=getattr(cfg, "num_generation_samples", 3), - max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50), - temperature=getattr(cfg, "generation_temperature", 0.7), - top_p=getattr(cfg, "generation_top_p", None), - top_k=getattr(cfg, "generation_top_k", None), - do_sample=getattr(cfg, "generation_do_sample", True), - prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5), - ) - self._log_samples(samples, state.global_step) + dataloader = None + try: + if getattr(self.trainer, "eval_dataset", None) is not None: + dataloader = self.trainer.get_eval_dataloader() + LOG.info( + f"Using eval dataloader for generation at step {state.global_step}" + ) + except Exception as e: + LOG.warning(f"Could not get eval dataloader: {e}") + dataloader = None + + if dataloader is None: + dataloader = self.trainer.get_train_dataloader() + LOG.info( + f"Using train dataloader for generation at step {state.global_step}" + ) + + samples = generate_samples( + model=self.trainer.model, + tokenizer=self.trainer.processing_class, + dataloader=dataloader, + num_generation_samples=getattr(cfg, "num_generation_samples", 3), + max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50), + temperature=getattr(cfg, "generation_temperature", 0.7), + top_p=getattr(cfg, "generation_top_p", None), + top_k=getattr(cfg, "generation_top_k", None), + do_sample=getattr(cfg, "generation_do_sample", True), + prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5), + ) + self._log_samples(samples, state.global_step)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/callbacks/generation.py` around lines 28 - 60, The sample-generation block is incorrectly indented under the early-return for getattr(cfg, "generate_samples", False), making it unreachable; dedent the entire block that calls dataloader selection and generate_samples (the try/except fetching self.trainer.get_eval_dataloader(), the fallback to self.trainer.get_train_dataloader(), the call to generate_samples with model=self.trainer.model and tokenizer=self.trainer.processing_class, and the subsequent self._log_samples(samples, state.global_step)) so it runs only when generate_samples is True (i.e., move that block out of the if that contains the return).src/axolotl/core/trainers/grpo/async_trainer.py (1)
645-670:⚠️ Potential issue | 🟠 MajorRestore
_init_vllmafter this trainer finishes initialization.This mutates
VLLMGeneration._init_vllmat class scope and never puts the original method back. After one async/HTTP-only trainer is constructed, later trainers in the same process will also skip communicator init even when they need the stock behavior.Suggested fix
- if _skip_nccl: + restore_init_vllm = None + if _skip_nccl: from trl.generation.vllm_generation import VLLMGeneration _orig_init_vllm = VLLMGeneration._init_vllm + restore_init_vllm = _orig_init_vllm ... VLLMGeneration._init_vllm = _init_vllm_no_communicator - super().__init__(*args, **kwargs) + try: + super().__init__(*args, **kwargs) + finally: + if restore_init_vllm is not None: + VLLMGeneration._init_vllm = restore_init_vllm🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/core/trainers/grpo/async_trainer.py` around lines 645 - 670, The patch permanently replaces VLLMGeneration._init_vllm when _skip_nccl is true, causing later trainers to inherit the no-communicator behavior; instead, save _orig_init_vllm, assign VLLMGeneration._init_vllm = _init_vllm_no_communicator only for the duration of this trainer's initialization and restore the original in a finally/cleanup block (or use a context manager) so that the original _init_vllm is reinstated whether initialization succeeds or raises; reference VLLMGeneration._init_vllm, _orig_init_vllm, and _init_vllm_no_communicator and ensure restoration happens after the trainer finishes initialization.
♻️ Duplicate comments (2)
src/axolotl/prompt_strategies/ebft/ebft_reasoning.py (1)
148-151:⚠️ Potential issue | 🟠 MajorReplace the remaining
dir()guards with an explicit sentinel.This is the same unresolved issue from the previous round:
transform_split_thinking()andtransform_answer_only()still useif "prompt_msgs_snapshot" in dir()to decide whether a local was assigned. That keeps the branch dependent on interpreter locals instead of explicit state. Initializeprompt_msgs_snapshot = Nonebefore the loop and testis not Nonein both functions.Also applies to: 172-175
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py` around lines 148 - 151, The code uses if "prompt_msgs_snapshot" in dir() to detect whether prompt_msgs_snapshot was set; instead, initialize prompt_msgs_snapshot = None before the loop and change both guards in transform_split_thinking and transform_answer_only to explicit checks (prompt_msgs_snapshot is not None) so the branch depends on explicit state; update all occurrences (including the similar checks around lines 172-175) to use the sentinel instead of dir() and ensure the functions return prompt_msgs_snapshot when not None and fall back to split_messages[:-1] otherwise.src/axolotl/core/trainers/ebft/rewards.py (1)
208-236:⚠️ Potential issue | 🔴 CriticalFix whitening to operate in feature space before enabling it anywhere.
This still builds
WfromU, soWis(B, B)instead of(D, D). Besides the math bug from the earlier review,EBFTMixin._feature_matching_reward()calls this withphi.shape == (num_generations, D)andphi_gt.shape == (1, D), soW @ phi_gt_fwill shape-mismatch as soon as whitening is enabled with more than one generation.Suggested fix
- U, S, _ = torch.linalg.svd(phi_f.unsqueeze(0), full_matrices=False) + _, S, Vh = torch.linalg.svd(phi_f, full_matrices=False) ... - U, S = U.squeeze(0), S.squeeze(0) - # Safe inverse of singular values s_max = S.max() inv_s = torch.where(S > whiten_tol * s_max, 1.0 / (S + 1e-12), torch.zeros_like(S)) - # FIXME - # W = U @ diag(inv_S) @ U^T - W = (U * inv_s.unsqueeze(0)) @ U.T # (B, B) - phi_w = (W @ phi_f).to(phi.dtype) - phi_gt_w = (W @ phi_gt_f).to(phi_gt.dtype) + V = Vh.transpose(-2, -1) + W = (V * inv_s.unsqueeze(0)) @ Vh # (D, D) + phi_w = (phi_f @ W.T).to(phi.dtype) + phi_gt_w = (phi_gt_f @ W.T).to(phi_gt.dtype)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/core/trainers/ebft/rewards.py` around lines 208 - 236, The whitening builds W in sample space (B,B) because SVD was taken on phi_f (shape (B,D)), causing a shape mismatch when multiplying with phi_gt_f; fix by performing SVD in feature space so W is (D,D): compute SVD on phi_f.T (or equivalently compute eigendecomposition of phi_f.T @ phi_f) to produce U with shape (D,D), form inv_s from S and build W = U @ diag(inv_s) @ U.T (use whiten_tol and small eps as before), then apply W @ phi_f.T (or transpose inputs appropriately) to get phi_w and phi_gt_w in the feature dimension; update the code paths around U, S, inv_s, W, phi_f, phi_gt_f and ensure EBFTMixin._feature_matching_reward() (which calls this) receives correctly-shaped outputs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/core/trainers/ebft/rewards.py`:
- Around line 92-120: Pooling assumes left-aligned tokens; change pooling to use
attention_mask-derived token positions per sample: in last_token use
attention_mask to compute per-sample last valid index
(torch.where(attention_mask.bool()) grouped by batch or
attention_mask.sum(dim=1)-1) and index hidden_states accordingly instead of raw
indices; in completion_mean build comp_mask by computing valid token positions
per sample from prompt_lengths and attention_mask (i.e., find positions >=
prompt_lengths AND attention_mask==1) before mean-pooling; in concat, for each
sample compute the list of valid token indices from attention_mask (or
prompt-aware valid positions), pick quartile positions relative to that
per-sample valid-length (e.g., floor((valid_len-1)*[0.25,0.5,0.75])) and gather
hidden_states at those indices before concatenation so padding/left-padding is
never sampled. Ensure all indexing handles batched gather safely and uses
hidden_states device/dtypes.
In `@src/axolotl/core/trainers/ebft/trainer.py`:
- Around line 395-478: _sequential_rollout currently uses
self.vllm_generation.client and returns only concatenated assistant text; change
it to use self.vllm_generation.vllm_client (replace vllm_client =
self.vllm_generation.client with vllm_client = self.vllm_generation.vllm_client)
and instead of appending full_gen_text to extended_completions append the full
conversation representation (the conv list or a fully rendered conversation that
includes both user and assistant turns) so downstream reward code sees prompt +
interleaved user/assistant messages; ensure the no-remaining-turns branch
returns a full-message list consistent with the new format (e.g., original
prompt_msgs + assistant first turn) and keep vllm_client.chat call and decoding
logic (result, gen_ids, self.processing_class.decode) unchanged.
In `@src/axolotl/core/trainers/grpo/async_trainer.py`:
- Around line 897-902: The code only syncs weights when the computed mod_path
exists in lora_info, which drops trainable parameters stored under
modules_to_save.default.* (e.g., lm_head, embed_tokens). Update the conditional
around vllm_name/mod_path so it also accepts entries that were prefixed by
"modules_to_save.default.": after computing mod_path from vllm_name (and after
calling fix_name with extra_prefixes), check both mod_path and
"modules_to_save.default."+mod_path (or the original un-fixed mod_path) against
lora_info, and only continue if neither is present; this ensures
modules_to_save.default.* parameters are included in the sync.
In `@src/axolotl/monkeypatch/trainer/trl_vllm.py`:
- Around line 78-97: The current fallback loop uses MAX_PARAMS_PER_REQUEST to
split by parameter count which can still produce huge base64 JSON bodies; change
the batching in the function that builds payloads (where MAX_PARAMS_PER_REQUEST,
chunk, payload, and the self.session.post call are used) to instead accumulate
parameters until a MAX_BYTES_PER_REQUEST threshold (e.g., ~10MB) would be
exceeded, then send that batch; additionally, detect individual tensors whose
serialized byte size exceeds MAX_BYTES_PER_REQUEST and split them into smaller
slices (preserving name and adding slice metadata such as a shard index or
byte/row range) so each slice is serialized, base64-encoded and included as
separate payload entries, and ensure the shape/dtype fields reflect the slice so
the server can reassemble; replace the fixed-count loop with this byte-aware
accumulator before calling self.session.post.
- Around line 58-63: The POST calls that sync weights (the session.post to
f"{self.base_url}/batch_update_named_params/" and the other session.post to
f"{self.base_url}/http_update_weights/") currently have no timeout and can block
forever; update both calls to pass an explicit timeout (e.g., timeout=30) to
session.post, and keep the existing status_code check/Exception behavior; also
import and optionally catch requests.exceptions.Timeout around the calls in the
same function (or let it propagate) so timeouts surface deterministically.
In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py`:
- Around line 225-260: The loop in ebft_reasoning.py currently marks every
assistant turn as trainable (labels[]) which makes prompt_length be derived from
an earlier assistant turn; instead detect and record only the final assistant
turn span: when iterating messages use the existing tokenizer logic to compute
start/end but do not set labels for each assistant turn immediately—instead
store the last assistant span (final_start, final_end), then after the loop set
labels[i]=input_ids[i] only for i in range(final_start, min(final_end,
len(labels))) and compute prompt_length from final_start (before applying any
mask_thinking_ce modifications); apply the same change to the other occurrence
noted around the 285-290 region so only the final assistant span is treated as
the structured completion.
- Line 205: The code treats tokenizer.pad_token_id as missing when it equals 0
by using a falsy fallback expression; change the logic around pad_id in
ebft_reasoning.py so you only fall back to tokenizer.eos_token_id when
pad_token_id is actually None (or not set), e.g. replace the
`tokenizer.pad_token_id or tokenizer.eos_token_id` pattern with an explicit None
check (use tokenizer.pad_token_id if it is not None, otherwise
tokenizer.eos_token_id) where pad_id is assigned so PAD=0 remains respected.
In `@src/axolotl/utils/schemas/config.py`:
- Around line 245-250: Add a pydantic root validator on the same config model
that defines ebft (the model containing the field "ebft: EBFTConfig | None =
Field(...)") to enforce that when rl is set to EBFT (check data.get("rl")
against the RL enum or the string "ebft"), the incoming data contains a non-null
"ebft" entry; if missing or None, raise a ValueError with a clear message (e.g.,
"ebft config is required when rl is EBFT") so parsing fails early instead of
letting downstream validators in validation.py silently bypass the requirement.
- Around line 67-157: Tighten validation on the EBFT config fields by adding
explicit enum/range constraints: restrict embed_method to the allowed values
(e.g., "last_token","mean_pooling","concat") and mode to
("structured","strided") and advantage_estimator to
("rloo","group_norm","reinforce") (use Literal or an Enum for
embed_method/mode/advantage_estimator), enforce top_p between 0.0 and 1.0
(le=1.0, ge=0.0), ensure temperature is non-negative (ge=0.0), and make integer
fields like stride, context_length, generate_max_len, n_samples_per_prompt, and
min_completion_prefix positive or non-negative as appropriate (ge=0 or ge=1).
Update the Field declarations for embed_method, mode, advantage_estimator,
top_p, temperature, stride, context_length, generate_max_len,
n_samples_per_prompt, and min_completion_prefix to include these constraints so
invalid values fail fast during model validation.
---
Outside diff comments:
In `@src/axolotl/core/trainers/grpo/async_trainer.py`:
- Around line 645-670: The patch permanently replaces VLLMGeneration._init_vllm
when _skip_nccl is true, causing later trainers to inherit the no-communicator
behavior; instead, save _orig_init_vllm, assign VLLMGeneration._init_vllm =
_init_vllm_no_communicator only for the duration of this trainer's
initialization and restore the original in a finally/cleanup block (or use a
context manager) so that the original _init_vllm is reinstated whether
initialization succeeds or raises; reference VLLMGeneration._init_vllm,
_orig_init_vllm, and _init_vllm_no_communicator and ensure restoration happens
after the trainer finishes initialization.
In `@src/axolotl/utils/callbacks/generation.py`:
- Around line 28-60: The sample-generation block is incorrectly indented under
the early-return for getattr(cfg, "generate_samples", False), making it
unreachable; dedent the entire block that calls dataloader selection and
generate_samples (the try/except fetching self.trainer.get_eval_dataloader(),
the fallback to self.trainer.get_train_dataloader(), the call to
generate_samples with model=self.trainer.model and
tokenizer=self.trainer.processing_class, and the subsequent
self._log_samples(samples, state.global_step)) so it runs only when
generate_samples is True (i.e., move that block out of the if that contains the
return).
---
Duplicate comments:
In `@src/axolotl/core/trainers/ebft/rewards.py`:
- Around line 208-236: The whitening builds W in sample space (B,B) because SVD
was taken on phi_f (shape (B,D)), causing a shape mismatch when multiplying with
phi_gt_f; fix by performing SVD in feature space so W is (D,D): compute SVD on
phi_f.T (or equivalently compute eigendecomposition of phi_f.T @ phi_f) to
produce U with shape (D,D), form inv_s from S and build W = U @ diag(inv_s) @
U.T (use whiten_tol and small eps as before), then apply W @ phi_f.T (or
transpose inputs appropriately) to get phi_w and phi_gt_w in the feature
dimension; update the code paths around U, S, inv_s, W, phi_f, phi_gt_f and
ensure EBFTMixin._feature_matching_reward() (which calls this) receives
correctly-shaped outputs.
In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py`:
- Around line 148-151: The code uses if "prompt_msgs_snapshot" in dir() to
detect whether prompt_msgs_snapshot was set; instead, initialize
prompt_msgs_snapshot = None before the loop and change both guards in
transform_split_thinking and transform_answer_only to explicit checks
(prompt_msgs_snapshot is not None) so the branch depends on explicit state;
update all occurrences (including the similar checks around lines 172-175) to
use the sentinel instead of dir() and ensure the functions return
prompt_msgs_snapshot when not None and fall back to split_messages[:-1]
otherwise.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 3bc4db9a-27a7-4043-a5b9-dba87868b9a1
📒 Files selected for processing (44)
docker/Dockerfile-cloud-uvexamples/ebft/README.mdexamples/ebft/ebft_opencode.pyexamples/ebft/ebft_pretrain.pyexamples/ebft/ebft_strided_structured.pyexamples/ebft/llama-1b-ebft-opencode-novllm.yamlexamples/ebft/llama-1b-ebft-opencode.yamlexamples/ebft/llama-1b-ebft-strided-structured.yamlexamples/ebft/llama-1b-ebft-strided.yamlexamples/ebft/llama-3b-ebft-strided-fft.yamlexamples/ebft/llama-8b-ebft-strided-fft.yamlexamples/ebft/qwen35-4b-ebft-structured-async.yamlexamples/ebft/qwen35-4b-ebft-structured.yamlexamples/ebft/qwen35-9b-ebft-structured.yamlsrc/axolotl/cli/vllm_serve.pysrc/axolotl/common/datasets.pysrc/axolotl/core/builders/rl.pysrc/axolotl/core/trainers/__init__.pysrc/axolotl/core/trainers/ebft/__init__.pysrc/axolotl/core/trainers/ebft/args.pysrc/axolotl/core/trainers/ebft/kernels.pysrc/axolotl/core/trainers/ebft/rewards.pysrc/axolotl/core/trainers/ebft/strided.pysrc/axolotl/core/trainers/ebft/trainer.pysrc/axolotl/core/trainers/grpo/async_trainer.pysrc/axolotl/integrations/diffusion/callbacks.pysrc/axolotl/monkeypatch/trainer/trl_vllm.pysrc/axolotl/prompt_strategies/ebft/__init__.pysrc/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.pysrc/axolotl/prompt_strategies/ebft/ebft_opencode.pysrc/axolotl/prompt_strategies/ebft/ebft_reasoning.pysrc/axolotl/prompt_strategies/ebft/ebft_strided_chat.pysrc/axolotl/prompt_strategies/ebft/ebft_strided_structured.pysrc/axolotl/scripts/vllm_serve_lora.pysrc/axolotl/scripts/vllm_worker_ext.pysrc/axolotl/train.pysrc/axolotl/utils/callbacks/__init__.pysrc/axolotl/utils/callbacks/generation.pysrc/axolotl/utils/data/rl.pysrc/axolotl/utils/schemas/config.pysrc/axolotl/utils/schemas/enums.pysrc/axolotl/utils/schemas/trl.pysrc/axolotl/utils/schemas/validation.pysrc/axolotl/utils/schemas/vllm.py
✅ Files skipped from review due to trivial changes (21)
- docker/Dockerfile-cloud-uv
- src/axolotl/integrations/diffusion/callbacks.py
- src/axolotl/utils/schemas/enums.py
- src/axolotl/prompt_strategies/ebft/init.py
- examples/ebft/ebft_pretrain.py
- examples/ebft/ebft_opencode.py
- src/axolotl/utils/callbacks/init.py
- examples/ebft/qwen35-4b-ebft-structured.yaml
- examples/ebft/qwen35-9b-ebft-structured.yaml
- examples/ebft/llama-1b-ebft-opencode.yaml
- src/axolotl/prompt_strategies/ebft/ebft_strided_structured.py
- examples/ebft/llama-1b-ebft-strided.yaml
- examples/ebft/README.md
- src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py
- examples/ebft/llama-1b-ebft-strided-structured.yaml
- src/axolotl/cli/vllm_serve.py
- examples/ebft/qwen35-4b-ebft-structured-async.yaml
- examples/ebft/llama-3b-ebft-strided-fft.yaml
- examples/ebft/llama-8b-ebft-strided-fft.yaml
- examples/ebft/llama-1b-ebft-opencode-novllm.yaml
- src/axolotl/core/trainers/ebft/kernels.py
🚧 Files skipped from review as they are similar to previous changes (15)
- src/axolotl/common/datasets.py
- src/axolotl/utils/schemas/vllm.py
- src/axolotl/core/trainers/init.py
- src/axolotl/utils/data/rl.py
- src/axolotl/utils/schemas/trl.py
- src/axolotl/core/builders/rl.py
- examples/ebft/ebft_strided_structured.py
- src/axolotl/scripts/vllm_serve_lora.py
- src/axolotl/prompt_strategies/ebft/ebft_opencode.py
- src/axolotl/prompt_strategies/ebft/ebft_strided_chat.py
- src/axolotl/train.py
- src/axolotl/scripts/vllm_worker_ext.py
- src/axolotl/utils/schemas/validation.py
- src/axolotl/core/trainers/ebft/init.py
- src/axolotl/core/trainers/ebft/args.py
| if method == "last_token": | ||
| if attention_mask is not None: | ||
| # Find last non-padding position per sample | ||
| last_idx = attention_mask.sum(dim=1).long() - 1 # (B,) | ||
| return hidden_states[torch.arange(hidden_states.shape[0]), last_idx] | ||
| return hidden_states[:, -1, :] | ||
|
|
||
| if method == "mean_pooling": | ||
| if attention_mask is not None: | ||
| mask = attention_mask.unsqueeze(-1).float() # (B, S, 1) | ||
| return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) | ||
| return hidden_states.mean(dim=1) | ||
|
|
||
| if method == "completion_mean": | ||
| # Mean pool over completion tokens only (exclude prompt) | ||
| if prompt_lengths is None: | ||
| raise ValueError("completion_mean requires prompt_lengths") | ||
| B, S, _ = hidden_states.shape | ||
| positions = torch.arange(S, device=hidden_states.device).unsqueeze(0) # (1, S) | ||
| comp_mask = positions >= prompt_lengths.unsqueeze(1) # (B, S) | ||
| if attention_mask is not None: | ||
| comp_mask = comp_mask & attention_mask.bool() | ||
| mask = comp_mask.unsqueeze(-1).float() # (B, S, 1) | ||
| return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) | ||
|
|
||
| if method == "concat": | ||
| seq_len = hidden_states.shape[1] | ||
| positions = [seq_len // 4, seq_len // 2, 3 * seq_len // 4] | ||
| return torch.cat([hidden_states[:, p, :] for p in positions], dim=-1) |
There was a problem hiding this comment.
Make pooling respect padded sequence layouts.
last_token and completion_mean currently assume tokens start at index 0, and concat samples fixed quartiles from the padded width. Since _feature_matching_reward() tokenizes with padding=True, left-padded batches will pool from padding/prompt positions, and shorter rows can feed pad tokens into concat. Derive positions from attention_mask per sample instead of raw tensor indices.
🧰 Tools
🪛 Ruff (0.15.6)
[warning] 109-109: Unpacked variable B is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/core/trainers/ebft/rewards.py` around lines 92 - 120, Pooling
assumes left-aligned tokens; change pooling to use attention_mask-derived token
positions per sample: in last_token use attention_mask to compute per-sample
last valid index (torch.where(attention_mask.bool()) grouped by batch or
attention_mask.sum(dim=1)-1) and index hidden_states accordingly instead of raw
indices; in completion_mean build comp_mask by computing valid token positions
per sample from prompt_lengths and attention_mask (i.e., find positions >=
prompt_lengths AND attention_mask==1) before mean-pooling; in concat, for each
sample compute the list of valid token indices from attention_mask (or
prompt-aware valid positions), pick quartile positions relative to that
per-sample valid-length (e.g., floor((valid_len-1)*[0.25,0.5,0.75])) and gather
hidden_states at those indices before concatenation so padding/left-padding is
never sampled. Ensure all indexing handles batched gather safely and uses
hidden_states device/dtypes.
| def _sequential_rollout( | ||
| self, | ||
| prompts: list, | ||
| first_completions: list, | ||
| remaining_turns: list, | ||
| num_gens: int, | ||
| ) -> list: | ||
| """ | ||
| Extend single-turn completions into multi-turn conversations. | ||
|
|
||
| For each prompt group, takes the first generated assistant turn and | ||
| sequentially generates subsequent assistant turns by calling vLLM, | ||
| building up a full multi-turn conversation. | ||
|
|
||
| Args: | ||
| prompts: List of prompt message lists (repeated num_gens times) | ||
| first_completions: List of generated first-turn completions | ||
| remaining_turns: List of remaining turn pairs after first assistant turn. | ||
| Each element is a list of dicts: [{"role": "user", "content": "..."}, | ||
| {"role": "assistant", "content": "...GT..."}] | ||
| num_gens: Number of generations per prompt | ||
|
|
||
| Returns: | ||
| Extended completions incorporating all generated turns | ||
| """ | ||
| vllm_client = self.vllm_generation.client | ||
| max_tokens = getattr(self.args, "max_completion_length", 256) | ||
| temperature = getattr(self.args, "temperature", 0.7) | ||
| gen_kwargs = getattr(self.args, "generation_kwargs", None) or {} | ||
|
|
||
| extended_completions = [] | ||
|
|
||
| for idx in range(len(prompts)): | ||
| prompt_msgs = prompts[idx] if isinstance(prompts[idx], list) else [] | ||
| first_comp = first_completions[idx] | ||
|
|
||
| # Extract first completion text | ||
| if isinstance(first_comp, list): | ||
| first_text = first_comp[0].get("content", "") if first_comp else "" | ||
| else: | ||
| first_text = first_comp | ||
|
|
||
| # Get remaining turns for this prompt (same for all num_gens copies) | ||
| prompt_idx = idx // num_gens | ||
| turns = ( | ||
| remaining_turns[prompt_idx] if prompt_idx < len(remaining_turns) else [] | ||
| ) | ||
|
|
||
| if not turns: | ||
| # No remaining turns — just use the first completion | ||
| extended_completions.append(first_comp) | ||
| continue | ||
|
|
||
| # Build conversation with generated first turn | ||
| conv = list(prompt_msgs) + [{"role": "assistant", "content": first_text}] | ||
| full_gen_text = first_text | ||
|
|
||
| # Generate subsequent turns | ||
| for turn in turns: | ||
| if turn["role"] == "user": | ||
| conv.append(turn) | ||
| elif turn["role"] == "assistant": | ||
| # Generate this assistant turn via vLLM | ||
| try: | ||
| result = vllm_client.chat( | ||
| messages=[conv], | ||
| n=1, | ||
| max_tokens=max_tokens, | ||
| temperature=temperature, | ||
| generation_kwargs=gen_kwargs, | ||
| ) | ||
| # Extract generated text | ||
| gen_ids = result.get("completion_ids", [[]])[0] | ||
| gen_text = self.processing_class.decode( | ||
| gen_ids, skip_special_tokens=True | ||
| ) | ||
| except Exception as e: | ||
| LOG.warning(f"Multi-turn rollout generation failed: {e}") | ||
| gen_text = "" | ||
|
|
||
| conv.append({"role": "assistant", "content": gen_text}) | ||
| full_gen_text += gen_text | ||
|
|
||
| extended_completions.append(full_gen_text) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the file structure and understand the broader context
cd src/axolotl/core/trainers/ebft && wc -l trainer.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 82
🏁 Script executed:
# Check the imports and class definition to understand vllm_client usage patterns
head -50 src/axolotl/core/trainers/ebft/trainer.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 1922
🏁 Script executed:
# Search for all vllm_client usages in the file
rg "vllm_client" src/axolotl/core/trainers/ebft/trainer.py -B 2 -A 2Repository: axolotl-ai-cloud/axolotl
Length of output: 544
🏁 Script executed:
# Find where _sequential_rollout is called and how its return value is used
rg "_sequential_rollout" src/axolotl/core/trainers/ebft/trainer.py -B 3 -A 5Repository: axolotl-ai-cloud/axolotl
Length of output: 1164
🏁 Script executed:
# Look for reward model input construction - search for patterns around completion/prompt concatenation
rg "prompt_text|comp_text|reward" src/axolotl/core/trainers/ebft/trainer.py -B 2 -A 2Repository: axolotl-ai-cloud/axolotl
Length of output: 5202
🏁 Script executed:
# Look at the full _feature_matching_reward function to see how returned completions are used
sed -n '196,320p' src/axolotl/core/trainers/ebft/trainer.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 5731
🏁 Script executed:
# Check if there are other uses of vllm_client in the generate_and_score_completions or calling code
rg "self.vllm_generation" src/axolotl/core/trainers/ebft/trainer.py -B 2 -A 2Repository: axolotl-ai-cloud/axolotl
Length of output: 1056
🏁 Script executed:
# Search for how the completions variable flows after _sequential_rollout returns
sed -n '285,330p' src/axolotl/core/trainers/ebft/trainer.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 1937
🏁 Script executed:
# Look more carefully at the reward computation to see if multi-turn context should be included
sed -n '305,365p' src/axolotl/core/trainers/ebft/trainer.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 2824
🏁 Script executed:
# Verify if there's a different vllm_client pattern elsewhere in the GRPO trainer or this PR context
rg "vllm_client\s*=" src/axolotl/core/trainers/ -B 3 -A 3 --max-count=10Repository: axolotl-ai-cloud/axolotl
Length of output: 2694
🏁 Script executed:
# Check if _sequential_rollout should return full message list vs string by looking at ground_truth structure
rg "ground_truth" src/axolotl/core/trainers/ebft/trainer.py -B 2 -A 2 | head -50Repository: axolotl-ai-cloud/axolotl
Length of output: 1743
Return only the full conversation with user turns and use vllm_client.vllm_client instead of vllm_client.client.
_sequential_rollout() has two issues:
-
Uses
self.vllm_generation.client, but the async trainer and rest of the codebase useself.vllm_generation.vllm_client. -
Returns only concatenated assistant text (
full_gen_text), dropping user turns from the conversation. Althoughconvis built with interleaved user/assistant messages, user turns are never accumulated into the return value. This means the reward model seesprompt + [asst₁ + asst₂ + ...]instead of the actual generated conversation structure with interspersed user turns. Return the full message list (or fully rendered conversation including user turns) so reward computation aligns with the actual multi-turn generation.
🧰 Tools
🪛 Ruff (0.15.6)
[warning] 449-449: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
[warning] 471-471: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/core/trainers/ebft/trainer.py` around lines 395 - 478,
_sequential_rollout currently uses self.vllm_generation.client and returns only
concatenated assistant text; change it to use self.vllm_generation.vllm_client
(replace vllm_client = self.vllm_generation.client with vllm_client =
self.vllm_generation.vllm_client) and instead of appending full_gen_text to
extended_completions append the full conversation representation (the conv list
or a fully rendered conversation that includes both user and assistant turns) so
downstream reward code sees prompt + interleaved user/assistant messages; ensure
the no-remaining-turns branch returns a full-message list consistent with the
new format (e.g., original prompt_msgs + assistant first turn) and keep
vllm_client.chat call and decoding logic (result, gen_ids,
self.processing_class.decode) unchanged.
| vllm_name = fix_name(vllm_name, extra_prefixes=["modules_to_save.default."]) | ||
|
|
||
| # Only sync weights that have LoRA adapters | ||
| mod_path = vllm_name[: -len(".weight")] | ||
| if mod_path not in lora_info: | ||
| continue |
There was a problem hiding this comment.
modules_to_save weights are still dropped from no-merge sync.
The new filter only syncs weights whose mod_path exists in lora_info. That excludes trainable modules_to_save.default.* parameters such as lm_head or embed_tokens, so server-side generation can drift from training whenever PEFT is configured with modules_to_save.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/core/trainers/grpo/async_trainer.py` around lines 897 - 902, The
code only syncs weights when the computed mod_path exists in lora_info, which
drops trainable parameters stored under modules_to_save.default.* (e.g.,
lm_head, embed_tokens). Update the conditional around vllm_name/mod_path so it
also accepts entries that were prefixed by "modules_to_save.default.": after
computing mod_path from vllm_name (and after calling fix_name with
extra_prefixes), check both mod_path and "modules_to_save.default."+mod_path (or
the original un-fixed mod_path) against lora_info, and only continue if neither
is present; this ensures modules_to_save.default.* parameters are included in
the sync.
| url = f"{self.base_url}/batch_update_named_params/" | ||
| response = self.session.post(url, json={"params": param_metadata}) | ||
| if response.status_code != 200: | ||
| raise Exception( | ||
| f"Request failed: {response.status_code}, {response.text}" | ||
| ) |
There was a problem hiding this comment.
Add an explicit timeout to both weight-sync POSTs.
These requests currently inherit Requests' infinite timeout. If the vLLM server wedges during metadata sync or /http_update_weights/, training can block forever on the request thread.
Suggested fix
- response = self.session.post(url, json={"params": param_metadata})
+ response = self.session.post(
+ url,
+ json={"params": param_metadata},
+ timeout=getattr(self, "connection_timeout", None),
+ )
...
- response = self.session.post(url, json={"params": payload})
+ response = self.session.post(
+ url,
+ json={"params": payload},
+ timeout=getattr(self, "connection_timeout", None),
+ )Also applies to: 96-101
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/monkeypatch/trainer/trl_vllm.py` around lines 58 - 63, The POST
calls that sync weights (the session.post to
f"{self.base_url}/batch_update_named_params/" and the other session.post to
f"{self.base_url}/http_update_weights/") currently have no timeout and can block
forever; update both calls to pass an explicit timeout (e.g., timeout=30) to
session.post, and keep the existing status_code check/Exception behavior; also
import and optionally catch requests.exceptions.Timeout around the calls in the
same function (or let it propagate) so timeouts surface deterministically.
| MAX_PARAMS_PER_REQUEST = 32 # avoid huge HTTP payloads | ||
| for i in range(0, len(params), MAX_PARAMS_PER_REQUEST): | ||
| chunk = params[i : i + MAX_PARAMS_PER_REQUEST] | ||
| payload = [] | ||
| for name, weights in chunk: | ||
| w_cpu = weights.contiguous().cpu() | ||
| # NumPy doesn't support bfloat16; cast to float32 for serialization | ||
| if w_cpu.dtype == torch.bfloat16: | ||
| w_cpu = w_cpu.float() | ||
| raw = w_cpu.numpy().tobytes() | ||
| payload.append( | ||
| { | ||
| "name": name, | ||
| "dtype": str(w_cpu.dtype), | ||
| "shape": list(weights.shape), | ||
| "data": base64.b64encode(raw).decode("ascii"), | ||
| } | ||
| ) | ||
| url = f"{self.base_url}/http_update_weights/" | ||
| response = self.session.post(url, json={"params": payload}) |
There was a problem hiding this comment.
Chunk the HTTP fallback by bytes, not just parameter count.
MAX_PARAMS_PER_REQUEST = 32 does not actually cap request size. A single embedding or projection matrix can still turn into a hundreds-of-MB base64 JSON body, which is likely to trip body limits or spike memory on both client and server. Bound chunks by serialized byte size, and split oversized tensors when needed.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/monkeypatch/trainer/trl_vllm.py` around lines 78 - 97, The
current fallback loop uses MAX_PARAMS_PER_REQUEST to split by parameter count
which can still produce huge base64 JSON bodies; change the batching in the
function that builds payloads (where MAX_PARAMS_PER_REQUEST, chunk, payload, and
the self.session.post call are used) to instead accumulate parameters until a
MAX_BYTES_PER_REQUEST threshold (e.g., ~10MB) would be exceeded, then send that
batch; additionally, detect individual tensors whose serialized byte size
exceeds MAX_BYTES_PER_REQUEST and split them into smaller slices (preserving
name and adding slice metadata such as a shard index or byte/row range) so each
slice is serialized, base64-encoded and included as separate payload entries,
and ensure the shape/dtype fields reflect the slice so the server can
reassemble; replace the fixed-count loop with this byte-aware accumulator before
calling self.session.post.
| return {"prompt": m["content"]} | ||
| return {"prompt": str(messages)} | ||
|
|
||
| pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id |
There was a problem hiding this comment.
Don't treat pad_token_id == 0 as “missing”.
tokenizer.pad_token_id or tokenizer.eos_token_id will fall back to EOS whenever the pad id is 0, which is a valid PAD token for several tokenizers. That silently changes the padded tail from PAD to EOS.
Suggested fix
- pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
+ pad_id = (
+ tokenizer.pad_token_id
+ if tokenizer.pad_token_id is not None
+ else tokenizer.eos_token_id
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id | |
| pad_id = ( | |
| tokenizer.pad_token_id | |
| if tokenizer.pad_token_id is not None | |
| else tokenizer.eos_token_id | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py` at line 205, The code
treats tokenizer.pad_token_id as missing when it equals 0 by using a falsy
fallback expression; change the logic around pad_id in ebft_reasoning.py so you
only fall back to tokenizer.eos_token_id when pad_token_id is actually None (or
not set), e.g. replace the `tokenizer.pad_token_id or tokenizer.eos_token_id`
pattern with an explicit None check (use tokenizer.pad_token_id if it is not
None, otherwise tokenizer.eos_token_id) where pad_id is assigned so PAD=0
remains respected.
| # Find assistant turn boundaries using incremental tokenization | ||
| prefix_messages = [] | ||
| for msg in messages: | ||
| if msg["role"] == "assistant": | ||
| prefix_text = tokenizer.apply_chat_template( | ||
| prefix_messages, | ||
| tokenize=False, | ||
| add_generation_prompt=True, | ||
| ) | ||
| prefix_ids = tokenizer( | ||
| prefix_text, | ||
| truncation=True, | ||
| max_length=seq_len, | ||
| add_special_tokens=False, | ||
| return_tensors=None, | ||
| )["input_ids"] | ||
| start = len(prefix_ids) | ||
|
|
||
| prefix_messages.append(msg) | ||
| with_turn_text = tokenizer.apply_chat_template( | ||
| prefix_messages, | ||
| tokenize=False, | ||
| add_generation_prompt=False, | ||
| ) | ||
| with_turn_ids = tokenizer( | ||
| with_turn_text, | ||
| truncation=True, | ||
| max_length=seq_len, | ||
| add_special_tokens=False, | ||
| return_tensors=None, | ||
| )["input_ids"] | ||
| end = len(with_turn_ids) | ||
|
|
||
| # Mark assistant tokens as trainable | ||
| for i in range(start, min(end, len(labels))): | ||
| labels[i] = input_ids[i] |
There was a problem hiding this comment.
Only treat the final assistant span as the structured completion.
This loop labels every assistant turn and then derives prompt_length from the first non--100 label. On multi-turn chats, earlier assistant replies become part of the “completion” span, so the strided trainer can place anchors and CE loss inside prompt history instead of only the final target answer. If mask_thinking_ce=True, the derived boundary can move even further into the answer. Track the final assistant span explicitly and compute prompt_length from that true start before optional masking.
Also applies to: 285-290
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py` around lines 225 - 260,
The loop in ebft_reasoning.py currently marks every assistant turn as trainable
(labels[]) which makes prompt_length be derived from an earlier assistant turn;
instead detect and record only the final assistant turn span: when iterating
messages use the existing tokenizer logic to compute start/end but do not set
labels for each assistant turn immediately—instead store the last assistant span
(final_start, final_end), then after the loop set labels[i]=input_ids[i] only
for i in range(final_start, min(final_end, len(labels))) and compute
prompt_length from final_start (before applying any mask_thinking_ce
modifications); apply the same change to the other occurrence noted around the
285-290 region so only the final assistant span is treated as the structured
completion.
| embed_method: str = Field( | ||
| default="last_token", | ||
| json_schema_extra={ | ||
| "description": "Embedding method: 'last_token', 'mean_pooling', or 'concat'" | ||
| }, | ||
| ) | ||
| use_whitening: bool = Field( | ||
| default=False, | ||
| json_schema_extra={"description": "Apply SVD whitening to feature embeddings"}, | ||
| ) | ||
| alignment_coef: float = Field( | ||
| default=1.0, | ||
| json_schema_extra={ | ||
| "description": "Coefficient for alignment reward (cosine similarity with ground truth)" | ||
| }, | ||
| ) | ||
| diversity_coef: float = Field( | ||
| default=1.0, | ||
| json_schema_extra={ | ||
| "description": "Coefficient for diversity penalty (pairwise similarity between samples)" | ||
| }, | ||
| ) | ||
| ce_coef: float = Field( | ||
| default=0.0, | ||
| json_schema_extra={ | ||
| "description": "Cross-entropy loss coefficient on ground-truth tokens" | ||
| }, | ||
| ) | ||
| adaptive_max_tokens: bool = Field( | ||
| default=True, | ||
| json_schema_extra={ | ||
| "description": "Set per-batch max_tokens based on ground-truth length" | ||
| }, | ||
| ) | ||
| gt_length_multiplier: float = Field( | ||
| default=1.5, | ||
| json_schema_extra={ | ||
| "description": "Multiplier for ground-truth token count when computing adaptive max_tokens" | ||
| }, | ||
| ) | ||
|
|
||
| # Strided mode fields (for unstructured text) | ||
| mode: str = Field( | ||
| default="structured", | ||
| json_schema_extra={ | ||
| "description": "EBFT mode: 'structured' (QA with vLLM) or 'strided' (unstructured text)" | ||
| }, | ||
| ) | ||
| stride: int = Field( | ||
| default=8, | ||
| json_schema_extra={"description": "Stride between anchor points (tokens)"}, | ||
| ) | ||
| context_length: int = Field( | ||
| default=8, | ||
| json_schema_extra={"description": "Context window size per block"}, | ||
| ) | ||
| generate_max_len: int = Field( | ||
| default=8, | ||
| json_schema_extra={"description": "Tokens to generate per block"}, | ||
| ) | ||
| n_samples_per_prompt: int = Field( | ||
| default=4, | ||
| json_schema_extra={"description": "Independent rollouts per document"}, | ||
| ) | ||
| temperature: float = Field( | ||
| default=0.6, | ||
| json_schema_extra={ | ||
| "description": "Sampling temperature for strided generation" | ||
| }, | ||
| ) | ||
| top_p: float = Field( | ||
| default=1.0, | ||
| json_schema_extra={"description": "Top-p nucleus sampling threshold"}, | ||
| ) | ||
| rl_coef: float = Field( | ||
| default=1.0, | ||
| json_schema_extra={"description": "RL policy gradient loss coefficient"}, | ||
| ) | ||
| advantage_estimator: str = Field( | ||
| default="rloo", | ||
| json_schema_extra={ | ||
| "description": "Advantage estimator: 'rloo', 'group_norm', 'reinforce'" | ||
| }, | ||
| ) | ||
| min_completion_prefix: int = Field( | ||
| default=0, | ||
| json_schema_extra={ | ||
| "description": "Minimum tokens into completion before placing anchors. " | ||
| "Skips anchors too close to the prompt boundary where features are dominated by prompt context." | ||
| }, | ||
| ) |
There was a problem hiding this comment.
Tighten EBFT field validation to fail fast on invalid values.
Several EBFT fields (e.g., embed_method, mode, advantage_estimator, top_p, temperature, stride) currently accept invalid inputs at schema level, which can push failures to runtime. Please add enum/range constraints in this schema.
Proposed schema hardening
class EBFTConfig(BaseModel):
@@
- embed_method: str = Field(
+ embed_method: Literal["last_token", "mean_pooling", "concat"] = Field(
@@
- gt_length_multiplier: float = Field(
+ gt_length_multiplier: float = Field(
default=1.5,
+ gt=0.0,
@@
- mode: str = Field(
+ mode: Literal["structured", "strided", "async"] = Field(
default="structured",
@@
- stride: int = Field(
+ stride: int = Field(
default=8,
+ ge=1,
@@
- context_length: int = Field(
+ context_length: int = Field(
default=8,
+ ge=1,
@@
- generate_max_len: int = Field(
+ generate_max_len: int = Field(
default=8,
+ ge=1,
@@
- n_samples_per_prompt: int = Field(
+ n_samples_per_prompt: int = Field(
default=4,
+ ge=1,
@@
- temperature: float = Field(
+ temperature: float = Field(
default=0.6,
+ gt=0.0,
@@
- top_p: float = Field(
+ top_p: float = Field(
default=1.0,
+ ge=0.0,
+ le=1.0,
@@
- advantage_estimator: str = Field(
+ advantage_estimator: Literal["rloo", "group_norm", "reinforce"] = Field(
default="rloo",
@@
- min_completion_prefix: int = Field(
+ min_completion_prefix: int = Field(
default=0,
+ ge=0,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/utils/schemas/config.py` around lines 67 - 157, Tighten
validation on the EBFT config fields by adding explicit enum/range constraints:
restrict embed_method to the allowed values (e.g.,
"last_token","mean_pooling","concat") and mode to ("structured","strided") and
advantage_estimator to ("rloo","group_norm","reinforce") (use Literal or an Enum
for embed_method/mode/advantage_estimator), enforce top_p between 0.0 and 1.0
(le=1.0, ge=0.0), ensure temperature is non-negative (ge=0.0), and make integer
fields like stride, context_length, generate_max_len, n_samples_per_prompt, and
min_completion_prefix positive or non-negative as appropriate (ge=0 or ge=1).
Update the Field declarations for embed_method, mode, advantage_estimator,
top_p, temperature, stride, context_length, generate_max_len,
n_samples_per_prompt, and min_completion_prefix to include these constraints so
invalid values fail fast during model validation.
| ebft: EBFTConfig | None = Field( | ||
| default=None, | ||
| json_schema_extra={ | ||
| "description": "Configuration for Energy-Based Fine-Tuning (EBFT)" | ||
| }, | ||
| ) |
There was a problem hiding this comment.
Require ebft config when rl is set to EBFT.
At Line 245, ebft is optional, so rl: ebft can pass parsing with no ebft block. Cross-file validators in src/axolotl/utils/schemas/validation.py use data.get("ebft", {}), which silently bypasses required EBFT checks and shifts failures downstream. Add a schema validator that enforces presence of ebft when rl is ebft.
Proposed guard validator
class AxolotlInputConfig(
@@
ebft: EBFTConfig | None = Field(
default=None,
json_schema_extra={
"description": "Configuration for Energy-Based Fine-Tuning (EBFT)"
},
)
+
+ `@model_validator`(mode="before")
+ `@classmethod`
+ def check_ebft_block_present_when_rl_ebft(cls, data):
+ rl = data.get("rl")
+ if rl in ("ebft", RLType.EBFT) and not data.get("ebft"):
+ raise ValueError("`ebft` configuration is required when `rl: ebft`.")
+ return data🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/utils/schemas/config.py` around lines 245 - 250, Add a pydantic
root validator on the same config model that defines ebft (the model containing
the field "ebft: EBFTConfig | None = Field(...)") to enforce that when rl is set
to EBFT (check data.get("rl") against the RL enum or the string "ebft"), the
incoming data contains a non-null "ebft" entry; if missing or None, raise a
ValueError with a clear message (e.g., "ebft config is required when rl is
EBFT") so parsing fails early instead of letting downstream validators in
validation.py silently bypass the requirement.
Summary by CodeRabbit
New Features
Chores
Documentation