generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Open
Description
As part of #5119
Motivation:
After sub-issue #5121 decouples rollout dispatch, _generate_single_turn still has
three backend-specific branches inlined in the method body. These branches share
common structure (apply chat template -> call backend -> extract ids/logprobs)
but differ in backend-specific details. Encapsulating them behind a common
interface would:
- Reduce the cognitive load of reading _generate_single_turn
- Make it possible to add new backends without modifying the trainer
- Allow openenv/utils.py to call the generation backend without knowing which
one is active - Align TRL with the patterns used by other frameworks
Scope:
Define a minimal internal interface (not necessarily a full ABC) that all
generation paths implement:
class GenerationBackend(Protocol):
def generate(
self,
prompts: list,
num_generations: int,
processing_class,
generation_config: GenerationConfig | None = None,
) -> tuple[list[list[int]], list[list[int]], list[list[float]] | None]:
"""
Returns (prompt_ids, completion_ids, logprobs).
Chat templating happens inside this method.
"""
...
def sync_weights(self) -> None:
"""Sync model weights from trainer to inference engine (no-op for in-process backends)."""
...CC: @albertvillanova
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels