Skip to content

Commit a103c70

Browse files
authored
add thinking dataset support + bugfix (#2672)
1 parent 32e4a08 commit a103c70

File tree

4 files changed

+33
-7
lines changed

4 files changed

+33
-7
lines changed

paddleformers/datasets/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def __init__(
161161
sub_dataset_type=["erniekit"],
162162
random_seed=11,
163163
process_fn=None,
164+
process_fn_fc=None,
164165
shuffle_file=False,
165166
shuffle_files=False,
166167
):
@@ -226,7 +227,7 @@ def __init__(
226227
task["dataset"] = FileDataset(
227228
task["filepath"],
228229
process_fn=(
229-
partial(process_fn, task_name=task["task_name"]) if "task_name" in task else process_fn
230+
partial(process_fn_fc, task_name=task["task_name"]) if "task_name" in task else process_fn_fc
230231
),
231232
shuffle_file=shuffle_file,
232233
)

paddleformers/datasets/finetuning.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def create_dataset(**dataset_config):
6969
task_dataset_path=task_dataset_path,
7070
task_dataset_prob=task_dataset_prob,
7171
sub_dataset_type=sub_dataset_type,
72-
process_fn=(process_fc if dataset_config["sub_dataset_type"] == "chatml" else process_example),
72+
process_fn=process_example,
73+
process_fn_fc=process_fc,
7374
)
7475
sequence_dataset = SequenceDataset(
7576
dataset=example_dataset,
@@ -174,7 +175,7 @@ def collate_fn(batch: List[List[Sequence]], tokenizer, model_args, max_seq_len:
174175

175176
def process_fc(data, input_file):
176177
multi_turns_messages = data["messages"]
177-
tools_list = data["tools"]
178+
tools_list = data["tools"] if "tools" in data else None
178179
label = data["label"] if "label" in data else None
179180

180181
system = ""
@@ -507,17 +508,26 @@ def __iter__(self):
507508

508509
def function_call_chat_template(self, messages, tools):
509510
history = messages[:-1]
511+
input_dict = dict()
512+
input_dict["messages"] = history
513+
if tools is not None:
514+
input_dict["tools"] = tools
510515
history_str = self.tokenizer.apply_chat_template(
511-
{"messages": history, "tools": tools},
516+
input_dict,
512517
add_generation_prompt=True,
513518
tokenize=False,
514519
)
515520
history_len = len(history_str)
521+
input_dict["messages"] = messages
516522
all_str = self.tokenizer.apply_chat_template(
517-
{"messages": messages, "tools": tools},
523+
input_dict,
518524
add_generation_prompt=False,
519525
tokenize=False,
520526
)
527+
# (21b think model) remove generation content
528+
s = "<|im_end|>\n\n<|im_start|>assistant\n<think>\n"
529+
if all_str.endswith(s):
530+
all_str = all_str[: -len(s)]
521531
response_str = all_str[history_len:]
522532
history_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(history_str))
523533
response_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(response_str))
@@ -591,7 +601,7 @@ def _postprocess_sequence(self, example, actual_example_num):
591601
if LOGGER_COUNT <= 5:
592602
logger.warning(f"even one turn, example_output:'{{'src':[{sub_src}, ……],'tgt':[……{sub_tgt}]}}'")
593603
except Exception:
594-
logger.warning(f"[SKIP] wrong example: {example}")
604+
logger.warning("[SKIP] wrong example")
595605

596606
return None
597607

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,9 @@ def get_package_data_files(package, data, package_dir=None):
186186
where=".",
187187
exclude=("examples*", "tests*", "applications*", "fast_generation*", "model_zoo*"),
188188
),
189-
package_data={},
189+
package_data={
190+
"paddleformers": ["datasets/hf/data_info.json"],
191+
},
190192
setup_requires=["cython", "numpy"],
191193
install_requires=REQUIRED_PACKAGES,
192194
entry_points={"console_scripts": ["paddleformers = paddleformers.cli:main"]},

0 commit comments

Comments
 (0)