Skip to content
Open
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
2b411bb
Adding video modality
AdnanElAssadi56 Feb 23, 2026
fd0ce74
Add Kinetics-400 dataset
AdnanElAssadi56 Feb 23, 2026
a65f505
Add pe_av model
AdnanElAssadi56 Feb 23, 2026
ecca13e
fix typo
AdnanElAssadi56 Feb 23, 2026
8210f82
fix collator bug
AdnanElAssadi56 Feb 23, 2026
f4e0ece
Edit selecting column in classification abstask
AdnanElAssadi56 Feb 23, 2026
8f67fb7
Properly handle frames in PE_AV
AdnanElAssadi56 Feb 23, 2026
80d9217
add self kwarg to method
AdnanElAssadi56 Feb 23, 2026
c01d591
Add audio collator
AdnanElAssadi56 Feb 23, 2026
287d47c
fix type error
AdnanElAssadi56 Feb 23, 2026
66c108f
fix audio_video embeds object handling
AdnanElAssadi56 Feb 23, 2026
b24794b
Add Ravdess_av clustering
AdnanElAssadi56 Feb 23, 2026
82ccf4d
fix task metadata
AdnanElAssadi56 Feb 23, 2026
6979034
start video integration
Samoed Feb 23, 2026
4af8520
start video integration
Samoed Feb 23, 2026
fa753b4
upd task structure
Samoed Feb 24, 2026
f5e7a8f
upd video input type
Samoed Feb 24, 2026
32f3b4f
combine video and audio to dict
Samoed Feb 24, 2026
77e964a
fix task side
AdnanElAssadi56 Feb 25, 2026
f1b7989
fix pe_av model
AdnanElAssadi56 Feb 25, 2026
95d75d9
lower writer batch size
AdnanElAssadi56 Feb 25, 2026
e59f283
fix col labels
AdnanElAssadi56 Feb 25, 2026
5321b3c
lint
AdnanElAssadi56 Mar 5, 2026
7b36363
add pe_av model metadata
AdnanElAssadi56 Mar 5, 2026
05cd7f6
fix datasets metadata
AdnanElAssadi56 Mar 5, 2026
23c3135
remove accidently commited files
AdnanElAssadi56 Mar 5, 2026
7dabd1c
remove nested list structure from datasets
AdnanElAssadi56 Mar 5, 2026
7238edc
edit collator to handle one video item
AdnanElAssadi56 Mar 5, 2026
6ef8678
multimodal collator + fix comments
AdnanElAssadi56 Mar 5, 2026
0a3aa0f
lint
AdnanElAssadi56 Mar 5, 2026
52d70a3
metadata update
AdnanElAssadi56 Mar 5, 2026
412da86
using forward pass to get embeds
AdnanElAssadi56 Mar 5, 2026
5e719bd
replace forward pass + add audio to msrvtt
AdnanElAssadi56 Mar 5, 2026
c6855a3
fix category metadata
AdnanElAssadi56 Mar 5, 2026
3168318
edit get embeddings
AdnanElAssadi56 Mar 5, 2026
68747c2
add n_embedding_parameters
AdnanElAssadi56 Mar 7, 2026
fa9c3d6
fix task type test
AdnanElAssadi56 Mar 7, 2026
b9273d9
change input col name to list
AdnanElAssadi56 Mar 9, 2026
7907adb
lint + type check
AdnanElAssadi56 Mar 9, 2026
75bc5c7
add classvar
AdnanElAssadi56 Mar 9, 2026
4c87896
add str to classvar
AdnanElAssadi56 Mar 9, 2026
cb39536
Change list to sequence
AdnanElAssadi56 Mar 9, 2026
61c775f
lint + type check error
AdnanElAssadi56 Mar 9, 2026
400925b
edit dataloader and msrvtt handling of input column
AdnanElAssadi56 Mar 9, 2026
64c94b5
move seqeuence out of type checking
AdnanElAssadi56 Mar 9, 2026
a131a89
fix random baseline
AdnanElAssadi56 Mar 10, 2026
73bf160
add collator to random baseline
AdnanElAssadi56 Mar 10, 2026
978622e
restore previous dict structure + make audio optional
AdnanElAssadi56 Mar 10, 2026
939eefa
clean structure
AdnanElAssadi56 Mar 10, 2026
57eb8d9
lint
AdnanElAssadi56 Mar 10, 2026
ac7484f
safety check
AdnanElAssadi56 Mar 10, 2026
bb68de2
decrease writer batch size
AdnanElAssadi56 Mar 11, 2026
91cada2
match msrvtt format
AdnanElAssadi56 Mar 11, 2026
56c243f
type check fix
AdnanElAssadi56 Mar 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 69 additions & 27 deletions mteb/_create_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ImageInput,
QueryInput,
TextInput,
VideoInput,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -313,7 +314,7 @@ def _custom_collate_fn(batch: list[dict[str, Any]]) -> BatchedInput:
"image", # images can be with different sizes
"conversation", # conversations are lists of varying lengths
"audio", # audio can have different lengths
"video", # video VideoDecoder objects can't be collated
"video", # video can have different lengths
):
collated[key] = [item[key] for item in batch]
else:
Expand Down Expand Up @@ -383,7 +384,7 @@ def _create_queries_dataloader(
input_column: str | None = None,
batch_size: int = 32,
num_proc: int | None = None,
) -> DataLoader[QueryInput | ImageInput | AudioInput]:
) -> DataLoader[QueryInput | ImageInput | AudioInput | VideoInput]:
"""Create a dataloader for queries."""
queries_type = task_metadata.get_modalities(PromptType.query)
if queries_type == ["text"]: # text only
Expand Down Expand Up @@ -424,7 +425,7 @@ def _create_document_dataloader(
input_column: str | None = None,
batch_size: int = 32,
num_proc: int | None = None,
) -> DataLoader[CorpusInput | ImageInput | AudioInput]:
) -> DataLoader[CorpusInput | ImageInput | AudioInput | VideoInput]:
"""Create a dataloader for documents.

Args:
Expand Down Expand Up @@ -511,18 +512,18 @@ def _create_video_dataloader(
input_column: str | None = None,
batch_size: int = 32,
num_proc: int | None = None,
) -> DataLoader[AudioInput]:
) -> DataLoader[VideoInput]:
"""Create a dataloader for video.

