@@ -377,6 +377,49 @@ def _encode_chat_inputs_openai_format(
377
377
378
378
return conversation_ids
379
379
380
+ def _encode_chat_inputs_oneturn (
381
+ self ,
382
+ conversations : Dict [str , Any ],
383
+ add_generation_prompt = True ,
384
+ ):
385
+ conversation_dict = {} if "tools" not in conversations else {"tools" : conversations ["tools" ]}
386
+ conversation_dict ["messages" ] = (
387
+ [conversations ["messages" ][0 ]] if conversations ["messages" ][0 ]["role" ] == "system" else []
388
+ )
389
+
390
+ if conversations ["messages" ][0 ]["role" ] == "system" :
391
+ conversations ["messages" ] = conversations ["messages" ][1 :]
392
+
393
+ cur_str = ""
394
+ conversation_ids = []
395
+ for idx in range (0 , len (conversations ["messages" ]), 2 ):
396
+ conversation_id = []
397
+ conversation_dict ["messages" ].append (conversations ["messages" ][idx ])
398
+ round_str = self .apply_chat_template (
399
+ conversation_dict ["messages" ], add_generation_prompt = True , tokenize = False
400
+ )
401
+ # query: user prefix + user content + assist prefix
402
+ query = round_str [len (cur_str ) :]
403
+ input_ids = self .convert_tokens_to_ids (self .tokenize (query ))
404
+ conversation_id .append (input_ids )
405
+ cur_str = round_str
406
+
407
+ if idx + 1 < len (conversations ["messages" ]):
408
+ conversation_dict ["messages" ].append (conversations ["messages" ][idx + 1 ])
409
+ round_str = self .apply_chat_template (
410
+ conversation_dict ["messages" ], add_generation_prompt = False , tokenize = False
411
+ )
412
+ # answer: assistant content
413
+ answer = round_str [len (cur_str ) :]
414
+ output_ids = self .convert_tokens_to_ids (self .tokenize (answer ))
415
+ conversation_id .append (output_ids )
416
+
417
+ conversation_ids .append (conversation_id )
418
+ conversation_dict ["messages" ] = []
419
+ cur_str = ""
420
+
421
+ return conversation_ids
422
+
380
423
def _extract_non_learnable_parts (self , origin_msg : List [Dict [str , str ]], split_s : List [str ]):
381
424
"""Split the entire chat by specified words. Extract the non-learnable parts."""
382
425
# TODO:We will upgrade this feature later
@@ -458,14 +501,18 @@ def encode_chat_inputs(
458
501
if not self .chat_template :
459
502
raise ValueError ("chat_template is not set, please set chat_template first." )
460
503
else :
504
+ encode_one_turn = kwargs .pop ("encode_one_turn" , True )
461
505
add_generation_prompt = kwargs .pop ("add_generation_prompt" , True )
462
506
if not isinstance (conversations , dict ):
463
507
query = self ._encode_chat_inputs (
464
508
conversations , context_data , add_generation_prompt = add_generation_prompt
465
509
)
466
510
else :
467
511
conversations .update (add_generation_prompt = add_generation_prompt )
468
- query = self ._encode_chat_inputs_openai_format (conversations )
512
+ if encode_one_turn :
513
+ query = self ._encode_chat_inputs_oneturn (conversations )
514
+ else :
515
+ query = self ._encode_chat_inputs_openai_format (conversations )
469
516
return query
470
517
471
518
def decode_token (
0 commit comments