66import torch
77import decord
88from pathlib import Path
9- from llava .action .utils import AvionMultiChoiceGenerator , avion_video_loader , EK100_frame_loader
9+ from llava .action .utils import AvionMultiChoiceGenerator , RandomMultiChoiceGenerator , avion_video_loader , EK100_frame_loader
1010from llava .action .prediction_analysis import PredictionAnalysis
1111import torch .distributed as dist
1212
@@ -147,6 +147,7 @@ def __init__(
147147 verb_maps = None ,
148148 noun_maps = None ,
149149 eval_result_folder = None ,
150+ gen_type = 'action_model' ,
150151 action_representation = 'GT_random_narration_cut' ,
151152 mapping_vn2narration = None ,
152153 avion_predictions = None ,
@@ -175,14 +176,19 @@ def __init__(
175176 self .labels = labels
176177 self .topk_predictions = topk_predictions
177178 self .ann_root = Path (metadata ).parent
178- self .mc_generator = AvionMultiChoiceGenerator (self .ann_root )
179+ self .gen_type = gen_type
180+ if gen_type == 'action_model' :
181+ self .mc_generator = AvionMultiChoiceGenerator (self .ann_root )
182+ elif gen_type == 'random' :
183+ self .mc_generator = RandomMultiChoiceGenerator (self .ann_root )
179184 self .rank = dist .get_rank ()
180185 self .prediction_analysis = PredictionAnalysis (rank = self .rank , save_folder = eval_result_folder )
181186 self .action_representation = action_representation
182187 self .n_narrations = n_narrations
183188 self .mapping_vn2narration = mapping_vn2narration
184189 self .avion_predictions = avion_predictions
185190
191+
186192 def __getitem__ (self , i ):
187193 frames , label , time_meta = self .get_raw_item (
188194 i , is_training = self .is_training ,
@@ -205,19 +211,31 @@ def __getitem__(self, i):
205211 frames = self .transform (frames )
206212 narration = self .samples [i ][4 ]
207213 avion_preds = self .avion_predictions [str (i )]['predictions' ]
208-
209- data = self .mc_generator .generate_multi_choice (label ,
210- avion_preds ,
211- narration ,
212- self .topk_predictions ,
213- self .action_representation ,
214- self .n_narrations ,
215- self .labels ,
216- self .mapping_vn2narration ,
217- self .verb_maps ,
218- self .noun_maps ,
219- benchmark_testing = self .eval_args .benchmark_testing ,
220- is_train = False ) # note we only use this dataset for evaluation for now.
221-
214+ if self .gen_type == 'action_model' :
215+ data = self .mc_generator .generate_multi_choice (label ,
216+ avion_preds ,
217+ narration ,
218+ self .topk_predictions ,
219+ self .action_representation ,
220+ self .n_narrations ,
221+ self .labels ,
222+ self .mapping_vn2narration ,
223+ self .verb_maps ,
224+ self .noun_maps ,
225+ benchmark_testing = self .eval_args .benchmark_testing ,
226+ is_train = False ) # note we only use this dataset for evaluation for now.
227+ else :
228+ data = self .mc_generator .generate_multi_choice (label ,
229+ narration ,
230+ self .topk_predictions ,
231+ self .action_representation ,
232+ self .n_narrations ,
233+ self .labels ,
234+ self .mapping_vn2narration ,
235+ self .verb_maps ,
236+ self .noun_maps ,
237+ benchmark_testing = self .eval_args .benchmark_testing ,
238+ is_train = False ) # no
239+
222240
223241 return frames , data , time_meta , i
0 commit comments