Skip to content

Commit 4f6412f

Browse files
committed
[Refactor] refactor trainer fit loop for better code organization
- Extract model input preparation logic into _prepare_model_input method - Move loss_log update logic from trainer to train_engine - Simplify _log_step method signature by using instance variables - Fix type hints: consumed_tokens and consumed_img_tokens should be int - Adjust consumed_samples calculation position for better logic flow
1 parent d786b83 commit 4f6412f

File tree

3 files changed

+89
-83
lines changed

3 files changed

+89
-83
lines changed

xtuner/v1/engine/train_engine.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ class LossLog(TypedDict):
5353
class OtherLog(TypedDict):
5454
__pydantic_config__ = ConfigDict(arbitrary_types_allowed=True) # type: ignore[misc]
5555
maxvio: NotRequired[float]
56-
consumed_tokens: float
57-
consumed_img_tokens: NotRequired[float]
56+
consumed_tokens: int
57+
consumed_img_tokens: NotRequired[int]
5858
extra_info: ModelForwardExtraLogInfo
5959
efficient_attn_ratio: float
6060

@@ -347,9 +347,20 @@ def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
347347
reduced_z_loss = step_z_loss
348348
dist.all_reduce(reduced_z_loss.div_(dist.get_world_size()))
349349
loss_log["reduced_z_loss"] = reduced_z_loss.item()
350-
other_log["consumed_tokens"] = step_consumed_tokens.item()
350+
other_log["consumed_tokens"] = int(step_consumed_tokens.item())
351351
other_log["extra_info"] = train_engine_extra_info
352352
other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item()
353+
354+
extra_info = other_log.get("extra_info", {}) # type: ignore
355+
356+
# TODO: @duanyanhui `extra_info` should be redesigned.
357+
if not isinstance(extra_info, ModelForwardExtraLogInfo):
358+
extra_info = ModelForwardExtraLogInfo(extra_info)
359+
loss_log.update(extra_info.get())
360+
361+
if "maxvio" in other_log:
362+
loss_log["maxvio"] = other_log["maxvio"] # type: ignore
363+
loss_log["efficient_attn_ratio"] = other_log["efficient_attn_ratio"] # type: ignore
353364
return loss_log, other_log
354365

355366
def from_hf(self, hf_path: str | Path, strict: bool = False):

xtuner/v1/engine/vision_compose_train_engine.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,19 @@ def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]:
220220
reduced_z_loss = step_z_loss
221221
dist.all_reduce(reduced_z_loss.div_(dist.get_world_size()))
222222
loss_log["reduced_z_loss"] = reduced_z_loss.item()
223-
other_log["consumed_tokens"] = step_consumed_tokens.item()
223+
other_log["consumed_tokens"] = int(step_consumed_tokens.item())
224224
other_log["extra_info"] = train_engine_extra_info # type: ignore[assignment]
225225
other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item()
226-
other_log["consumed_img_tokens"] = step_consumed_img_tokens
226+
other_log["consumed_img_tokens"] = int(step_consumed_img_tokens)
227+
228+
extra_info = other_log.get("extra_info", {}) # type: ignore
229+
230+
# TODO: @duanyanhui `extra_info` should be redesigned.
231+
if not isinstance(extra_info, ModelForwardExtraLogInfo):
232+
extra_info = ModelForwardExtraLogInfo(extra_info)
233+
loss_log.update(extra_info.get())
234+
235+
if "maxvio" in other_log:
236+
loss_log["maxvio"] = other_log["maxvio"] # type: ignore
237+
loss_log["efficient_attn_ratio"] = other_log["efficient_attn_ratio"] # type: ignore
227238
return loss_log, other_log

xtuner/v1/train/trainer.py

Lines changed: 62 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from xtuner.v1.model.base import ModelItem, TransformerConfig, XTunerBaseModelConfig
3838
from xtuner.v1.model.compose.base import BaseComposeConfig
3939
from xtuner.v1.model.moe.moe import MoEConfig
40-
from xtuner.v1.model.utils import ModelForwardExtraLogInfo
4140
from xtuner.v1.patch import patch_default_save_plan
4241
from xtuner.v1.profiler import profiling_memory, profiling_time
4342
from xtuner.v1.profiler.prober import ProberList
@@ -701,44 +700,13 @@ def fit(self):
701700
train_begin = time.time()
702701
time_before_get_data = time.time()
703702
for data_batch in self._data_iter():
703+
consumed_samples = len(data_batch)
704+
time_before_train_step = time.time()
705+
704706
ProberList.set_step(self._cur_step + 1)
705707
DEVICE_MODULE.reset_peak_memory_stats()
706708

