Skip to content

Commit 1695d5a

Browse files
committed
add trainingStepTimeLog in train controller
1 parent 265707b commit 1695d5a

File tree

3 files changed

+30
-5
lines changed

3 files changed

+30
-5
lines changed

xtuner/v1/rl/base/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .controller import TrainingController, TrainingControllerProxy
1+
from .controller import TrainingController, TrainingControllerProxy, TrainingStepTimeLog
22
from .loss import BaseRLLossConfig, RLLossContextInputItem
33
from .worker import TrainingWorker, TrainingWorkerClass, TrainingWorkerProxy, WorkerConfig, WorkerLogItem
44

@@ -13,4 +13,5 @@
1313
"BaseRLLossConfig",
1414
"RLLossContextInputItem",
1515
"WorkerLogItem",
16+
"TrainingStepTimeLog",
1617
]

xtuner/v1/rl/base/controller.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import math
22
import os
33
import random
4+
import time
45
from pathlib import Path
56
from typing import Literal, cast
67

78
import numpy as np
89
import ray
910
import torch
1011
from ray.actor import ActorProxy
12+
from typing_extensions import TypedDict
1113

1214
from xtuner.v1.data_proto.sequence_context import SequenceContext
1315
from xtuner.v1.model.compose.base import BaseComposeConfig
@@ -21,6 +23,11 @@
2123
from .worker import TrainingWorker, WorkerInputItem, WorkerLogItem
2224

2325

26+
class TrainingStepTimeLog(TypedDict):
27+
data_packing_time: float
28+
worker_training_time: float
29+
30+
2431
class RawTrainingController:
2532
def __init__(self, workers: list[TrainingWorker]) -> None:
2633
self.workers = workers
@@ -367,7 +374,8 @@ def fit(
367374
pack_max_length: int,
368375
rollout_idx: int,
369376
enable_dp_balance: bool = True,
370-
) -> list[WorkerLogItem]:
377+
) -> tuple[list[WorkerLogItem], TrainingStepTimeLog]:
378+
pack_start_time = time.perf_counter()
371379
self._set_data_batches_properties(data_batches)
372380

373381
world_size = len(self.workers)
@@ -422,6 +430,9 @@ def fit(
422430
max_packs = max_packs_per_step[step_idx]
423431
self._pad_to_max_packs_across_workes(packed_data_batches, step_idx, max_packs, pack_max_length)
424432

433+
pack_end_time = time.perf_counter()
434+
self.logger.info(f"Data packing took {pack_end_time - pack_start_time:.2f} seconds.")
435+
425436
handles = []
426437
for worker_idx, worker in enumerate(self.workers):
427438
handles.append(
@@ -430,8 +441,14 @@ def fit(
430441
rollout_idx=rollout_idx,
431442
)
432443
)
433-
log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT)
434-
return log_infos
444+
worker_log_infos = ray.get(handles)
445+
fit_end_time = time.perf_counter()
446+
self.logger.info(f"Training step took {fit_end_time - pack_end_time:.2f} seconds.")
447+
training_time: TrainingStepTimeLog = {
448+
"data_packing_time": pack_end_time - pack_start_time,
449+
"worker_training_time": fit_end_time - pack_end_time,
450+
}
451+
return worker_log_infos, training_time
435452

436453
@ray_method
437454
def offload(self, target: Literal["model", "optimizer", "all"] = "all"):

xtuner/v1/train/rl_trainer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from xtuner.v1.rl.base import (
3030
TrainingController,
3131
TrainingControllerProxy,
32+
TrainingStepTimeLog,
3233
TrainingWorkerClass,
3334
TrainingWorkerProxy,
3435
WorkerConfig,
@@ -555,12 +556,18 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste
555556
)
556557

557558
with timer("training", step_timer_dict):
558-
workers_log_item: List[WorkerLogItem] = ray.get(
559+
workers_log_item: List[WorkerLogItem]
560+
training_time: TrainingStepTimeLog
561+
workers_log_item, training_time = ray.get(
559562
self._train_controller.fit.remote(
560563
data_batches, pack_max_length=self._train_worker_cfg.pack_max_length, rollout_idx=rollout_idx
561564
)
562565
)
563566
self._writer.add_scalar(tag="time/training", scalar_value=step_timer_dict["training"], global_step=rollout_idx)
567+
self._writer.add_scalars(
568+
tag_scalar_dict={f"time/train_{key}": cast(float, value) for key, value in training_time.items()},
569+
global_step=rollout_idx,
570+
)
564571

565572
rank0_log_item = workers_log_item[0]
566573
# These metrics are already aggregated across distributed workers and logging only the metrics from rank 0.

0 commit comments

Comments
 (0)