55# LICENSE file in the root directory of this source tree. 
66
77import  asyncio 
8+ import  logging 
89import  time 
910from  dataclasses  import  dataclass 
1011from  typing  import  Callable 
1516from  forge .actors .reference_actor  import  compute_sequence_logprobs , RefModel 
1617from  forge .actors .replay_buffer  import  ReplayBuffer 
1718from  forge .controller .actor  import  ForgeActor 
18- from  forge .controller .service  import  ServiceConfig , spawn_service 
19+ from  forge .controller .service  import  ServiceConfig , shutdown_service ,  spawn_service 
1920from  forge .data .rewards  import  MathReward , ThinkingReward 
2021from  forge .util .metric_logging  import  get_metric_logger 
2122from  monarch .actor  import  endpoint 
2223from  transformers  import  AutoModelForCausalLM , AutoTokenizer 
2324
25+ logger  =  logging .getLogger (__name__ )
26+ logger .setLevel (logging .DEBUG )
27+ 
2428
2529@dataclass  
2630class  Group :
@@ -242,18 +246,18 @@ async def __call__(self, groups: list[Group]) -> list[float]:
242246class  DatasetActor (ForgeActor ):
243247    """Actor wrapper for HuggingFace dataset to provide async interface.""" 
244248
245-     def  __init__ (self , * args , ** kwargs ):
249+     def  __init__ (
250+         self , path : str , config_name : str , split : str , streaming : bool , ** kwargs 
251+     ):
246252        super ().__init__ ()
247-         self ._setup_dataset (* args , ** kwargs )
248253
249-     def  _setup_dataset (self , * args , ** kwargs ):
250254        def  gsm8k_to_messages (sample ):
251255            question  =  sample ["question" ]
252256            full_answer : str  =  sample ["answer" ]
253257            answer  =  full_answer .split ("#### " )[1 ]
254258            return  {"question" : question , "answer" : answer }
255259
256-         ds  =  load_dataset (* args ,  ** kwargs )
260+         ds  =  load_dataset (path ,  config_name ,  split = split ,  streaming = streaming )
257261        ds  =  ds .map (gsm8k_to_messages )
258262        ds  =  ds .shuffle ()
259263        self ._iterator  =  iter (ds )
@@ -279,66 +283,69 @@ async def main():
279283    )
280284
281285    # ---- Setup services ---- # 
282-     policy  =  await  spawn_service (
283-         ServiceConfig (procs_per_replica = 1 , with_gpus = True , num_replicas = 1 ),
284-         Policy ,
285-         PolicyConfig (
286-             num_workers = 1 ,
287-             worker_params = WorkerConfig (model = model ),
288-             sampling_params = SamplingOverrides (num_samples = group_size , max_tokens = 16 ),
286+     (
287+         dataloader ,
288+         policy ,
289+         trainer ,
290+         replay_buffer ,
291+         compute_advantages ,
292+         ref_model ,
293+         reward_actor ,
294+     ) =  await  asyncio .gather (
295+         spawn_service (
296+             ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
297+             DatasetActor ,
298+             path = "openai/gsm8k" ,
299+             config_name = "main" ,
300+             split = "train" ,
301+             streaming = True ,
302+         ),
303+         spawn_service (
304+             ServiceConfig (procs_per_replica = 1 , with_gpus = True , num_replicas = 1 ),
305+             Policy ,
306+             config = PolicyConfig (
307+                 worker_params = WorkerConfig (model = model ),
308+                 sampling_params = SamplingOverrides (
309+                     num_samples = group_size , max_tokens = 16 
310+                 ),
311+             ),
312+         ),
313+         spawn_service (
314+             ServiceConfig (procs_per_replica = 1 , with_gpus = True , num_replicas = 1 ),
315+             Trainer ,
316+             learning_rate = 1e-5 ,
317+             beta = 0.1 ,
318+             model_name = model ,
319+         ),
320+         spawn_service (
321+             ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
322+             ReplayBuffer ,
323+             batch_size = 4 ,
324+             max_policy_age = 1 ,
325+         ),
326+         spawn_service (
327+             ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
328+             ComputeAdvantages ,
329+             gamma = 0.99 ,
330+             lambda_ = 0.95 ,
331+         ),
332+         spawn_service (
333+             ServiceConfig (procs_per_replica = 1 , num_replicas = 1 , with_gpus = True ),
334+             RefModel ,
335+             model_name = model ,
336+         ),
337+         spawn_service (
338+             ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
339+             RewardActor ,
340+             reward_functions = [MathReward (), ThinkingReward ()],
289341        ),
290-     )
291- 
292-     trainer  =  await  spawn_service (
293-         ServiceConfig (procs_per_replica = 1 , with_gpus = True , num_replicas = 1 ),
294-         Trainer ,
295-         learning_rate = 1e-5 ,
296-         beta = 0.1 ,
297-         model_name = model ,
298-     )
299- 
300-     replay_buffer  =  await  spawn_service (
301-         ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
302-         ReplayBuffer ,
303-         batch_size = 4 ,
304-         max_policy_age = 1 ,
305-     )
306- 
307-     dataloader  =  await  spawn_service (
308-         ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
309-         DatasetActor ,
310-         "openai/gsm8k" ,
311-         "main" ,
312-         split = "train" ,
313-         streaming = True ,
314-     )
315- 
316-     compute_advantages  =  await  spawn_service (
317-         ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
318-         ComputeAdvantages ,
319-         gamma = 0.99 ,
320-         lambda_ = 0.95 ,
321-     )
322- 
323-     ref_model  =  await  spawn_service (
324-         ServiceConfig (procs_per_replica = 1 , num_replicas = 1 , with_gpus = True ),
325-         RefModel ,
326-         model_name = model ,
327-     )
328- 
329-     reward_actor  =  await  spawn_service (
330-         ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
331-         RewardActor ,
332-         reward_functions = [MathReward (), ThinkingReward ()],
333342    )
334343
335344    print ("All services initialized successfully!" )
336345
337346    # ---- Core RL loops ---- # 
338347    async  def  continuous_rollouts ():
339348        rollout_count  =  0 
340-         # TODO: Move this into setup 
341-         asyncio .create_task (policy .run_processing .call ())
342349        while  True :
343350            sample  =  await  dataloader .__next__ .choose ()
344351            if  sample  is  None :
@@ -409,6 +416,17 @@ async def continuous_training():
409416        print ("Training interrupted by user" )
410417        rollout_task .cancel ()
411418        training_task .cancel ()
419+     finally :
420+         print ("Shutting down..." )
421+         await  asyncio .gather (
422+             shutdown_service (policy ),
423+             shutdown_service (trainer ),
424+             shutdown_service (replay_buffer ),
425+             shutdown_service (dataloader ),
426+             shutdown_service (compute_advantages ),
427+             shutdown_service (ref_model ),
428+             shutdown_service (reward_actor ),
429+         )
412430
413431
414432if  __name__  ==  "__main__" :
0 commit comments