Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/framework.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ numpy
openai
oss2
pandas
peft>=0.11,<0.17
peft>=0.11,<0.18
pillow
PyYAML>=5.4
requests
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from transformers import Seq2SeqTrainingArguments
from transformers.utils.versions import require_version

from swift.plugin import LOSS_MAPPING
from swift.plugin import loss_mapping
from swift.trainers import TrainerFactory
from swift.trainers.arguments import TrainArgumentsMixin
from swift.utils import (add_version_to_work_dir, get_device_count, get_logger, get_pai_tensorboard_dir, is_master,
Expand Down Expand Up @@ -118,7 +118,7 @@ class TrainArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTra
create_checkpoint_symlink: bool = False

# plugin
loss_type: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(LOSS_MAPPING.keys())}'})
loss_type: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(loss_mapping.keys())}'})
metric: Optional[str] = None

# extra
Expand Down
8 changes: 2 additions & 6 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,10 +971,6 @@ def _encode_context_list(
tokenizer_kwargs = {}
if loss_scale_list is None:
loss_scale_list = [0.] * len(context_list)
if self.loss_scale.keep_loss_scale:
ignore_loss_scale = False
else:
ignore_loss_scale = all(loss_scale in {0, 1} for loss_scale in loss_scale_list)
for i, (context, loss_weight) in enumerate(zip(context_list, loss_scale_list)):
if isinstance(context, str):
# tokenizer_kwargs is the returned tokenizer_kwargs,
Expand All @@ -987,9 +983,9 @@ def _encode_context_list(
labels += token_list
else:
labels += [-100] * len(token_list)
if not ignore_loss_scale:
if not self.loss_scale.is_binary:
loss_scale.extend([loss_weight] * len(token_list))
if ignore_loss_scale:
if self.loss_scale.is_binary:
loss_scale = None
return input_ids, labels, loss_scale, tokenizer_kwargs

Expand Down
2 changes: 2 additions & 0 deletions swift/megatron/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def convert_hf2mcore(args: ExportArguments) -> None:
megatron_model_meta.convert_hf2mcore(hf_model, mg_model)
if args.test_convert_precision:
test_convert_precision(hf_model, mg_model, template)
del hf_model
logger.info('Successfully transferred HF model weights to MG model.')
args.save_args()
mg_save_checkpoint(1, [mg_model], None, None, 0)
Expand Down Expand Up @@ -228,6 +229,7 @@ def convert_mcore2hf(args: ExportArguments) -> None:
megatron_model_meta.convert_mcore2hf(hf_model, mg_model)
if args.test_convert_precision:
test_convert_precision(hf_model, mg_model, template)
del mg_model
logger.info('Successfully transferred MG model weights to HF model.')
ckpt_dir = megatron_args.load if megatron_args.adapter_load is None else megatron_args.adapter_load
save_checkpoint(
Expand Down
4 changes: 2 additions & 2 deletions swift/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

if TYPE_CHECKING:
from .callback import extra_callbacks
from .loss import LOSS_MAPPING, get_loss_func
from .loss import loss_mapping, get_loss_func
from .loss_scale import loss_scale_map
from .metric import InferStats, MeanMetric, Metric, compute_acc, get_metric, compute_rouge_bleu
from .optimizer import optimizers_map
Expand All @@ -21,7 +21,7 @@
else:
_import_structure = {
'callback': ['extra_callbacks'],
'loss': ['LOSS_MAPPING', 'get_loss_func'],
'loss': ['loss_mapping', 'get_loss_func'],
'loss_scale': ['loss_scale_map'],
'metric': ['InferStats', 'MeanMetric', 'Metric', 'compute_acc', 'get_metric', 'compute_rouge_bleu'],
'optimizer': ['optimizers_map'],
Expand Down
108 changes: 28 additions & 80 deletions swift/plugin/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,78 +15,18 @@
from swift.plugin import MeanMetric


class LossType:
loss_scale = 'loss_scale'
cosine_similarity = 'cosine_similarity'
contrastive = 'contrastive'
online_contrastive = 'online_contrastive'
infonce = 'infonce'
channel_loss = 'channel_loss'
reranker = 'reranker'
generative_reranker = 'generative_reranker'
listwise_reranker = 'listwise_reranker'
listwise_generative_reranker = 'listwise_generative_reranker'


LOSS_MAPPING = {}


def register_loss_func(loss_type: str, loss_func: Optional[Callable] = None):
loss_info = {}

if loss_func is not None:
loss_info['loss_func'] = loss_func
LOSS_MAPPING[loss_type] = loss_info
return

def _register_loss_func(loss_func: Callable) -> Callable:
loss_info['loss_func'] = loss_func
LOSS_MAPPING[loss_type] = loss_info
return loss_func

return _register_loss_func


def ce_loss_func(outputs, labels):
def per_token_loss_func(outputs, labels, **kwargs):
logits = outputs.logits
device = logits.device
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:].to(device)
# Save memory
masks = shift_labels != -100
shift_logits = shift_logits[masks]
shift_labels = shift_labels[masks]
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction='none')
loss = loss_fct(shift_logits, shift_labels)
return loss, masks

# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = torch.roll(labels, shifts=-1, dims=-1)

# Use @register_loss_func to decorate your own loss, use --loss_type xxx to train
@register_loss_func(LossType.loss_scale)
def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
"""Loss func

Args:
outputs: The model outputs
labels: The labels
loss_scale: The loss scale
num_items_in_batch: Number of tokens in the labels of gradient accumulation round that are not -100.

Returns:

"""
loss, masks = ce_loss_func(outputs, labels)
if loss_scale is not None:
shift_scale = loss_scale[..., 1:].to(masks.device)
shift_scale = shift_scale[masks]
loss = (shift_scale * loss)
if num_items_in_batch is None:
loss = loss.mean()
else:
# compat transformers>=4.46
loss = loss.sum() / num_items_in_batch
# Flatten the tokens
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = F.cross_entropy(logits, labels, ignore_index=-100, reduction='none')
return loss


Expand Down Expand Up @@ -117,7 +57,6 @@ class SiameseDistanceMetric(Enum):
COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y) # noqa


@register_loss_func(LossType.cosine_similarity)
def cosine_similarity_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
cos_score_transformation = nn.Identity()
loss_fct = MSELoss()
Expand All @@ -126,7 +65,6 @@ def cosine_similarity_func(outputs, labels, loss_scale=None, num_items_in_batch=
return loss_fct(output, labels.to(output.dtype).view(-1))


@register_loss_func(LossType.contrastive)
def contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
sentence1, sentence2 = _parse_pair_sentence(outputs)
distance_metric = SiameseDistanceMetric.COSINE_DISTANCE
Expand Down Expand Up @@ -390,7 +328,6 @@ def _parse_multi_negative_sentences(sentences, labels, hard_negatives=None):
return split_tensors


@register_loss_func(LossType.infonce)
def infonce_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
temperature = float(os.environ.get('INFONCE_TEMPERATURE', '0.01')) # temperature
# calculate CE across the batch, meaning all samples will be negative except the matching positive
Expand Down Expand Up @@ -491,7 +428,6 @@ def mask_fake_negative(sim_matrix, sim_labels):
return loss


@register_loss_func(LossType.online_contrastive)
def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
sentence1, sentence2 = _parse_pair_sentence(outputs)
distance_metric = SiameseDistanceMetric.COSINE_DISTANCE
Expand All @@ -510,13 +446,13 @@ def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch
return loss


@register_loss_func(LossType.channel_loss)
def channel_loss_func(outputs,
labels,
num_items_in_batch=None,
sample_channels=None,
trainer=None,
position_ids=None) -> torch.Tensor:
# Note: loss_scale is not supported at the moment.
channels = trainer.args.channels
assert channels is not None, 'Please pass --channels as a hyperparameter.'
assert sample_channels is not None, 'Data does not have channel field.'
Expand Down Expand Up @@ -583,7 +519,6 @@ def channel_loss_func(outputs,
else total_loss / (total_tokens.float() + 1e-12)


@register_loss_func(LossType.reranker)
def reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
logits = outputs.logits
logits = logits.squeeze(1)
Expand All @@ -593,7 +528,6 @@ def reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) ->
return loss


@register_loss_func(LossType.generative_reranker)
def generative_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None, trainer=None) -> torch.Tensor:
"""
Generative reranker loss function.
Expand Down Expand Up @@ -649,7 +583,6 @@ def generative_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batc
return loss


@register_loss_func(LossType.listwise_reranker)
def listwise_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
"""
List-wise reranker loss function.
Expand Down Expand Up @@ -739,7 +672,6 @@ def listwise_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=
return total_loss / num_groups


@register_loss_func(LossType.listwise_generative_reranker)
def listwise_generative_reranker_loss(outputs,
labels,
loss_scale=None,
Expand Down Expand Up @@ -863,7 +795,23 @@ def listwise_generative_reranker_loss(outputs,
return total_loss / num_groups


loss_mapping = {
'per_token_cross_entropy': per_token_loss_func,
'channel_loss': channel_loss_func,
# embedding
'cosine_similarity': cosine_similarity_func,
'contrastive': contrastive_loss,
'online_contrastive': online_contrastive_loss,
'infonce': infonce_loss,
# reranker
'reranker': reranker_loss,
'generative_reranker': generative_reranker_loss,
'listwise_reranker': listwise_reranker_loss,
'listwise_generative_reranker': listwise_generative_reranker_loss,
}


def get_loss_func(loss_type: Optional[str]) -> Optional[Callable]:
if loss_type is None:
return None
return LOSS_MAPPING[loss_type]['loss_func']
return loss_mapping[loss_type]
35 changes: 17 additions & 18 deletions swift/plugin/loss_scale/loss_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,14 @@


class LossScale:
# Indicates whether loss_scale contains only 0 and 1.
# If set to True, loss_scale will be replaced by labels to stay compatible with
# acceleration techniques such as liger_kernel.
# If set to False, an additional 'loss_scale' key will be stored and the
# corresponding loss function will be used.
is_binary = False
loss_scale_config = None # path

def _set_keep_loss_scale(self):
self.keep_loss_scale = False
if self.loss_scale_map is None:
return
res = set()
for v in self.loss_scale_map.values():
res.update(v)
if len(res - {0., 1.}) > 0:
self.keep_loss_scale = True

def __init__(self):
if self.loss_scale_config is not None:
path = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -30,14 +26,9 @@ def __init__(self):
self.loss_scale_map = json.load(json_file)
else:
self.loss_scale_map = None
self._set_keep_loss_scale()

def get_loss_scale(self,
context: str,
context_type: ContextType,
is_last_round: bool,
*,
query: Optional[str] = None) -> Tuple[List[str], List[float]]:
def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool,
**kwargs) -> Tuple[List[str], List[float]]:
"""Calculate loss scale

Args:
Expand Down Expand Up @@ -80,7 +71,12 @@ def __call__(self, context_list: List[str], context_types: List[ContextType], me
return res_context_list, res_loss_scale


class DefaultLossScale(LossScale):
is_binary = True


class LastRoundLossScale(LossScale):
is_binary = True

def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool, **kwargs):
if context_type == ContextType.RESPONSE:
Expand Down Expand Up @@ -129,17 +125,20 @@ class AlphaUmiLossScale(REACTLossScale):


class TrainAllLossScale(LossScale):
is_binary = True

def get_loss_scale(self, context: str, context_type: ContextType, *args, **kwargs):
return [context], [1.]


class IgnoreEmptyThink(REACTLossScale):
loss_scale_config = 'ignore_empty_think.json'
is_binary = True


class LastRoundWithIgnoreEmptyThink(LossScale):
loss_scale_config = 'ignore_empty_think.json'
is_binary = True

def get_loss_scale(self,
context: str,
Expand All @@ -159,7 +158,7 @@ def get_loss_scale(self,
# Add your loss scale here, use --loss_scale xxx to train
loss_scale_map = {
'last_round': LastRoundLossScale,
'default': LossScale,
'default': DefaultLossScale,
'all': TrainAllLossScale,
'ignore_empty_think': IgnoreEmptyThink,
'last_round_with_ignore_empty_think': LastRoundWithIgnoreEmptyThink,
Expand Down
12 changes: 10 additions & 2 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,18 +786,26 @@ def _evalscope_eval(self):
self.model.train()
return eval_dict

def get_logits_to_keep(self, labels):
def prepare_logits_to_keep(self, inputs):
labels = inputs['labels']
loss_scale = inputs.get('loss_scale')
if labels.shape[0] == 1 and not is_mp():
# device_map may encounter device mismatch issues.
loss_mask = (labels != -100)[0]
labels = labels[:, loss_mask]
labels = nn.functional.pad(labels, (1, 0), value=-100)
if loss_scale is not None:
loss_scale = loss_scale[:, loss_mask]
inputs['loss_scale'] = nn.functional.pad(loss_scale, (1, 0), value=0)
logits_to_keep = nn.functional.pad(loss_mask[1:], (0, 1), value=True)
else:
logits_to_keep = labels.shape[-1] - ((labels != -100).int().argmax(-1).min().item()) + 1
assert logits_to_keep > 0
labels = labels[:, -logits_to_keep:]
return labels, logits_to_keep
if loss_scale is not None:
inputs['loss_scale'] = loss_scale[:, -logits_to_keep:]
inputs['labels'] = labels
inputs['logits_to_keep'] = logits_to_keep

def get_cu_seqlens(self, position_ids, logits_to_keep) -> torch.Tensor:
from swift.llm import get_packed_seq_params
Expand Down
Loading
Loading