@@ -33,6 +33,7 @@ def __init__(
33
33
message_field_training : Optional [str ] = None ,
34
34
message_field_training_detail : Optional [str ] = None ,
35
35
field_messages : str = "messages" ,
36
+ field_system : str = "system" ,
36
37
roles : Optional [Dict [str , List [str ]]] = None ,
37
38
drop_system_message : bool = False ,
38
39
):
@@ -62,6 +63,7 @@ def __init__(
62
63
self .message_field_training = message_field_training
63
64
self .message_field_training_detail = message_field_training_detail
64
65
self .field_messages = field_messages
66
+ self .field_system = field_system
65
67
self .tokenizer = tokenizer
66
68
self .processor : Optional [ProcessorMixin ] = processor
67
69
self .chat_template = chat_template
@@ -488,6 +490,17 @@ def find_turn(self, turns: list[dict], turn_idx: int):
488
490
489
491
def get_conversation_thread (self , prompt ):
490
492
turns = []
493
+
494
+ possible_sys_turn = self .transform_message (
495
+ prompt [self .prompter .field_messages ][0 ]
496
+ )
497
+ if (
498
+ possible_sys_turn ["role" ] != "system"
499
+ and self .prompter .field_system in prompt
500
+ ):
501
+ turn = {"role" : "system" , "content" : prompt [self .prompter .field_system ]}
502
+ turns .append (turn )
503
+
491
504
for message in prompt [self .prompter .field_messages ]:
492
505
transformed_message = self .transform_message (message )
493
506
0 commit comments