-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Add loss scale from data #8430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add loss scale from data #8430
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ class RowPreprocessor: | |
| 'channel', | ||
| 'margin', | ||
| 'teacher_prompt', | ||
| 'loss_scale', | ||
| ] | ||
|
|
||
| def __init__(self, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -4,8 +4,11 @@ | |||||||||
| from typing import List, Literal, Optional, Tuple | ||||||||||
|
|
||||||||||
| from swift.template import ContextType, Messages, get_last_user_round | ||||||||||
| from swift.utils import get_logger | ||||||||||
| from .utils import calculate_loss_scale | ||||||||||
|
|
||||||||||
| logger = get_logger() | ||||||||||
|
|
||||||||||
| ALL_BASE_STRATEGY = ['default', 'last_round', 'all'] | ||||||||||
|
|
||||||||||
|
|
||||||||||
|
|
@@ -77,6 +80,10 @@ def __call__(self, context_list: List[str], context_types: List[ContextType], me | |||||||||
| context_types: List of context types corresponding to each context, indicating | ||||||||||
| whether it's a system prompt, user query, assistant response, etc. | ||||||||||
| messages: Complete message list containing the conversation history. | ||||||||||
| **kwargs: Additional keyword arguments. Supports 'loss_scale' to override | ||||||||||
| the global loss scale strategy for this specific data row. The value | ||||||||||
| can be a string like 'default', 'last_round', 'all', or combined | ||||||||||
| strategies like 'last_round+ignore_empty_think'. | ||||||||||
|
|
||||||||||
| Returns: | ||||||||||
| A tuple containing: | ||||||||||
|
|
@@ -85,6 +92,22 @@ def __call__(self, context_list: List[str], context_types: List[ContextType], me | |||||||||
| - List[float]: Loss scale values corresponding one-to-one with the | ||||||||||
| returned context list | ||||||||||
| """ | ||||||||||
| # Check for per-row loss_scale override in kwargs (from data row) | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to use different loss_scale in the template? ms-swift/swift/template/base.py Line 140 in 9092bc5
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
是的,数据中的loss_scale可以传入 ms-swift/swift/template/base.py Lines 1233 to 1234 in 9092bc5
|
||||||||||
| row_loss_scale = kwargs.get('loss_scale') | ||||||||||
| if row_loss_scale is not None: | ||||||||||
| # Use per-row loss_scale with higher priority than global setting | ||||||||||
| from .mapping import get_loss_scale | ||||||||||
|
||||||||||
| try: | ||||||||||
| loss_scale_handler = get_loss_scale(row_loss_scale) | ||||||||||
| # Call the handler without 'loss_scale' in kwargs to avoid infinite recursion | ||||||||||
| kwargs_without_loss_scale = {k: v for k, v in kwargs.items() if k != 'loss_scale'} | ||||||||||
| return loss_scale_handler(context_list, context_types, messages, **kwargs_without_loss_scale) | ||||||||||
| except (KeyError, ValueError) as e: | ||||||||||
| # If invalid loss_scale specified in data row, fall back to global setting | ||||||||||
| logger.warning(f"Invalid loss_scale '{row_loss_scale}' specified in data row, " | ||||||||||
| f"falling back to global setting '{self.base_strategy}'. Error: {e}") | ||||||||||
| pass | ||||||||||
|
|
||||||||||
| res_context_list = [] | ||||||||||
| res_loss_scale = [] | ||||||||||
| i = 0 | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be implemented by modifying here?
ms-swift/swift/template/template_inputs.py
Lines 57 to 62 in 9092bc5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分代码不影响,已还原