Skip to content

Commit 8428bfb

Browse files
committed
[Refactor] refactor trainer fit loop for better code organization
- Extract model input preparation logic into _prepare_model_input method - Move loss_log update logic from trainer to train_engine - Simplify _log_step method signature by using instance variables - Fix type hints: consumed_tokens and consumed_img_tokens should be int - Adjust consumed_samples calculation position for better logic flow
1 parent d786b83 commit 8428bfb

File tree

2 files changed

+75
-79
lines changed

2 files changed

+75
-79
lines changed

xtuner/v1/engine/train_engine.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ class LossLog(TypedDict):
5353
class OtherLog(TypedDict):
5454
__pydantic_config__ = ConfigDict(arbitrary_types_allowed=True) # type: ignore[misc]
5555
maxvio: NotRequired[float]
56-
consumed_tokens: float
57-
consumed_img_tokens: NotRequired[float]
56+
consumed_tokens: int
57+
consumed_img_tokens: NotRequired[int]
5858
extra_info: ModelForwardExtraLogInfo
5959
efficient_attn_ratio: float
6060

@@ -350,6 +350,17 @@ def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
350350
other_log["consumed_tokens"] = step_consumed_tokens.item()
351351
other_log["extra_info"] = train_engine_extra_info
352352
other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item()
353+
354+
extra_info = other_log.get("extra_info", {})
355+
356+
# TODO: @duanyanhui `extra_info` should be redesigned.
357+
if not isinstance(extra_info, ModelForwardExtraLogInfo):
358+
extra_info = ModelForwardExtraLogInfo(extra_info)
359+
loss_log.update(extra_info.get())
360+
361+
if "maxvio" in other_log:
362+
loss_log["maxvio"] = other_log["maxvio"]
363+
loss_log["efficient_attn_ratio"] = other_log["efficient_attn_ratio"]
353364
return loss_log, other_log
354365

355366
def from_hf(self, hf_path: str | Path, strict: bool = False):

xtuner/v1/train/trainer.py

Lines changed: 62 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -701,44 +701,13 @@ def fit(self):
701701
train_begin = time.time()
702702
time_before_get_data = time.time()
703703
for data_batch in self._data_iter():
704+
consumed_samples = len(data_batch)
705+
time_before_train_step = time.time()
706+
704707
ProberList.set_step(self._cur_step + 1)
705708
DEVICE_MODULE.reset_peak_memory_stats()
706709

707-
time_before_train_step = time.time()
708-
data_time = time_before_train_step - time_before_get_data
709-
710-
seq_ctx_list: list[SequenceContext] = []
711-
loss_ctx_input_list: list[CELossContextInputItem] = []
712-
for data in data_batch:
713-
seq_ctx = data["seq_ctx"].to(DEVICE)
714-
loss_ctx_input = CELossContextInputItem(shifted_labels=data["shifted_labels"]).to(DEVICE)
715-
if self.sp_mesh.size() > 1:
716-
seq_ctx = seq_ctx.split(sequence_parallel_mesh=self.sp_mesh)
717-
loss_ctx_input = loss_ctx_input.sp_split(self.sp_mesh)
718-
seq_ctx_list.append(seq_ctx)
719-
loss_ctx_input_list.append(loss_ctx_input)
720-
721-
del data_batch
722-
723-
LossContext = self.loss_cfg.loss_ctx_cls
724-
batches_loss_kwargs = LossContext.build_batches_loss_kwargs(
725-
loss_ctx_input_list,
726-
self.loss_cfg,
727-
cu_seq_lens_list=[seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list],
728-
sp_mesh=self.sp_mesh,
729-
)
730-
engine_input = []
731-
for seq_ctx, loss_kwargs in zip(seq_ctx_list, batches_loss_kwargs):
732-
loss_ctx = LossContext(
733-
loss_cfg=self.loss_cfg,
734-
loss_kwargs=loss_kwargs,
735-
)
736-
engine_input.append(
737-
ModelItem(
738-
seq_ctx=seq_ctx,
739-
loss_ctx=loss_ctx,
740-
)
741-
)
710+
engine_input = self._prepare_model_input(data_batch)
742711

743712
with self._maybe_profiling():
744713
loss_log, other_log = self._engine.train_step(engine_input)
@@ -756,47 +725,30 @@ def fit(self):
756725