Args:
dataset: The dataset containing the audio.
task_metadata: Metadata of the task to determine the audio type.
input_column: The column to use as input. If None, it will use the first column that matches the audio.
dataset: The dataset containing the video.
task_metadata: Metadata of the task to determine the video type.
input_column: The column to use as input. If None, it will use the first column that matches the video.
batch_size: Batch size for the dataloader.
num_proc: The number of workers for the dataloader.

Returns:
A DataLoader with the audio dataset.
A DataLoader with the video dataset.
"""
if (
input_column
Expand Down Expand Up @@ -590,6 +591,14 @@ def create_dataloader(
batch_size=batch_size,
num_proc=num_proc,
)
if "video" in task_metadata.modalities:
return _create_video_dataloader(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we consider doing a more principled refactor here as discussed in #4182

That issue also showed how the current approach can lead to some odd interactions between modalities.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll do it a bit later

dataset,
task_metadata,
input_column=input_column,
batch_size=batch_size,
num_proc=num_proc,
)
if "audio" in task_metadata.modalities:
return _create_audio_dataloader(
dataset,
Expand Down Expand Up @@ -742,25 +751,20 @@ def __call__(self, inputs: list[dict[str, Any]]) -> BatchedInput:

collated_inputs = []
for row in inputs:
videos = row.pop("video")
video_inputs = []
for video in videos:
frames = self.resample_video(video["frames"], self.max_frames)
audio = self.audio_collator.resample_audio(
video,
target_sampling_rate=self.audio_collator.target_sampling_rate,
max_samples=self.audio_collator.max_samples,
)
video_inputs.append(
VideoInputItem(
frames=frames,
audio=AudioInputItem(
array=audio,
sampling_rate=self.audio_collator.target_sampling_rate,
),
)
)
row["video"] = video_inputs
video = row.pop("video")
frames = self.resample_video(video["frames"], self.max_frames)
audio = self.audio_collator.resample_audio(
video,
target_sampling_rate=self.audio_collator.target_sampling_rate,
max_samples=self.audio_collator.max_samples,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can also add fps for resampling

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe in another PR

)
row["video"] = VideoInputItem(
frames=frames,
audio=AudioInputItem(
array=audio,
sampling_rate=self.audio_collator.target_sampling_rate,
),
)
collated_inputs.append(row)

return cast(
Expand Down Expand Up @@ -794,3 +798,41 @@ def resample_video(
else list(range(video_frames))
)
return video.get_frames_at(selected_frames).data


class MultimodalCollator:
"""Collator that handles any combination of video and audio modalities.

Delegates to VideoCollator when video is present (which also handles audio
embedded in VideoInputItem), and falls back to AudioCollator for audio-only.
"""

def __init__(
self,
target_sampling_rate: int,
max_frames: int = 16,
max_samples: int | None = None,
) -> None:
"""Initialize the collator.

