55"""
66
77import asyncio
8+ import torch
89
9- from areal .api .cli_args import TrainEngineConfig
10+ from torch import Tensor
11+ from collections .abc import Callable
12+ from typing import Any
13+ from areal .extension .asystem .api .cli_args import TrainEngineConfig
1014from areal .api .engine_api import TrainEngine
1115from areal .api .io_struct import AllocationMode , FinetuneSpec
1216from areal .api .scheduler_api import Job , Scheduler
1317from areal .controller .train_controller import TrainController as BaseTrainController
1418from areal .extension .asystem .remote_hybrid_train_worker import RemoteMegatronInitConfig
15- from areal .utils import logging
19+ from areal .utils import logging , stats_tracker
20+ from areal .controller .batch import DistributedBatch
21+ from areal .api .io_struct import AllocationMode , SaveLoadMeta , WeightUpdateMeta
1622
1723logger = logging .getLogger ("TrainController" )
1824
1925
26+ def _execute_parallel_tasks (workers , scheduler , method_name , * args ):
27+ """Execute tasks in parallel across all workers.
28+
29+ This is a helper function to reduce code duplication when executing
30+ the same method on all workers with identical parameters.
31+
32+ Parameters
33+ ----------
34+ workers : list
35+ List of worker objects
36+ scheduler : Scheduler
37+ Scheduler instance for async calls
38+ method_name : str
39+ Name of the method to call on each worker's engine
40+ *args, **kwargs
41+ Arguments to pass to the method
42+
43+ Returns
44+ -------
45+ list
46+ Results from all workers
47+
48+ Raises
49+ ------
50+ RuntimeError
51+ If any worker fails to execute the task
52+ """
53+ tasks = [
54+ scheduler .async_call_engine (
55+ worker .id , method_name , * args , _should_bcast = False
56+ )
57+ for worker in workers
58+ ]
59+
60+ try :
61+ return asyncio .run (asyncio .gather (* tasks , return_exceptions = False ))
62+ except KeyboardInterrupt :
63+ raise
64+ except Exception as e :
65+ raise RuntimeError (f"{ method_name } failed, error: { e } " )
66+
67+
68+ def _calc_metrics (batch_inputs ):
69+ # seqlen std
70+ seqlens = [td ["seqlen" ].sum ().item () for td in batch_inputs ]
71+ seqlen_std = torch .tensor (seqlens ).float ().std ().item ()
72+ stats_tracker .scalar (** {"seqlen_std" : seqlen_std })
73+
74+
2075class TrainController (BaseTrainController ):
2176 """ASystem-specific TrainController.
2277
@@ -69,9 +124,16 @@ def initialize(
69124 self .logger = logging .getLogger ("[TrainController]" )
70125
71126 # Store configuration
127+ self .parallel_strategy = alloc_mode .train
72128 self ._worker_role = role
73129 self .alloc_mode = alloc_mode
74- self .parallel_strategy = alloc_mode .train
130+ self .world_size = self .alloc_mode .train .world_size
131+ self .dp_size = self .alloc_mode .train .dp_size
132+ self .tp_size = self .alloc_mode .train .tp_size
133+ self .pp_size = self .alloc_mode .train .pp_size
134+ self .group_size = kwargs .get ("group_size" )
135+ self .enable_colocate_mode = kwargs .get ("enable_colocate_mode" )
136+ self .storage_prefix = kwargs .get ("storage_prefix" )
75137
76138 # Create job for scheduler
77139 job = Job (
@@ -99,10 +161,6 @@ def initialize(
99161 asyncio .run (self ._async_create_engines (engine_path ))
100162 asyncio .run (self ._async_initialize (job , ft_spec , ** kwargs ))
101163
102- # Identify DP head workers
103- # todo: @chucai, implement this, record rank info in hybrid train worker and implement is_data_parallel_head...
104- # self._identify_dp_heads()
105-
106164 self .logger .info ("TrainController initialization complete" )
107165
108166 async def _async_initialize (self , job : Job , ft_spec : FinetuneSpec , ** kwargs ):
@@ -121,7 +179,17 @@ async def _async_initialize(self, job: Job, ft_spec: FinetuneSpec, **kwargs):
121179 for worker , init_config in zip (self .workers , init_configs )
122180 ]
123181
124- await asyncio .gather (* tasks )
182+ self .rank_info = {}
183+ try :
184+ gather_results = await asyncio .gather (* tasks , return_exceptions = False )
185+ except Exception as e :
186+ self .logger .error (f"Initialization failed with error: { e } " )
187+ raise RuntimeError (f"Failed to initialize workers, error: { e } " )
188+
189+ for worker_index , result in enumerate (gather_results ):
190+ self .rank_info [worker_index ] = result
191+ self .logger .info (f"Worker { worker_index } succeeded: { result } " )
192+
125193 self .logger .info ("All engines are initialized!" )
126194
127195 def _build_engine_initialize_config (
@@ -139,3 +207,101 @@ def _build_engine_initialize_config(
139207 )
140208 for index , worker in enumerate (self .workers )
141209 ]
210+
211+ def train_batch (
212+ self ,
213+ input_ : DistributedBatch ,
214+ loss_fn : Callable [[torch .Tensor , dict [str , Any ]], torch .Tensor ],
215+ loss_weight_fn : Callable [[dict [str , Any ]], torch .Tensor ],
216+ ) -> dict [str , float ]:
217+ self .logger .info (f"start to train_batch" )
218+ with (stats_tracker .record_timing ("train_batch_data_split" ), ):
219+ batches = input_ .chunk_by_ffd (self .group_size , self .dp_size )
220+
221+ _calc_metrics (batches )
222+
223+ tasks = [
224+ self .scheduler .async_call_engine (
225+ worker .id , "train_batch" , batches [self .rank_info [index ]["dp_rank" ]], _should_bcast = False
226+ )
227+ for index , worker in enumerate (self .workers )
228+ ]
229+
230+ try :
231+ results = asyncio .run (asyncio .gather (* tasks , return_exceptions = False ))
232+ except KeyboardInterrupt :
233+ raise
234+ except Exception as e :
235+ raise RuntimeError (f"train_batch failed, error: { e } " )
236+
237+ for worker_result in results :
238+ if len (worker_result ) > 1 :
239+ for minibatch in worker_result :
240+ stats_tracker .scalar (** minibatch )
241+ else :
242+ stats_tracker .scalar (** worker_result [0 ])
243+
244+ return {}
245+
246+ def compute_logp (self , input_ : DistributedBatch ) -> Tensor :
247+ """Update the model with a batch of data and a loss function."""
248+ logger .info (f"start to compute_logp" )
249+ with (
250+ stats_tracker .record_timing ("compute_logp_data_split" ),
251+ ):
252+ batches = input_ .chunk (self .dp_size )
253+ tasks = [
254+ self .scheduler .async_call_engine (
255+ worker .id , "compute_logprobs" , batches [self .rank_info [index ]["dp_rank" ]], _should_bcast = False
256+ )
257+ for index , worker in enumerate (self .workers )
258+ ]
259+
260+ try :
261+ results = asyncio .run (asyncio .gather (* tasks , return_exceptions = False ))
262+ except KeyboardInterrupt :
263+ raise
264+ except Exception as e :
265+ raise RuntimeError (f"compute_logp failed, error: { e } " )
266+
267+ # cat tensor from dp head with padding
268+ tensors_from_dp_heads = results [: self .dp_size ]
269+ if not tensors_from_dp_heads :
270+ return torch .tensor ([])
271+
272+ # Find max length in dim 1
273+ max_len = max (t .shape [1 ] for t in tensors_from_dp_heads )
274+ max_len_all = max (t .shape [1 ] for t in results )
275+ assert max_len_all == max_len
276+ # Pad all tensors to max length
277+ padded_tensors = []
278+ for t in tensors_from_dp_heads :
279+ pad_size = max_len - t .shape [1 ]
280+ padded = torch .nn .functional .pad (t , (0 , pad_size ), value = 0.0 )
281+ padded_tensors .append (padded )
282+
283+ # Concatenate along batch dimension
284+ concatenated_result = torch .cat (padded_tensors , dim = 0 )
285+ return concatenated_result
286+
287+ def upload_weights (self , meta : WeightUpdateMeta ):
288+ """Upload weights to the inference engine."""
289+ _execute_parallel_tasks (self .workers , self .scheduler , "upload_weights" , meta )
290+
291+ def save (self , meta : SaveLoadMeta ):
292+ """Save model weights (and optimizer states) for later use."""
293+ _execute_parallel_tasks (self .workers , self .scheduler , "save" , meta )
294+
295+ def load (self , meta : SaveLoadMeta ):
296+ """Load model weights and optimizer states from a file."""
297+ _execute_parallel_tasks (self .workers , self .scheduler , "load" , meta )
298+
299+ def notify_event (self , event : str , global_step : int ) -> None :
300+ """Notify workers about training start/end events.
301+
302+ Args:
303+ event: "train_start" or "train_end"
304+ global_step: Current global step
305+ """
306+ _execute_parallel_tasks (self .workers , self .scheduler , "notify_event" , event , global_step )
307+ return None
0 commit comments