@@ -459,22 +459,22 @@ def get_train_data_batch(self, max_prompt_length, max_response_length, device):
459459 n_transition = len (input_ids_list )
460460 print ("***************************************" ,n_transition )
461461
462- # # 直接扔掉多余的 transitions,限制最大数量
463- MAX_TRANSITIONS = 96
464- if n_transition > MAX_TRANSITIONS :
465- # 确保所有列表长度一致
466- input_ids_list = input_ids_list [:MAX_TRANSITIONS ]
467- input_attention_mask_list = input_attention_mask_list [:MAX_TRANSITIONS ]
468- response_ids_list = response_ids_list [:MAX_TRANSITIONS ]
469- response_attention_mask_list = response_attention_mask_list [:MAX_TRANSITIONS ]
470- reward_list = reward_list [:MAX_TRANSITIONS ]
471- data_id_list = data_id_list [:MAX_TRANSITIONS ]
472- rollout_id_list = rollout_id_list [:MAX_TRANSITIONS ]
473- turn_index_list = turn_index_list [:MAX_TRANSITIONS ]
474- is_drop_list = is_drop_list [:MAX_TRANSITIONS ]
462+ # # 直接扔掉多余的 transitions,限制最大数量(会报错)
463+ # MAX_TRANSITIONS = 96
464+ # if n_transition > MAX_TRANSITIONS:
465+ # # 确保所有列表长度一致
466+ # input_ids_list = input_ids_list[:MAX_TRANSITIONS]
467+ # input_attention_mask_list = input_attention_mask_list[:MAX_TRANSITIONS]
468+ # response_ids_list = response_ids_list[:MAX_TRANSITIONS]
469+ # response_attention_mask_list = response_attention_mask_list[:MAX_TRANSITIONS]
470+ # reward_list = reward_list[:MAX_TRANSITIONS]
471+ # data_id_list = data_id_list[:MAX_TRANSITIONS]
472+ # rollout_id_list = rollout_id_list[:MAX_TRANSITIONS]
473+ # turn_index_list = turn_index_list[:MAX_TRANSITIONS]
474+ # is_drop_list = is_drop_list[:MAX_TRANSITIONS]
475475
476- n_transition = MAX_TRANSITIONS
477-
476+ # n_transition = MAX_TRANSITIONS
477+ # print("********************MAX_TRANSITIONS*******************",n_transition)
478478 batch_input_ids = torch .LongTensor (input_ids_list ).to (device )
479479 input_attention_mask = torch .LongTensor (input_attention_mask_list ).to (device )
480480 batch_response_ids = torch .LongTensor (response_ids_list ).to (device )
0 commit comments