Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions swift/arguments/base_args/base_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def get_template(self, processor: Optional[Processor] = None, **kwargs) -> Templ
template_type = self.template
template_kwargs['template_type'] = template_type
template = get_template(processor, **template_kwargs)
template.loss_type = getattr(self, 'loss_type', None)
return template

def get_model_processor(self,
Expand Down
7 changes: 4 additions & 3 deletions swift/megatron/trainers/reranker_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ def _get_listwise_reranker_preds(logits, labels):
labels = torch.tensor([0] * (len(positive_indices) - 1), device=preds.device)
return preds, labels

def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, packed_seq_params=None):
def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, group_sizes=None, packed_seq_params=None):
training = self.unwrapped_models[0].training
logits = self.get_last_tokens(output_tensor, packed_seq_params)
loss = self._loss_func(ModelOutputs(logits=logits), labels)
args = self.args
logits_detach = logits.detach().squeeze(-1)
if not training:
self.eval_metrics.update(logits_detach, labels)
self.eval_metrics.update(logits_detach, labels, group_sizes)
if args.loss_type == 'listwise_reranker':
preds, labels = self._get_listwise_reranker_preds(logits_detach, labels)
else:
Expand All @@ -64,7 +64,8 @@ def forward_step(self, data_iterator, model):
vp_stage = model.module.module.vp_stage
data = self.get_batch(data_iterator, vp_stage)
labels = data.pop('labels', None)
group_sizes = data.pop('group_sizes', None)
output_tensor = model(**data)
packed_seq_params = data.get('packed_seq_params')
loss_func = partial(self.loss_func, labels=labels, packed_seq_params=packed_seq_params)
loss_func = partial(self.loss_func, labels=labels, group_sizes=group_sizes, packed_seq_params=packed_seq_params)
return output_tensor, loss_func
123 changes: 88 additions & 35 deletions swift/metrics/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,82 @@ def __init__(self, *args, **kwargs):
Metric.__init__(self)
self.add_state('logits', default_factory=list)
self.add_state('labels', default_factory=list)
self.add_state('group_sizes', default_factory=list)

def update(self, logits, labels):
def update(self, logits, labels, group_sizes=None):
self.logits.append(logits.cpu().numpy())
self.labels.append(labels.cpu().numpy())
if group_sizes is not None:
self.group_sizes.append(group_sizes.cpu().numpy())

def compute(self):
predictions = np.concatenate(self.logits)
labels = np.concatenate(self.labels)
return self._calculate_metrics(predictions, labels)
group_sizes = np.concatenate(self.group_sizes) if self.group_sizes else None
return self._calculate_metrics(predictions, labels, group_sizes)

def compute_metrics(self, eval_prediction: EvalPrediction) -> Dict[str, float]:
return self._calculate_metrics(eval_prediction.predictions, eval_prediction.label_ids)
label_ids = eval_prediction.label_ids
group_sizes = None
if isinstance(label_ids, (tuple, list)):
labels = label_ids[0]
if len(label_ids) > 1:
group_sizes = label_ids[1]
else:
labels = label_ids
return self._calculate_metrics(eval_prediction.predictions, labels, group_sizes)

@staticmethod
def _split_query_groups(logits, labels, group_sizes=None):
if group_sizes is not None:
group_sizes = np.array(group_sizes).astype(int).flatten()
total_size = int(group_sizes.sum())
if total_size == len(labels):
query_groups = []
start = 0
for group_size in group_sizes:
if group_size <= 0:
continue
end = start + group_size
query_groups.append((logits[start:end], labels[start:end]))
start = end
return query_groups
logger.warning('The sum of group_sizes does not match the number of labels. Falling back to label-based '
'query boundary inference.')

def _calculate_metrics(self, logits, labels):
positive_indices = np.where(labels == 1)[0]
if len(positive_indices) == 0:
return []

query_groups = []
for i, pos_idx in enumerate(positive_indices):
group_start = pos_idx
if i + 1 < len(positive_indices):
group_end = positive_indices[i + 1]
else:
group_end = len(labels)
query_groups.append((logits[group_start:group_end], labels[group_start:group_end]))
return query_groups

@staticmethod
def _calculate_classification_metrics(logits, labels):
preds = (logits > 0).astype(int)
labels = labels.astype(int)
tp = np.sum((preds == 1) & (labels == 1))
fp = np.sum((preds == 1) & (labels == 0))
fn = np.sum((preds == 0) & (labels == 1))
precision = tp / (tp + fp) if tp + fp > 0 else 0.0
recall = tp / (tp + fn) if tp + fn > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0.0
acc = np.mean(preds == labels) if len(labels) > 0 else 0.0
return {
'acc': acc,
'precision': precision,
'recall': recall,
'f1': f1,
}

