-
Notifications
You must be signed in to change notification settings - Fork 568
[MVEB] PE-AV Model, Kinetics400 Dataset, RavdessAV Dataset #4199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 37 commits
2b411bb
fd0ce74
a65f505
ecca13e
8210f82
f4e0ece
8f67fb7
80d9217
c01d591
287d47c
66c108f
b24794b
82ccf4d
6979034
4af8520
fa753b4
f5e7a8f
32f3b4f
77e964a
f1b7989
95d75d9
e59f283
5321b3c
7b36363
05cd7f6
23c3135
7dabd1c
7238edc
6ef8678
0a3aa0f
52d70a3
412da86
5e719bd
c6855a3
3168318
68747c2
fa9c3d6
b9273d9
7907adb
75bc5c7
4c87896
cb39536
61c775f
400925b
64c94b5
a131a89
73bf160
978622e
939eefa
57eb8d9
ac7484f
bb68de2
91cada2
56c243f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,6 +32,7 @@ | |
| ImageInput, | ||
| QueryInput, | ||
| TextInput, | ||
| VideoInput, | ||
| ) | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -313,7 +314,7 @@ def _custom_collate_fn(batch: list[dict[str, Any]]) -> BatchedInput: | |
| "image", # images can be with different sizes | ||
| "conversation", # conversations are lists of varying lengths | ||
| "audio", # audio can have different lengths | ||
| "video", # video VideoDecoder objects can't be collated | ||
| "video", # video can have different lengths | ||
| ): | ||
| collated[key] = [item[key] for item in batch] | ||
| else: | ||
|
|
@@ -383,7 +384,7 @@ def _create_queries_dataloader( | |
| input_column: str | None = None, | ||
| batch_size: int = 32, | ||
| num_proc: int | None = None, | ||
| ) -> DataLoader[QueryInput | ImageInput | AudioInput]: | ||
| ) -> DataLoader[QueryInput | ImageInput | AudioInput | VideoInput]: | ||
| """Create a dataloader for queries.""" | ||
| queries_type = task_metadata.get_modalities(PromptType.query) | ||
| if queries_type == ["text"]: # text only | ||
|
|
@@ -424,7 +425,7 @@ def _create_document_dataloader( | |
| input_column: str | None = None, | ||
| batch_size: int = 32, | ||
| num_proc: int | None = None, | ||
| ) -> DataLoader[CorpusInput | ImageInput | AudioInput]: | ||
| ) -> DataLoader[CorpusInput | ImageInput | AudioInput | VideoInput]: | ||
| """Create a dataloader for documents. | ||
|
|
||
| Args: | ||
|
|
@@ -511,18 +512,18 @@ def _create_video_dataloader( | |
| input_column: str | None = None, | ||
| batch_size: int = 32, | ||
| num_proc: int | None = None, | ||
| ) -> DataLoader[AudioInput]: | ||
| ) -> DataLoader[VideoInput]: | ||
| """Create a dataloader for video. | ||
|
|
||
| Args: | ||
| dataset: The dataset containing the audio. | ||
| task_metadata: Metadata of the task to determine the audio type. | ||
| input_column: The column to use as input. If None, it will use the first column that matches the audio. | ||
| dataset: The dataset containing the video. | ||
| task_metadata: Metadata of the task to determine the video type. | ||
| input_column: The column to use as input. If None, it will use the first column that matches the video. | ||
| batch_size: Batch size for the dataloader. | ||
| num_proc: The number of workers for the dataloader. | ||
|
|
||
| Returns: | ||
| A DataLoader with the audio dataset. | ||
| A DataLoader with the video dataset. | ||
| """ | ||
| if ( | ||
| input_column | ||
|
|
@@ -590,6 +591,14 @@ def create_dataloader( | |
| batch_size=batch_size, | ||
| num_proc=num_proc, | ||
| ) | ||
| if "video" in task_metadata.modalities: | ||
| return _create_video_dataloader( | ||
| dataset, | ||
| task_metadata, | ||
| input_column=input_column, | ||
| batch_size=batch_size, | ||
| num_proc=num_proc, | ||
| ) | ||
| if "audio" in task_metadata.modalities: | ||
| return _create_audio_dataloader( | ||
| dataset, | ||
|
|
@@ -742,25 +751,20 @@ def __call__(self, inputs: list[dict[str, Any]]) -> BatchedInput: | |
|
|
||
| collated_inputs = [] | ||
| for row in inputs: | ||
| videos = row.pop("video") | ||
| video_inputs = [] | ||
| for video in videos: | ||
| frames = self.resample_video(video["frames"], self.max_frames) | ||
| audio = self.audio_collator.resample_audio( | ||
| video, | ||
| target_sampling_rate=self.audio_collator.target_sampling_rate, | ||
| max_samples=self.audio_collator.max_samples, | ||
| ) | ||
| video_inputs.append( | ||
| VideoInputItem( | ||
| frames=frames, | ||
| audio=AudioInputItem( | ||
| array=audio, | ||
| sampling_rate=self.audio_collator.target_sampling_rate, | ||
| ), | ||
| ) | ||
| ) | ||
| row["video"] = video_inputs | ||
| video = row.pop("video") | ||
| frames = self.resample_video(video["frames"], self.max_frames) | ||
| audio = self.audio_collator.resample_audio( | ||
| video, | ||
| target_sampling_rate=self.audio_collator.target_sampling_rate, | ||
| max_samples=self.audio_collator.max_samples, | ||
|
||
| ) | ||
| row["video"] = VideoInputItem( | ||
| frames=frames, | ||
| audio=AudioInputItem( | ||
| array=audio, | ||
| sampling_rate=self.audio_collator.target_sampling_rate, | ||
| ), | ||
| ) | ||
| collated_inputs.append(row) | ||
|
|
||
| return cast( | ||
|
|
@@ -794,3 +798,41 @@ def resample_video( | |
| else list(range(video_frames)) | ||
| ) | ||
| return video.get_frames_at(selected_frames).data | ||
|
|
||
|
|
||
| class MultimodalCollator: | ||
| """Collator that handles any combination of video and audio modalities. | ||
|
|
||
| Delegates to VideoCollator when video is present (which also handles audio | ||
| embedded in VideoInputItem), and falls back to AudioCollator for audio-only. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| target_sampling_rate: int, | ||
| max_frames: int = 16, | ||
| max_samples: int | None = None, | ||
| ) -> None: | ||
| """Initialize the collator. | ||
|
|
||
| Args: | ||
| target_sampling_rate: The sampling rate to resample audio to. | ||
| max_frames: Maximum number of frames to keep per video. | ||
| max_samples: Maximum number of audio samples to keep. If None, no truncation. | ||
| """ | ||
| self.video_collator = VideoCollator( | ||
| max_frames=max_frames, | ||
| target_sampling_rate=target_sampling_rate, | ||
| max_samples=max_samples, | ||
| ) | ||
| self.audio_collator = AudioCollator( | ||
| target_sampling_rate=target_sampling_rate, | ||
| max_samples=max_samples, | ||
| ) | ||
|
|
||
| def __call__(self, inputs: list[dict[str, Any]]) -> BatchedInput: | ||
| if "video" in inputs[0]: | ||
| return self.video_collator(inputs) | ||
| if "audio" in inputs[0]: | ||
| return self.audio_collator(inputs) | ||
| return cast("BatchedInput", _custom_collate_fn(inputs)) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -193,7 +193,23 @@ def evaluate( | |
| ds = self.dataset[hf_subset] | ||
|
|
||
| if isinstance(ds, Dataset | DatasetDict): | ||
| ds = ds.select_columns([self.label_column_name, self.input_column_name]) | ||
| # Keep label and input columns, plus any columns required by | ||
| # the task's declared modalities (e.g., audio for va2c tasks) | ||
| modality_to_column = { | ||
| "video": "video", | ||
| "audio": "audio", | ||
| "image": "image", | ||
| } | ||
|
||
| columns_to_keep = {self.label_column_name, self.input_column_name} | ||
| if isinstance(ds, DatasetDict): | ||
| available = set(next(iter(ds.values())).column_names) | ||
| else: | ||
| available = set(ds.column_names) | ||
AdnanElAssadi56 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| for mod in self.metadata.modalities: | ||
| col = modality_to_column.get(mod) | ||
| if col and col in available: | ||
| columns_to_keep.add(col) | ||
| ds = ds.select_columns(list(columns_to_keep)) | ||
| eval_function = ( | ||
| self._evaluate_subset | ||
| if not self.is_cross_validation | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -201,6 +201,15 @@ | |||
| "Any2AnyRetrieval", | ||||
| ) | ||||
|
|
||||
| MVEB_TASK_TYPE = ( | ||||
| "VideoClassification", | ||||
| "VideoClustering", | ||||
| "VideoPairClassification", | ||||
| "VideoZeroshotClassification", | ||||
| "VideoCentricQA", | ||||
| "Any2AnyRetrieval", | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Duplicated?
Suggested change
|
||||
| ) | ||||
|
|
||||
|
|
||||
| _TASK_TYPE = ( | ||||
| ( | ||||
|
|
@@ -219,6 +228,7 @@ | |||
| ) | ||||
| + MIEB_TASK_TYPE | ||||
| + MAEB_TASK_TYPE | ||||
| + MVEB_TASK_TYPE | ||||
| ) | ||||
|
|
||||
| TaskType = Literal[_TASK_TYPE] # type: ignore[valid-type] | ||||
|
|
@@ -246,7 +256,20 @@ | |||
| "a2at", | ||||
| "t2at", | ||||
| "at2at", | ||||
| "v2v", | ||||
| "v2c", | ||||
| "v2t", | ||||
| "t2v", | ||||
| "vt2t", | ||||
| "vt2v", | ||||
| "v2vt", | ||||
| "t2vt", | ||||
| "vt2vt", | ||||
| "va2c", | ||||
| "va2t", | ||||
| "vat2t", | ||||
| "v2a", | ||||
| "a2v", | ||||
| ] | ||||
| """The category of the task. | ||||
|
|
||||
|
|
@@ -270,7 +293,20 @@ | |||
| 18. a2at: audio to audio+text | ||||
| 19. t2at: text to audio+text | ||||
| 20. at2at: audio+text to audio+text | ||||
| 21. v2t: video to text | ||||
| 21. v2v: video to video | ||||
| 22. v2c: video to category | ||||
| 23. v2t: video to text | ||||
| 24. t2v: text to video | ||||
| 25. vt2t: video+text to text | ||||
| 26. vt2v: video+text to video | ||||
| 27. v2vt: video to video+text | ||||
| 28. t2vt: text to video+text | ||||
| 29. vt2vt: video+text to video+text | ||||
| 30. va2c: video+audio to category | ||||
| 31. va2t: video+audio to text | ||||
| 32. vat2t: video+audio+text to text | ||||
| 33. v2a: video to audio | ||||
| 34. a2v: audio to video | ||||
| """ | ||||
|
|
||||
| AnnotatorType = Literal[ | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we consider doing a more principled refactor here as discussed in #4182
That issue also showed how the current approach can lead to some odd interactions between modalities.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'll do it a bit later