3737
3838@dataclass  
3939class  ReferenceActor (ForgeActor ):
40+     """ 
41+     Original idea (not updated); On second throught this might be overkill 
42+     if we can rely on the Service Replicas to handle the queue since there's no 
43+     real pre/post proc or host management (maybe later for DP?). For now just 
44+     directly spin up services of the reference models 
45+     """ 
46+ 
4047    model : Model  =  field (default_factory = Model )
4148    # parallelism: Parallelism = field(default_factory=Parallelism) 
4249    # comm: Comm = field(default_factory=Comm) 
@@ -95,13 +102,18 @@ async def setup(self):
95102        # Spawn the RefModel 
96103        self .ref_model  =  await  spawn_service (
97104            default_service_cfg ,
98-             RefModel ,
105+             HuggingFaceRefModel ,
99106            model_name = self .model .name ,
100107            device = self .device ,
101108        )
102109
103110        # Kick off background processing 
104-         asyncio .create_task (self .run_processing .call ())
111+         self .start_processing ()
112+ 
113+     def  start_processing (self ):
114+         """Start the replica's processing loop if not already running.""" 
115+         if  self ._run_task  is  None  or  self ._run_task .done ():
116+             self ._run_task  =  asyncio .create_task (self .run ())
105117
106118    @endpoint  
107119    async  def  forward (self , token_ids : list [int ]) ->  torch .Tensor :
@@ -112,8 +124,7 @@ async def forward(self, token_ids: list[int]) -> torch.Tensor:
112124        self .queue .append ((token_ids , fut ))
113125        return  await  fut 
114126
115-     @endpoint  
116-     async  def  run_processing (self ):
127+     async  def  run (self ):
117128        """ 
118129        Simple loop to pass things along to the ref model 
119130        """ 
@@ -127,11 +138,105 @@ async def run_processing(self):
127138            fut .set_result (model_output )
128139
129140    @endpoint  
130-     async  def  cleanup (self ) ->  None :
141+     async  def  stop (self ) ->  None :
131142        self .running  =  False 
132143
133144
134- class  RefModel (ForgeActor ):
145+ @dataclass  
146+ class  TitanRefModel (ForgeActor ):
147+     """ 
148+     Represents a reference actor leveraging a torchtitan model for execution 
149+     """ 
150+ 
151+     # Refer to titan JobConfig for enabling more ForgeEngine configuration 
152+     model : Model  =  field (default_factory = Model )
153+     parallelism : Parallelism  =  field (default_factory = Parallelism )
154+ 
155+     # Populated in setup (commented out for now for engine_config parsing) 
156+     # engine: ForgeEngine | None = None 
157+ 
158+     def  __post_init__ (self ):
159+         """Initializes config types and env variables.""" 
160+         # Instantiate dict fields 
161+         for  f  in  fields (self ):
162+             attr  =  getattr (self , f .name )
163+             if  isinstance (attr , Mapping ):
164+                 setattr (self , f .name , f .type (** attr ))
165+             elif  not  isinstance (attr , f .type ):
166+                 raise  TypeError (
167+                     f"{ f .name } { f .type }  
168+                 )
169+ 
170+         """ 
171+         torchrun normally hands env variables, but we need to do it ourselves 
172+         in monarch for now. 
173+         """ 
174+         self .rank  =  current_rank ().rank 
175+         self .size  =  math .prod (current_size ().values ())
176+ 
177+         env  =  {
178+             "RANK" : str (self .rank ),
179+             "LOCAL_RANK" : str (self .rank ),
180+             "LOCAL_WORLD_SIZE" : str (self .size ),
181+             "GROUP_RANK" : str (self .size ),
182+             "GROUP_WORLD_SIZE" : str (self .size ),
183+             "ROLE_RANK" : str (self .rank ),
184+             "ROLE_WORLD_SIZE" : str (self .size ),
185+             "ROLE_NAME" : "rank" ,
186+             "WORLD_SIZE" : str (self .size ),
187+             "PYTORCH_CUDA_ALLOC_CONF" : "expandable_segments:True" ,
188+         }
189+         os .environ .update (env )
190+ 
191+     @endpoint  
192+     async  def  setup (self ):
193+         engine_config  =  {f .name : getattr (self , f .name ) for  f  in  fields (self )}
194+         self .engine  =  ForgeEngine (ForgeJobConfig (** engine_config ))
195+ 
196+     @endpoint  
197+     async  def  forward (self , token_ids : list [int ]) ->  torch .Tensor :
198+         """ 
199+         Given a return the log_probability of the token_ids 
200+         (Used as the reference_logprobs for KL Divergence) 
201+         """ 
202+         model_parts  =  self .engine .model_parts 
203+         parallel_dims  =  self .engine .parallel_dims 
204+ 
205+         # Use provided token_ids directly 
206+         device  =  torch .device ("cuda"  if  torch .cuda .is_available () else  "cpu" )
207+         input_ids  =  torch .tensor (token_ids , dtype = torch .long ).unsqueeze (0 ).to (device )
208+ 
209+         optional_context_parallel_ctx  =  (
210+             dist_utils .create_context_parallel_ctx (
211+                 cp_mesh = parallel_dims .world_mesh ["cp" ],
212+                 cp_buffers = [inputs , labels ] +  [m .freqs_cis  for  m  in  model_parts ],
213+                 cp_seq_dims = [1 , 1 ] +  [0  for  _  in  model_parts ],
214+                 cp_no_restore_buffers = {inputs , labels },
215+                 cp_rotate_method = self .job_config .parallelism .context_parallel_rotate_method ,
216+             )
217+             if  parallel_dims .cp_enabled 
218+             else  None 
219+         )
220+ 
221+         if  parallel_dims .pp_enabled :
222+             raise  NotImplementedError ("PP not implemented yet" )
223+         else :
224+             # Non-PP forward / backward 
225+             with  self .engine .train_context (optional_context_parallel_ctx ):
226+                 assert  len (model_parts ) ==  1 
227+                 with  self .engine .maybe_enable_amp :
228+                     pred  =  model_parts [0 ](input_ids )
229+ 
230+         # TODO: Update compute_sequence_logprobs to convert probs (logits) to logprobs 
231+         return  pred 
232+ 
233+ 
234+ # Maintained to keep GRPO app prior to migration 
235+ class  HuggingFaceRefModel (ForgeActor ):
236+     """ 
237+     Represents a reference actor leveraging HuggingFace for execution 
238+     """ 
239+ 
135240    def  __init__ (self , model_name , device : torch .device  |  None  =  None ):
136241        super ().__init__ ()
137242        self .model_name  =  model_name 
0 commit comments