Skip to content

Commit 03f4b1d

Browse files
TongLi3701Tong Li
andauthored
add prompt template (#6273)
Co-authored-by: Tong Li <[email protected]>
1 parent 9467c10 commit 03f4b1d

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

applications/ColossalChat/coati/dataset/loader.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,12 +352,14 @@ def apply_chat_template_and_mask(
352352
tokenizer: PreTrainedTokenizer,
353353
chat: List[Dict[str, str]],
354354
max_length: Optional[int] = None,
355+
system_prompt: str = None,
355356
padding: bool = True,
356357
truncation: bool = True,
357358
ignore_idx: int = -100,
358359
) -> Dict[str, torch.Tensor]:
359360

360-
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"
361+
if system_prompt is None:
362+
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"
361363

362364
system_element = {
363365
"role": "system",
@@ -419,21 +421,22 @@ class RawConversationDataset(Dataset):
419421
Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`.
420422
"""
421423

422-
def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int) -> None:
424+
def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int, system_prompt: str) -> None:
423425
self.tokenizer = tokenizer
424426
self.raw_texts = []
425427
with jsonlines.open(input_file) as f:
426428
for line in f:
427429
self.raw_texts.append(line)
428430
self.tokenized_texts = [None] * len(self.raw_texts)
429431
self.max_length = max_length
432+
self.system_prompt = system_prompt
430433

431434
def __len__(self) -> int:
432435
return len(self.raw_texts)
433436

434437
def __getitem__(self, index: int):
435438
if self.tokenized_texts[index] is None:
436439
message = self.raw_texts[index]
437-
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length)
440+
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
438441
self.tokenized_texts[index] = dict(tokens)
439442
return self.tokenized_texts[index]

applications/ColossalChat/rl_example.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
)
5050
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
5151
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
52+
parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.")
5253
args = parser.parse_args()
5354

5455
assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0"
@@ -112,20 +113,20 @@
112113
train_batch_size=args.train_batch_size,
113114
train_minibatch_size=args.train_minibatch_size,
114115
train_microbatch_size=args.train_microbatch_size,
115-
dataset_config={"path": args.dataset, "max_length": 300},
116+
dataset_config={"path": args.dataset, "max_length": 300, "system_prompt": args.system_prompt},
116117
dataloaders_config={},
117118
inference_model_config=inference_model_config,
118119
generate_config=generate_config,
119120
num_generations=args.num_generations,
120121
train_model_config=train_model_config,
121-
# plugin_config={}, # for zero
122-
plugin_config={
123-
"pp_size": 2,
124-
"tp_size": 2,
125-
"microbatch_size": args.train_microbatch_size // 2,
126-
"zero_stage": 0,
127-
"max_norm": 1.0,
128-
}, # for pp
122+
plugin_config={}, # Default setting: zero.
123+
# plugin_config={
124+
# "pp_size": 2,
125+
# "tp_size": 2,
126+
# "microbatch_size": args.train_microbatch_size // 2,
127+
# "zero_stage": 0,
128+
# "max_norm": 1.0,
129+
# }, # for pp
129130
inference_backend=args.backend,
130131
master_addr="localhost",
131132
master_port=29506,

0 commit comments

Comments
 (0)