Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 98 additions & 60 deletions mteb/_create_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,67 @@ def _create_text_queries_dataloader(
)


def _prepare_multimodal_dataset(
dataset: Dataset,
task_metadata: TaskMetadata,
prompt_type: PromptType,
input_column: str | None = None,
num_proc: int | None = None,
) -> Dataset:
"""Apply all modality-specific transformations to the dataset.

Returns the transformed Dataset (no DataLoader wrapping).
"""
modalities = task_metadata.get_modalities(prompt_type)
new_ds = dataset

if "text" in modalities:
if prompt_type == PromptType.document:
new_ds = new_ds.map(
_corpus_to_dict,
desc="Standardizing text corpus format",
num_proc=num_proc,
)
elif prompt_type == PromptType.query:
if isinstance(new_ds["text"][0], list):
new_ds = new_ds.map(
_convert_conv_history_to_query,
desc="Converting conversations to queries",
num_proc=num_proc,
)
else:
new_ds = new_ds.map(
_combine_queries_with_instruction_text,
Comment on lines +254 to +269
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _prepare_multimodal_dataset, the .map(...) calls for text processing run over the full multimodal rows. For datasets that include heavy columns (e.g., image/audio/video), this can force decoding/serialization of those columns during mapping even though the mapper only needs text fields, which can significantly slow preprocessing and increase memory use (especially with num_proc). Consider using Dataset.map(..., input_columns=[...]) (or an equivalent approach) so the mapper only receives the columns it actually needs (e.g., ['id','text','title'] for _corpus_to_dict, ['text','instruction'] for _combine_queries_with_instruction_text).

Suggested change
if prompt_type == PromptType.document:
new_ds = new_ds.map(
_corpus_to_dict,
desc="Standardizing text corpus format",
num_proc=num_proc,
)
elif prompt_type == PromptType.query:
if isinstance(new_ds["text"][0], list):
new_ds = new_ds.map(
_convert_conv_history_to_query,
desc="Converting conversations to queries",
num_proc=num_proc,
)
else:
new_ds = new_ds.map(
_combine_queries_with_instruction_text,
if prompt_type == PromptType.document:
corpus_input_columns = [
col
for col in ("id", "text", "title")
if col in new_ds.column_names
]
new_ds = new_ds.map(
_corpus_to_dict,
input_columns=corpus_input_columns,
desc="Standardizing text corpus format",
num_proc=num_proc,
)
elif prompt_type == PromptType.query:
if isinstance(new_ds["text"][0], list):
conv_input_columns = [
col
for col in ("text",)
if col in new_ds.column_names
]
new_ds = new_ds.map(
_convert_conv_history_to_query,
input_columns=conv_input_columns,
desc="Converting conversations to queries",
num_proc=num_proc,
)
else:
query_input_columns = [
col
for col in ("text", "instruction")
if col in new_ds.column_names
]
new_ds = new_ds.map(
_combine_queries_with_instruction_text,
input_columns=query_input_columns,

Copilot uses AI. Check for mistakes.
desc="Processing queries for dataloading",
num_proc=num_proc,
)

if "image" in modalities:
new_ds = _prepare_image_dataset(
new_ds,
image_column_name=input_column if input_column else "image",
num_proc=num_proc,
)

if "audio" in modalities:
if (
input_column
and input_column in new_ds.column_names
and "audio" not in new_ds.column_names
):
new_ds = new_ds.rename_column(input_column, "audio")

if "video" in modalities:
if (
input_column
and input_column in new_ds.column_names
and "video" not in new_ds.column_names
):
new_ds = new_ds.rename_column(input_column, "video")

return new_ds


def _create_queries_dataloader(
dataset: Dataset,
task_metadata: TaskMetadata,
Expand All @@ -386,36 +447,26 @@ def _create_queries_dataloader(
) -> DataLoader[QueryInput | ImageInput | AudioInput]:
"""Create a dataloader for queries."""
queries_type = task_metadata.get_modalities(PromptType.query)
if queries_type == ["text"]: # text only
return _create_text_queries_dataloader(
dataset,
batch_size=batch_size,
num_proc=num_proc,
)
if "image" in queries_type: # contains image
return _create_image_dataloader(
dataset,
image_column_name="image",
batch_size=batch_size,
num_proc=num_proc,
)
if "audio" in task_metadata.modalities:
return _create_audio_dataloader(
dataset,
task_metadata,
input_column="audio",
batch_size=batch_size,
num_proc=num_proc,
)
if "video" in task_metadata.modalities:
return _create_video_dataloader(
dataset,
task_metadata,
input_column="video",
batch_size=batch_size,
num_proc=num_proc,
)
raise ValueError(f"Can't handle queries type {queries_type}")
prepared = _prepare_multimodal_dataset(
dataset,
task_metadata,
prompt_type=PromptType.query,
input_column=input_column,
num_proc=num_proc,
)
needs_custom_collate = any(
m in queries_type for m in ("image", "audio", "video")
) or (
"text" in queries_type and isinstance(dataset["text"][0], list) # conversations
)

return DataLoader(
prepared,
batch_size=batch_size,
collate_fn=_custom_collate_fn if needs_custom_collate else None,
num_workers=num_proc if num_proc is not None and num_proc > 1 else 0,
shuffle=False,
)


def _create_document_dataloader(
Expand All @@ -438,36 +489,23 @@ def _create_document_dataloader(
A dataloader for the documents.
"""
document_type = task_metadata.get_modalities(PromptType.document)
if document_type == ["text"]: # text only
return _create_dataloader_for_retrieval_corpus(
dataset,
batch_size=batch_size,
num_proc=num_proc,
)
if "image" in document_type: # contains image
return _create_image_dataloader(
dataset,
image_column_name="image",
batch_size=batch_size,
num_proc=num_proc,
)
if "audio" in task_metadata.modalities:
return _create_audio_dataloader(
dataset,
task_metadata,
input_column="audio",
batch_size=batch_size,
num_proc=num_proc,
)
if "video" in task_metadata.modalities:
return _create_video_dataloader(
dataset,
task_metadata,
input_column="video",
batch_size=batch_size,
num_proc=num_proc,
)
raise ValueError(f"Can't handle queries type {document_type}")
prepared = _prepare_multimodal_dataset(
dataset,
task_metadata,
prompt_type=PromptType.document,
input_column=input_column,
num_proc=num_proc,
)
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_create_document_dataloader/_prepare_multimodal_dataset don't actually implement the behavior described in the surrounding docstring (selecting the first column matching the modality when input_column is None). Right now the code assumes canonical column names (e.g., "image") unless input_column is explicitly provided. Consider either adding the advertised column inference or adjusting the docstring so callers aren't misled.

Copilot uses AI. Check for mistakes.

needs_custom_collate = any(m in document_type for m in ("image", "audio", "video"))

return DataLoader(
prepared,
batch_size=batch_size,
collate_fn=_custom_collate_fn if needs_custom_collate else None,
num_workers=num_proc if num_proc is not None and num_proc > 1 else 0,
shuffle=False,
)


def _create_audio_dataloader(
Expand Down
Loading