707-
time_before_train_step = time.time()
708-
data_time = time_before_train_step - time_before_get_data
709-
710-
seq_ctx_list: list[SequenceContext] = []
711-
loss_ctx_input_list: list[CELossContextInputItem] = []
712-
for data in data_batch:
713-
seq_ctx = data["seq_ctx"].to(DEVICE)
714-
loss_ctx_input = CELossContextInputItem(shifted_labels=data["shifted_labels"]).to(DEVICE)
715-
if self.sp_mesh.size() > 1:
716-
seq_ctx = seq_ctx.split(sequence_parallel_mesh=self.sp_mesh)
717-
loss_ctx_input = loss_ctx_input.sp_split(self.sp_mesh)
718-
seq_ctx_list.append(seq_ctx)
719-
loss_ctx_input_list.append(loss_ctx_input)
720-
721-
del data_batch
722-
723-
LossContext = self.loss_cfg.loss_ctx_cls
724-
batches_loss_kwargs = LossContext.build_batches_loss_kwargs(
725-
loss_ctx_input_list,
726-
self.loss_cfg,
727-
cu_seq_lens_list=[seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list],
728-
sp_mesh=self.sp_mesh,
729-
)
730-
engine_input = []
731-
for seq_ctx, loss_kwargs in zip(seq_ctx_list, batches_loss_kwargs):
732-
loss_ctx = LossContext(
733-
loss_cfg=self.loss_cfg,
734-
loss_kwargs=loss_kwargs,
735-
)
736-
engine_input.append(
737-
ModelItem(
738-
seq_ctx=seq_ctx,
739-
loss_ctx=loss_ctx,
740-
)
741-
)
709+
engine_input = self._prepare_model_input(data_batch)
742710

743711
with self._maybe_profiling():
744712
loss_log, other_log = self._engine.train_step(engine_input)
@@ -756,47 +724,30 @@ def fit(self):
756724

757725
grad_norm = self._engine.clip_grad_norm(do_clip=self._do_clip, dtype=self._grad_norm_dtype)
758726
self._engine.step_optimizer(grad_norm)
727+
759728
time_after_train_step = time.time()
760729
ProberList.after_step()
761-
step_time = time_after_train_step - time_before_train_step
762-
step_consumed_tokens = other_log["consumed_tokens"]
763-
step_consumed_img_tokens = other_log.get("consumed_img_tokens", None)
764730

765-
extra_info = other_log.get("extra_info", {})
766-
if isinstance(extra_info, ModelForwardExtraLogInfo):
767-
extra_info_dict = extra_info.get()
768-
else:
769-
extra_info_updated = ModelForwardExtraLogInfo(extra_info)
770-
extra_info_dict = extra_info_updated.get()
771-
loss_log.update(extra_info_dict)
772-
773-
if "maxvio" in other_log:
774-
loss_log["maxvio"] = other_log["maxvio"]
775-
loss_log["efficient_attn_ratio"] = other_log["efficient_attn_ratio"]
731+
data_time = time_before_train_step - time_before_get_data
732+
step_time = time_after_train_step - time_before_train_step
776733

777734
internal_metrics = self._maybe_pop_model_internal_metrics(engine_input)
778735

779736
self._cur_step += 1
780-
781-
reduced_step_consumed_tokens = self._reduce_number_across_rank(step_consumed_tokens)
782-
self._reduced_consumed_tokens += reduced_step_consumed_tokens
783-
784-
self._exp_consumed_tokens += step_consumed_tokens
737+
self._reduced_consumed_tokens += self._reduce_number_across_rank(other_log["consumed_tokens"])
738+
self._reduced_consumed_samples += self._reduce_number_across_rank(consumed_samples)
739+
self._exp_consumed_tokens += other_log["consumed_tokens"]
785740
self._train_time = time_after_train_step - train_begin
786741

787742
# TODO: This log should be move before lr_scheduler.step, but for CI BC, keep it temporarily
788743
self._log_step(
789744
loss_log=loss_log,
790-
step_consumed_tokens=step_consumed_tokens,
791-
exp_consumed_tokens=self._exp_consumed_tokens,
792-
step_consumed_img_tokens=step_consumed_img_tokens,
793-
reduced_consumed_tokens=self._reduced_consumed_tokens,
794-
data_time=data_time,
795-
step_time=step_time,
796-
train_time=self._train_time,
797-
train_time_offset=self._train_time_offset,
745+
step_consumed_tokens=other_log["consumed_tokens"],
746+
step_consumed_img_tokens=other_log.get("consumed_img_tokens", None),
798747
grad_norm=grad_norm.item(),
799748
internal_metrics=internal_metrics,
749+
data_time=data_time,
750+
step_time=step_time,
800751
)
801752

802753
self._lr_scheduler.step()
@@ -817,7 +768,44 @@ def fit(self):
817768
self._metrics_recorder.close()
818769
self.logger.info(f"Training finished in {time.time() - train_begin:.2f} seconds")
819770

