66
77import asyncio
88import copy
9+ import logging
910import time
1011import uuid
1112from dataclasses import dataclass
1516from datasets import load_dataset
1617from forge .actors .policy import Policy , PolicyConfig , SamplingOverrides , WorkerConfig
1718from forge .actors .replay_buffer import ReplayBuffer
18- from forge .controller import ServiceConfig , spawn_service
1919from forge .controller .actor import ForgeActor
20+ from forge .controller .service import ServiceConfig , shutdown_service , spawn_service
2021from forge .data .rewards import MathReward , ThinkingReward
2122from forge .util .metric_logging import get_metric_logger
2223from monarch .actor import endpoint
2324from torch import nn
2425from transformers import AutoModelForCausalLM
2526from vllm .transformers_utils .tokenizer import get_tokenizer
2627
28+ logger = logging .getLogger (__name__ )
29+ logger .setLevel (logging .DEBUG )
30+
2731
2832def compute_logprobs (
2933 logits : torch .Tensor , input_ids : torch .Tensor , temperature : float = 1.0
@@ -365,66 +369,60 @@ async def main():
365369 )
366370
367371 # ---- Setup services ---- #
368- default_service_cfg = ServiceConfig (
369- procs_per_replica = 1 ,
370- num_replicas = 1 ,
371- )
372-
373- policy = await spawn_service (
374- default_service_cfg ,
375- Policy ,
376- PolicyConfig (
377- num_workers = 1 ,
378- worker_params = WorkerConfig (model = model ),
379- sampling_params = SamplingOverrides (n = group_size , max_tokens = max_res_tokens ),
380- available_devices = "3" ,
372+ (
373+ dataloader ,
374+ policy ,
375+ trainer ,
376+ replay_buffer ,
377+ compute_advantages ,
378+ ref_model ,
379+ reward_actor ,
380+ ) = await asyncio .gather (
381+ spawn_service (
382+ ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
383+ DatasetActor ,
384+ path = "openai/gsm8k" ,
385+ name = "main" ,
386+ data_split = "train" ,
387+ streaming = True ,
388+ model = model ,
389+ ),
390+ spawn_service (
391+ ServiceConfig (procs_per_replica = 1 , with_gpus = True , num_replicas = 1 ),
392+ Policy ,
393+ config = PolicyConfig (
394+ worker_params = WorkerConfig (model = model ),
395+ sampling_params = SamplingOverrides (
396+ n = group_size , max_tokens = max_res_tokens
397+ ),
398+ ),
399+ ),
400+ spawn_service (
401+ ServiceConfig (procs_per_replica = 1 , with_gpus = True , num_replicas = 1 ),
402+ Trainer ,
403+ learning_rate = 1e-5 ,
404+ model_name = model ,
405+ ),
406+ spawn_service (
407+ ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
408+ ReplayBuffer ,
409+ batch_size = 4 ,
410+ max_policy_age = 1 ,
411+ ),
412+ spawn_service (
413+ ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
414+ ComputeAdvantages ,
415+ ),
416+ spawn_service (
417+ ServiceConfig (procs_per_replica = 1 , num_replicas = 1 , with_gpus = True ),
418+ RefModel ,
419+ model = titan_model ,
420+ ),
421+ spawn_service (
422+ ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
423+ RewardActor ,
424+ reward_functions = [MathReward (), ThinkingReward ()],
381425 ),
382- )
383-
384- trainer = await spawn_service (
385- default_service_cfg ,
386- Trainer ,
387- learning_rate = 1e-5 ,
388- beta = 0.1 ,
389- model_name = model ,
390- device = torch .device ("cuda:1" ),
391- )
392-
393- replay_buffer = await spawn_service (
394- default_service_cfg ,
395- ReplayBuffer ,
396- batch_size = 4 ,
397- max_policy_age = 1 ,
398- )
399-
400- dataloader = await spawn_service (
401- default_service_cfg ,
402- DatasetActor ,
403- "openai/gsm8k" ,
404- "main" ,
405- data_split = "train" ,
406- streaming = True ,
407- model = model ,
408- )
409-
410- compute_advantages = await spawn_service (
411- default_service_cfg ,
412- ComputeAdvantages ,
413- gamma = 0.99 ,
414- lambda_ = 0.95 ,
415- )
416-
417- ref_model = await spawn_service (
418- default_service_cfg ,
419- RefModel ,
420- model_name = model ,
421- device = torch .device ("cuda:2" ),
422- )
423-
424- reward_actor = await spawn_service (
425- default_service_cfg ,
426- RewardActor ,
427- reward_functions = [MathReward (), ThinkingReward ()],
428426 )
429427
430428 print ("All services initialized successfully!" )
@@ -433,8 +431,6 @@ async def main():
433431 async def continuous_rollouts ():
434432 rollout_count = 0
435433 pad_id = dataloader .pad_token .choose ()
436- # TODO: Move this into setup
437- asyncio .create_task (policy .run_processing .call ())
438434 while True :
439435 sample = await dataloader .sample .choose ()
440436 if sample is None :
@@ -501,6 +497,17 @@ async def continuous_training():
501497 print ("Training interrupted by user" )
502498 rollout_task .cancel ()
503499 training_task .cancel ()
500+ finally :
501+ print ("Shutting down..." )
502+ await asyncio .gather (
503+ shutdown_service (policy ),
504+ shutdown_service (trainer ),
505+ shutdown_service (replay_buffer ),
506+ shutdown_service (dataloader ),
507+ shutdown_service (compute_advantages ),
508+ shutdown_service (ref_model ),
509+ shutdown_service (reward_actor ),
510+ )
504511
505512
506513if __name__ == "__main__" :
0 commit comments