1212from typing import Any
1313from areal .extension .asystem .api .cli_args import TrainEngineConfig
1414from areal .api .engine_api import TrainEngine
15- from areal .api .io_struct import AllocationMode , FinetuneSpec
15+ from areal .api .io_struct import FinetuneSpec
1616from areal .api .scheduler_api import Job , Scheduler
1717from areal .controller .train_controller import TrainController as BaseTrainController
18+ from areal .extension .asystem .controller .util import execute_parallel_tasks , calc_metrics
1819from areal .extension .asystem .remote_hybrid_train_worker import RemoteMegatronInitConfig
1920from areal .utils import logging , stats_tracker
2021from areal .controller .batch import DistributedBatch
2324logger = logging .getLogger ("TrainController" )
2425
2526
26- def _execute_parallel_tasks (workers , scheduler , method_name , * args ):
27- """Execute tasks in parallel across all workers.
28-
29- This is a helper function to reduce code duplication when executing
30- the same method on all workers with identical parameters.
31-
32- Parameters
33- ----------
34- workers : list
35- List of worker objects
36- scheduler : Scheduler
37- Scheduler instance for async calls
38- method_name : str
39- Name of the method to call on each worker's engine
40- *args, **kwargs
41- Arguments to pass to the method
42-
43- Returns
44- -------
45- list
46- Results from all workers
47-
48- Raises
49- ------
50- RuntimeError
51- If any worker fails to execute the task
52- """
53- tasks = [
54- scheduler .async_call_engine (
55- worker .id , method_name , * args , _should_bcast = False
56- )
57- for worker in workers
58- ]
59-
60- try :
61- return asyncio .run (asyncio .gather (* tasks , return_exceptions = False ))
62- except KeyboardInterrupt :
63- raise
64- except Exception as e :
65- raise RuntimeError (f"{ method_name } failed, error: { e } " )
66-
67-
68- def _calc_metrics (batch_inputs ):
69- # seqlen std
70- seqlens = [td ["seqlen" ].sum ().item () for td in batch_inputs ]
71- seqlen_std = torch .tensor (seqlens ).float ().std ().item ()
72- stats_tracker .scalar (** {"seqlen_std" : seqlen_std })
73-
74-
7527class TrainController (BaseTrainController ):
7628 """ASystem-specific TrainController.
7729
@@ -218,7 +170,7 @@ def train_batch(
218170 with (stats_tracker .record_timing ("train_batch_data_split" ), ):
219171 batches = input_ .chunk_by_ffd (self .group_size , self .dp_size )
220172
221- _calc_metrics (batches )
173+ calc_metrics (batches )
222174
223175 tasks = [
224176 self .scheduler .async_call_engine (
@@ -286,15 +238,17 @@ def compute_logp(self, input_: DistributedBatch) -> Tensor:
286238
287239 def upload_weights (self , meta : WeightUpdateMeta ):
288240 """Upload weights to the inference engine."""
289- _execute_parallel_tasks (self .workers , self .scheduler , "upload_weights" , meta )
241+ self .logger .info ("begin upload_weights" )
242+ execute_parallel_tasks (self .workers , self .scheduler , "upload_weights" , meta )
243+ self .logger .info ("finished upload_weights" )
290244
291245 def save (self , meta : SaveLoadMeta ):
292246 """Save model weights (and optimizer states) for later use."""
293- _execute_parallel_tasks (self .workers , self .scheduler , "save" , meta )
247+ execute_parallel_tasks (self .workers , self .scheduler , "save" , meta )
294248
295249 def load (self , meta : SaveLoadMeta ):
296250 """Load model weights and optimizer states from a file."""
297- _execute_parallel_tasks (self .workers , self .scheduler , "load" , meta )
251+ execute_parallel_tasks (self .workers , self .scheduler , "load" , meta )
298252
299253 def notify_event (self , event : str , global_step : int ) -> None :
300254 """Notify workers about training start/end events.
@@ -303,5 +257,5 @@ def notify_event(self, event: str, global_step: int) -> None:
303257 event: "train_start" or "train_end"
304258 global_step: Current global step
305259 """
306- _execute_parallel_tasks (self .workers , self .scheduler , "notify_event" , event , global_step )
260+ execute_parallel_tasks (self .workers , self .scheduler , "notify_event" , event , global_step )
307261 return None
0 commit comments