Skip to content

Commit 9008153

Browse files
authored
fix template bug for qwen3 reranker (#4795)
1 parent 841c2e7 commit 9008153

File tree

2 files changed

+59
-9
lines changed

2 files changed

+59
-9
lines changed

swift/llm/template/base.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,10 @@ def _reranker_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
405405

406406
positive = deepcopy(inputs)
407407
positive.rejected_response = []
408+
if '{doc}' in positive.messages[-2]['content']:
409+
positive.messages[-2]['content'] = positive.messages[-2]['content'].replace(
410+
'{doc}', inputs.messages[-1]['content'])
411+
positive.messages.pop(-1)
408412
positive_encoded = self._encode_truncated(positive)
409413
for key in positive_encoded:
410414
_encoded[f'positive_{key}'] = positive_encoded[key]
@@ -414,7 +418,12 @@ def _reranker_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
414418
rejected_len = len(inputs.rejected_response) if inputs.rejected_response else 0
415419
for i in range(rejected_len):
416420
negative = deepcopy(inputs)
417-
negative.messages[-1]['content'] = negative.rejected_response[i]
421+
if '{doc}' in negative.messages[-2]['content']:
422+
negative.messages[-2]['content'] = negative.messages[-2]['content'].replace(
423+
'{doc}', negative.rejected_response[i])
424+
negative.messages.pop(-1)
425+
else:
426+
negative.messages[-1]['content'] = negative.rejected_response[i]
418427
negative.rejected_response = []
419428
negative_encoded = self._encode_truncated(negative)
420429
for key in negative_encoded:
@@ -1637,19 +1646,62 @@ def _torchacc_xtuner_data_collator(self, res, padding_to, tokenizer, padding_sid
16371646
def print_inputs(self, inputs: Dict[str, Any], tokenizer_kwargs: Optional[Dict[str, Any]] = None) -> None:
16381647
if tokenizer_kwargs is None:
16391648
tokenizer_kwargs = {}
1640-
for key in [
1641-
'input', 'labels', 'generate', 'chosen_input', 'chosen_labels', 'rejected_input', 'rejected_labels'
1642-
]:
1649+
1650+
# Base keys to check
1651+
base_keys = [
1652+
'input', 'labels', 'generate', 'chosen_input', 'chosen_labels', 'rejected_input', 'rejected_labels'
1653+
]
1654+
1655+
# For reranker/embedding modes, also check prefixed keys
1656+
if self.mode in {'reranker', 'generative_reranker', 'embedding'}:
1657+
prefixes = []
1658+
if self.mode in {'reranker', 'generative_reranker'}:
1659+
prefixes = ['positive_', 'negative_']
1660+
elif self.mode == 'embedding':
1661+
prefixes = ['anchor_', 'positive_', 'negative_']
1662+
1663+
# Add prefixed keys for reranker/embedding modes
1664+
extended_keys = base_keys.copy()
1665+
for prefix in prefixes:
1666+
for base_key in ['input', 'labels']:
1667+
extended_keys.append(f'{prefix}{base_key}')
1668+
1669+
# Also check for numbered negative keys (negative0_, negative1_, etc.)
1670+
input_keys = list(inputs.keys())
1671+
for key in input_keys:
1672+
if any(key.startswith(f'{prefix}') for prefix in prefixes):
1673+
# Extract the base key after removing prefix
1674+
for prefix in prefixes:
1675+
if key.startswith(prefix):
1676+
base_key = key[len(prefix):]
1677+
if base_key in ['input_ids', 'labels'
1678+
] or base_key.rstrip('0123456789_') in ['input', 'labels']:
1679+
extended_keys.append(key.replace('_ids', ''))
1680+
break
1681+
1682+
keys_to_check = list(set(extended_keys))
1683+
else:
1684+
keys_to_check = base_keys
1685+
1686+
for key in keys_to_check:
1687+
# Skip labels completely for certain modes
1688+
if key.endswith('labels') and self.mode in {'reranker', 'generative_reranker'}:
1689+
continue
1690+
16431691
val = inputs.get(key) # fix val is a tensor
16441692
if val is None:
16451693
val = inputs.get(f'{key}_ids')
16461694
if val is not None:
16471695
key_upper = key.upper()
16481696
logger.info(f'[{key_upper}_IDS] {val}')
1649-
if key == 'labels' and self.mode in {'seq_cls', 'embedding', 'reranker', 'generative_reranker'}:
1697+
if key.endswith('labels') and self.mode in {'seq_cls', 'embedding'}:
16501698
continue
16511699
if isinstance(val, (list, tuple, torch.Tensor)):
1652-
val_str = self.safe_decode(val, **tokenizer_kwargs)
1700+
# Handle nested lists (e.g., for reranker negative samples)
1701+
if isinstance(val, (list, tuple)) and len(val) > 0 and isinstance(val[0], (list, tuple)):
1702+
val_str = [self.safe_decode(sub_val, **tokenizer_kwargs) for sub_val in val]
1703+
else:
1704+
val_str = self.safe_decode(val, **tokenizer_kwargs)
16531705
logger.info(f'[{key_upper}] {val_str}')
16541706
if inputs.get('loss_scale') is not None:
16551707
val = inputs['loss_scale']

swift/llm/template/template/qwen.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,8 @@ class Qwen3RerankerTemplate(Template):
6969
def _preprocess_inputs(self, inputs: StdTemplateInputs) -> None:
7070
super()._preprocess_inputs(inputs)
7171
query = inputs.messages[-2]['content']
72-
doc = inputs.messages[-1]['content']
73-
user_message = '<Instruct>: ' + self.instruction + '\n' + '<Query>: ' + query + '\n' + '<Document>: ' + doc
72+
user_message = '<Instruct>: ' + self.instruction + '\n' + '<Query>: ' + query + '\n' + '<Document>: {doc}'
7473
inputs.messages[-2]['content'] = user_message
75-
inputs.messages.pop(-1)
7674

7775

7876
qwen3_reranker_system = (

0 commit comments

Comments
 (0)