Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR introduces async GRPO trainer variants with replay buffers and parallel reward workers, adds vLLM native LoRA adapter serving capability, extends FP8 quantization support for LoRA parameters, updates vLLM serve module selection logic to support LoRA, and adds comprehensive async GRPO configuration fields to schemas while removing RL-kernel compatibility validation. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
1a56321 to
33ee2d6
Compare
|
📖 Documentation Preview: https://69b22aeca021762a2313828f--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit c590a53 |
There was a problem hiding this comment.
Actionable comments posted: 14
🧹 Nitpick comments (2)
src/axolotl/core/builders/rl.py (1)
57-67: Centralize theasync_grpopredicate.The same enablement check is duplicated for trainer-class selection and training-args selection. Once another async knob is added, those two branches can drift and produce mismatched trainer/args pairs.
Also applies to: 162-171
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/core/builders/rl.py` around lines 57 - 67, The async enablement check for GRPO is duplicated (computed as async_grpo and used when calling GRPOStrategy.get_trainer_class and again when selecting training args), which can drift; extract that predicate into a single shared accessor (e.g., a static/helper method or property such as GRPOStrategy.is_async_enabled(cfg) or a property on the configuration like cfg.trl.async_grpo_enabled) and replace both uses (the async_grpo variable near the call to GRPOStrategy.get_trainer_class and the separate training-args selection branch) to call that single accessor so trainer-class selection and training-args selection always use the same logic.src/axolotl/scripts/vllm_serve_lora.py (1)
131-168: Usepydantic.Fieldinstead ofdataclasses.fieldin BaseModel classes.
GenerateRequestandChatRequestmixdataclasses.fieldwith PydanticBaseModel. While this currently works with Pydantic 2.12.5, it's non-idiomatic and can cause confusion. Pydantic models should usepydantic.Fieldfor field defaults.Suggested fix
- from pydantic import BaseModel + from pydantic import BaseModel, Field ... - generation_kwargs: dict = field(default_factory=dict) + generation_kwargs: dict[str, Any] = Field(default_factory=dict) ... - generation_kwargs: dict = field(default_factory=dict) - chat_template_kwargs: dict = field(default_factory=dict) + generation_kwargs: dict[str, Any] = Field(default_factory=dict) + chat_template_kwargs: dict[str, Any] = Field(default_factory=dict)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/scripts/vllm_serve_lora.py` around lines 131 - 168, The GenerateRequest and ChatRequest Pydantic models use dataclasses.field for generation_kwargs and chat_template_kwargs; replace dataclasses.field with pydantic.Field and import Field from pydantic so the default_factory is passed via Field(default_factory=...), e.g. update the models GenerateRequest (generation_kwargs) and ChatRequest (generation_kwargs, chat_template_kwargs) to use Field instead of field.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/core/builders/rl.py`:
- Around line 237-248: The monkeypatch replacing
transformers.trainer.validate_quantization_for_training is currently applied
globally and not restored, so if trainer_cls(...) raises the no-op remains;
modify the block around the import and trainer instantiation (referencing
validate_quantization_for_training and trainer_cls) to save the original
function, set the no-op, then call trainer_cls(...) inside a try block and
restore the original function in a finally block so the original
validate_quantization_for_training is always reinstated even on errors.
In `@src/axolotl/core/trainers/grpo/__init__.py`:
- Around line 39-42: The current selection logic silently prefers
sequence_parallel over async_grpo; add an explicit guard that rejects the
unsupported combination (sequence_parallel and async_grpo both true) by raising
a clear exception (e.g., ValueError or custom config error) before returning
trainers. Update the conditional block that currently returns
AxolotlGRPOSequenceParallelTrainer or AxolotlAsyncGRPOTrainer so it first checks
if sequence_parallel and async_grpo are both set and raises an error referencing
those flags to fail fast.
In `@src/axolotl/core/trainers/grpo/fast_async_trainer.py`:
- Around line 361-376: The current logic treats non-Module/non-async reward
funcs as safe for background workers, but model-path strings are non-callable
and cause a single worker to return None and desynchronize the pool; update the
check that sets reward_can_bg to also require callable(rf) for every rf in
self.reward_funcs (so any string or non-callable disables BG workers), and
modify _collect_reward_workers (the worker-collection/reset path) to on any
worker failure or unexpected None reply fully drain and reset all worker pipes
and set self._reward_workers_used = None and keep the batch args in
self._pending_reward_args to ensure no stale replies remain; reference symbols:
reward_can_bg, self.reward_funcs, num_workers/reward_num_workers,
_collect_reward_workers, _reward_workers_used, _pending_reward_args.
- Around line 744-758: The zero-advantage fallback in _compute_loss incorrectly
hardcodes torch.amp.autocast(device_type="cuda") and only passes
input_ids/attention_mask, causing failures on CPU or multimodal models; change
to use
torch.autocast(device_type=torch.device(inputs["prompt_ids"].device).type,
dtype=torch.bfloat16) (or omit device_type to let torch decide) and build the
minimal input dict from inputs by conditionally including keys the model may
expect (e.g., "input_ids", "attention_mask", "pixel_values", "image_grid_thw",
etc.) before calling model(...), so the tiny forward pass works on CPU and with
multimodal models while preserving grad_fn for DeepSpeed/ZeRO.
In `@src/axolotl/core/trainers/grpo/replay_buffer.py`:
- Around line 14-33: The add() method currently assumes max_size>0 and accesses
self._heap[0], which crashes when ReplayBuffer is disabled (max_size<=0); modify
add(self, score, data) to early-return (no-op) when self.max_size <= 0 or raise
a clear, documented exception, and likewise ensure sample(self, num_samples)
returns None or empty list immediately if self.max_size <= 0 or self._heap is
empty; update logic around heapq.heappush/heapq.heapreplace to only run when
self.max_size > 0 and when the heap is non-empty to avoid indexing
self._heap[0].
In `@src/axolotl/kernels/lora.py`:
- Around line 49-57: The code unconditionally falls back to
base_layer.weight_scale_inv for quant_state which can pass non-FP8 metadata into
dequantize and break the NF4 path; change the fallback so you only assign
quant_state = base_layer.weight_scale_inv when W.dtype indicates an FP8 type
(i.e., guard by dtype check on W inside matmul_lora / the block using W), e.g.
check W.dtype is an FP8 dtype before using base_layer.weight_scale_inv,
otherwise leave quant_state None; update the logic around W, quant_state, and
base_layer.weight_scale_inv (the existing attributes proj.disable_adapters,
proj.merged, W, base_layer, quant_state) so tensor metadata only flows for FP8
weights.
In `@src/axolotl/kernels/quantize.py`:
- Around line 37-43: The code assumes W.shape is exactly divisible by the scale
grid and reshapes W_float into (sr, br, sc, bc) which breaks for tail blocks;
instead, tile/expand scale_inv to cover W's full row/column dimensions and then
crop to W.shape before applying it. Concretely, compute sr, sc =
scale_inv.shape, create scale_expanded = scale_inv.to(dtype)[:, None, :, None]
then use torch.repeat_interleave (or expand + repeat) to repeat each scale cell
the needed number of rows/cols to reach at least W.shape, and finally slice/crop
scale_expanded to match W.shape (or W_float.reshape(-1, W.shape[1]) shape)
before doing elementwise multiply and reshaping; update the code around the
reshape/multiply that references scale_inv, W, and W_float to use this
tiling-and-cropping approach so partial/tail FP8 blocks are handled correctly.
In `@src/axolotl/loaders/adapter.py`:
- Around line 163-167: The cast for FP8 LoRA params currently falls back to
hardcoded torch.bfloat16 when cfg.torch_dtype is unset; change it to choose an
effective training dtype instead: compute effective_train_dtype =
cfg.torch_dtype if set, otherwise select torch.bfloat16 when the device supports
BF16 (torch.cuda.is_bf16_supported()) and fall back to torch.float16 when BF16
is not supported (or when CUDA is available but BF16 unsupported), then use that
effective_train_dtype in the loop where model.named_parameters() checks
param.requires_grad and param.dtype == torch.float8_e4m3fn and call param.data =
param.data.to(effective_train_dtype).
In `@src/axolotl/loaders/model.py`:
- Around line 216-218: The current logic unconditionally sets
self.model_kwargs["allow_all_kernels"] when self.cfg.use_kernels is true,
overwriting any explicit value supplied via cfg.overrides_of_model_kwargs;
change it so you only set allow_all_kernels when it is not already present in
self.model_kwargs (i.e., check "allow_all_kernels" in self.model_kwargs and skip
assignment if found), while still setting self.model_kwargs["use_kernels"] based
on self.cfg.use_kernels.
In `@src/axolotl/scripts/vllm_serve_lora.py`:
- Around line 155-168: ChatRequest declares structured_outputs_regex and
chat_template_kwargs but the /chat/ handler never reads or forwards them; update
the chat endpoint handler (the function handling the /chat/ route) to extract
request.structured_outputs_regex and request.chat_template_kwargs and pass them
into the downstream generation/template call (where generation_kwargs and the
chat template are applied), ensuring None/defaults are honored and the fields
are included in any structured-output parsing or template override logic; also
apply the same fix to the other handler block referenced around the 363-388
region so both code paths consume these fields.
- Around line 209-223: The lifespan() loop busy-waits for all workers to report
"ready" without any timeout or liveness checks; update the readiness logic in
lifespan to track elapsed time (use a configurable startup timeout, e.g.,
script_args.startup_timeout or a reasonable default), periodically sleep to
avoid tight spinning, and during each iteration check each Process in processes
with p.is_alive()—if any process has exited or the timeout is exceeded, break
the wait loop and raise/return an explicit startup error (including info from
conn.recv() if available) rather than hanging; keep the existing checks on
conn.poll()/conn.recv() and the ready set but add the timeout and process
liveness handling to fail fast.
In `@src/axolotl/utils/schemas/trl.py`:
- Around line 206-257: The schema must validate numeric ranges for the new async
GRPO knobs: update the Pydantic Field declarations in the TRL schema (the class
containing prefetch_depth, vllm_sync_interval, streaming_min_groups,
vllm_importance_sampling_cap, off_policy_mask_threshold,
vllm_importance_sampling_mode, streaming_partial_batch, etc.) and add
appropriate ge/le constraints (or use conint/confloat) so invalid values are
rejected at validation time; e.g. set prefetch_depth, vllm_sync_interval,
streaming_min_groups, reward_num_workers, replay_buffer_size to have minimum 1
(Field(..., ge=1)), set reroll_start_fraction to a float range 0.0..1.0
(Field(..., ge=0.0, le=1.0)), and set vllm_importance_sampling_cap and
off_policy_mask_threshold to non-negative floats (Field(..., ge=0.0)); apply the
same pattern to the other numeric knobs referenced in lines ~260-313.
---
Nitpick comments:
In `@src/axolotl/core/builders/rl.py`:
- Around line 57-67: The async enablement check for GRPO is duplicated (computed
as async_grpo and used when calling GRPOStrategy.get_trainer_class and again
when selecting training args), which can drift; extract that predicate into a
single shared accessor (e.g., a static/helper method or property such as
GRPOStrategy.is_async_enabled(cfg) or a property on the configuration like
cfg.trl.async_grpo_enabled) and replace both uses (the async_grpo variable near
the call to GRPOStrategy.get_trainer_class and the separate training-args
selection branch) to call that single accessor so trainer-class selection and
training-args selection always use the same logic.
In `@src/axolotl/scripts/vllm_serve_lora.py`:
- Around line 131-168: The GenerateRequest and ChatRequest Pydantic models use
dataclasses.field for generation_kwargs and chat_template_kwargs; replace
dataclasses.field with pydantic.Field and import Field from pydantic so the
default_factory is passed via Field(default_factory=...), e.g. update the models
GenerateRequest (generation_kwargs) and ChatRequest (generation_kwargs,
chat_template_kwargs) to use Field instead of field.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: ab4abe54-079f-4218-96d9-3b60a2f17aac
📒 Files selected for processing (17)
src/axolotl/cli/vllm_serve.pysrc/axolotl/core/builders/rl.pysrc/axolotl/core/trainers/grpo/__init__.pysrc/axolotl/core/trainers/grpo/args.pysrc/axolotl/core/trainers/grpo/async_trainer.pysrc/axolotl/core/trainers/grpo/fast_async_trainer.pysrc/axolotl/core/trainers/grpo/replay_buffer.pysrc/axolotl/core/trainers/grpo/trainer.pysrc/axolotl/kernels/lora.pysrc/axolotl/kernels/quantize.pysrc/axolotl/loaders/adapter.pysrc/axolotl/loaders/model.pysrc/axolotl/scripts/__init__.pysrc/axolotl/scripts/vllm_serve_lora.pysrc/axolotl/utils/schemas/trl.pysrc/axolotl/utils/schemas/validation.pysrc/axolotl/utils/schemas/vllm.py
💤 Files with no reviewable changes (1)
- src/axolotl/utils/schemas/validation.py
| # Allow FP8-quantized models to be fine-tuned with LoRA adapters. | ||
| # transformers' validate_quantization_for_training blocks FP8 because | ||
| # hf_quantizer.is_trainable is False, but LoRA only trains the adapters | ||
| # (base weights stay frozen in FP8). | ||
| if ( | ||
| self.cfg.adapter | ||
| and hasattr(self.model, "is_quantized") | ||
| and self.model.is_quantized | ||
| ): | ||
| import transformers.trainer as _trainer_module | ||
|
|
||
| _trainer_module.validate_quantization_for_training = lambda model: None |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Expect to see the reassignment below, and no restoration in this file.
rg -nC2 'validate_quantization_for_training' src/axolotl/core/builders/rl.py
rg -n 'original_validate|finally' src/axolotl/core/builders/rl.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 564
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Read broader context around the patch to see trainer_cls call and any error handling
sed -n '237,270p' src/axolotl/core/builders/rl.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 1538
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Get the function signature and broader context - find the function name
sed -n '200,250p' src/axolotl/core/builders/rl.py | head -60Repository: axolotl-ai-cloud/axolotl
Length of output: 2132
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Get the end of the build function to see if there's try/finally wrapper
rg -n 'def build' src/axolotl/core/builders/rl.py
# Then get the structure of the function
sed -n '206,300p' src/axolotl/core/builders/rl.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 3030
🏁 Script executed:
#!/bin/bash
set -euo pipefabio
# Find all calls to .build() method in the codebase
rg -n '\.build\(' src/axolotl --type py | head -20
# Also check if validate_quantization_for_training is patched elsewhere
rg -n 'validate_quantization_for_training' src/axolotl --type pyRepository: axolotl-ai-cloud/axolotl
Length of output: 122
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Find all calls to .build() method in the codebase
rg -n '\.build\(' src/axolotl --type py | head -20
# Also check if validate_quantization_for_training is patched elsewhere
rg -n 'validate_quantization_for_training' src/axolotl --type pyRepository: axolotl-ai-cloud/axolotl
Length of output: 384
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Check the context around the build() call in trainer.py
sed -n '710,730p' src/axolotl/utils/trainer.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 577
Scope the Transformers quantization patch to a single trainer build.
The monkeypatch at line 248 modifies transformers.trainer.validate_quantization_for_training process-wide for any adapter + quantized model, not just the FP8 case. If trainer_cls(...) raises, the no-op stays installed and later trainer builds will silently skip the library safety check. Wrap the monkeypatch and trainer instantiation in a try/finally block to restore the original function.
🧰 Tools
🪛 Ruff (0.15.5)
[warning] 248-248: Unused lambda argument: model
(ARG005)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/core/builders/rl.py` around lines 237 - 248, The monkeypatch
replacing transformers.trainer.validate_quantization_for_training is currently
applied globally and not restored, so if trainer_cls(...) raises the no-op
remains; modify the block around the import and trainer instantiation
(referencing validate_quantization_for_training and trainer_cls) to save the
original function, set the no-op, then call trainer_cls(...) inside a try block
and restore the original function in a finally block so the original
validate_quantization_for_training is always reinstated even on errors.
| if sequence_parallel: | ||
| return AxolotlGRPOSequenceParallelTrainer | ||
| if async_grpo: | ||
| return AxolotlAsyncGRPOTrainer |
There was a problem hiding this comment.
Fail fast on sequence_parallel + async_grpo.
This branch silently drops the async path when both flags are true, so an async config can be accepted but executed with the sync sequence-parallel trainer. Please reject the unsupported combination instead of changing trainer semantics implicitly.
Suggested guard
- if sequence_parallel:
- return AxolotlGRPOSequenceParallelTrainer
+ if sequence_parallel:
+ if async_grpo:
+ raise ValueError("Async GRPO is not supported with sequence parallel")
+ return AxolotlGRPOSequenceParallelTrainer
if async_grpo:
return AxolotlAsyncGRPOTrainer📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if sequence_parallel: | |
| return AxolotlGRPOSequenceParallelTrainer | |
| if async_grpo: | |
| return AxolotlAsyncGRPOTrainer | |
| if sequence_parallel: | |
| if async_grpo: | |
| raise ValueError("Async GRPO is not supported with sequence parallel") | |
| return AxolotlGRPOSequenceParallelTrainer | |
| if async_grpo: | |
| return AxolotlAsyncGRPOTrainer |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/core/trainers/grpo/__init__.py` around lines 39 - 42, The current
selection logic silently prefers sequence_parallel over async_grpo; add an
explicit guard that rejects the unsupported combination (sequence_parallel and
async_grpo both true) by raising a clear exception (e.g., ValueError or custom
config error) before returning trainers. Update the conditional block that
currently returns AxolotlGRPOSequenceParallelTrainer or AxolotlAsyncGRPOTrainer
so it first checks if sequence_parallel and async_grpo are both set and raises
an error referencing those flags to fail fast.
| reward_can_bg = all( | ||
| not isinstance(rf, nn.Module) and not asyncio.iscoroutinefunction(rf) | ||
| for rf in self.reward_funcs | ||
| ) | ||
| num_workers = getattr(self.args, "reward_num_workers", 1) | ||
|
|
||
| if not reward_can_bg or num_workers <= 1: | ||
| # Can't parallelize — store args for sync fallback in collect | ||
| self._reward_workers_used = None | ||
| self._pending_reward_args = ( | ||
| inputs, | ||
| prompts, | ||
| completions, | ||
| completion_ids_list, | ||
| ) | ||
| return |
There was a problem hiding this comment.
A single reward-worker failure can desynchronize the pool.
reward_can_bg accepts any non-nn.Module / non-async object, but GRPOStrategy.get_reward_func() can also supply model-path strings. Those are not callable, so the worker returns None; _collect_reward_workers() then bails out without draining/resetting the other pipes, which leaves stale replies to be consumed by later batches.
Suggested fix
reward_can_bg = all(
- not isinstance(rf, nn.Module) and not asyncio.iscoroutinefunction(rf)
+ callable(rf)
+ and not isinstance(rf, nn.Module)
+ and not asyncio.iscoroutinefunction(rf)
for rf in self.reward_funcs
)
...
all_worker_results = []
any_failed = False
for conn in workers_used:
- result = conn.recv()
+ try:
+ result = conn.recv()
+ except EOFError:
+ any_failed = True
+ continue
if result is None:
any_failed = True
- break
- all_worker_results.append(result)
+ else:
+ all_worker_results.append(result)
...
# Fallback to main thread on failure
+ self._shutdown_reward_workers()
if args is not None:
return self._calculate_rewards(*args)Also applies to: 433-459
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/core/trainers/grpo/fast_async_trainer.py` around lines 361 - 376,
The current logic treats non-Module/non-async reward funcs as safe for
background workers, but model-path strings are non-callable and cause a single
worker to return None and desynchronize the pool; update the check that sets
reward_can_bg to also require callable(rf) for every rf in self.reward_funcs (so
any string or non-callable disables BG workers), and modify
_collect_reward_workers (the worker-collection/reset path) to on any worker
failure or unexpected None reply fully drain and reset all worker pipes and set
self._reward_workers_used = None and keep the batch args in
self._pending_reward_args to ensure no stale replies remain; reference symbols:
reward_can_bg, self.reward_funcs, num_workers/reward_num_workers,
_collect_reward_workers, _reward_workers_used, _pending_reward_args.
| r_fwd_kwargs = {} | ||
| for fk in ( | ||
| "pixel_values", | ||
| "image_grid_thw", | ||
| "pixel_attention_mask", | ||
| "image_sizes", | ||
| "token_type_ids", | ||
| "mm_token_type_ids", | ||
| ): | ||
| if fk in data: | ||
| r_fwd_kwargs[fk] = data[fk] | ||
| r_logps, _ = self._get_per_token_logps_and_entropies( | ||
| self.model, | ||
| r_ids, | ||
| r_mask, | ||
| r_logits_to_keep, | ||
| r_end - r_start, | ||
| **r_fwd_kwargs, | ||
| ) |
There was a problem hiding this comment.
Replay logprob recompute needs sliced multimodal inputs.
This path slices prompt_ids / completion_ids to the replayed range, but forwards full-batch pixel_values, image_grid_thw, pixel_attention_mask, and friends. Any multimodal batch will either misalign prompts to side inputs or fail on batch-size checks.
Suggested fix
for fk in (
"pixel_values",
"image_grid_thw",
"pixel_attention_mask",
"image_sizes",
"token_type_ids",
"mm_token_type_ids",
):
if fk in data:
- r_fwd_kwargs[fk] = data[fk]
+ value = data[fk]
+ if (
+ isinstance(value, torch.Tensor)
+ and value.dim() > 0
+ and value.size(0) >= r_end
+ ):
+ r_fwd_kwargs[fk] = value[r_start:r_end]
+ else:
+ r_fwd_kwargs[fk] = value📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| r_fwd_kwargs = {} | |
| for fk in ( | |
| "pixel_values", | |
| "image_grid_thw", | |
| "pixel_attention_mask", | |
| "image_sizes", | |
| "token_type_ids", | |
| "mm_token_type_ids", | |
| ): | |
| if fk in data: | |
| r_fwd_kwargs[fk] = data[fk] | |
| r_logps, _ = self._get_per_token_logps_and_entropies( | |
| self.model, | |
| r_ids, | |
| r_mask, | |
| r_logits_to_keep, | |
| r_end - r_start, | |
| **r_fwd_kwargs, | |
| ) | |
| r_fwd_kwargs = {} | |
| for fk in ( | |
| "pixel_values", | |
| "image_grid_thw", | |
| "pixel_attention_mask", | |
| "image_sizes", | |
| "token_type_ids", | |
| "mm_token_type_ids", | |
| ): | |
| if fk in data: | |
| value = data[fk] | |
| if ( | |
| isinstance(value, torch.Tensor) | |
| and value.dim() > 0 | |
| and value.size(0) >= r_end | |
| ): | |
| r_fwd_kwargs[fk] = value[r_start:r_end] | |
| else: | |
| r_fwd_kwargs[fk] = value | |
| r_logps, _ = self._get_per_token_logps_and_entropies( | |
| self.model, | |
| r_ids, | |
| r_mask, | |
| r_logits_to_keep, | |
| r_end - r_start, | |
| **r_fwd_kwargs, | |
| ) |
| def _compute_loss(self, model, inputs): | ||
| if self.args.skip_zero_advantage_batches and torch.all( | ||
| inputs["advantages"] == 0 | ||
| ): | ||
| mode = "train" if self.model.training else "eval" | ||
| self._metrics[mode]["skipped_zero_adv_batches"].append(1.0) | ||
| # Create zero loss with grad_fn. DeepSpeed requires grad_fn != None. | ||
| # With ZeRO-3, parameters are partitioned (shape=[0], requires_grad=False) | ||
| # so we can't just do `(p * 0).sum()`. Instead, do a tiny forward pass | ||
| # with a single token to create a proper computation graph. | ||
| prompt_ids = inputs["prompt_ids"][:1, :1] # (1, 1) | ||
| attn = torch.ones_like(prompt_ids) | ||
| with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | ||
| out = model(input_ids=prompt_ids, attention_mask=attn) | ||
| return out.logits.sum() * 0 |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Show whether GRPO inputs in this repo can be multimodal and how autocast is handled elsewhere.
rg -nC2 'pixel_values|image_grid_thw|pixel_attention_mask|image_sizes|token_type_ids|mm_token_type_ids' src/axolotl/core/trainers src/axolotl
rg -nC2 'accelerator\.autocast|torch\.amp\.autocast' src/axolotl/core/trainers src/axolotlRepository: axolotl-ai-cloud/axolotl
Length of output: 33265
🏁 Script executed:
#!/bin/bash
# Get context around _compute_loss to see class definition and accelerator availability
sed -n '700,770p' src/axolotl/core/trainers/grpo/fast_async_trainer.py | head -100
# Check class hierarchy and accelerator usage in FastAsyncGRPOTrainer
grep -n "class.*FastAsyncGRPOTrainer\|self.accelerator" src/axolotl/core/trainers/grpo/fast_async_trainer.py | head -20
# Check if accelerator.autocast is used elsewhere in this file
rg -n "accelerator\.autocast" src/axolotl/core/trainers/grpo/fast_async_trainer.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 3075
🏁 Script executed:
#!/bin/bash
# Check AsyncGRPOTrainer parent class for accelerator availability
rg -n "accelerator\.autocast|torch\.amp\.autocast" src/axolotl/core/trainers/grpo/async_trainer.py | head -20
# Check if AsyncGRPOTrainer is a Trainer subclass and has accelerator
grep -n "class AsyncGRPOTrainer\|from.*Trainer\|self.accelerator" src/axolotl/core/trainers/grpo/async_trainer.py | head -30
# Check the exact imports in fast_async_trainer to see if accelerator is standard
head -80 src/axolotl/core/trainers/grpo/fast_async_trainer.py | grep -n "import\|from"Repository: axolotl-ai-cloud/axolotl
Length of output: 2585
🏁 Script executed:
#!/bin/bash
# Check if accelerator.autocast is used at all in the trainer files
rg -n "accelerator\.autocast\(\)" src/axolotl/core/trainers/
# Check the pattern of how autocast is done elsewhere in fast_async_trainer
rg -n "torch\.amp\.autocast" src/axolotl/core/trainers/grpo/fast_async_trainer.py
# Check what pattern is used in async_trainer.py
rg -n "torch\.amp\.autocast" src/axolotl/core/trainers/grpo/async_trainer.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 150
🏁 Script executed:
#!/bin/bash
# Check if there's any example of accelerator.autocast in the entire codebase
rg -rn "accelerator\.autocast" src/
# Check what the Accelerator class provides - look at imports from accelerate
rg -n "from accelerate\|import.*Accelerator" src/axolotl/core/trainers/grpo/async_trainer.py
# Look for how autocast is handled in other training contexts in axolotl
rg -n "torch\.amp\.autocast\|autocast" src/axolotl/core/trainers/ | grep -v "fast_async_trainer.py:756"Repository: axolotl-ai-cloud/axolotl
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Check what accelerator actually provides - search for how it's used in the trainer
grep -n "self.accelerator\." src/axolotl/core/trainers/grpo/async_trainer.py | head -20
# Check GRPOTrainer base class by looking at what methods exist in FastAsyncGRPOTrainer
grep -n "def " src/axolotl/core/trainers/grpo/fast_async_trainer.py | head -30
# Let's check if accelerate library even has autocast method
python3 <<'EOF'
try:
from accelerate import Accelerator
# Check what methods are available
a = Accelerator()
if hasattr(a, 'autocast'):
print("accelerator.autocast exists")
else:
print("accelerator.autocast does NOT exist")
# List relevant autocast-like methods
methods = [m for m in dir(a) if 'autocast' in m.lower() or 'cast' in m.lower()]
print(f"Available casting methods: {methods}")
except Exception as e:
print(f"Error: {e}")
EOFRepository: axolotl-ai-cloud/axolotl
Length of output: 2094
🏁 Script executed:
#!/bin/bash
# Check how torch.amp.autocast is configured/used elsewhere in axolotl
rg -rn "torch\.amp\|autocast\|bfloat16" src/axolotl/core/trainers/ | grep -v ".pyc" | head -30
# Check if there's a standard way to get device and dtype in GRPOTrainer
grep -n "args\.\|self\.\(device\|dtype\)" src/axolotl/core/trainers/grpo/fast_async_trainer.py | head -30
# Check what args are available in the trainer
sed -n '226,270p' src/axolotl/core/trainers/grpo/fast_async_trainer.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 2907
🏁 Script executed:
#!/bin/bash
# Check if there are other forward passes in async_trainer.py and how they handle autocast/device
rg -n "with torch\.\|with self\.accelerator\.\|model(" src/axolotl/core/trainers/grpo/async_trainer.py | grep -A2 -B2 "model(" | head -50
# Look for actual model forward calls in async_trainer to see if they use autocast
sed -n '2300,2450p' src/axolotl/core/trainers/grpo/async_trainer.py | grep -n "with\|model(\|autocast"
# Check what device the model is on and how it's handled
grep -n "self.model\|model.to\|\.device" src/axolotl/core/trainers/grpo/async_trainer.py | head -20Repository: axolotl-ai-cloud/axolotl
Length of output: 1840
🏁 Script executed:
#!/bin/bash
# Get the exact context of the autocast usage in async_trainer around line 2332
sed -n '2325,2345p' src/axolotl/core/trainers/grpo/async_trainer.py
# Get the pattern used in the helper function around line 32
sed -n '25,50p' src/axolotl/core/trainers/grpo/async_trainer.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 1636
🏁 Script executed:
#!/bin/bash
# Get more context on the autocast pattern - check what device and dtype are available
sed -n '2320,2340p' src/axolotl/core/trainers/grpo/async_trainer.py
# Check if input_ids is available in _compute_loss and how device is determined
sed -n '744,770p' src/axolotl/core/trainers/grpo/fast_async_trainer.py
# Check what self.args has available for dtype
rg -n "mixed_precision\|bf16\|fp16\|dtype" src/axolotl/core/trainers/grpo/fast_async_trainer.py | head -20Repository: axolotl-ai-cloud/axolotl
Length of output: 2100
🏁 Script executed:
#!/bin/bash
# Check if there's a pattern for getting dtype from args in the trainer
rg -n "mixed_precision\|compute_dtype\|bfloat16" src/axolotl/core/trainers/grpo/ | head -15
# Confirm the device.type pattern is the standard way in async_trainer
sed -n '2327,2335p' src/axolotl/core/trainers/grpo/async_trainer.py
# Look at how other forward passes handle dtype
rg -n "torch.autocast\|device.type" src/axolotl/core/trainers/grpo/async_trainer.py | head -10Repository: axolotl-ai-cloud/axolotl
Length of output: 523
🏁 Script executed:
#!/bin/bash
# Check if "prompt_mask" is a standard key in inputs or if it should be attention_mask
rg -n "prompt_mask\|prompt_ids" src/axolotl/core/trainers/grpo/fast_async_trainer.py | head -20
# Verify what keys are available in inputs for _compute_loss context
sed -n '744,800p' src/axolotl/core/trainers/grpo/fast_async_trainer.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 1034
🏁 Script executed:
#!/bin/bash
# Search for prompt_mask usage in the codebase
rg -n "prompt_mask" src/axolotl/core/trainers/grpo/
# Also check for attention_mask construction from inputs
rg -n "inputs\[.*attention" src/axolotl/core/trainers/grpo/ | head -20Repository: axolotl-ai-cloud/axolotl
Length of output: 2770
🏁 Script executed:
#!/bin/bash
# Verify that prompt_mask is indeed available in inputs for _compute_loss
sed -n '649,670p' src/axolotl/core/trainers/grpo/fast_async_trainer.py
# Final check: confirm the pattern from async_trainer for computing with minimal inputs
sed -n '2450,2465p' src/axolotl/core/trainers/grpo/async_trainer.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 1585
🏁 Script executed:
#!/bin/bash
# Check the exact shape and slicing pattern for prompt_mask
sed -n '645,675p' src/axolotl/core/trainers/grpo/fast_async_trainer.py
# Also verify from async_trainer how minimal inputs are passed
sed -n '2450,2470p' src/axolotl/core/trainers/grpo/async_trainer.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 2146
Fix zero-advantage fallback to support non-CUDA devices and multimodal models.
The hardcoded device_type="cuda" will crash on non-CUDA runs. Additionally, the forward pass only provides input_ids and attention_mask, so multimodal models requiring pixel_values, image_grid_thw, or other inputs will fail. Use torch.autocast() with the actual device type and conditionally include all model-specific inputs.
Suggested fix
- prompt_ids = inputs["prompt_ids"][:1, :1] # (1, 1)
- attn = torch.ones_like(prompt_ids)
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
- out = model(input_ids=prompt_ids, attention_mask=attn)
+ prompt_ids = inputs["prompt_ids"][:1, :1] # (1, 1)
+ prompt_mask = inputs["prompt_mask"][:1]
+ model_kwargs = {
+ "input_ids": prompt_ids,
+ "attention_mask": prompt_mask,
+ }
+ for key in (
+ "pixel_values",
+ "image_grid_thw",
+ "pixel_attention_mask",
+ "image_sizes",
+ "token_type_ids",
+ "mm_token_type_ids",
+ ):
+ if key in inputs:
+ model_kwargs[key] = inputs[key][:1]
+ with torch.autocast(device_type=prompt_ids.device.type, dtype=torch.bfloat16):
+ out = model(**model_kwargs)
return out.logits.sum() * 0🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/core/trainers/grpo/fast_async_trainer.py` around lines 744 - 758,
The zero-advantage fallback in _compute_loss incorrectly hardcodes
torch.amp.autocast(device_type="cuda") and only passes input_ids/attention_mask,
causing failures on CPU or multimodal models; change to use
torch.autocast(device_type=torch.device(inputs["prompt_ids"].device).type,
dtype=torch.bfloat16) (or omit device_type to let torch decide) and build the
minimal input dict from inputs by conditionally including keys the model may
expect (e.g., "input_ids", "attention_mask", "pixel_values", "image_grid_thw",
etc.) before calling model(...), so the tiny forward pass works on CPU and with
multimodal models while preserving grad_fn for DeepSpeed/ZeRO.
| # FP8 models: LoRA A/B inherit FP8 dtype from base weights, but training | ||
| # requires a compute dtype (bf16/fp16). Cast trainable LoRA params. | ||
| for _name, param in model.named_parameters(): | ||
| if param.requires_grad and param.dtype == torch.float8_e4m3fn: | ||
| param.data = param.data.to(cfg.torch_dtype or torch.bfloat16) |
There was a problem hiding this comment.
Resolve the FP8 adapter cast from the effective training precision.
Line 167 hardcodes torch.bfloat16 as the fallback whenever cfg.torch_dtype is unset. That will miscast adapters in fp16-configured runs that drive precision via trainer flags instead of cfg.torch_dtype, and it can fail outright on non-bf16 hardware.
Possible fix
- # FP8 models: LoRA A/B inherit FP8 dtype from base weights, but training
- # requires a compute dtype (bf16/fp16). Cast trainable LoRA params.
- for _name, param in model.named_parameters():
- if param.requires_grad and param.dtype == torch.float8_e4m3fn:
- param.data = param.data.to(cfg.torch_dtype or torch.bfloat16)
+ # FP8 models: LoRA A/B inherit FP8 dtype from base weights, but training
+ # requires the actual compute dtype used by the run.
+ compute_dtype = cfg.torch_dtype
+ if compute_dtype is None:
+ if getattr(cfg, "bf16", False):
+ compute_dtype = torch.bfloat16
+ elif getattr(cfg, "fp16", False):
+ compute_dtype = torch.float16
+
+ for _name, param in model.named_parameters():
+ if (
+ compute_dtype is not None
+ and param.requires_grad
+ and param.dtype == torch.float8_e4m3fn
+ ):
+ param.data = param.data.to(compute_dtype)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # FP8 models: LoRA A/B inherit FP8 dtype from base weights, but training | |
| # requires a compute dtype (bf16/fp16). Cast trainable LoRA params. | |
| for _name, param in model.named_parameters(): | |
| if param.requires_grad and param.dtype == torch.float8_e4m3fn: | |
| param.data = param.data.to(cfg.torch_dtype or torch.bfloat16) | |
| # FP8 models: LoRA A/B inherit FP8 dtype from base weights, but training | |
| # requires the actual compute dtype used by the run. | |
| compute_dtype = cfg.torch_dtype | |
| if compute_dtype is None: | |
| if getattr(cfg, "bf16", False): | |
| compute_dtype = torch.bfloat16 | |
| elif getattr(cfg, "fp16", False): | |
| compute_dtype = torch.float16 | |
| for _name, param in model.named_parameters(): | |
| if ( | |
| compute_dtype is not None | |
| and param.requires_grad | |
| and param.dtype == torch.float8_e4m3fn | |
| ): | |
| param.data = param.data.to(compute_dtype) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/loaders/adapter.py` around lines 163 - 167, The cast for FP8 LoRA
params currently falls back to hardcoded torch.bfloat16 when cfg.torch_dtype is
unset; change it to choose an effective training dtype instead: compute
effective_train_dtype = cfg.torch_dtype if set, otherwise select torch.bfloat16
when the device supports BF16 (torch.cuda.is_bf16_supported()) and fall back to
torch.float16 when BF16 is not supported (or when CUDA is available but BF16
unsupported), then use that effective_train_dtype in the loop where
model.named_parameters() checks param.requires_grad and param.dtype ==
torch.float8_e4m3fn and call param.data = param.data.to(effective_train_dtype).
| if self.cfg.use_kernels: | ||
| self.model_kwargs["use_kernels"] = self.cfg.use_kernels | ||
| self.model_kwargs["allow_all_kernels"] = self.cfg.use_kernels |
There was a problem hiding this comment.
Preserve explicit allow_all_kernels overrides.
Line 218 currently clobbers any value already supplied through cfg.overrides_of_model_kwargs, so a caller can no longer enable kernels while keeping allow_all_kernels=False. Please only default this flag when it was not explicitly provided.
Proposed fix
if self.cfg.use_kernels:
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
- self.model_kwargs["allow_all_kernels"] = self.cfg.use_kernels
+ self.model_kwargs.setdefault(
+ "allow_all_kernels", self.cfg.use_kernels
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/loaders/model.py` around lines 216 - 218, The current logic
unconditionally sets self.model_kwargs["allow_all_kernels"] when
self.cfg.use_kernels is true, overwriting any explicit value supplied via
cfg.overrides_of_model_kwargs; change it so you only set allow_all_kernels when
it is not already present in self.model_kwargs (i.e., check "allow_all_kernels"
in self.model_kwargs and skip assignment if found), while still setting
self.model_kwargs["use_kernels"] based on self.cfg.use_kernels.
| class ChatRequest(BaseModel): | ||
| messages: list[list[dict]] | ||
| n: int = 1 | ||
| repetition_penalty: float = 1.0 | ||
| temperature: float = 1.0 | ||
| top_p: float = 1.0 | ||
| top_k: int = -1 | ||
| min_p: float = 0.0 | ||
| max_tokens: int = 16 | ||
| logprobs: int | None = 0 | ||
| truncate_prompt_tokens: int | None = None | ||
| structured_outputs_regex: str | None = None | ||
| generation_kwargs: dict = field(default_factory=dict) | ||
| chat_template_kwargs: dict = field(default_factory=dict) |
There was a problem hiding this comment.
/chat/ drops two request fields on the floor.
structured_outputs_regex and chat_template_kwargs are part of ChatRequest, but the handler never reads them. Callers that rely on structured output or template overrides will silently get plain chat behavior instead.
Also applies to: 363-388
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/scripts/vllm_serve_lora.py` around lines 155 - 168, ChatRequest
declares structured_outputs_regex and chat_template_kwargs but the /chat/
handler never reads or forwards them; update the chat endpoint handler (the
function handling the /chat/ route) to extract request.structured_outputs_regex
and request.chat_template_kwargs and pass them into the downstream
generation/template call (where generation_kwargs and the chat template are
applied), ensuring None/defaults are honored and the fields are included in any
structured-output parsing or template override logic; also apply the same fix to
the other handler block referenced around the 363-388 region so both code paths
consume these fields.
| @asynccontextmanager | ||
| async def lifespan(app: FastAPI): | ||
| ready: set[int] = set() | ||
| while len(ready) < script_args.data_parallel_size: | ||
| for conn in connections: | ||
| if id(conn) not in ready and conn.poll(): | ||
| msg = conn.recv() | ||
| if isinstance(msg, dict) and msg.get("status") == "ready": | ||
| ready.add(id(conn)) | ||
| yield | ||
| for p in processes: | ||
| p.join(timeout=10) | ||
| if p.is_alive(): | ||
| p.terminate() | ||
| p.join() |
There was a problem hiding this comment.
Don't wait forever for worker readiness.
lifespan() busy-polls until every child reports "ready", but it never checks Process.is_alive() or applies a timeout. Any init failure inside LLM(...) will hang server startup indefinitely.
Suggested fix
`@asynccontextmanager`
async def lifespan(app: FastAPI):
ready: set[int] = set()
+ loop = asyncio.get_running_loop()
+ deadline = loop.time() + 300
while len(ready) < script_args.data_parallel_size:
+ for idx, process in enumerate(processes):
+ if id(connections[idx]) not in ready and not process.is_alive():
+ raise RuntimeError("vLLM worker exited before reporting ready")
for conn in connections:
if id(conn) not in ready and conn.poll():
msg = conn.recv()
if isinstance(msg, dict) and msg.get("status") == "ready":
ready.add(id(conn))
+ if loop.time() > deadline:
+ raise TimeoutError("Timed out waiting for vLLM workers to become ready")
+ await asyncio.sleep(0.05)
yield🧰 Tools
🪛 Ruff (0.15.5)
[warning] 210-210: Unused function argument: app
(ARG001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/scripts/vllm_serve_lora.py` around lines 209 - 223, The
lifespan() loop busy-waits for all workers to report "ready" without any timeout
or liveness checks; update the readiness logic in lifespan to track elapsed time
(use a configurable startup timeout, e.g., script_args.startup_timeout or a
reasonable default), periodically sleep to avoid tight spinning, and during each
iteration check each Process in processes with p.is_alive()—if any process has
exited or the timeout is exceeded, break the wait loop and raise/return an
explicit startup error (including info from conn.recv() if available) rather
than hanging; keep the existing checks on conn.poll()/conn.recv() and the ready
set but add the timeout and process liveness handling to fail fast.
| prefetch_depth: int | None = Field( | ||
| default=None, | ||
| json_schema_extra={ | ||
| "description": "Number of rollouts to prefetch ahead of training." | ||
| }, | ||
| ) | ||
| vllm_sync_interval: int | None = Field( | ||
| default=None, | ||
| json_schema_extra={ | ||
| "description": "Sync model weights to vLLM every N optimizer steps (async mode only)." | ||
| }, | ||
| ) | ||
| streaming_partial_batch: bool | None = Field( | ||
| default=None, | ||
| json_schema_extra={ | ||
| "description": "Score prompt groups incrementally instead of the full batch at once." | ||
| }, | ||
| ) | ||
| streaming_min_groups: int | None = Field( | ||
| default=None, | ||
| json_schema_extra={ | ||
| "description": "Minimum prompt groups to score per streaming chunk." | ||
| }, | ||
| ) | ||
| vllm_importance_sampling_correction: bool | None = Field( | ||
| default=None, | ||
| json_schema_extra={ | ||
| "description": "Apply IS correction for distribution mismatch between vLLM and training model." | ||
| }, | ||
| ) | ||
| vllm_importance_sampling_mode: ( | ||
| Literal["token_truncate", "token_mask", "sequence_truncate", "sequence_mask"] | ||
| | None | ||
| ) = Field( | ||
| default=None, | ||
| json_schema_extra={ | ||
| "description": "IS mode: token_truncate, token_mask, sequence_truncate, or sequence_mask." | ||
| }, | ||
| ) | ||
| vllm_importance_sampling_cap: float | None = Field( | ||
| default=None, | ||
| json_schema_extra={"description": "Cap C for IS ratio clipping/masking."}, | ||
| ) | ||
| off_policy_mask_threshold: float | None = Field( | ||
| default=None, | ||
| json_schema_extra={ | ||
| "description": "KL threshold for off-policy sequence masking (OPSM). None = disabled." | ||
| }, | ||
| ) | ||
| use_bias_correction_kl: bool | None = Field( | ||
| default=None, | ||
| json_schema_extra={"description": "Apply IS correction to KL divergence term."}, |
There was a problem hiding this comment.
Add range validation for the new async GRPO numeric knobs.
Right now values like prefetch_depth=0, reward_num_workers=0, replay_buffer_size=-1, or reroll_start_fraction=1.5 will all pass schema validation and only fail deeper in the trainer/replay path. These should be rejected here.
Representative schema constraints
prefetch_depth: int | None = Field(
default=None,
+ ge=1,
json_schema_extra={
"description": "Number of rollouts to prefetch ahead of training."
},
)
vllm_sync_interval: int | None = Field(
default=None,
+ ge=1,
json_schema_extra={
"description": "Sync model weights to vLLM every N optimizer steps (async mode only)."
},
)
streaming_min_groups: int | None = Field(
default=None,
+ ge=1,
json_schema_extra={
"description": "Minimum prompt groups to score per streaming chunk."
},
)
reward_num_workers: int = Field(
default=1,
+ ge=1,
json_schema_extra={
"description": "Number of persistent subprocess workers for parallel reward computation. Each worker has its "
"own main thread so signal.alarm() (used by math_verify) works correctly. Work is sharded across "
"workers by prompt groups. Only used with use_data_producer=True and non-nn.Module reward functions."
},
)
replay_buffer_size: int = Field(
default=0,
+ ge=0,
json_schema_extra={
"description": "[Experimental, disabled by default] Size of the replay buffer for storing high-signal rollout "
"groups. When > 0, groups with reward variance are cached and used to replace zero-signal groups "
"(where all rewards are identical). Set to 0 to disable. Only used with use_data_producer=True."
},
)
reroll_start_fraction: float = Field(
default=1.0,
+ ge=0.0,
+ le=1.0,
json_schema_extra={
"description": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts "
"(where all rewards in a group are identical) are buffered and re-injected into later batches when the "
"model is more likely to solve them. Set to 1.0 to disable. Only used with use_data_producer=True."
},
)
reroll_max_groups: int = Field(
default=1,
+ ge=1,
json_schema_extra={
"description": "Maximum number of prompt groups to replace with re-roll candidates per batch. Higher values "
"increase data utilization but reduce prompt diversity. Only used with use_data_producer=True."
},
)Apply the same pattern to the threshold/cap fields as well.
Also applies to: 260-313
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/utils/schemas/trl.py` around lines 206 - 257, The schema must
validate numeric ranges for the new async GRPO knobs: update the Pydantic Field
declarations in the TRL schema (the class containing prefetch_depth,
vllm_sync_interval, streaming_min_groups, vllm_importance_sampling_cap,
off_policy_mask_threshold, vllm_importance_sampling_mode,
streaming_partial_batch, etc.) and add appropriate ge/le constraints (or use
conint/confloat) so invalid values are rejected at validation time; e.g. set
prefetch_depth, vllm_sync_interval, streaming_min_groups, reward_num_workers,
replay_buffer_size to have minimum 1 (Field(..., ge=1)), set
reroll_start_fraction to a float range 0.0..1.0 (Field(..., ge=0.0, le=1.0)),
and set vllm_importance_sampling_cap and off_policy_mask_threshold to
non-negative floats (Field(..., ge=0.0)); apply the same pattern to the other
numeric knobs referenced in lines ~260-313.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
Summary by CodeRabbit
Release Notes
New Features
Improvements
Bug Fixes