Skip to content

Commit d2d7107

Browse files
committed
update grpo.main
1 parent 9278d75 commit d2d7107

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

apps/grpo/main.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import torch
1414
from datasets import load_dataset
15-
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
15+
from forge.actors.policy import Policy
1616
from forge.actors.reference_actor import compute_sequence_logprobs, TitanRefModel
1717
from forge.actors.replay_buffer import ReplayBuffer
1818
from forge.controller.actor import ForgeActor
@@ -305,12 +305,8 @@ async def main():
305305
spawn_service(
306306
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
307307
Policy,
308-
config=PolicyConfig(
309-
worker_params=WorkerConfig(model=model),
310-
sampling_params=SamplingOverrides(
311-
num_samples=group_size, max_tokens=16
312-
),
313-
),
308+
worker_params={"model": model, "vllm_args": None},
309+
sampling_params={"num_samples": group_size, "max_tokens": 16},
314310
),
315311
spawn_service(
316312
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),

0 commit comments

Comments
 (0)