Args:
target_sampling_rate: The sampling rate to resample audio to.
max_frames: Maximum number of frames to keep per video.
max_samples: Maximum number of audio samples to keep. If None, no truncation.
"""
self.video_collator = VideoCollator(
max_frames=max_frames,
target_sampling_rate=target_sampling_rate,
max_samples=max_samples,
)
self.audio_collator = AudioCollator(
target_sampling_rate=target_sampling_rate,
max_samples=max_samples,
)

def __call__(self, inputs: list[dict[str, Any]]) -> BatchedInput:
if "video" in inputs[0]:
return self.video_collator(inputs)
if "audio" in inputs[0]:
return self.audio_collator(inputs)
return cast("BatchedInput", _custom_collate_fn(inputs))
18 changes: 17 additions & 1 deletion mteb/abstasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,23 @@ def evaluate(
ds = self.dataset[hf_subset]

if isinstance(ds, Dataset | DatasetDict):
ds = ds.select_columns([self.label_column_name, self.input_column_name])
# Keep label and input columns, plus any columns required by
# the task's declared modalities (e.g., audio for va2c tasks)
modality_to_column = {
"video": "video",
"audio": "audio",
"image": "image",
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just extend input_column_name to list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also cause changes in dataloader; we can do this separately.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What changs? I think easier to use list for processing rather than processing like this

columns_to_keep = {self.label_column_name, self.input_column_name}
if isinstance(ds, DatasetDict):
available = set(next(iter(ds.values())).column_names)
else:
available = set(ds.column_names)
for mod in self.metadata.modalities:
col = modality_to_column.get(mod)
if col and col in available:
columns_to_keep.add(col)
ds = ds.select_columns(list(columns_to_keep))
eval_function = (
self._evaluate_subset
if not self.is_cross_validation
Expand Down
14 changes: 11 additions & 3 deletions mteb/abstasks/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,17 @@ def _evaluate_subset(
)
downsampled_dataset = data_split.select(example_indices)

downsampled_dataset = downsampled_dataset.select_columns(
[self.input_column_name, self.label_column_name]
)
# Keep label and input columns, plus any columns required by
# the task's declared modalities (e.g., audio for va2c tasks)
modality_to_column = {"video": "video", "audio": "audio", "image": "image"}
columns_to_keep = {self.label_column_name, self.input_column_name}
available = set(data_split.column_names)
for mod in self.metadata.modalities:
col = modality_to_column.get(mod)
if col and col in available:
columns_to_keep.add(col)

downsampled_dataset = downsampled_dataset.select_columns(list(columns_to_keep))

logger.info("Running clustering - Encoding samples...")
embeddings = model.encode(
Expand Down
4 changes: 2 additions & 2 deletions mteb/abstasks/retrieval_dataset_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class RetrievalSplitData(TypedDict):
"""A dictionary containing the corpus, queries, relevant documents, instructions, and top-ranked documents for a retrieval task.

Attributes:
corpus: The corpus dataset containing documents. Should have columns `id`, `title`, `text` or `image`.
queries: The queries dataset containing queries. Should have columns `id`, `text`, `instruction` (for instruction retrieval/reranking) or `image`.
corpus: The corpus dataset containing documents. Should have columns `id`, `title`, `text` or `image` or `audio` or `video`.
queries: The queries dataset containing queries. Should have columns `id`, `text`, `instruction` (for instruction retrieval/reranking) or `image` or `audio` or `video`.
relevant_docs: A mapping of query IDs to relevant document IDs and their relevance scores. Should have columns `query-id`, `corpus-id`, `score`.
top_ranked: A mapping of query IDs to a list of top-ranked document IDs. Should have columns `query-id`, `corpus-ids` (list[str]). This is optional and used for reranking tasks.
"""
Expand Down
38 changes: 37 additions & 1 deletion mteb/abstasks/task_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,15 @@
"Any2AnyRetrieval",
)

MVEB_TASK_TYPE = (
"VideoClassification",
"VideoClustering",
"VideoPairClassification",
"VideoZeroshotClassification",
"VideoCentricQA",
"Any2AnyRetrieval",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated?

Suggested change
"Any2AnyRetrieval",

)


_TASK_TYPE = (
(
Expand All @@ -219,6 +228,7 @@
)
+ MIEB_TASK_TYPE
+ MAEB_TASK_TYPE
+ MVEB_TASK_TYPE
)

TaskType = Literal[_TASK_TYPE] # type: ignore[valid-type]
Expand Down Expand Up @@ -246,7 +256,20 @@
"a2at",
"t2at",
"at2at",
"v2v",
"v2c",
"v2t",
"t2v",
"vt2t",
"vt2v",
"v2vt",
"t2vt",
"vt2vt",
"va2c",
"va2t",
"vat2t",
"v2a",
"a2v",
]
"""The category of the task.

Expand All @@ -270,7 +293,20 @@
18. a2at: audio to audio+text
19. t2at: text to audio+text
20. at2at: audio+text to audio+text
21. v2t: video to text
21. v2v: video to video
22. v2c: video to category
23. v2t: video to text
24. t2v: text to video
25. vt2t: video+text to text
26. vt2v: video+text to video
27. v2vt: video to video+text
28. t2vt: text to video+text
29. vt2vt: video+text to video+text
30. va2c: video+audio to category
31. va2t: video+audio to text
32. vat2t: video+audio+text to text
33. v2a: video to audio
34. a2v: audio to video
"""

AnnotatorType = Literal[
Expand Down
Loading
Loading