diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index bbbe30a61..a67b1df4b 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -54,7 +54,7 @@ class OtherLog(TypedDict): __pydantic_config__ = ConfigDict(arbitrary_types_allowed=True) # type: ignore[misc] maxvio: NotRequired[float] step_consumed_tokens: int - step_consumed_img_tokens: NotRequired[float] + step_consumed_img_tokens: NotRequired[int] extra_info: ModelForwardExtraLogInfo efficient_attn_ratio: float @@ -351,9 +351,20 @@ def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]: reduced_z_loss = step_z_loss dist.all_reduce(reduced_z_loss.div_(dist.get_world_size())) loss_log["reduced_z_loss"] = reduced_z_loss.item() - other_log["step_consumed_tokens"] = cast(int, step_consumed_tokens.item()) + other_log["step_consumed_tokens"] = int(step_consumed_tokens.item()) other_log["extra_info"] = train_engine_extra_info other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item() + + extra_info = other_log.get("extra_info", {}) # type: ignore + + # TODO: @duanyanhui `extra_info` should be redesigned. + if not isinstance(extra_info, ModelForwardExtraLogInfo): + extra_info = ModelForwardExtraLogInfo(extra_info) + loss_log.update(extra_info.get()) + + if "maxvio" in other_log: + loss_log["maxvio"] = other_log["maxvio"] # type: ignore + loss_log["efficient_attn_ratio"] = other_log["efficient_attn_ratio"] # type: ignore return loss_log, other_log def from_hf(self, hf_path: str | Path, strict: bool = False): diff --git a/xtuner/v1/engine/vision_compose_train_engine.py b/xtuner/v1/engine/vision_compose_train_engine.py index 48e335b92..292e67f81 100644 --- a/xtuner/v1/engine/vision_compose_train_engine.py +++ b/xtuner/v1/engine/vision_compose_train_engine.py @@ -224,8 +224,20 @@ def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]: reduced_z_loss = step_z_loss dist.all_reduce(reduced_z_loss.div_(dist.get_world_size())) loss_log["reduced_z_loss"] = reduced_z_loss.item() - other_log["step_consumed_tokens"] = cast(int, step_consumed_tokens.item()) + + other_log["step_consumed_tokens"] = int(step_consumed_tokens.item()) other_log["extra_info"] = train_engine_extra_info # type: ignore[assignment] other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item() - other_log["step_consumed_img_tokens"] = step_consumed_img_tokens + other_log["step_consumed_img_tokens"] = int(step_consumed_img_tokens) + + extra_info = other_log.get("extra_info", {}) # type: ignore + + # TODO: @duanyanhui `extra_info` should be redesigned. + if not isinstance(extra_info, ModelForwardExtraLogInfo): + extra_info = ModelForwardExtraLogInfo(extra_info) + loss_log.update(extra_info.get()) + + if "maxvio" in other_log: + loss_log["maxvio"] = other_log["maxvio"] # type: ignore + loss_log["efficient_attn_ratio"] = other_log["efficient_attn_ratio"] # type: ignore return loss_log, other_log diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 7573ee8fa..5d45dbf41 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -37,7 +37,6 @@ from xtuner.v1.model.base import ModelItem, TransformerConfig, XTunerBaseModelConfig from xtuner.v1.model.compose.base import BaseComposeConfig from xtuner.v1.model.moe.moe import MoEConfig -from xtuner.v1.model.utils import ModelForwardExtraLogInfo from xtuner.v1.patch import patch_default_save_plan from xtuner.v1.profiler import profiling_memory, profiling_time from xtuner.v1.profiler.prober import ProberList @@ -700,45 +699,13 @@ def fit(self): train_begin = time.time() time_before_get_data = time.time() for data_batch in self._data_iter(): + consumed_samples = len(data_batch) + time_before_train_step = time.time() + ProberList.set_step(self._cur_step + 1) DEVICE_MODULE.reset_peak_memory_stats() - time_before_train_step = time.time() - data_time = time_before_train_step - time_before_get_data - cur_sample_num = len(data_batch) - - seq_ctx_list: list[SequenceContext] = [] - loss_ctx_input_list: list[CELossContextInputItem] = [] - for data in data_batch: - seq_ctx = data["seq_ctx"].to(DEVICE) - loss_ctx_input = CELossContextInputItem(shifted_labels=data["shifted_labels"]).to(DEVICE) - if self.sp_mesh.size() > 1: - seq_ctx = seq_ctx.split(sequence_parallel_mesh=self.sp_mesh) - loss_ctx_input = loss_ctx_input.sp_split(self.sp_mesh) - seq_ctx_list.append(seq_ctx) - loss_ctx_input_list.append(loss_ctx_input) - - del data_batch - - LossContext = self.loss_cfg.loss_ctx_cls - batches_loss_kwargs = LossContext.build_batches_loss_kwargs( - loss_ctx_input_list, - self.loss_cfg, - cu_seq_lens_list=[seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list], - sp_mesh=self.sp_mesh, - ) - engine_input = [] - for seq_ctx, loss_kwargs in zip(seq_ctx_list, batches_loss_kwargs): - loss_ctx = LossContext( - loss_cfg=self.loss_cfg, - loss_kwargs=loss_kwargs, - ) - engine_input.append( - ModelItem( - seq_ctx=seq_ctx, - loss_ctx=loss_ctx, - ) - ) + engine_input = self._prepare_model_input(data_batch) with self._maybe_profiling(): loss_log, other_log = self._engine.train_step(engine_input) @@ -756,46 +723,32 @@ def fit(self): grad_norm = self._engine.clip_grad_norm(do_clip=self._do_clip, dtype=self._grad_norm_dtype) self._engine.step_optimizer(grad_norm) + time_after_train_step = time.time() ProberList.after_step() - step_time = time_after_train_step - time_before_train_step - step_consumed_tokens = other_log["step_consumed_tokens"] - step_consumed_img_tokens = other_log.get("step_consumed_img_tokens", None) - - extra_info = other_log.get("extra_info", {}) - if isinstance(extra_info, ModelForwardExtraLogInfo): - extra_info_dict = extra_info.get() - else: - extra_info_updated = ModelForwardExtraLogInfo(extra_info) - extra_info_dict = extra_info_updated.get() - loss_log.update(extra_info_dict) - loss_log["efficient_attn_ratio"] = other_log["efficient_attn_ratio"] + data_time = time_before_train_step - time_before_get_data + step_time = time_after_train_step - time_before_train_step internal_metrics = self._maybe_pop_model_internal_metrics(engine_input) self._cur_step += 1 - - self._total_consumed_samples += self._reduce_number_across_rank(cur_sample_num) - reduced_step_consumed_tokens = self._reduce_number_across_rank(step_consumed_tokens) + reduced_step_consumed_tokens = self._reduce_number_across_rank(other_log["step_consumed_tokens"]) self._total_consumed_tokens += reduced_step_consumed_tokens self._exp_consumed_tokens += reduced_step_consumed_tokens + self._total_consumed_samples += self._reduce_number_across_rank(consumed_samples) self._train_time = time_after_train_step - train_begin # TODO: This log should be move before lr_scheduler.step, but for CI BC, keep it temporarily self._log_step( loss_log=loss_log, - local_step_consumed_tokens=step_consumed_tokens, + local_step_consumed_tokens=other_log["step_consumed_tokens"], + local_step_consumed_img_tokens=other_log["step_consumed_tokens"], step_consumed_tokens=reduced_step_consumed_tokens, - exp_consumed_tokens=self._exp_consumed_tokens, - total_consumed_tokens=self._total_consumed_tokens, - data_time=data_time, - step_time=step_time, - train_time=self._train_time, - train_time_offset=self._train_time_offset, grad_norm=grad_norm.item(), - local_step_consumed_img_tokens=step_consumed_img_tokens, internal_metrics=internal_metrics, + data_time=data_time, + step_time=step_time, ) self._lr_scheduler.step() @@ -816,7 +769,44 @@ def fit(self): self._metrics_recorder.close() self.logger.info(f"Training finished in {time.time() - train_begin:.2f} seconds") - def _reduce_number_across_rank(self, rank_number: int) -> int: + def _prepare_model_input(self, data_batch) -> list[ModelItem]: + seq_ctx_list: list[SequenceContext] = [] + loss_ctx_input_list: list[CELossContextInputItem] = [] + + for data in data_batch: + seq_ctx = data["seq_ctx"].to(DEVICE) + loss_ctx_input = CELossContextInputItem(shifted_labels=data["shifted_labels"]).to(DEVICE) + if self.sp_mesh.size() > 1: + seq_ctx = seq_ctx.split(sequence_parallel_mesh=self.sp_mesh) + loss_ctx_input = loss_ctx_input.sp_split(self.sp_mesh) + seq_ctx_list.append(seq_ctx) + loss_ctx_input_list.append(loss_ctx_input) + + # TODO: Consider moving data_batch deletion to the caller for better memory management. + del data_batch + + LossContext = self.loss_cfg.loss_ctx_cls + batches_loss_kwargs = LossContext.build_batches_loss_kwargs( + loss_ctx_input_list, + self.loss_cfg, + cu_seq_lens_list=[seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list], + sp_mesh=self.sp_mesh, + ) + engine_input = [] + for seq_ctx, loss_kwargs in zip(seq_ctx_list, batches_loss_kwargs): + loss_ctx = LossContext( + loss_cfg=self.loss_cfg, + loss_kwargs=loss_kwargs, + ) + engine_input.append( + ModelItem( + seq_ctx=seq_ctx, + loss_ctx=loss_ctx, + ) + ) + return engine_input + + def _reduce_number_across_rank(self, rank_number: int | float) -> int: _gathered_list = [None for _ in range(self.world_size)] dist.all_gather_object(_gathered_list, rank_number) reduced_number = sum(_gathered_list) # type: ignore[arg-type] @@ -1437,26 +1427,21 @@ def _maybe_profiling(self): def _log_step( self, loss_log: LossLog, - local_step_consumed_tokens: int, step_consumed_tokens: int, - exp_consumed_tokens: int, - total_consumed_tokens: int, + local_step_consumed_tokens: int, + local_step_consumed_img_tokens: float | None, + grad_norm: float, data_time: float, step_time: float, - train_time: float, - train_time_offset: float, - grad_norm: float, - local_step_consumed_img_tokens: float | None, internal_metrics: InternalMetrics | None = None, ): """Log the training step information.""" - e2e_train_time = train_time + train_time_offset - total_consumed_tokens_per_rank = total_consumed_tokens / self.world_size - exp_consumed_tokens_per_rank = exp_consumed_tokens / self.world_size + e2e_train_time = self._train_time + self._train_time_offset tgs = local_step_consumed_tokens / step_time + total_consumed_tokens_per_rank = self._total_consumed_tokens / self.world_size e2e_tgs = total_consumed_tokens_per_rank / e2e_train_time - exp_tgs = exp_consumed_tokens_per_rank / train_time + exp_tgs = self._exp_consumed_tokens / self._train_time lr = self._lr_scheduler.get_last_lr()[0] remaining_steps = self.total_step - self.cur_step @@ -1485,7 +1470,7 @@ def _log_step( f"data_time: {data_time:.4f} lr: {lr:.6e} time: {step_time:.4f} " f"text_tokens: {local_step_consumed_tokens} {img_tokens_str}" f"step_consumed_tokens: {step_consumed_tokens} " - f"total_consumed_tokens: {total_consumed_tokens} " + f"total_consumed_tokens: {self._total_consumed_tokens} " f"{loss_log_str} " f"grad_norm: {grad_norm:.8f} " f"max_memory: {max_memory / (1024**3):.2f} GB " @@ -1500,11 +1485,11 @@ def _log_step( "lr": lr, "time/data_time": round(data_time, 4), "time/step_time": round(step_time, 4), - "time/train_time": round(train_time, 4), + "time/train_time": round(self._train_time, 4), "time/eta_seconds": round(eta_seconds, 1), "runtime_info/text_tokens": local_step_consumed_tokens, "runtime_info/step_consumed_tokens": step_consumed_tokens, - "runtime_info/total_consumed_tokens": total_consumed_tokens, + "runtime_info/total_consumed_tokens": self._total_consumed_tokens, "runtime_info/tgs": tgs, "runtime_info/exp_tgs": exp_tgs, "runtime_info/e2e_tgs": e2e_tgs,