2828from forge .controller .actor import ForgeActor
2929from forge .controller .provisioner import init_provisioner , shutdown
3030from forge .data .rewards import MathReward , ThinkingReward
31- from forge .env import MONARCH_HOSTMESH_V1
31+ from forge .data_models . completion import Completion
3232from forge .observability .metric_actors import get_or_create_metric_logger
3333from forge .observability .metrics import record_metric , Reduce
3434from forge .observability .perf_tracker import Tracer
4242
4343@dataclass
4444class Episode :
45- # TODO: add adtional layer for multi-turn
4645 episode_id : str
47- request : str
48- policy_version : int
4946 pad_id : int
5047 request_len : int
5148 response_len : int
5249 target : Any | None = None
53- # processed data
54- response : str | None = None
55- request_tokens : list [int ] | None = None
56- response_tokens : list [int ] | None = None
50+ # Processed data
51+ completion : Completion | None = None
5752 ref_logprobs : torch .Tensor | None = None
5853 reward : float | None = None
5954 advantage : float | None = None
6055
6156 @property
62- def request_tensor (self ):
63- tensor = torch .tensor (self .request_tokens , dtype = torch .long )
57+ def policy_version (self ) -> int | None :
58+ return self .completion .generator_version
59+
60+ @property
61+ def request_tensor (self ) -> torch .Tensor :
62+ request_tokens : torch .Tensor = self .completion .prompt_ids
63+ tensor = torch .tensor (request_tokens , dtype = torch .long )
6464 if tensor .shape [0 ] < self .request_len : # left pad
6565 diff = self .request_len - tensor .shape [0 ]
6666 tensor = F .pad (tensor , (diff , 0 ), value = self .pad_id )
6767 return tensor
6868
6969 @property
70- def response_tensor (self ):
71- tensor = torch .tensor (self .response_tokens , dtype = torch .long )
70+ def response_tensor (self ) -> torch .Tensor :
71+ response_tokens : torch .Tensor = self .completion .token_ids
72+ tensor = torch .tensor (response_tokens , dtype = torch .long )
7273 if tensor .shape [0 ] < self .response_len : # right pad
7374 diff = self .response_len - tensor .shape [0 ]
7475 tensor = F .pad (tensor , (0 , diff ), value = self .pad_id )
7576 return tensor
7677
7778
78- @dataclass
79- class Group :
80- group_id : str
81- episodes : list [Episode ]
82-
83- @classmethod
84- def new_group (
85- cls ,
86- group_id : int ,
87- group_size : int ,
88- request : str ,
89- policy_version : int ,
90- pad_id : int ,
91- request_len : int ,
92- response_len : int ,
93- target : Any = None ,
94- ):
95- episodes = []
96- for _ in range (group_size ):
97- episodes .append (
98- Episode (
99- episode_id = str (uuid .uuid4 ()),
100- request = request ,
101- policy_version = policy_version ,
102- pad_id = pad_id ,
103- request_len = request_len ,
104- response_len = response_len ,
105- target = target ,
106- )
107- )
108- return cls (str (group_id ), episodes )
79+ # Represents the group (G) of episodes in GRPO
80+ Group = list [Episode ]
10981
11082
111- def collate (batches : list [list [Episode ]]):
83+ def collate (
84+ batches : list [Group ],
85+ ) -> tuple [list [dict [str , Any ]], list [dict [str , Any ]]]:
86+ """
87+ Collates a list of batches into a single batch of inputs and targets.
88+ Each batch is a list of episodes, and each episode is a dict of tensors.
89+ """
11290 inputs = []
11391 targets = []
11492 for batch in batches :
@@ -222,7 +200,7 @@ class ComputeAdvantages(ForgeActor):
222200 @endpoint
223201 async def compute (self , group : Group ) -> list [float ]:
224202 # TODO: add batch processing
225- rewards = torch .tensor ([[e .reward for e in group . episodes ]])
203+ rewards = torch .tensor ([[e .reward for e in group ]])
226204 mean = rewards .mean (1 , keepdim = True )
227205 std = rewards .std (1 , keepdim = True )
228206 advantages = (rewards - mean ) / (std + 1e-4 )
@@ -327,12 +305,6 @@ async def main(cfg: DictConfig):
327305 mlogger = await get_or_create_metric_logger ()
328306 await mlogger .init_backends .call_one (metric_logging_cfg )
329307
330- # In the host mesh v0 case, actors on remote hosts are not able to communicate
331- # with one another. Therefore we use the controller as our storage volume.
332- if not MONARCH_HOSTMESH_V1 .get_value ():
333- await ts .initialize (strategy = ts .ControllerStorageVolumes ())
334- print ("Torchstore successfully initialized with controller storage strategy" )
335-
336308 # ---- Setup services ---- #
337309
338310 (
@@ -364,21 +336,19 @@ async def main(cfg: DictConfig):
364336
365337 print ("All services initialized successfully!" )
366338 shutdown_event = asyncio .Event ()
367- # In the HostMesh v1 case, we spawn a torchstore storage volume
368- # per trainer process.
339+ # Here we spawn a torchstore storage volume per trainer process.
369340 # We initialize after service initialization because torchstore currently
370341 # requires access to the underlying proc meshes in the local rank strategy.
371342 # We should be able to hide this in the future.
372- if MONARCH_HOSTMESH_V1 .get_value ():
373- # TODO: support multiple host meshes
374- trainer_num_procs = cfg .actors .trainer ["procs" ]
375- trainer_host_mesh_name = cfg .actors .trainer ["mesh_name" ]
376- trainer_hosts = provisioner .get_host_mesh (trainer_host_mesh_name )
377- await ts .initialize (
378- mesh = trainer_hosts .spawn_procs (per_host = {"procs" : trainer_num_procs }),
379- strategy = ts .LocalRankStrategy (),
380- )
381- print ("Torchstore successfully initialized with local rank strategy" )
343+ # TODO: support multiple host meshes
344+ trainer_num_procs = cfg .actors .trainer ["procs" ]
345+ trainer_host_mesh_name = cfg .actors .trainer ["mesh_name" ]
346+ trainer_hosts = provisioner .get_host_mesh (trainer_host_mesh_name )
347+ await ts .initialize (
348+ mesh = trainer_hosts .spawn_procs (per_host = {"procs" : trainer_num_procs }),
349+ strategy = ts .LocalRankStrategy (),
350+ )
351+ print ("Torchstore successfully initialized with local rank strategy" )
382352
383353 # ---- Core RL loops ---- #
384354 async def continuous_rollouts ():
@@ -395,44 +365,32 @@ async def continuous_rollouts():
395365 t .step ("data_loading" )
396366
397367 prompt , target = sample ["request" ], sample ["target" ]
398- responses = await policy .generate .route (prompt )
399- # TODO: this shall be part of the responses metadata instead of a separate call
400- version = await policy .get_version .route ()
401-
368+ responses : list [Completion ] = await policy .generate .route (prompt )
402369 t .step ("policy_generation" )
403370
404- assert (
405- len (responses ) > 0
406- ), "Sanity check: Responses should NEVER return empty"
407- assert (
408- version := responses [0 ].generator_version
409- ) is not None , "Response must indicate a version"
410- group = Group .new_group (
411- group_id = rollout_count ,
412- group_size = group_size ,
413- request = prompt ,
414- policy_version = version ,
415- pad_id = pad_id ,
416- request_len = max_req_tokens ,
417- response_len = max_res_tokens ,
418- target = target ,
419- )
420-
371+ # Construct episodes and calculate rewards
372+ episodes = []
421373 input_ids = torch .ones (
422374 (group_size , max_req_tokens + max_res_tokens ),
423375 dtype = torch .long ,
424- device = "cuda" ,
425376 )
426- # Populate episode info and calculate rewards
427- for i , (episode , response ) in enumerate (zip (group .episodes , responses )):
428- episode .request_tokens = response .prompt_ids
429- episode .response_tokens = response .token_ids
430- episode .response = response .text
431- input_ids [i , :max_req_tokens ] = episode .request_tensor
432- input_ids [i , max_req_tokens :] = episode .response_tensor
377+ for i , response in enumerate (responses ):
378+ episode = Episode (
379+ episode_id = str (uuid .uuid4 ()),
380+ pad_id = pad_id ,
381+ request_len = max_req_tokens ,
382+ response_len = max_res_tokens ,
383+ target = target ,
384+ completion = response ,
385+ )
433386 episode .reward = await reward_actor .evaluate_response .route (
434387 prompt = prompt , response = response .text , target = target
435388 )
389+ episodes .append (episode )
390+
391+ # Build input_ids for reference logprobs
392+ input_ids [i , :max_req_tokens ] = episode .request_tensor
393+ input_ids [i , max_req_tokens :] = episode .response_tensor
436394
437395 t .step ("reward_evaluation" )
438396
@@ -441,14 +399,13 @@ async def continuous_rollouts():
441399 )
442400 t .step ("reference_model_calculate_logprobs" )
443401
444- for i , episode in enumerate (group . episodes ):
402+ for i , episode in enumerate (episodes ):
445403 episode .ref_logprobs = ref_logprobs [i ]
446404 del ref_logprobs , input_ids
447- t .step ("compute_logprobs" )
448405
449406 # Calculate advantages and add to replay buffer
450- advantages = await compute_advantages .compute .call_one (group )
451- for episode , advantage in zip (group . episodes , advantages ):
407+ advantages = await compute_advantages .compute .call_one (episodes )
408+ for episode , advantage in zip (episodes , advantages ):
452409 episode .advantage = advantage
453410 await replay_buffer .add .call_one (episode )
454411
0 commit comments