Skip to content

Commit 225724c

Browse files
Fix the conflict between agent and CT (#379)
1 parent 0248e8a commit 225724c

File tree

5 files changed

+14
-5
lines changed

5 files changed

+14
-5
lines changed

docs/source/LLM/命令行参数.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
- `--gpu_memory_fraction`: 默认为None. 该参数旨在指定显卡最大可用显存比例的情况下运行训练,用于极限测试.
9797
- `--train_dataset_mix_ratio`: 默认为0. 该参数定义了如何进行数据集打混训练. 指定该参数时, 训练集会以`train_dataset_mix_ratio`倍数混合`train_dataset_mix_ds`指定的通用知识数据集, 使整体数据集长度达到`train_dataset_sample`.
9898
- `--train_dataset_mix_ds`: 默认为`ms-bench`. 用于防止知识遗忘的通用知识数据集.
99+
- `--use_loss_scale`: 默认为True. 生效时会讲Agent的部分字段(Action/Action Input部分)的loss权重加强以强化CoT, 对普通SFT场景没有任何效果.
99100

100101
### AdaLoRA微调参数
101102

swift/llm/agent/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def split_agent_parts_by(text: str, delimiters: List[str]):
5656
return text_list
5757

5858

59-
def calculate_loss_scale(response: str) -> Tuple[List[str], List[float]]:
59+
def calculate_loss_scale(response: str,
60+
use_loss_scale=True) -> Tuple[List[str], List[float]]:
6061
"""Calculate the loss scale by splitting the agent response.
6162
6263
This algorithm comes from paper: https://arxiv.org/pdf/2309.00986.pdf
@@ -76,11 +77,12 @@ def calculate_loss_scale(response: str) -> Tuple[List[str], List[float]]:
7677
7778
Args:
7879
response: The response text
80+
use_loss_scale: Use weighted loss. With this, some part of the loss will be enhanced to improve performance.
7981
8082
Returns:
8183
A tuple of agent response parts and their weights.
8284
"""
83-
if 'Action:' in response and 'Thought:' in response:
85+
if 'Action:' in response and 'Observation:' in response and use_loss_scale:
8486
agent_keyword = [
8587
'Action:', 'Action Input:', 'Thought:', 'Final Answer:',
8688
'Observation:'

swift/llm/sft.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
134134
use_model = template_info.get('use_model', False)
135135
if use_model:
136136
template_kwargs['model'] = model
137+
template_kwargs['use_loss_scale'] = args.use_loss_scale
137138
template: Template = get_template(args.template_type, tokenizer,
138139
args.system, args.max_length,
139140
args.truncation_strategy,

swift/llm/utils/argument.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class SftArguments:
7070
train_dataset_mix_ds: List[str] = field(
7171
default_factory=lambda: ['ms-bench'])
7272
val_dataset_sample: Optional[int] = None # -1: all dataset
73+
use_loss_scale: Optional[bool] = True
7374
system: Optional[str] = None
7475
max_length: int = 2048 # -1: no limit
7576
truncation_strategy: Literal['delete', 'truncation_left'] = 'delete'

swift/llm/utils/template.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def _init_template(self,
177177
self.max_length = max_length
178178
self.truncation_strategy = truncation_strategy
179179
self.model = kwargs.get('model', None)
180+
self.use_loss_scale = kwargs.get('use_loss_scale', True)
180181
for key in [
181182
'prefix', 'prompt', 'chat_sep', 'suffix', 'prefix_has_system'
182183
]:
@@ -207,6 +208,8 @@ def encode(
207208
system = None
208209
else:
209210
assert self.prefix_has_system is not None, 'The template does not support `system`.'
211+
if query is None:
212+
query = ''
210213
inputs, tokenizer_kwargs = self._encode(query, response, history,
211214
system,
212215
self.truncation_strategy)
@@ -233,7 +236,8 @@ def _concat_context_list(
233236
if isinstance(context, str):
234237
if '{{RESPONSE}}' == context:
235238
assert response is not None
236-
content_part, weight_part = calculate_loss_scale(response)
239+
content_part, weight_part = calculate_loss_scale(
240+
response, self.use_loss_scale)
237241
res_context_list.extend(content_part)
238242
compute_loss_idx.extend(weight_part)
239243
continue
@@ -330,7 +334,7 @@ def _encode(
330334
# last response
331335
context_list.append('{{RESPONSE}}')
332336
context_list += self.suffix
333-
if q is not None:
337+
if q or r:
334338
self._concat_context_list(
335339
context_list,
336340
res_context_list,
@@ -457,7 +461,7 @@ def register_template(template_type: str,
457461
class DefaultGenerationTemplate(Template):
458462

459463
def __init__(self):
460-
return super().__init__([], ['{{QUERY}}'], None, [['eos_token_id']])
464+
super().__init__([], ['{{QUERY}}'], None, [['eos_token_id']])
461465

462466

463467
register_template(TemplateType.default_generation, DefaultGenerationTemplate())

0 commit comments

Comments
 (0)