11import math
22import os
33import random
4+ import time
45from pathlib import Path
56from typing import Literal , cast
67
78import numpy as np
89import ray
910import torch
1011from ray .actor import ActorProxy
12+ from typing_extensions import TypedDict
1113
1214from xtuner .v1 .data_proto .sequence_context import SequenceContext
1315from xtuner .v1 .model .compose .base import BaseComposeConfig
2123from .worker import TrainingWorker , WorkerInputItem , WorkerLogItem
2224
2325
26+ class TrainingStepTimeLog (TypedDict ):
27+ data_packing_time : float
28+ worker_training_time : float
29+
30+
2431class RawTrainingController :
2532 def __init__ (self , workers : list [TrainingWorker ]) -> None :
2633 self .workers = workers
@@ -367,7 +374,8 @@ def fit(
367374 pack_max_length : int ,
368375 rollout_idx : int ,
369376 enable_dp_balance : bool = True ,
370- ) -> list [WorkerLogItem ]:
377+ ) -> tuple [list [WorkerLogItem ], TrainingStepTimeLog ]:
378+ pack_start_time = time .perf_counter ()
371379 self ._set_data_batches_properties (data_batches )
372380
373381 world_size = len (self .workers )
@@ -422,6 +430,9 @@ def fit(
422430 max_packs = max_packs_per_step [step_idx ]
423431 self ._pad_to_max_packs_across_workes (packed_data_batches , step_idx , max_packs , pack_max_length )
424432
433+ pack_end_time = time .perf_counter ()
434+ self .logger .info (f"Data packing took { pack_end_time - pack_start_time :.2f} seconds." )
435+
425436 handles = []
426437 for worker_idx , worker in enumerate (self .workers ):
427438 handles .append (
@@ -430,8 +441,14 @@ def fit(
430441 rollout_idx = rollout_idx ,
431442 )
432443 )
433- log_infos = ray .get (handles , timeout = TRAIN_RAY_GET_TIMEOUT )
434- return log_infos
444+ worker_log_infos = ray .get (handles )
445+ fit_end_time = time .perf_counter ()
446+ self .logger .info (f"Training step took { fit_end_time - pack_end_time :.2f} seconds." )
447+ training_time : TrainingStepTimeLog = {
448+ "data_packing_time" : pack_end_time - pack_start_time ,
449+ "worker_training_time" : fit_end_time - pack_end_time ,
450+ }
451+ return worker_log_infos , training_time
435452
436453 @ray_method
437454 def offload (self , target : Literal ["model" , "optimizer" , "all" ] = "all" ):
0 commit comments