55# LICENSE file in the root directory of this source tree.
66
77import asyncio
8- import copy
98import logging
109import time
1110import uuid
@@ -118,14 +117,16 @@ def new_group(
118117 ):
119118 episodes = []
120119 for i in range (group_size ):
121- Episode (
122- episode_id = str (uuid .uuid4 ()),
123- request = copy .deepcopy (messages ),
124- policy_version = policy_version ,
125- pad_id = pad_iddd ,
126- request_len = request_len ,
127- response_len = response_len ,
128- target = target ,
120+ episodes .append (
121+ Episode (
122+ episode_id = str (uuid .uuid4 ()),
123+ request = request ,
124+ policy_version = policy_version ,
125+ pad_id = pad_id ,
126+ request_len = request_len ,
127+ response_len = response_len ,
128+ target = target ,
129+ )
129130 )
130131 return cls (group_id , episodes )
131132
@@ -148,7 +149,7 @@ def setup(self):
148149
149150 # Initialize model
150151 self .model = AutoModelForCausalLM .from_pretrained (
151- model_name ,
152+ self . model_name ,
152153 torch_dtype = torch .bfloat16 ,
153154 trust_remote_code = True ,
154155 ).to (self .device )
@@ -313,7 +314,7 @@ class DatasetActor(ForgeActor):
313314 """Actor wrapper for HuggingFace dataset to provide async interface."""
314315
315316 path : str
316- name : str
317+ revision : str
317318 data_split : str
318319 streaming : bool
319320 model : str
@@ -334,7 +335,7 @@ def gsm8k_transform(sample):
334335 return {"request" : formatted_request , "target" : formatted_target }
335336
336337 ds = load_dataset (
337- self .path , self .name , split = self .data_split , streaming = self .streaming
338+ self .path , self .revision , split = self .data_split , streaming = self .streaming
338339 )
339340 ds = ds .map (gsm8k_transform )
340341 ds = ds .shuffle ()
@@ -382,7 +383,7 @@ async def main():
382383 ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
383384 DatasetActor ,
384385 path = "openai/gsm8k" ,
385- name = "main" ,
386+ revision = "main" ,
386387 data_split = "train" ,
387388 streaming = True ,
388389 model = model ,
@@ -416,7 +417,7 @@ async def main():
416417 spawn_service (
417418 ServiceConfig (procs_per_replica = 1 , num_replicas = 1 , with_gpus = True ),
418419 RefModel ,
419- model = titan_model ,
420+ model_name = model ,
420421 ),
421422 spawn_service (
422423 ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
0 commit comments