Skip to content

Commit b66f661

Browse files
committed
Fix the missing eval_acc issue (when use_logits_to_keep is True) (#4938)
1 parent 4b6bba4 commit b66f661

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

swift/llm/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import torch.nn as nn
1111
from modelscope.hub.utils.utils import get_cache_dir
12+
from peft import PeftModel
1213
from transformers import FeatureExtractionMixin, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase
1314
from transformers import ProcessorMixin as HfProcessorMixin
1415

@@ -152,6 +153,8 @@ def _new_forward(self, *args, **kwargs):
152153

153154
def dynamic_gradient_checkpointing(model, including_vit: bool = False) -> None:
154155
from .model import ModelMeta, get_model_arch
156+
if isinstance(model, PeftModel):
157+
model = model.model
155158
model_meta: ModelMeta = model.model_meta
156159
model_arch = get_model_arch(model_meta.model_arch)
157160
if model_meta.is_multimodal and model_arch:

swift/trainers/trainers.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -265,10 +265,10 @@ def prediction_step(
265265
labels_list = pad_sequence(labels_list, batch_first=True, padding_value=0)
266266
return None, response_list, labels_list
267267

268-
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
268+
def _prepare_inputs(self, inputs):
269+
inputs = super()._prepare_inputs(inputs)
269270
from swift.plugin.loss import get_loss_func
270271
loss_kwargs = {}
271-
labels = None
272272
compute_loss_func = self.compute_loss_func
273273
loss_scale = inputs.pop('loss_scale', None)
274274
if loss_scale is not None:
@@ -287,14 +287,25 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
287287
if inputs.get('position_ids') is not None:
288288
loss_kwargs['position_ids'] = inputs['position_ids']
289289

290-
if (self.label_smoother is not None or compute_loss_func is not None) and 'labels' in inputs:
291-
labels = inputs.pop('labels')
292-
293-
use_logits_to_keep = self.get_use_logits_to_keep('labels' in inputs)
290+
use_logits_to_keep = self.get_use_logits_to_keep('labels' in inputs and self.label_smoother is None
291+
and compute_loss_func is None)
294292
if use_logits_to_keep:
295293
inputs['labels'], logits_to_keep = self.get_logits_to_keep(inputs['labels'])
296294
if logits_to_keep is not None:
297295
inputs['logits_to_keep'] = logits_to_keep
296+
297+
inputs['compute_loss_func'] = compute_loss_func
298+
inputs['loss_kwargs'] = loss_kwargs
299+
return inputs
300+
301+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
302+
labels = None
303+
compute_loss_func = inputs.pop('compute_loss_func', None)
304+
loss_kwargs = inputs.pop('loss_kwargs', {})
305+
306+
if (self.label_smoother is not None or compute_loss_func is not None) and 'labels' in inputs:
307+
labels = inputs.pop('labels')
308+
298309
outputs = model(**inputs)
299310
# Save past state if it exists
300311
# TODO: this needs to be fixed and made cleaner later.

0 commit comments

Comments
 (0)