Refactor queries and document dataloader to allow multiple modalities#4232
Refactor queries and document dataloader to allow multiple modalities#4232ayush1298 wants to merge 13 commits intoembeddings-benchmark:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR refactors retrieval query/document dataloader creation to support datasets with multiple modalities (e.g., ["image", "text"] for it2it tasks), addressing issue #4182 where only one modality was previously included in the dataloader batches.
Changes:
- Introduces
_prepare_multimodal_dataset()to centralize modality-specific dataset transformations before wrapping in aDataLoader. - Updates
_create_queries_dataloader()and_create_document_dataloader()to use the shared preparation function and to select a custom collate function when needed. - Enables multimodal retrieval tasks to produce batches containing all expected modality keys (e.g., both
imageandtext).
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
mteb/_create_dataloaders.py
Outdated
| prepared = _prepare_multimodal_dataset( | ||
| dataset, | ||
| task_metadata, | ||
| prompt_type=PromptType.document, | ||
| input_column=input_column, | ||
| num_proc=num_proc, | ||
| ) |
There was a problem hiding this comment.
_create_document_dataloader/_prepare_multimodal_dataset don't actually implement the behavior described in the surrounding docstring (selecting the first column matching the modality when input_column is None). Right now the code assumes canonical column names (e.g., "image") unless input_column is explicitly provided. Consider either adding the advertised column inference or adjusting the docstring so callers aren't misled.
| if prompt_type == PromptType.document: | ||
| new_ds = new_ds.map( | ||
| _corpus_to_dict, | ||
| desc="Standardizing text corpus format", | ||
| num_proc=num_proc, | ||
| ) | ||
| elif prompt_type == PromptType.query: | ||
| if isinstance(new_ds["text"][0], list): | ||
| new_ds = new_ds.map( | ||
| _convert_conv_history_to_query, | ||
| desc="Converting conversations to queries", | ||
| num_proc=num_proc, | ||
| ) | ||
| else: | ||
| new_ds = new_ds.map( | ||
| _combine_queries_with_instruction_text, |
There was a problem hiding this comment.
In _prepare_multimodal_dataset, the .map(...) calls for text processing run over the full multimodal rows. For datasets that include heavy columns (e.g., image/audio/video), this can force decoding/serialization of those columns during mapping even though the mapper only needs text fields, which can significantly slow preprocessing and increase memory use (especially with num_proc). Consider using Dataset.map(..., input_columns=[...]) (or an equivalent approach) so the mapper only receives the columns it actually needs (e.g., ['id','text','title'] for _corpus_to_dict, ['text','instruction'] for _combine_queries_with_instruction_text).
| if prompt_type == PromptType.document: | |
| new_ds = new_ds.map( | |
| _corpus_to_dict, | |
| desc="Standardizing text corpus format", | |
| num_proc=num_proc, | |
| ) | |
| elif prompt_type == PromptType.query: | |
| if isinstance(new_ds["text"][0], list): | |
| new_ds = new_ds.map( | |
| _convert_conv_history_to_query, | |
| desc="Converting conversations to queries", | |
| num_proc=num_proc, | |
| ) | |
| else: | |
| new_ds = new_ds.map( | |
| _combine_queries_with_instruction_text, | |
| if prompt_type == PromptType.document: | |
| corpus_input_columns = [ | |
| col | |
| for col in ("id", "text", "title") | |
| if col in new_ds.column_names | |
| ] | |
| new_ds = new_ds.map( | |
| _corpus_to_dict, | |
| input_columns=corpus_input_columns, | |
| desc="Standardizing text corpus format", | |
| num_proc=num_proc, | |
| ) | |
| elif prompt_type == PromptType.query: | |
| if isinstance(new_ds["text"][0], list): | |
| conv_input_columns = [ | |
| col | |
| for col in ("text",) | |
| if col in new_ds.column_names | |
| ] | |
| new_ds = new_ds.map( | |
| _convert_conv_history_to_query, | |
| input_columns=conv_input_columns, | |
| desc="Converting conversations to queries", | |
| num_proc=num_proc, | |
| ) | |
| else: | |
| query_input_columns = [ | |
| col | |
| for col in ("text", "instruction") | |
| if col in new_ds.column_names | |
| ] | |
| new_ds = new_ds.map( | |
| _combine_queries_with_instruction_text, | |
| input_columns=query_input_columns, |
Agree
Agree |
…er_for_queries, _create_dataloader_for_queries_conversation functions
…e_video_dataloader under refactoring
|
Simplified as much as possible. |
closes #4182
Both
_create_queries_dataloaderand_create_document_dataloaderare refactored.Before refactor, these function just used 1 of the modality from
[ image, text, audio, video ], and uses that only to create dataloader. So, its giving issue for tasks likeEncyclopediaVQAIT2ITRetrievalwith both['image', 'text']modality in dataset.In the refactor, a helper function
_prepare_multimodal_datasetis used which does all dataset specific transformation together and then dataloader conversion is happened in respective function like_create_queries_dataloaderand_create_document_dataloader. So, this now handles a dataset with more than 1 modality.Tested with
EncyclopediaVQAIT2ITRetrievalusing below dummy script:Output of above script:
So, its working as expected now.
I want to propose 2 more extensions in this refactor to make it more clean and easy to use(removing mess of lot of functions):
We can deprecated/remove the following functions as their work is already absorbed during the refactor in other function and they are just duplicate:
_create_dataloader_for_retrieval_corpus, _create_text_dataloader_for_queries, _create_dataloader_for_queries_conversationAlso, I think we can update bm25 and bb25 which uses
_create_text_queries_dataloaderwith
_combine_queries_with_instruction_textand remove_create_text_queries_dataloaderfunction also.Because, in both these files, we are 1st creating a dataloader and then immediately flattening them back:
https://github.com/embeddings-benchmark/mteb/blob/main/mteb/models/model_implementations/bm25.py#L89-L90
So, instead of that, we can simply use
_combine_queries_with_instruction_textdirectly.Also, functions like:
_create_image_dataloader, _create_audio_dataloader, and _create_video_dataloaderare called only in the functioncreate_dataloaderas part of the fallback whenprompt_type=None. In this function, we are only doing column renaming and then wrapping under Dataloader, so that thing can be done in thecreate_dataloaderfunction only.Would love to have your opinion on both these refactor extension @KennethEnevoldsen @Samoed