-
Notifications
You must be signed in to change notification settings - Fork 810
[plugin] refactor loss_type/loss_scale #5337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @Jintao-Huang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
I've refactored parts of the plugin system, primarily focusing on how loss functions are defined and how loss scaling is handled. This change simplifies the internal architecture by moving from a decorator-based registration for loss functions to a direct dictionary mapping and by streamlining the loss scaling logic with an explicit is_binary
flag. Additionally, I've updated a dependency and introduced memory optimizations during model conversion.
Highlights
- Loss Function Refactoring: I've simplified the loss function plugin architecture by removing the
LossType
enum and@register_loss_func
decorators, replacing them with a directloss_mapping
dictionary. This makes the registration and retrieval of loss functions more straightforward. - Loss Scaling Simplification: I've streamlined the loss scaling mechanism by introducing an
is_binary
flag to loss scale classes, removing the need for dynamickeep_loss_scale
calculation. This makes the loss scaling behavior more explicit and easier to understand. - Memory Optimization: To improve memory management, I've added explicit deletion of large model objects (
hf_model
,mg_model
) after conversion operations in the Megatron utility, ensuring memory is freed up promptly. - Dependency Update: I've updated the
peft
library version constraint inrequirements/framework.txt
to allow for newer versions, ensuring compatibility and access to the latest features or fixes.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the plugin system, primarily by replacing the dynamic registration of loss functions with a more explicit mapping. This improves code clarity and maintainability. Other changes include memory optimizations by deleting large models after use and simplifying the loss scaling logic.
However, I've identified a critical issue in swift/plugin/loss.py
where get_loss_func
was not fully updated to reflect the new structure of loss_mapping
, which will lead to a TypeError
. Please see my comment for the fix.
swift/plugin/loss.py
Outdated
def get_loss_func(loss_type: Optional[str]) -> Optional[Callable]: | ||
if loss_type is None: | ||
return None | ||
return LOSS_MAPPING[loss_type]['loss_func'] | ||
return loss_mapping[loss_type]['loss_func'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The structure of loss_mapping
has been refactored to be a direct mapping from a string to a function. However, the access pattern here still assumes the old nested dictionary structure (['loss_func']
). This will cause a TypeError
at runtime. You should directly return the function from the map.
return loss_mapping[loss_type]['loss_func'] | |
return loss_mapping[loss_type] |
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request primarily refactors the plugin system, with a focus on loss functions and loss scaling. The changes simplify the codebase by replacing a decorator-based registration system with a more direct dictionary mapping for loss functions and encapsulating loss scaling logic more effectively. This improves code clarity and maintainability. I've identified one critical issue in gkd_trainer.py
where the refactoring appears to have introduced a bug, causing an important optimization to be missed. My review includes a suggested fix for this.
inputs['labels'], logits_to_keep = self.get_logits_to_keep(inputs['labels']) | ||
if logits_to_keep is not None: | ||
model_inputs['logits_to_keep'] = logits_to_keep | ||
self.prepare_logits_to_keep(inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logits_to_keep
key, which is added to the inputs
dictionary by prepare_logits_to_keep
, is not being passed to model_inputs
. Since model_inputs
is what's used in the model call, this optimization is being missed. This appears to be a bug introduced during refactoring.
self.prepare_logits_to_keep(inputs) | |
self.prepare_logits_to_keep(inputs) | |
if 'logits_to_keep' in inputs: | |
model_inputs['logits_to_keep'] = inputs['logits_to_keep'] |
No description provided.