820-
def _reduce_number_across_rank(self, rank_number: int) -> int:
771+
def _prepare_model_input(self, data_batch) -> list[ModelItem]:
772+
seq_ctx_list: list[SequenceContext] = []
773+
loss_ctx_input_list: list[CELossContextInputItem] = []
774+
775+
for data in data_batch:
776+
seq_ctx = data["seq_ctx"].to(DEVICE)
777+
loss_ctx_input = CELossContextInputItem(shifted_labels=data["shifted_labels"]).to(DEVICE)
778+
if self.sp_mesh.size() > 1:
779+
seq_ctx = seq_ctx.split(sequence_parallel_mesh=self.sp_mesh)
780+
loss_ctx_input = loss_ctx_input.sp_split(self.sp_mesh)
781+
seq_ctx_list.append(seq_ctx)
782+
loss_ctx_input_list.append(loss_ctx_input)
783+
784+
# TODO: Consider moving data_batch deletion to the caller for better memory management.
785+
del data_batch
786+
787+
LossContext = self.loss_cfg.loss_ctx_cls
788+
batches_loss_kwargs = LossContext.build_batches_loss_kwargs(
789+
loss_ctx_input_list,
790+
self.loss_cfg,
791+
cu_seq_lens_list=[seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list],
792+
sp_mesh=self.sp_mesh,
793+
)
794+
engine_input = []
795+
for seq_ctx, loss_kwargs in zip(seq_ctx_list, batches_loss_kwargs):
796+
loss_ctx = LossContext(
797+
loss_cfg=self.loss_cfg,
798+
loss_kwargs=loss_kwargs,
799+
)
800+
engine_input.append(
801+
ModelItem(
802+
seq_ctx=seq_ctx,
803+
loss_ctx=loss_ctx,
804+
)
805+
)
806+
return engine_input
807+
808+
def _reduce_number_across_rank(self, rank_number: int | float) -> int:
821809
_gathered_list = [None for _ in range(self.world_size)]
822810
dist.all_gather_object(_gathered_list, rank_number)
823811
reduced_number = sum(_gathered_list) # type: ignore[arg-type]
@@ -1257,7 +1245,6 @@ def _data_iter(self):
12571245
data_iter = iter(self._dataloader)
12581246
data = next(data_iter)
12591247

1260-
self._reduced_consumed_samples += self._reduce_number_across_rank(len(data))
12611248
yield data
12621249

12631250
def _get_checkpoint_path(self, epoch: int, step: int, is_snapshot: bool = False) -> Path:
@@ -1434,24 +1421,21 @@ def _maybe_profiling(self):
14341421

14351422
def _log_step(
14361423
self,
1437-
loss_log: dict,
1424+
loss_log: LossLog,
14381425
step_consumed_tokens: int,
1439-
exp_consumed_tokens: int,
1440-
reduced_consumed_tokens: int,
1426+
step_consumed_img_tokens: int | None,
1427+
grad_norm: float,
14411428
data_time: float,
14421429
step_time: float,
1443-
train_time: float,
1444-
train_time_offset: float,
1445-
grad_norm: float,
1446-
step_consumed_img_tokens: float | None,
14471430
internal_metrics: InternalMetrics | None = None,
14481431
):
14491432
"""Log the training step information."""
1450-
e2e_train_time = train_time + train_time_offset
1433+
e2e_train_time = self._train_time + self._train_time_offset
1434+
14511435
tgs = step_consumed_tokens / step_time
1452-
rank_consumed_tokens = reduced_consumed_tokens / self.world_size
1436+
rank_consumed_tokens = self._reduced_consumed_tokens / self.world_size
14531437
e2e_tgs = rank_consumed_tokens / e2e_train_time
1454-
exp_tgs = exp_consumed_tokens / train_time
1438+
exp_tgs = self._exp_consumed_tokens / self._train_time
14551439
lr = self._lr_scheduler.get_last_lr()[0]
14561440

14571441
remaining_steps = self.total_step - self.cur_step
@@ -1481,7 +1465,7 @@ def _log_step(
14811465
f"Epoch {self._cur_epoch} Step {self.cur_step}/{self.total_step} "
14821466
f"data_time: {data_time:.4f} lr: {lr:.6e} time: {step_time:.4f} "
14831467
f"text_tokens: {step_consumed_tokens} {img_tokens_str}"
1484-
f"reduced_consumed_tokens: {reduced_consumed_tokens} "
1468+
f"reduced_consumed_tokens: {self._reduced_consumed_tokens} "
14851469
f"{loss_log_str} "
14861470
f"grad_norm: {grad_norm:.8f} "
14871471
f"max_memory: {max_memory / (1024**3):.2f} GB "
@@ -1497,11 +1481,11 @@ def _log_step(
14971481
"lr": lr,
14981482
"time/data_time": round(data_time, 4),
14991483
"time/step_time": round(step_time, 4),
1500-
"time/train_time": round(train_time, 4),
1484+
"time/train_time": round(self._train_time, 4),
15011485
"time/eta_seconds": round(eta_seconds, 1),
15021486
"runtime_info/text_tokens": step_consumed_tokens,
15031487
"runtime_info/est_global_batch_tokens": est_global_batch_tokens,
1504-
"runtime_info/reduced_consumed_tokens": reduced_consumed_tokens,
1488+
"runtime_info/reduced_consumed_tokens": self._reduced_consumed_tokens,
15051489
"runtime_info/tgs": tgs,
15061490
"runtime_info/exp_tgs": exp_tgs,
15071491
"runtime_info/e2e_tgs": e2e_tgs,

0 commit comments

Comments
 (0)