-
Notifications
You must be signed in to change notification settings - Fork 398
[Refactor] refactor trainer fit loop for better code organization #1388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
HAOCHENYE
commented
Dec 24, 2025
- Extract model input preparation logic into _prepare_model_input method
- Move loss_log update logic from trainer to train_engine
- Simplify _log_step method signature by using instance variables
- Fix type hints: consumed_tokens and consumed_img_tokens should be int
- Adjust consumed_samples calculation position for better logic flow
8428bfb to
1ef1c72
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR refactors the trainer fit loop to improve code organization by extracting model input preparation logic, relocating loss_log update logic, simplifying method signatures, and fixing type hints.
- Extracted model input preparation into a dedicated
_prepare_model_inputmethod for better code modularity - Moved loss_log update logic from trainer to train_engine for better separation of concerns
- Simplified
_log_stepmethod signature by using instance variables instead of passing them as parameters - Fixed type hints for
consumed_tokensandconsumed_img_tokensfrom float to int with appropriate conversions
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| xtuner/v1/train/trainer.py | Refactored fit loop by extracting _prepare_model_input method, removed loss_log update logic (moved to engine), simplified _log_step signature, adjusted consumed_samples calculation timing, updated _reduce_number_across_rank type hints, and removed unused ModelForwardExtraLogInfo import |
| xtuner/v1/engine/train_engine.py | Updated type hints for consumed_tokens and consumed_img_tokens to int, added loss_log update logic (moved from trainer), and added int conversion for consumed_tokens |
| xtuner/v1/engine/vision_compose_train_engine.py | Added int conversions for consumed_tokens and consumed_img_tokens to match updated type hints |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| other_log["extra_info"] = train_engine_extra_info # type: ignore[assignment] | ||
| other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item() | ||
| other_log["consumed_img_tokens"] = step_consumed_img_tokens | ||
| other_log["consumed_img_tokens"] = int(step_consumed_img_tokens) |
Copilot
AI
Dec 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable step_consumed_img_tokens is initialized as a float (0.0) on line 148 and may contain a fractional value after division on line 163. Converting to int here will truncate any fractional part. Consider using 0 instead of 0.0 on line 148 and using integer division (//) on line 163 if integer values are required, or document that truncation is intentional.
- Extract model input preparation logic into _prepare_model_input method - Move loss_log update logic from trainer to train_engine - Simplify _log_step method signature by using instance variables - Fix type hints: consumed_tokens and consumed_img_tokens should be int - Adjust consumed_samples calculation position for better logic flow
1ef1c72 to
4f6412f
Compare
| else: | ||
| extra_info_updated = ModelForwardExtraLogInfo(extra_info) | ||
| extra_info_dict = extra_info_updated.get() | ||
| loss_log.update(extra_info_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不更新extra_info的话,sft/pretrain应该就不打印了每张卡的loss了,这个是符合预期的不
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part of the logic has been moved to 'TrainEngine', and 'Trainer' should not be aware of this part of the logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
YanhuiDua
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
/gemini review |
jayhenry
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| consumed_tokens: float | ||
| consumed_img_tokens: NotRequired[float] | ||
| consumed_tokens: int | ||
| consumed_img_tokens: NotRequired[int] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在下面的PR已经rename为step_consumed_tokens,在rebase时需要注意下:
参考 rename的PR,
统计变量前缀规则是:
- 空间上(dp rank还是reduce求和),rank的用 local_,默认reduced无前缀。
- 时间上(step还是累积),用step_和total_。