Skip to content

Commit 5000cb3

Browse files
grab sys prompt too from dataset (axolotl-ai-cloud#2397) [skip ci]
* grab sys prompt too from dataset * chore: add field_system to docs --------- Co-authored-by: NanoCode012 <[email protected]>
1 parent 170cdb5 commit 5000cb3

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

docs/config.qmd

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ datasets:
154154
# Key containing the messages (default: "messages")
155155
field_messages: messages
156156

157+
# Key containing the system message (default: "system")
158+
# If the system message is not present in the dataset sample, it will be loaded from the field_system property.
159+
field_system: system
160+
157161
# Mapping of properties from the input dataset to the chat template.
158162
# (default: message_property_mappings={'role':'role', 'content':'content'})
159163
# If a property exists in the template but not in this mapping, the system will attempt

src/axolotl/prompt_strategies/chat_template.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
message_field_training: Optional[str] = None,
3434
message_field_training_detail: Optional[str] = None,
3535
field_messages: str = "messages",
36+
field_system: str = "system",
3637
roles: Optional[Dict[str, List[str]]] = None,
3738
drop_system_message: bool = False,
3839
):
@@ -62,6 +63,7 @@ def __init__(
6263
self.message_field_training = message_field_training
6364
self.message_field_training_detail = message_field_training_detail
6465
self.field_messages = field_messages
66+
self.field_system = field_system
6567
self.tokenizer = tokenizer
6668
self.processor: Optional[ProcessorMixin] = processor
6769
self.chat_template = chat_template
@@ -488,6 +490,17 @@ def find_turn(self, turns: list[dict], turn_idx: int):
488490

489491
def get_conversation_thread(self, prompt):
490492
turns = []
493+
494+
possible_sys_turn = self.transform_message(
495+
prompt[self.prompter.field_messages][0]
496+
)
497+
if (
498+
possible_sys_turn["role"] != "system"
499+
and self.prompter.field_system in prompt
500+
):
501+
turn = {"role": "system", "content": prompt[self.prompter.field_system]}
502+
turns.append(turn)
503+
491504
for message in prompt[self.prompter.field_messages]:
492505
transformed_message = self.transform_message(message)
493506

0 commit comments

Comments
 (0)