Skip to content

Allow non-Tensor values in a batch with dispatch_batches=True#3850

Merged
SunMarc merged 3 commits intohuggingface:mainfrom
tomaarsen:feat/dispatch_batches_non_tensor_samples
Nov 26, 2025
Merged

Allow non-Tensor values in a batch with dispatch_batches=True#3850
SunMarc merged 3 commits intohuggingface:mainfrom
tomaarsen:feat/dispatch_batches_non_tensor_samples

Conversation

@tomaarsen
Copy link
Member

Resolves #3849

What does this PR do?

  • Allow non-Tensor values in a batch with dispatch_batches=True, matching behaviour for when dispatch_batches=False or when accelerate is not used

Details

Rerunning the script from #3849 now also gives this for the previously broken case:

Accelerator, with IterableDataset

Batch:
query_input_ids: <class 'torch.Tensor'> with shape torch.Size([4, 13])
query_token_type_ids: <class 'torch.Tensor'> with shape torch.Size([4, 13])
query_attention_mask: <class 'torch.Tensor'> with shape torch.Size([4, 13])
query_str_parameter: <class 'str'> parameter_value
query_bool_parameter: <class 'bool'> True
query_str_list: <class 'list'> ['list_item_1', 'list_item_2']
answer_input_ids: <class 'torch.Tensor'> with shape torch.Size([4, 328])
answer_token_type_ids: <class 'torch.Tensor'> with shape torch.Size([4, 328])
answer_attention_mask: <class 'torch.Tensor'> with shape torch.Size([4, 328])
answer_str_parameter: <class 'str'> parameter_value
answer_bool_parameter: <class 'bool'> True
answer_str_list: <class 'list'> ['list_item_1', 'list_item_2']

Batch:
query_input_ids: <class 'torch.Tensor'> with shape torch.Size([1, 11])
query_token_type_ids: <class 'torch.Tensor'> with shape torch.Size([1, 11])
query_attention_mask: <class 'torch.Tensor'> with shape torch.Size([1, 11])
query_str_parameter: <class 'str'> parameter_value
query_bool_parameter: <class 'bool'> True
query_str_list: <class 'list'> ['list_item_1', 'list_item_2']
answer_input_ids: <class 'torch.Tensor'> with shape torch.Size([1, 164])
answer_token_type_ids: <class 'torch.Tensor'> with shape torch.Size([1, 164])
answer_attention_mask: <class 'torch.Tensor'> with shape torch.Size([1, 164])
answer_str_parameter: <class 'str'> parameter_value
answer_bool_parameter: <class 'bool'> True
answer_str_list: <class 'list'> ['list_item_1', 'list_item_2']

As expected!

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@BenjaminBossan @SunMarc

  • Tom Aarsen

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this ! Left a comment, let me know if this is unclear

return torch.cat(data, dim=dim)
elif isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=dim)
return data[0]
Copy link
Member

@SunMarc SunMarc Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fix that only works when we have one gpu (only one batch is passed in concatenate). The issue happens when we have multi-gpu, we might have the following situation for 2 batchs:
[{'key1': ["str1", "str2"]}, {'key1':["str3", "str4"]}], we would get the following result:
[{'key1': ["str1", "str2"]]. Not sure if this is what we want unless str1==str3 and str2==str4.

Well for now, what we can do is to check if len(data)>=2 when it is not a tensor, list or mapping. If it is 1, we do that otherwise we return an error saying that we can only concat tensors.

Also, can you add some simple tests for concatenate ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Damn, you're right. Apologies, I was thinking that the elif isinstance(data[0], Mapping) branch would take care of it, but the recursive call with ['str1', 'str3'] and ['str2', 'str4'] will result into 'str1' and 'str2', dropping the other 2.

I think the len(data) check is smart, will incorporate and run some tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good ! Yeah this part is quite tricky so I prefer to be extra cautious

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return data[0]
elif isinstance(data, (tuple, list)) and len(data) == 1:
return data[0]
else:
raise TypeError(f"Can only concatenate tensors but got {type(data[0])}")

The suggestion here works fine with single-process, but the len(data) is simply equal to the number of processes. In short: it'll always fail in multi-process settings. That tells me that perhaps it's simply not viable to pass string parameters from tokenization or collation to the model during training in this way. Bools are simpler, I can just turn those into a singleton bool tensor that get concatenated.

Perhaps we should leave this PR be?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still I think it might be worth to make it work in single-process no ? If you think this will create more issues, then we can leave this PR be.
We can add in the docstring that if we receive only one batch of data, we will return it as it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it's a step in the right direction, even though multi-process support is not possible. We can still merge it. I've also added some tests and updated the docstring slightly.

I do think I'll have to reconsider some stuff in Sentence Transformers, e.g. whether I want to use string values in my batches to pass parameters during training. Not supporting IterableDataset + MultiGPU is a bit annoying.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks for adding these nice tests

@SunMarc SunMarc merged commit b521400 into huggingface:main Nov 26, 2025
25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DataLoaderDispatcher doesn't accept non-Tensor values from the data collator

3 participants