Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class OtherLog(TypedDict):
__pydantic_config__ = ConfigDict(arbitrary_types_allowed=True) # type: ignore[misc]
maxvio: NotRequired[float]
step_consumed_tokens: int
step_consumed_img_tokens: NotRequired[float]
step_consumed_img_tokens: NotRequired[int]
extra_info: ModelForwardExtraLogInfo
efficient_attn_ratio: float

Expand Down Expand Up @@ -351,9 +351,20 @@ def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
reduced_z_loss = step_z_loss
dist.all_reduce(reduced_z_loss.div_(dist.get_world_size()))
loss_log["reduced_z_loss"] = reduced_z_loss.item()
other_log["step_consumed_tokens"] = cast(int, step_consumed_tokens.item())
other_log["step_consumed_tokens"] = int(step_consumed_tokens.item())
other_log["extra_info"] = train_engine_extra_info
other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item()

extra_info = other_log.get("extra_info", {}) # type: ignore

# TODO: @duanyanhui `extra_info` should be redesigned.
if not isinstance(extra_info, ModelForwardExtraLogInfo):
extra_info = ModelForwardExtraLogInfo(extra_info)
loss_log.update(extra_info.get())

if "maxvio" in other_log:
loss_log["maxvio"] = other_log["maxvio"] # type: ignore
loss_log["efficient_attn_ratio"] = other_log["efficient_attn_ratio"] # type: ignore
return loss_log, other_log

def from_hf(self, hf_path: str | Path, strict: bool = False):
Expand Down
16 changes: 14 additions & 2 deletions xtuner/v1/engine/vision_compose_train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,20 @@ def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]:
reduced_z_loss = step_z_loss
dist.all_reduce(reduced_z_loss.div_(dist.get_world_size()))
loss_log["reduced_z_loss"] = reduced_z_loss.item()
other_log["step_consumed_tokens"] = cast(int, step_consumed_tokens.item())

other_log["step_consumed_tokens"] = int(step_consumed_tokens.item())
other_log["extra_info"] = train_engine_extra_info # type: ignore[assignment]
other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item()
other_log["step_consumed_img_tokens"] = step_consumed_img_tokens
other_log["step_consumed_img_tokens"] = int(step_consumed_img_tokens)

extra_info = other_log.get("extra_info", {}) # type: ignore

# TODO: @duanyanhui `extra_info` should be redesigned.
if not isinstance(extra_info, ModelForwardExtraLogInfo):
extra_info = ModelForwardExtraLogInfo(extra_info)
loss_log.update(extra_info.get())

if "maxvio" in other_log:
loss_log["maxvio"] = other_log["maxvio"] # type: ignore
loss_log["efficient_attn_ratio"] = other_log["efficient_attn_ratio"] # type: ignore
return loss_log, other_log
135 changes: 60 additions & 75 deletions xtuner/v1/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from xtuner.v1.model.base import ModelItem, TransformerConfig, XTunerBaseModelConfig
from xtuner.v1.model.compose.base import BaseComposeConfig
from xtuner.v1.model.moe.moe import MoEConfig
from xtuner.v1.model.utils import ModelForwardExtraLogInfo
from xtuner.v1.patch import patch_default_save_plan
from xtuner.v1.profiler import profiling_memory, profiling_time
from xtuner.v1.profiler.prober import ProberList
Expand Down Expand Up @@ -700,45 +699,13 @@ def fit(self):
train_begin = time.time()
time_before_get_data = time.time()
for data_batch in self._data_iter():
consumed_samples = len(data_batch)
time_before_train_step = time.time()

ProberList.set_step(self._cur_step + 1)
DEVICE_MODULE.reset_peak_memory_stats()

time_before_train_step = time.time()
data_time = time_before_train_step - time_before_get_data
cur_sample_num = len(data_batch)

seq_ctx_list: list[SequenceContext] = []
loss_ctx_input_list: list[CELossContextInputItem] = []
for data in data_batch:
seq_ctx = data["seq_ctx"].to(DEVICE)
loss_ctx_input = CELossContextInputItem(shifted_labels=data["shifted_labels"]).to(DEVICE)
if self.sp_mesh.size() > 1:
seq_ctx = seq_ctx.split(sequence_parallel_mesh=self.sp_mesh)
loss_ctx_input = loss_ctx_input.sp_split(self.sp_mesh)
seq_ctx_list.append(seq_ctx)
loss_ctx_input_list.append(loss_ctx_input)