757726
grad_norm = self._engine.clip_grad_norm(do_clip=self._do_clip, dtype=self._grad_norm_dtype)
758727
self._engine.step_optimizer(grad_norm)
728+
759729
time_after_train_step = time.time()
760730
ProberList.after_step()
761-
step_time = time_after_train_step - time_before_train_step
762-
step_consumed_tokens = other_log["consumed_tokens"]
763-
step_consumed_img_tokens = other_log.get("consumed_img_tokens", None)
764731

765-
extra_info = other_log.get("extra_info", {})
766-
if isinstance(extra_info, ModelForwardExtraLogInfo):
767-
extra_info_dict = extra_info.get()
768-
else:
769-
extra_info_updated = ModelForwardExtraLogInfo(extra_info)
770-
extra_info_dict = extra_info_updated.get()
771-
loss_log.update(extra_info_dict)
772-
773-
if "maxvio" in other_log:
774-
loss_log["maxvio"] = other_log["maxvio"]
775-
loss_log["efficient_attn_ratio"] = other_log["efficient_attn_ratio"]
732+
data_time = time_before_train_step - time_before_get_data
733+
step_time = time_after_train_step - time_before_train_step
776734

777735
internal_metrics = self._maybe_pop_model_internal_metrics(engine_input)
778736

779737
self._cur_step += 1
780-
781-
reduced_step_consumed_tokens = self._reduce_number_across_rank(step_consumed_tokens)
782-
self._reduced_consumed_tokens += reduced_step_consumed_tokens
783-
784-
self._exp_consumed_tokens += step_consumed_tokens
738+
self._reduced_consumed_tokens += self._reduce_number_across_rank(other_log["consumed_tokens"])
739+
self._reduced_consumed_samples += self._reduce_number_across_rank(consumed_samples)
740+
self._exp_consumed_tokens += other_log["consumed_tokens"]
785741
self._train_time = time_after_train_step - train_begin
786742

787743
# TODO: This log should be move before lr_scheduler.step, but for CI BC, keep it temporarily
788744
self._log_step(
789745
loss_log=loss_log,
790-
step_consumed_tokens=step_consumed_tokens,
791-
exp_consumed_tokens=self._exp_consumed_tokens,
792-
step_consumed_img_tokens=step_consumed_img_tokens,
793-
reduced_consumed_tokens=self._reduced_consumed_tokens,
794-
data_time=data_time,
795-
step_time=step_time,
796-
train_time=self._train_time,
797-
train_time_offset=self._train_time_offset,
746+
step_consumed_tokens=other_log["consumed_tokens"],
747+
step_consumed_img_tokens=other_log.get("consumed_img_tokens", None),
798748
grad_norm=grad_norm.item(),
799749
internal_metrics=internal_metrics,
750+
data_time=data_time,
751+
step_time=step_time,
800752
)
801753

802754
self._lr_scheduler.step()
@@ -817,7 +769,44 @@ def fit(self):
817769
self._metrics_recorder.close()
818770
self.logger.info(f"Training finished in {time.time() - train_begin:.2f} seconds")
819771

820-
def _reduce_number_across_rank(self, rank_number: int) -> int:
772+
def _prepare_model_input(self, data_batch) -> list[ModelItem]:
773+
seq_ctx_list: list[SequenceContext] = []
774+
loss_ctx_input_list: list[CELossContextInputItem] = []
775+
776+
for data in data_batch:
777+
seq_ctx = data["seq_ctx"].to(DEVICE)
778+
loss_ctx_input = CELossContextInputItem(shifted_labels=data["shifted_labels"]).to(DEVICE)
779+
if self.sp_mesh.size() > 1:
780+
seq_ctx = seq_ctx.split(sequence_parallel_mesh=self.sp_mesh)
781+
loss_ctx_input = loss_ctx_input.sp_split(self.sp_mesh)
782+
seq_ctx_list.append(seq_ctx)
783+
loss_ctx_input_list.append(loss_ctx_input)
784+
785+
# TODO: Consider moving data_batch deletion to the caller for better memory management.
786+
del data_batch
787+
788+
LossContext = self.loss_cfg.loss_ctx_cls
789+
batches_loss_kwargs = LossContext.build_batches_loss_kwargs(
790+
loss_ctx_input_list,
791+
self.loss_cfg,
792+
cu_seq_lens_list=[seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list],
793+
sp_mesh=self.sp_mesh,
794+
)
795+
engine_input = []
796+
for seq_ctx, loss_kwargs in zip(seq_ctx_list, batches_loss_kwargs):
797+
loss_ctx = LossContext(
798+
loss_cfg=self.loss_cfg,
799+
loss_kwargs=loss_kwargs,
800+
)
801+
engine_input.append(
802+
ModelItem(
803+
seq_ctx=seq_ctx,
804+
loss_ctx=loss_ctx,
805+
)
806+
)
807+
return engine_input
808+
809+
def _reduce_number_across_rank(self, rank_number: int | float) -> int:
821810
_gathered_list = [None for _ in range(self.world_size)]
822811
dist.all_gather_object(_gathered_list, rank_number)
823812
reduced_number = sum(_gathered_list) # type: ignore[arg-type]
@@ -1257,7 +1246,6 @@ def _data_iter(self):
12571246
data_iter = iter(self._dataloader)
12581247
data = next(data_iter)
12591248

