Skip to content

Introduce a Minimal Backend Generation Interface #5193

@rycerzes

Description

@rycerzes

As part of #5119

Motivation:
After sub-issue #5121 decouples rollout dispatch, _generate_single_turn still has
three backend-specific branches inlined in the method body. These branches share
common structure (apply chat template -> call backend -> extract ids/logprobs)
but differ in backend-specific details. Encapsulating them behind a common
interface would:

  • Reduce the cognitive load of reading _generate_single_turn
  • Make it possible to add new backends without modifying the trainer
  • Allow openenv/utils.py to call the generation backend without knowing which
    one is active
  • Align TRL with the patterns used by other frameworks

Scope:

Define a minimal internal interface (not necessarily a full ABC) that all
generation paths implement:

class GenerationBackend(Protocol):
    def generate(
        self,
        prompts: list,
        num_generations: int,
        processing_class,
        generation_config: GenerationConfig | None = None,
    ) -> tuple[list[list[int]], list[list[int]], list[list[float]] | None]:
        """
        Returns (prompt_ids, completion_ids, logprobs).
        Chat templating happens inside this method.
        """
        ...

    def sync_weights(self) -> None:
        """Sync model weights from trainer to inference engine (no-op for in-process backends)."""
        ...

CC: @albertvillanova

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions