Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ def _slice_tensor(tensor, tensor_slice):
def concatenate(data, dim=0):
"""
Recursively concatenate the tensors in a nested list/tuple/dictionary of lists of tensors with the same shape.
If there is only a single batch of data, it is returned as-is.

Args:
data (nested list/tuple/dictionary of lists of tensors `torch.Tensor`):
Expand All @@ -615,9 +616,12 @@ def concatenate(data, dim=0):
return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
elif isinstance(data[0], Mapping):
return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
elif not isinstance(data[0], torch.Tensor):
elif isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=dim)
elif isinstance(data, (tuple, list)) and len(data) == 1:
return data[0]
else:
raise TypeError(f"Can only concatenate tensors but got {type(data[0])}")
return torch.cat(data, dim=dim)


class CannotPadNestedTensorWarning(UserWarning):
Expand Down
16 changes: 16 additions & 0 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,22 @@ def test_iterable_dataset_using_none_batch_size(self):
for d in dataloader:
assert isinstance(d, torch.Tensor)

def test_iterable_dataset_with_non_tensor_samples(self):
dataset = SimpleIterableDataset(10)

def collate_fn(features):
return {
"tensor": torch.stack(features),
"non_tensor": "non_tensor_value",
}

dataloader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
accelerator = Accelerator()
dataloader = accelerator.prepare_data_loader(dataloader)
for d in dataloader:
assert isinstance(d["tensor"], torch.Tensor)
assert d["non_tensor"] == "non_tensor_value"

@parameterized.expand([1, 2], name_func=parameterized_custom_name_func)
def test_reproducibility(self, num_processes):
set_seed(21)
Expand Down
89 changes: 89 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
CannotPadNestedTensorWarning,
check_os_kernel,
clear_environment,
concatenate,
convert_dict_to_env_variables,
convert_outputs_to_fp32,
convert_to_fp32,
Expand Down Expand Up @@ -442,6 +443,94 @@ def test_has_offloaded_params(self):
attach_align_device_hook(model, offload=True)
assert has_offloaded_params(model)

def test_concatenate(self):
tensor1 = torch.randn(2, 3)
tensor2 = torch.randn(2, 3)
result = concatenate([tensor1, tensor2])
assert result.shape == torch.Size([4, 3])
assert torch.equal(result[:2], tensor1)
assert torch.equal(result[2:], tensor2)

single_tensor = torch.randn(3, 4)
result = concatenate([single_tensor])
assert result.shape == torch.Size([3, 4])
assert torch.equal(result, single_tensor)

# NOTE: We return as-is if there's just a single batch of data, even if it's not a tensor
single_value = "test_string"
result = concatenate([single_value])
assert result == single_value

data = [
[torch.randn(2, 3), torch.randn(2, 4)],
[torch.randn(2, 3), torch.randn(2, 4)],
]
result = concatenate(data)
assert isinstance(result, list)
assert len(result) == 2
assert result[0].shape == torch.Size([4, 3])
assert result[1].shape == torch.Size([4, 4])

data = [
(torch.randn(2, 3), torch.randn(2, 4)),
(torch.randn(2, 3), torch.randn(2, 4)),
]
result = concatenate(data)
assert isinstance(result, tuple)
assert len(result) == 2
assert result[0].shape == torch.Size([4, 3])
assert result[1].shape == torch.Size([4, 4])

data = [
{"a": torch.randn(2, 3), "b": torch.randn(2, 4)},
{"a": torch.randn(2, 3), "b": torch.randn(2, 4)},
]
result = concatenate(data)
assert isinstance(result, dict)
assert "a" in result and "b" in result
assert result["a"].shape == torch.Size([4, 3])
assert result["b"].shape == torch.Size([4, 4])

# NOTE: We can't merge multiple batches of non-tensor data
data = [
{"a": torch.randn(2, 3), "b": torch.randn(2, 4), "c": "test_string1"},
{"a": torch.randn(2, 3), "b": torch.randn(2, 4), "c": "test_string2"},
]
with self.assertRaises(TypeError):
result = concatenate(data)

batch1 = torch.randn(5, 10)
batch2 = torch.randn(5, 10)
batch3 = torch.randn(5, 10)
result = concatenate([batch1, batch2, batch3])
assert result.shape == torch.Size([15, 10])
assert torch.equal(result[:5], batch1)
assert torch.equal(result[5:10], batch2)
assert torch.equal(result[10:], batch3)

# NOTE: We can't merge misaligned batches, the torch.cat will raise a RuntimeError
batch1 = torch.randn(5, 10)
batch2 = torch.randn(5, 12)
with self.assertRaises(RuntimeError):
result = concatenate([batch1, batch2])

tensor1 = torch.randn(3, 2, 4)
tensor2 = torch.randn(3, 2, 4)
result = concatenate([tensor1, tensor2], dim=1)
assert result.shape == torch.Size([3, 4, 4])

data = [
{"inputs": [torch.randn(2, 3), torch.randn(2, 4)], "labels": torch.randn(2, 1)},
{"inputs": [torch.randn(2, 3), torch.randn(2, 4)], "labels": torch.randn(2, 1)},
{"inputs": [torch.randn(2, 3), torch.randn(2, 4)], "labels": torch.randn(2, 1)},
]
result = concatenate(data)
assert isinstance(result, dict)
assert isinstance(result["inputs"], list)
assert result["inputs"][0].shape == torch.Size([6, 3])
assert result["inputs"][1].shape == torch.Size([6, 4])
assert result["labels"].shape == torch.Size([6, 1])


def set_dummy_accelerate_env_var():
"""Set an accelerate env var
Expand Down