1260-
self._reduced_consumed_samples += self._reduce_number_across_rank(len(data))
12611249
yield data
12621250

12631251
def _get_checkpoint_path(self, epoch: int, step: int, is_snapshot: bool = False) -> Path:
@@ -1434,24 +1422,21 @@ def _maybe_profiling(self):
14341422

14351423
def _log_step(
14361424
self,
1437-
loss_log: dict,
1425+
loss_log: LossLog,
14381426
step_consumed_tokens: int,
1439-
exp_consumed_tokens: int,
1440-
reduced_consumed_tokens: int,
1427+
step_consumed_img_tokens: int | None,
1428+
grad_norm: float,
14411429
data_time: float,
14421430
step_time: float,
1443-
train_time: float,
1444-
train_time_offset: float,
1445-
grad_norm: float,
1446-
step_consumed_img_tokens: float | None,
14471431
internal_metrics: InternalMetrics | None = None,
14481432
):
14491433
"""Log the training step information."""
1450-
e2e_train_time = train_time + train_time_offset
1434+
e2e_train_time = self._train_time + self._train_time_offset
1435+
14511436
tgs = step_consumed_tokens / step_time
1452-
rank_consumed_tokens = reduced_consumed_tokens / self.world_size
1437+
rank_consumed_tokens = self._reduced_consumed_tokens / self.world_size
14531438
e2e_tgs = rank_consumed_tokens / e2e_train_time
1454-
exp_tgs = exp_consumed_tokens / train_time
1439+
exp_tgs = self._exp_consumed_tokens / self._train_time
14551440
lr = self._lr_scheduler.get_last_lr()[0]
14561441

14571442
remaining_steps = self.total_step - self.cur_step
@@ -1481,7 +1466,7 @@ def _log_step(
14811466
f"Epoch {self._cur_epoch} Step {self.cur_step}/{self.total_step} "
14821467
f"data_time: {data_time:.4f} lr: {lr:.6e} time: {step_time:.4f} "
14831468
f"text_tokens: {step_consumed_tokens} {img_tokens_str}"
1484-
f"reduced_consumed_tokens: {reduced_consumed_tokens} "
1469+
f"reduced_consumed_tokens: {self._reduced_consumed_tokens} "
14851470
f"{loss_log_str} "
14861471
f"grad_norm: {grad_norm:.8f} "
14871472
f"max_memory: {max_memory / (1024**3):.2f} GB "
@@ -1497,11 +1482,11 @@ def _log_step(
14971482
"lr": lr,
14981483
"time/data_time": round(data_time, 4),
14991484
"time/step_time": round(step_time, 4),
1500-
"time/train_time": round(train_time, 4),
1485+
"time/train_time": round(self._train_time, 4),
15011486
"time/eta_seconds": round(eta_seconds, 1),
15021487
"runtime_info/text_tokens": step_consumed_tokens,
15031488
"runtime_info/est_global_batch_tokens": est_global_batch_tokens,
1504-
"runtime_info/reduced_consumed_tokens": reduced_consumed_tokens,
1489+
"runtime_info/reduced_consumed_tokens": self._reduced_consumed_tokens,
15051490
"runtime_info/tgs": tgs,
15061491
"runtime_info/exp_tgs": exp_tgs,
15071492
"runtime_info/e2e_tgs": e2e_tgs,

0 commit comments

Comments
 (0)