@@ -923,8 +923,12 @@ def _loglikelihood_tokens( # noqa: C901
923
923
starting_batch_size = starting_batch_size ,
924
924
)
925
925
starting_batch_size = batch_size * 2
926
+ max_num_choices = max (len (d .choices ) for d in split )
927
+ # We divide the batch size by the number of choices as batch is samples * num choices
928
+ # then round up to closest 8 multiple
929
+ batch_size = max (1 , round (batch_size // max_num_choices / 8 ) * 8 )
926
930
logger .warning (
927
- f"batch size is set to { batch_size } however, logliklehood evaluates on n choices per samples so batch size will be muiltiplied by number of choices per sample"
931
+ f"batch size is set to { batch_size } (it should be understood as ' { batch_size } times the maximum number of choices per sample, { max_num_choices } ') "
928
932
)
929
933
930
934
dataloader = DataLoader (split , batch_size = batch_size , collate_fn = lambda batch : batch )
@@ -1026,33 +1030,46 @@ def _loglikelihood_tokens( # noqa: C901
1026
1030
if self .accelerator :
1027
1031
# Convert lists to tensors for proper gathering
1028
1032
# Pad and stack the tensors to make them gatherable
1029
- choices_lengths = [len (choices ) for choices in batch_tokenized_continuations_processed ]
1030
- choices_lengths_tensor = torch .tensor (choices_lengths , device = self .device )
1031
- gathered_choices_lengths = self .accelerator .gather_for_metrics (choices_lengths_tensor )
1032
- global_max_choices = gathered_choices_lengths .max ().item ()
1033
+ shape_choices = [
1034
+ choices .shape for choices in batch_tokenized_continuations_processed
1035
+ ] # num_choices * max len choices
1036
+ num_choices_tensor = torch .tensor ([shape [0 ] for shape in shape_choices ], device = self .device )
1037
+ len_choices_tensor = torch .tensor ([shape [1 ] for shape in shape_choices ], device = self .device )
1038
+ gathered_num_choices = self .accelerator .gather_for_metrics (num_choices_tensor )
1039
+ gathered_len_choices = self .accelerator .gather_for_metrics (len_choices_tensor )
1040
+ max_num_choices = gathered_num_choices .max ().item ()
1041
+ max_len_choices = gathered_len_choices .max ().item ()
1042
+ len_context_tensor = torch .tensor (
1043
+ [len (ctx ) for ctx in batch_tokenized_contexts_processed ], device = self .device
1044
+ )
1045
+ gathered_len_context = self .accelerator .gather_for_metrics (len_context_tensor )
1046
+ max_len_context = gathered_len_context .max ().item ()
1033
1047
1034
- # Pad logits_sum_batch to same size
1048
+ # 1d - Pad logits_sum and max_equals to same number of choices
1035
1049
padded_logits_sums = []
1036
1050
for logits_sum_doc in batch_logits_sums :
1037
- pad_amount = global_max_choices - len (logits_sum_doc )
1051
+ pad_amount = max_num_choices - len (logits_sum_doc )
1038
1052
padded = F .pad (logits_sum_doc , (0 , pad_amount ), value = - 1 )
1039
1053
padded_logits_sums .append (padded )
1040
1054
1041
1055
padded_max_equals = []
1042
1056
for max_equals_doc in batch_max_equals :
1043
- pad_amount = global_max_choices - len (max_equals_doc )
1057
+ pad_amount = max_num_choices - len (max_equals_doc )
1044
1058
padded = F .pad (max_equals_doc , (0 , pad_amount ), value = False )
1045
1059
padded_max_equals .append (padded )
1046
1060
1061
+ # 2d - Pad continuations to max number of choice and max length
1047
1062
padded_continuations = []
1048
1063
for cont_batch in batch_tokenized_continuations_processed :
1049
- pad_amount = global_max_choices - cont_batch .shape [0 ]
1050
- padded = F .pad (cont_batch , (0 , pad_amount ), value = - 1 )
1064
+ pad_amount_num = max_num_choices - cont_batch .shape [0 ]
1065
+ pad_amount_len = max_len_choices - cont_batch .shape [1 ]
1066
+ padded = F .pad (cont_batch , (0 , pad_amount_len , 0 , pad_amount_num ), value = - 1 )
1051
1067
padded_continuations .append (padded )
1052
1068
1069
+ # 1d - Pad context to maximum context size
1053
1070
padded_contexts = []
1054
1071
for ctx_batch in batch_tokenized_contexts_processed :
1055
- pad_amount = global_max_choices - ctx_batch . shape [ 0 ]
1072
+ pad_amount = max_len_context - len ( ctx_batch )
1056
1073
padded = F .pad (ctx_batch , (0 , pad_amount ), value = - 1 )
1057
1074
padded_contexts .append (padded )
1058
1075
@@ -1075,12 +1092,19 @@ def _loglikelihood_tokens( # noqa: C901
1075
1092
batch_tokenized_contexts_processed = []
1076
1093
1077
1094
# Only process if we have gathered results
1078
- for i , actual_count in enumerate (gathered_choices_lengths ):
1079
- # Extract non-padded values based on actual counts
1080
- batch_logits_sums .append (gathered_logits_sums [i ][:actual_count ])
1081
- batch_max_equals .append (gathered_max_equals [i ][:actual_count ])
1082
- batch_tokenized_continuations_processed .append (gathered_continuations [i ][:actual_count ])
1083
- batch_tokenized_contexts_processed .append (gathered_contexts [i ][:actual_count ])
1095
+ for i , num_choices in enumerate (gathered_num_choices ):
1096
+ # Extract non-padded values
1097
+ # 1d on num choices
1098
+ batch_logits_sums .append (gathered_logits_sums [i ][:num_choices ])
1099
+ batch_max_equals .append (gathered_max_equals [i ][:num_choices ])
1100
+ # 2d on num choices and max len
1101
+ len_choice = gathered_len_choices [i ]
1102
+ batch_tokenized_continuations_processed .append (
1103
+ gathered_continuations [i ][:num_choices ][:len_choice ]
1104
+ )
1105
+ # 1d on max len context
1106
+ len_context = gathered_len_context [i ]
1107
+ batch_tokenized_contexts_processed .append (gathered_contexts [i ][:len_context ])
1084
1108
1085
1109
# Process the gathered results
1086
1110
for i in range (len (batch_logits_sums )):
0 commit comments