Skip to content
Merged
5 changes: 4 additions & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
get_dcp_whole_state_dict_key,
get_param_prefix,
)
from forge.actors.policy import Policy
from forge.actors.generator import Generator
from forge.actors.reference_model import ReferenceModel
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.trainer import RLTrainer
Expand Down Expand Up @@ -79,6 +79,9 @@ def response_tensor(self) -> torch.Tensor:
# Represents the group (G) of episodes in GRPO
Group = list[Episode]

# Represents the Policy Model to collect data from
Policy = Generator


def collate(
batches: list[Group],
Expand Down
10 changes: 5 additions & 5 deletions docs/source/api_generator.md
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
# Generator

```{eval-rst}
.. currentmodule:: forge.actors.policy
.. currentmodule:: forge.actors.generator
```

The Generator (Policy) is the core inference engine in TorchForge,
built on top of [vLLM](https://docs.vllm.ai/en/latest/).
It manages model serving, text generation, and weight updates for reinforcement learning workflows.

## Policy
## Generator

```{eval-rst}
.. autoclass:: Policy
.. autoclass:: Generator
:members: generate, update_weights, get_version, stop
:exclude-members: __init__, launch
:no-inherited-members:
```

## PolicyWorker
## GeneratorWorker

```{eval-rst}
.. autoclass:: PolicyWorker
.. autoclass:: GeneratorWorker
:members: execute_model, update, setup_kv_cache
:show-inheritance:
:exclude-members: __init__
Expand Down
9 changes: 4 additions & 5 deletions src/forge/actors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
# LICENSE file in the root directory of this source tree.

__all__ = [
"Policy",
"PolicyRouter",
"Generator" "PolicyRouter",
"RLTrainer",
"ReplayBuffer",
"TitanRefModel",
Expand All @@ -15,10 +14,10 @@


def __getattr__(name):
if name == "Policy":
from .policy import Policy
if name == "Generator":
from .policy import Generator

return Policy
return Generator
elif name == "PolicyRouter":
from .policy import PolicyRouter

Expand Down
Loading
Loading