Skip to content

[plugin] refactor loss_type/loss_scale #5337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 11, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/framework.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ numpy
openai
oss2
pandas
peft>=0.11,<0.17
peft>=0.11,<0.18
pillow
PyYAML>=5.4
requests
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from transformers import Seq2SeqTrainingArguments
from transformers.utils.versions import require_version

from swift.plugin import LOSS_MAPPING
from swift.plugin import loss_mapping
from swift.trainers import TrainerFactory
from swift.trainers.arguments import TrainArgumentsMixin
from swift.utils import (add_version_to_work_dir, get_device_count, get_logger, get_pai_tensorboard_dir, is_master,
Expand Down Expand Up @@ -118,7 +118,7 @@ class TrainArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTra
create_checkpoint_symlink: bool = False

# plugin
loss_type: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(LOSS_MAPPING.keys())}'})
loss_type: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(loss_mapping.keys())}'})
metric: Optional[str] = None

# extra
Expand Down
8 changes: 2 additions & 6 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,10 +971,6 @@ def _encode_context_list(
tokenizer_kwargs = {}
if loss_scale_list is None:
loss_scale_list = [0.] * len(context_list)
if self.loss_scale.keep_loss_scale:
ignore_loss_scale = False
else:
ignore_loss_scale = all(loss_scale in {0, 1} for loss_scale in loss_scale_list)
for i, (context, loss_weight) in enumerate(zip(context_list, loss_scale_list)):
if isinstance(context, str):
# tokenizer_kwargs is the returned tokenizer_kwargs,
Expand All @@ -987,9 +983,9 @@ def _encode_context_list(
labels += token_list
else:
labels += [-100] * len(token_list)
if not ignore_loss_scale:
if not self.loss_scale.is_binary:
loss_scale.extend([loss_weight] * len(token_list))
if ignore_loss_scale:
if self.loss_scale.is_binary:
loss_scale = None
return input_ids, labels, loss_scale, tokenizer_kwargs

Expand Down
2 changes: 2 additions & 0 deletions swift/megatron/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def convert_hf2mcore(args: ExportArguments) -> None:
megatron_model_meta.convert_hf2mcore(hf_model, mg_model)
if args.test_convert_precision:
test_convert_precision(hf_model, mg_model, template)
del hf_model
logger.info('Successfully transferred HF model weights to MG model.')
args.save_args()
mg_save_checkpoint(1, [mg_model], None, None, 0)
Expand Down Expand Up @@ -228,6 +229,7 @@ def convert_mcore2hf(args: ExportArguments) -> None:
megatron_model_meta.convert_mcore2hf(hf_model, mg_model)
if args.test_convert_precision:
test_convert_precision(hf_model, mg_model, template)
del mg_model
logger.info('Successfully transferred MG model weights to HF model.')
ckpt_dir = megatron_args.load if megatron_args.adapter_load is None else megatron_args.adapter_load
save_checkpoint(
Expand Down
4 changes: 2 additions & 2 deletions swift/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

if TYPE_CHECKING:
from .callback import extra_callbacks
from .loss import LOSS_MAPPING, get_loss_func
from .loss import loss_mapping, get_loss_func
from .loss_scale import loss_scale_map
from .metric import InferStats, MeanMetric, Metric, compute_acc, get_metric, compute_rouge_bleu
from .optimizer import optimizers_map
Expand All @@ -21,7 +21,7 @@
else:
_import_structure = {
'callback': ['extra_callbacks'],
'loss': ['LOSS_MAPPING', 'get_loss_func'],
'loss': ['loss_mapping', 'get_loss_func'],
'loss_scale': ['loss_scale_map'],
'metric': ['InferStats', 'MeanMetric', 'Metric', 'compute_acc', 'get_metric', 'compute_rouge_bleu'],
'optimizer': ['optimizers_map'],
Expand Down
63 changes: 19 additions & 44 deletions swift/plugin/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,6 @@
from swift.plugin import MeanMetric


class LossType:
loss_scale = 'loss_scale'
cosine_similarity = 'cosine_similarity'
contrastive = 'contrastive'
online_contrastive = 'online_contrastive'
infonce = 'infonce'
channel_loss = 'channel_loss'
reranker = 'reranker'
generative_reranker = 'generative_reranker'
listwise_reranker = 'listwise_reranker'
listwise_generative_reranker = 'listwise_generative_reranker'


LOSS_MAPPING = {}


def register_loss_func(loss_type: str, loss_func: Optional[Callable] = None):
loss_info = {}

if loss_func is not None:
loss_info['loss_func'] = loss_func
LOSS_MAPPING[loss_type] = loss_info
return

def _register_loss_func(loss_func: Callable) -> Callable:
loss_info['loss_func'] = loss_func
LOSS_MAPPING[loss_type] = loss_info
return loss_func

return _register_loss_func


def ce_loss_func(outputs, labels):
logits = outputs.logits
device = logits.device
Expand All @@ -63,8 +31,6 @@ def ce_loss_func(outputs, labels):
return loss, masks


# Use @register_loss_func to decorate your own loss, use --loss_type xxx to train
@register_loss_func(LossType.loss_scale)
def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
"""Loss func

Expand Down Expand Up @@ -117,7 +83,6 @@ class SiameseDistanceMetric(Enum):
COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y) # noqa


@register_loss_func(LossType.cosine_similarity)
def cosine_similarity_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
cos_score_transformation = nn.Identity()
loss_fct = MSELoss()
Expand All @@ -126,7 +91,6 @@ def cosine_similarity_func(outputs, labels, loss_scale=None, num_items_in_batch=
return loss_fct(output, labels.to(output.dtype).view(-1))


@register_loss_func(LossType.contrastive)
def contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
sentence1, sentence2 = _parse_pair_sentence(outputs)
distance_metric = SiameseDistanceMetric.COSINE_DISTANCE
Expand Down Expand Up @@ -390,7 +354,6 @@ def _parse_multi_negative_sentences(sentences, labels, hard_negatives=None):
return split_tensors


@register_loss_func(LossType.infonce)
def infonce_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
temperature = float(os.environ.get('INFONCE_TEMPERATURE', '0.01')) # temperature
# calculate CE across the batch, meaning all samples will be negative except the matching positive
Expand Down Expand Up @@ -491,7 +454,6 @@ def mask_fake_negative(sim_matrix, sim_labels):
return loss


@register_loss_func(LossType.online_contrastive)
def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
sentence1, sentence2 = _parse_pair_sentence(outputs)
distance_metric = SiameseDistanceMetric.COSINE_DISTANCE
Expand All @@ -510,7 +472,6 @@ def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch
return loss


@register_loss_func(LossType.channel_loss)
def channel_loss_func(outputs,
labels,
num_items_in_batch=None,
Expand Down Expand Up @@ -583,7 +544,6 @@ def channel_loss_func(outputs,
else total_loss / (total_tokens.float() + 1e-12)


@register_loss_func(LossType.reranker)
def reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
logits = outputs.logits
logits = logits.squeeze(1)
Expand All @@ -593,7 +553,6 @@ def reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) ->
return loss


@register_loss_func(LossType.generative_reranker)
def generative_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None, trainer=None) -> torch.Tensor:
"""
Generative reranker loss function.
Expand Down Expand Up @@ -649,7 +608,6 @@ def generative_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batc
return loss


@register_loss_func(LossType.listwise_reranker)
def listwise_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
"""
List-wise reranker loss function.
Expand Down Expand Up @@ -739,7 +697,6 @@ def listwise_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=
return total_loss / num_groups


@register_loss_func(LossType.listwise_generative_reranker)
def listwise_generative_reranker_loss(outputs,
labels,
loss_scale=None,
Expand Down Expand Up @@ -863,7 +820,25 @@ def listwise_generative_reranker_loss(outputs,
return total_loss / num_groups


loss_mapping = {
'loss_scale': loss_scale_func,
'channel_loss': channel_loss_func,
# embedding
'cosine_similarity': cosine_similarity_func,
'contrastive': contrastive_loss,
'online_contrastive': online_contrastive_loss,
'infonce': infonce_loss,
# reranker
'reranker': reranker_loss,
'generative_reranker': generative_reranker_loss,
'listwise_reranker': listwise_reranker_loss,
'listwise_generative_reranker': listwise_generative_reranker_loss,
}

# use --loss_type xxx to train


def get_loss_func(loss_type: Optional[str]) -> Optional[Callable]:
if loss_type is None:
return None
return LOSS_MAPPING[loss_type]['loss_func']
return loss_mapping[loss_type]['loss_func']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The structure of loss_mapping has been refactored to be a direct mapping from a string to a function. However, the access pattern here still assumes the old nested dictionary structure (['loss_func']). This will cause a TypeError at runtime. You should directly return the function from the map.

Suggested change
return loss_mapping[loss_type]['loss_func']
return loss_mapping[loss_type]

31 changes: 13 additions & 18 deletions swift/plugin/loss_scale/loss_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,10 @@


class LossScale:
# If the values in loss_scale consist only of 0 and 1.
is_binary = False
loss_scale_config = None # path

def _set_keep_loss_scale(self):
self.keep_loss_scale = False
if self.loss_scale_map is None:
return
res = set()
for v in self.loss_scale_map.values():
res.update(v)
if len(res - {0., 1.}) > 0:
self.keep_loss_scale = True

def __init__(self):
if self.loss_scale_config is not None:
path = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -30,14 +22,9 @@ def __init__(self):
self.loss_scale_map = json.load(json_file)
else:
self.loss_scale_map = None
self._set_keep_loss_scale()

def get_loss_scale(self,
context: str,
context_type: ContextType,
is_last_round: bool,
*,
query: Optional[str] = None) -> Tuple[List[str], List[float]]:
def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool,
**kwargs) -> Tuple[List[str], List[float]]:
"""Calculate loss scale

Args:
Expand Down Expand Up @@ -80,7 +67,12 @@ def __call__(self, context_list: List[str], context_types: List[ContextType], me
return res_context_list, res_loss_scale


class DefaultLossScale(LossScale):
is_binary = True


class LastRoundLossScale(LossScale):
is_binary = True

def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool, **kwargs):
if context_type == ContextType.RESPONSE:
Expand Down Expand Up @@ -129,17 +121,20 @@ class AlphaUmiLossScale(REACTLossScale):


class TrainAllLossScale(LossScale):
is_binary = True

def get_loss_scale(self, context: str, context_type: ContextType, *args, **kwargs):
return [context], [1.]


class IgnoreEmptyThink(REACTLossScale):
loss_scale_config = 'ignore_empty_think.json'
is_binary = True


class LastRoundWithIgnoreEmptyThink(LossScale):
loss_scale_config = 'ignore_empty_think.json'
is_binary = True

def get_loss_scale(self,
context: str,
Expand All @@ -159,7 +154,7 @@ def get_loss_scale(self,
# Add your loss scale here, use --loss_scale xxx to train
loss_scale_map = {
'last_round': LastRoundLossScale,
'default': LossScale,
'default': DefaultLossScale,
'all': TrainAllLossScale,
'ignore_empty_think': IgnoreEmptyThink,
'last_round_with_ignore_empty_think': LastRoundWithIgnoreEmptyThink,
Expand Down
Loading