def _calculate_metrics(self, logits, labels, group_sizes=None):
"""
Calculate MRR and NDCG metrics for reranker.

Expand Down Expand Up @@ -58,50 +120,36 @@ def _calculate_metrics(self, logits, labels):
logits = np.array(logits).flatten()
labels = np.array(labels).flatten()

# Step 1: Find all positive sample indices (query boundaries)
positive_indices = np.where(labels == 1)[0]

if len(positive_indices) == 0:
return {'mrr': 0.0, 'ndcg': 0.0}

# Step 2: Split into groups (queries)
query_groups = []
for i, pos_idx in enumerate(positive_indices):
# Each group starts at a positive index
group_start = pos_idx

# Group ends at the next positive index or end of data
if i + 1 < len(positive_indices):
group_end = positive_indices[i + 1]
else:
group_end = len(labels)
metrics = {}
if getattr(self.args, 'loss_type', None) == 'pointwise_reranker':
metrics.update(self._calculate_classification_metrics(logits, labels))

# Extract this query's data
query_logits = logits[group_start:group_end]
query_labels = labels[group_start:group_end]

query_groups.append((query_logits, query_labels))
query_groups = self._split_query_groups(logits, labels, group_sizes)
metrics['query_count'] = float(len(query_groups))

# Step 3: Calculate metrics for each query independently
mrr_scores = []
ndcg_scores = []
negative_only_query_count = 0
skipped_query_count = 0

for query_idx, (query_logits, query_labels) in enumerate(query_groups):
# Skip groups that are too small (need at least 1 positive + 1 negative)
if len(query_logits) < 2:
logger.info(f'Query {query_idx}: Skipped (too small: {len(query_logits)} items)')
skipped_query_count += 1
continue

# Verify that the first sample is positive (data format validation)
if query_labels[0] != 1:
logger.info(f'Query {query_idx}: Skipped (first sample not positive)')
if np.sum(query_labels == 1) == 0:
negative_only_query_count += 1
skipped_query_count += 1
continue

# Step 3a: Calculate ranking within this query
ranking = np.argsort(-query_logits) # Sort by logits descending

# Step 3b: Find position of positive document (should be at index 0 in query)
pos_rank = np.where(ranking == 0)[0][0] + 1 # +1 for 1-based ranking
# Step 3b: Find the rank of the highest-ranked positive document.
positive_mask = query_labels[ranking] == 1
pos_rank = np.where(positive_mask)[0][0] + 1 # +1 for 1-based ranking

# Step 3c: Calculate MRR for this query
mrr = 1.0 / pos_rank
Expand Down Expand Up @@ -133,14 +181,19 @@ def calculate_ndcg_single_query(relevance_scores, ranking):
ndcg_scores.append(ndcg)

# Step 4: Calculate mean metrics across all valid queries
metrics['ranking_query_count'] = float(len(mrr_scores))
metrics['negative_only_query_count'] = float(negative_only_query_count)
metrics['skipped_query_count'] = float(skipped_query_count)
if len(mrr_scores) == 0:
logger.warning('No valid queries found for metric calculation')
return {'mrr': 0.0, 'ndcg': 0.0}
metrics.update({'mrr': 0.0, 'ndcg': 0.0})
return metrics

mean_mrr = np.mean(mrr_scores)
mean_ndcg = np.mean(ndcg_scores)

return {
metrics.update({
'mrr': mean_mrr,
'ndcg': mean_ndcg,
}
})
return metrics
22 changes: 21 additions & 1 deletion swift/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,30 +1655,50 @@ def _reranker_data_collator(self,
if self.is_training:
max_positive_samples = int(os.environ.get('MAX_POSITIVE_SAMPLES', 1))
max_negative_samples = int(os.environ.get('MAX_NEGATIVE_SAMPLES', 7))
pointwise_negative_only = getattr(self, 'loss_type', None) == 'pointwise_reranker'
labels_list = []
group_sizes = [] if pointwise_negative_only else None
new_batch = []
for b in batch:
labels = b.pop('labels', None)
positive_num = sum(labels)
negative_num = len(labels) - positive_num
max_positive = min(positive_num, max_positive_samples)
max_negative = min(negative_num, max_negative_samples)
if pointwise_negative_only and positive_num == 0:
# Pointwise BCE can train on all-negative samples, so keep them instead of dropping the row.
sampled_negative_indices = random.sample(range(negative_num), max_negative)
for j in sampled_negative_indices:
new_batch.append(
{key: b[key][j]
for key in b.keys() if isinstance(b[key], list) and b[key][j] is not None})
labels_list.append(0)
if sampled_negative_indices and group_sizes is not None:
group_sizes.append(len(sampled_negative_indices))
continue
for i in random.sample(range(positive_num), max_positive):
group_size = 1
new_batch.append(
{key: b[key][i]
for key in b.keys() if isinstance(b[key], list) and b[key][i] is not None})
labels_list.append(1)
for j in random.sample(range(negative_num), max_negative):
sampled_negative_indices = random.sample(range(negative_num), max_negative)
for j in sampled_negative_indices:
new_batch.append({
key: b[key][j + positive_num]
for key in b.keys() if isinstance(b[key], list) and b[key][j + positive_num] is not None
})
labels_list.append(0)
group_size += 1
if group_sizes is not None:
group_sizes.append(group_size)
num_samples = len(new_batch)
res = self._data_collator(new_batch, padding_to=padding_to)
res['num_samples'] = num_samples
if labels_list:
res['labels'] = torch.tensor(labels_list, dtype=torch.long)
if group_sizes:
res['group_sizes'] = torch.tensor(group_sizes, dtype=torch.long)
else:
res = self._data_collator(batch, padding_to=padding_to)
return res
Expand Down
15 changes: 13 additions & 2 deletions swift/trainers/reranker_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,28 @@
logger = get_logger()


def gather_for_reranker_metrics(input_data, use_gather_object=False):
if isinstance(input_data, tuple):
return tuple(gather_for_reranker_metrics(data, use_gather_object=use_gather_object) for data in input_data)
if isinstance(input_data, list):
return [gather_for_reranker_metrics(data, use_gather_object=use_gather_object) for data in input_data]
return gather_for_unpadded_tensors(input_data, use_gather_object=use_gather_object)


class RerankerTrainer(Trainer):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gather_function = gather_for_unpadded_tensors
self.gather_function = gather_for_reranker_metrics
if getattr(self.args, 'loss_type', None) == 'pointwise_reranker' and 'group_sizes' not in self.label_names:
self.label_names.append('group_sizes')

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
# Check if we have a custom loss function
if self.compute_loss_func is not None:
# Get labels and compute outputs
labels = inputs.pop('labels', None)
group_sizes = inputs.pop('group_sizes', None)
outputs = model(**inputs)
if self.task_type == 'generative_reranker':
logits = outputs.logits
Expand Down Expand Up @@ -46,5 +57,5 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N

def evaluation_loop(self, *args, **kwargs):
output = super().evaluation_loop(*args, **kwargs)
self.gather_function = gather_for_unpadded_tensors
self.gather_function = gather_for_reranker_metrics
return output
74 changes: 74 additions & 0 deletions tests/train/test_reranker_collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch

from swift.template.base import Template


def _build_template(loss_type):
template = Template.__new__(Template)
template.is_training = True
template.loss_type = loss_type
template._data_collator = lambda batch, padding_to=None: {'encoded_batch': batch}
return template


def test_pointwise_reranker_collator_supports_negative_only():
template = _build_template('pointwise_reranker')
batch = [{
'input_ids': [[101], [102]],
'attention_mask': [[1], [1]],
'labels': [0, 0],
}]

res = Template._reranker_data_collator(template, batch)

assert res['num_samples'] == 2
assert torch.equal(res['labels'], torch.tensor([0, 0], dtype=torch.long))
assert torch.equal(res['group_sizes'], torch.tensor([2], dtype=torch.long))
assert len(res['encoded_batch']) == 2


def test_pointwise_reranker_collator_supports_positive_only():
template = _build_template('pointwise_reranker')
batch = [{
'input_ids': [[201], [202]],
'attention_mask': [[1], [1]],
'labels': [1, 1],
}]

res = Template._reranker_data_collator(template, batch)

assert res['num_samples'] == 2
assert torch.equal(res['labels'], torch.tensor([1, 1], dtype=torch.long))
assert torch.equal(res['group_sizes'], torch.tensor([1, 1], dtype=torch.long))
assert len(res['encoded_batch']) == 2


def test_listwise_reranker_collator_still_skips_negative_only():
template = _build_template('listwise_reranker')
batch = [{
'input_ids': [[301], [302]],
'attention_mask': [[1], [1]],
'labels': [0, 0],
}]

res = Template._reranker_data_collator(template, batch)

assert res['num_samples'] == 0
assert 'labels' not in res
assert 'group_sizes' not in res
assert res['encoded_batch'] == []


def test_reranker_collator_does_not_emit_group_sizes_without_custom_loss():
template = _build_template(None)
batch = [{
'input_ids': [[401], [402]],
'attention_mask': [[1], [1]],
'labels': [1, 0],
}]

res = Template._reranker_data_collator(template, batch)

assert res['num_samples'] == 2
assert torch.equal(res['labels'], torch.tensor([1, 0], dtype=torch.long))
assert 'group_sizes' not in res
Loading