del data_batch

LossContext = self.loss_cfg.loss_ctx_cls
batches_loss_kwargs = LossContext.build_batches_loss_kwargs(
loss_ctx_input_list,
self.loss_cfg,
cu_seq_lens_list=[seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list],
sp_mesh=self.sp_mesh,
)
engine_input = []
for seq_ctx, loss_kwargs in zip(seq_ctx_list, batches_loss_kwargs):
loss_ctx = LossContext(
loss_cfg=self.loss_cfg,
loss_kwargs=loss_kwargs,
)
engine_input.append(
ModelItem(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
)
)
engine_input = self._prepare_model_input(data_batch)

with self._maybe_profiling():
loss_log, other_log = self._engine.train_step(engine_input)
Expand All @@ -756,46 +723,32 @@ def fit(self):

grad_norm = self._engine.clip_grad_norm(do_clip=self._do_clip, dtype=self._grad_norm_dtype)
self._engine.step_optimizer(grad_norm)

time_after_train_step = time.time()
ProberList.after_step()
step_time = time_after_train_step - time_before_train_step
step_consumed_tokens = other_log["step_consumed_tokens"]
step_consumed_img_tokens = other_log.get("step_consumed_img_tokens", None)

extra_info = other_log.get("extra_info", {})
if isinstance(extra_info, ModelForwardExtraLogInfo):
extra_info_dict = extra_info.get()
else:
extra_info_updated = ModelForwardExtraLogInfo(extra_info)
extra_info_dict = extra_info_updated.get()
loss_log.update(extra_info_dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不更新extra_info的话,sft/pretrain应该就不打印了每张卡的loss了,这个是符合预期的不

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of the logic has been moved to 'TrainEngine', and 'Trainer' should not be aware of this part of the logic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK


loss_log["efficient_attn_ratio"] = other_log["efficient_attn_ratio"]
data_time = time_before_train_step - time_before_get_data
step_time = time_after_train_step - time_before_train_step

internal_metrics = self._maybe_pop_model_internal_metrics(engine_input)

self._cur_step += 1

self._total_consumed_samples += self._reduce_number_across_rank(cur_sample_num)
reduced_step_consumed_tokens = self._reduce_number_across_rank(step_consumed_tokens)
reduced_step_consumed_tokens = self._reduce_number_across_rank(other_log["step_consumed_tokens"])
self._total_consumed_tokens += reduced_step_consumed_tokens
self._exp_consumed_tokens += reduced_step_consumed_tokens
self._total_consumed_samples += self._reduce_number_across_rank(consumed_samples)
self._train_time = time_after_train_step - train_begin

# TODO: This log should be move before lr_scheduler.step, but for CI BC, keep it temporarily
self._log_step(
loss_log=loss_log,
local_step_consumed_tokens=step_consumed_tokens,
local_step_consumed_tokens=other_log["step_consumed_tokens"],
local_step_consumed_img_tokens=other_log["step_consumed_tokens"],
step_consumed_tokens=reduced_step_consumed_tokens,
exp_consumed_tokens=self._exp_consumed_tokens,
total_consumed_tokens=self._total_consumed_tokens,
data_time=data_time,
step_time=step_time,
train_time=self._train_time,
train_time_offset=self._train_time_offset,
grad_norm=grad_norm.item(),
local_step_consumed_img_tokens=step_consumed_img_tokens,
internal_metrics=internal_metrics,
data_time=data_time,
step_time=step_time,
)

self._lr_scheduler.step()
Expand All @@ -816,7 +769,44 @@ def fit(self):
self._metrics_recorder.close()
self.logger.info(f"Training finished in {time.time() - train_begin:.2f} seconds")

def _reduce_number_across_rank(self, rank_number: int) -> int:
def _prepare_model_input(self, data_batch) -> list[ModelItem]:
seq_ctx_list: list[SequenceContext] = []
loss_ctx_input_list: list[CELossContextInputItem] = []

for data in data_batch:
seq_ctx = data["seq_ctx"].to(DEVICE)
loss_ctx_input = CELossContextInputItem(shifted_labels=data["shifted_labels"]).to(DEVICE)
if self.sp_mesh.size() > 1:
seq_ctx = seq_ctx.split(sequence_parallel_mesh=self.sp_mesh)
loss_ctx_input = loss_ctx_input.sp_split(self.sp_mesh)
seq_ctx_list.append(seq_ctx)
loss_ctx_input_list.append(loss_ctx_input)

# TODO: Consider moving data_batch deletion to the caller for better memory management.
del data_batch

LossContext = self.loss_cfg.loss_ctx_cls
batches_loss_kwargs = LossContext.build_batches_loss_kwargs(
loss_ctx_input_list,
self.loss_cfg,
cu_seq_lens_list=[seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list],
sp_mesh=self.sp_mesh,
)
engine_input = []
for seq_ctx, loss_kwargs in zip(seq_ctx_list, batches_loss_kwargs):
loss_ctx = LossContext(
loss_cfg=self.loss_cfg,
loss_kwargs=loss_kwargs,
)
engine_input.append(
ModelItem(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
)
)
return engine_input

def _reduce_number_across_rank(self, rank_number: int | float) -> int:
_gathered_list = [None for _ in range(self.world_size)]
dist.all_gather_object(_gathered_list, rank_number)
reduced_number = sum(_gathered_list) # type: ignore[arg-type]
Expand Down Expand Up @@ -1437,26 +1427,21 @@ def _maybe_profiling(self):
def _log_step(
self,
loss_log: LossLog,
local_step_consumed_tokens: int,
step_consumed_tokens: int,
exp_consumed_tokens: int,
total_consumed_tokens: int,
local_step_consumed_tokens: int,
local_step_consumed_img_tokens: float | None,
grad_norm: float,
data_time: float,
step_time: float,
train_time: float,
train_time_offset: float,
grad_norm: float,
local_step_consumed_img_tokens: float | None,
internal_metrics: InternalMetrics | None = None,
):
"""Log the training step information."""
e2e_train_time = train_time + train_time_offset
total_consumed_tokens_per_rank = total_consumed_tokens / self.world_size
exp_consumed_tokens_per_rank = exp_consumed_tokens / self.world_size
e2e_train_time = self._train_time + self._train_time_offset

tgs = local_step_consumed_tokens / step_time
total_consumed_tokens_per_rank = self._total_consumed_tokens / self.world_size
e2e_tgs = total_consumed_tokens_per_rank / e2e_train_time
exp_tgs = exp_consumed_tokens_per_rank / train_time
exp_tgs = self._exp_consumed_tokens / self._train_time
lr = self._lr_scheduler.get_last_lr()[0]

remaining_steps = self.total_step - self.cur_step
Expand Down Expand Up @@ -1485,7 +1470,7 @@ def _log_step(
f"data_time: {data_time:.4f} lr: {lr:.6e} time: {step_time:.4f} "
f"text_tokens: {local_step_consumed_tokens} {img_tokens_str}"
f"step_consumed_tokens: {step_consumed_tokens} "
f"total_consumed_tokens: {total_consumed_tokens} "
f"total_consumed_tokens: {self._total_consumed_tokens} "
f"{loss_log_str} "
f"grad_norm: {grad_norm:.8f} "
f"max_memory: {max_memory / (1024**3):.2f} GB "
Expand All @@ -1500,11 +1485,11 @@ def _log_step(
"lr": lr,
"time/data_time": round(data_time, 4),
"time/step_time": round(step_time, 4),
"time/train_time": round(train_time, 4),
"time/train_time": round(self._train_time, 4),
"time/eta_seconds": round(eta_seconds, 1),
"runtime_info/text_tokens": local_step_consumed_tokens,
"runtime_info/step_consumed_tokens": step_consumed_tokens,
"runtime_info/total_consumed_tokens": total_consumed_tokens,
"runtime_info/total_consumed_tokens": self._total_consumed_tokens,
"runtime_info/tgs": tgs,
"runtime_info/exp_tgs": exp_tgs,
"runtime_info/e2e_tgs": e2e_tgs,
Expand Down
Loading