Skip to content

Commit e2a3a68

Browse files
committed
debug merge
1 parent 52028a5 commit e2a3a68

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

apps/grpo/main.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import asyncio
8-
import copy
98
import logging
109
import time
1110
import uuid
@@ -118,14 +117,16 @@ def new_group(
118117
):
119118
episodes = []
120119
for i in range(group_size):
121-
Episode(
122-
episode_id=str(uuid.uuid4()),
123-
request=copy.deepcopy(messages),
124-
policy_version=policy_version,
125-
pad_id=pad_iddd,
126-
request_len=request_len,
127-
response_len=response_len,
128-
target=target,
120+
episodes.append(
121+
Episode(
122+
episode_id=str(uuid.uuid4()),
123+
request=request,
124+
policy_version=policy_version,
125+
pad_id=pad_id,
126+
request_len=request_len,
127+
response_len=response_len,
128+
target=target,
129+
)
129130
)
130131
return cls(group_id, episodes)
131132

@@ -148,7 +149,7 @@ def setup(self):
148149

149150
# Initialize model
150151
self.model = AutoModelForCausalLM.from_pretrained(
151-
model_name,
152+
self.model_name,
152153
torch_dtype=torch.bfloat16,
153154
trust_remote_code=True,
154155
).to(self.device)
@@ -313,7 +314,7 @@ class DatasetActor(ForgeActor):
313314
"""Actor wrapper for HuggingFace dataset to provide async interface."""
314315

315316
path: str
316-
name: str
317+
revision: str
317318
data_split: str
318319
streaming: bool
319320
model: str
@@ -334,7 +335,7 @@ def gsm8k_transform(sample):
334335
return {"request": formatted_request, "target": formatted_target}
335336

336337
ds = load_dataset(
337-
self.path, self.name, split=self.data_split, streaming=self.streaming
338+
self.path, self.revision, split=self.data_split, streaming=self.streaming
338339
)
339340
ds = ds.map(gsm8k_transform)
340341
ds = ds.shuffle()
@@ -382,7 +383,7 @@ async def main():
382383
ServiceConfig(procs_per_replica=1, num_replicas=1),
383384
DatasetActor,
384385
path="openai/gsm8k",
385-
name="main",
386+
revision="main",
386387
data_split="train",
387388
streaming=True,
388389
model=model,
@@ -416,7 +417,7 @@ async def main():
416417
spawn_service(
417418
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
418419
RefModel,
419-
model=titan_model,
420+
model_name=model,
420421
),
421422
spawn_service(
422423
ServiceConfig(procs_per_replica=1, num_replicas=1),

src/forge/actors/policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu
226226
request_id = str(self.request_id) # implement from a counter
227227

228228
# Wraps prompt into a dict
229-
prompt: Dict[str, str] = convert_input(prompt_token_ids=prompt)
229+
prompt: Dict[str, str] = convert_input(prompt=prompt)
230230

231231
# truncate prmpt
232232
tokenization_kwargs = self.tokenization_kwargs or {}

0 commit comments

Comments
 (0)