Allow non-Tensor values in a batch with dispatch_batches=True#3850
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
SunMarc
left a comment
There was a problem hiding this comment.
Thanks for this ! Left a comment, let me know if this is unclear
src/accelerate/utils/operations.py
Outdated
| return torch.cat(data, dim=dim) | ||
| elif isinstance(data[0], torch.Tensor): | ||
| return torch.cat(data, dim=dim) | ||
| return data[0] |
There was a problem hiding this comment.
This is a fix that only works when we have one gpu (only one batch is passed in concatenate). The issue happens when we have multi-gpu, we might have the following situation for 2 batchs:
[{'key1': ["str1", "str2"]}, {'key1':["str3", "str4"]}], we would get the following result:
[{'key1': ["str1", "str2"]]. Not sure if this is what we want unless str1==str3 and str2==str4.
Well for now, what we can do is to check if len(data)>=2 when it is not a tensor, list or mapping. If it is 1, we do that otherwise we return an error saying that we can only concat tensors.
Also, can you add some simple tests for concatenate ?
There was a problem hiding this comment.
Damn, you're right. Apologies, I was thinking that the elif isinstance(data[0], Mapping) branch would take care of it, but the recursive call with ['str1', 'str3'] and ['str2', 'str4'] will result into 'str1' and 'str2', dropping the other 2.
I think the len(data) check is smart, will incorporate and run some tests.
There was a problem hiding this comment.
Sounds good ! Yeah this part is quite tricky so I prefer to be extra cautious
There was a problem hiding this comment.
| return data[0] | |
| elif isinstance(data, (tuple, list)) and len(data) == 1: | |
| return data[0] | |
| else: | |
| raise TypeError(f"Can only concatenate tensors but got {type(data[0])}") |
The suggestion here works fine with single-process, but the len(data) is simply equal to the number of processes. In short: it'll always fail in multi-process settings. That tells me that perhaps it's simply not viable to pass string parameters from tokenization or collation to the model during training in this way. Bools are simpler, I can just turn those into a singleton bool tensor that get concatenated.
Perhaps we should leave this PR be?
There was a problem hiding this comment.
Still I think it might be worth to make it work in single-process no ? If you think this will create more issues, then we can leave this PR be.
We can add in the docstring that if we receive only one batch of data, we will return it as it.
There was a problem hiding this comment.
Agreed, it's a step in the right direction, even though multi-process support is not possible. We can still merge it. I've also added some tests and updated the docstring slightly.
I do think I'll have to reconsider some stuff in Sentence Transformers, e.g. whether I want to use string values in my batches to pass parameters during training. Not supporting IterableDataset + MultiGPU is a bit annoying.
SunMarc
left a comment
There was a problem hiding this comment.
Awesome, thanks for adding these nice tests
Resolves #3849
What does this PR do?
dispatch_batches=True, matching behaviour for whendispatch_batches=Falseor whenaccelerateis not usedDetails
Rerunning the script from #3849 now also gives this for the previously broken case:
Accelerator, with
IterableDatasetAs expected!
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@BenjaminBossan @SunMarc