Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 204 additions & 0 deletions test/stateful_dataloader/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3145,5 +3145,209 @@ def test_out_of_order_iterable_ds(self):
instantiate_device_type_tests(TestDataLoaderDeviceType, globals())


@unittest.skipIf(
TEST_WITH_TSAN,
"Fails with TSAN with the following error: starting new threads after multi-threaded "
"fork is not supported. Dying (set die_after_fork=0 to override)",
)
class TestStatefulDataLoaderEnumerate(TestCase):
def setUp(self):
super().setUp()
self.data = torch.arange(20)
self.dataset = TensorDataset(self.data)

def test_custom_enumerate_basic(self):
"""Test that custom enumerate works correctly without state restoration."""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=False)

# Test custom enumerate produces correct indices
custom_results = list(dataloader.enumerate())
builtin_results = list(enumerate(dataloader))

# Both should produce the same results when no state is loaded
self.assertEqual(len(custom_results), len(builtin_results))
for (custom_idx, custom_data), (builtin_idx, builtin_data) in zip(custom_results, builtin_results):
self.assertEqual(custom_idx, builtin_idx)
self.assertTrue(torch.equal(custom_data[0], builtin_data[0]))

def test_custom_enumerate_with_start_parameter(self):
"""Test that custom enumerate works correctly with start parameter."""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=False)

start_value = 100
results = list(dataloader.enumerate(start=start_value))

expected_indices = list(range(start_value, start_value + len(dataloader)))
actual_indices = [idx for idx, _ in results]

self.assertEqual(actual_indices, expected_indices)

def test_custom_enumerate_with_state_restoration(self):
"""Test that custom enumerate correctly handles state restoration."""
# Create initial dataloader and process some batches
dataloader1 = DataLoader(self.dataset, batch_size=2, shuffle=False)

# Process first 3 batches (indices 0, 1, 2) and save state
processed_count = 0
for i, (batch,) in enumerate(dataloader1):
processed_count += 1
if i == 2: # After processing batches 0, 1, 2
state = dataloader1.state_dict()
break

self.assertEqual(processed_count, 3)

# Create new dataloader and restore state
dataloader2 = DataLoader(self.dataset, batch_size=2, shuffle=False)
dataloader2.load_state_dict(state)

# Use custom enumerate to continue
remaining_results = list(dataloader2.enumerate())

# Should start from index 3 (since we processed 0, 1, 2)
expected_start_index = 3
expected_indices = list(range(expected_start_index, len(dataloader1)))
actual_indices = [idx for idx, _ in remaining_results]

self.assertEqual(actual_indices, expected_indices)

# Verify data correctness
expected_data_start = 6 # batch 3 should contain [6, 7]
first_batch_data = remaining_results[0][1][0]
self.assertTrue(torch.equal(first_batch_data, torch.tensor([expected_data_start, expected_data_start + 1])))

def test_custom_enumerate_vs_builtin_after_restoration(self):
"""Test that demonstrates the difference between custom and builtin enumerate after state restoration."""
# Create initial dataloader and process some batches
dataloader1 = DataLoader(self.dataset, batch_size=2, shuffle=False)

# Process first 2 batches and save state
for i, batch in enumerate(dataloader1):
if i == 1: # After processing batches 0, 1
state = dataloader1.state_dict()
break

# Test builtin enumerate (demonstrates the problem)
dataloader2 = DataLoader(self.dataset, batch_size=2, shuffle=False)
dataloader2.load_state_dict(state)
builtin_results = list(enumerate(dataloader2))
builtin_indices = [idx for idx, _ in builtin_results]

# Test custom enumerate (shows the fix)
dataloader3 = DataLoader(self.dataset, batch_size=2, shuffle=False)
dataloader3.load_state_dict(state)
custom_results = list(dataloader3.enumerate())
custom_indices = [idx for idx, _ in custom_results]

# Builtin enumerate should start from 0 (the problem)
self.assertEqual(builtin_indices, [0, 1, 2, 3, 4, 5, 6, 7])

# Custom enumerate should start from 2 (the fix)
self.assertEqual(custom_indices, [2, 3, 4, 5, 6, 7, 8, 9])

# Data should be the same for both
for (_, builtin_data), (_, custom_data) in zip(builtin_results, custom_results):
self.assertTrue(torch.equal(builtin_data[0], custom_data[0]))

def test_custom_enumerate_with_multiprocessing(self):
"""Test that custom enumerate works correctly with multiprocessing."""
# Test with 2 workers
dataloader1 = DataLoader(self.dataset, batch_size=2, shuffle=False, num_workers=2)

# Process some batches and save state
for i, batch in enumerate(dataloader1):
if i == 2:
state = dataloader1.state_dict()
break

# Restore state and use custom enumerate
dataloader2 = DataLoader(self.dataset, batch_size=2, shuffle=False, num_workers=2)
dataloader2.load_state_dict(state)
results = list(dataloader2.enumerate())

# Should start from the correct index
expected_start_index = 3
actual_indices = [idx for idx, _ in results]
self.assertEqual(actual_indices[0], expected_start_index)

def test_custom_enumerate_empty_after_restoration(self):
"""Test custom enumerate when no data remains after state restoration."""
# Use a small dataset
small_data = torch.arange(4)
small_dataset = TensorDataset(small_data)
dataloader1 = DataLoader(small_dataset, batch_size=2, shuffle=False)

# Process all but the last batch
state = None
batches_processed = 0
for i, batch in enumerate(dataloader1):
batches_processed += 1
if i == 0: # After first batch, only one batch remains
state = dataloader1.state_dict()

# Restore state and process remaining data
dataloader2 = DataLoader(small_dataset, batch_size=2, shuffle=False)
dataloader2.load_state_dict(state)
remaining_results = list(dataloader2.enumerate())

# Should have exactly one batch remaining with correct index
self.assertEqual(len(remaining_results), 1)
self.assertEqual(remaining_results[0][0], 1) # Should be index 1

def test_custom_enumerate_single_batch(self):
"""Test custom enumerate with single batch scenarios."""
# Create dataset with exactly one batch
single_batch_data = torch.arange(4)
single_batch_dataset = TensorDataset(single_batch_data)
dataloader = DataLoader(single_batch_dataset, batch_size=4, shuffle=False)

# Should produce one result with index 0
results = list(dataloader.enumerate())
self.assertEqual(len(results), 1)
self.assertEqual(results[0][0], 0)

# Test with start parameter
results_with_start = list(dataloader.enumerate(start=50))
self.assertEqual(len(results_with_start), 1)
self.assertEqual(results_with_start[0][0], 50)

def test_custom_enumerate_iterable_dataset(self):
"""Test custom enumerate with IterableDataset."""

class SimpleIterableDataset(IterableDataset):
def __init__(self, data):
self.data = data

def __iter__(self):
return iter(self.data)

def __len__(self):
return len(self.data)

iterable_dataset = SimpleIterableDataset(list(range(10)))
dataloader = DataLoader(iterable_dataset, batch_size=2, shuffle=False)

# Test basic custom enumerate
results = list(dataloader.enumerate())
expected_indices = list(range(5)) # 10 items / 2 batch_size = 5 batches
actual_indices = [idx for idx, _ in results]

self.assertEqual(actual_indices, expected_indices)

def test_custom_enumerate_consistency(self):
"""Test that multiple calls to custom enumerate produce consistent results."""
dataloader = DataLoader(self.dataset, batch_size=3, shuffle=False)

# Call enumerate multiple times
results1 = list(dataloader.enumerate())
results2 = list(dataloader.enumerate(start=0))

# Results should be identical
self.assertEqual(len(results1), len(results2))
for (idx1, data1), (idx2, data2) in zip(results1, results2):
self.assertEqual(idx1, idx2)
self.assertTrue(torch.equal(data1[0], data2[0]))


if __name__ == "__main__":
run_tests()
40 changes: 40 additions & 0 deletions torchdata/stateful_dataloader/stateful_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,46 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
return
self.next_iter_state = state_dict

def enumerate(self, start: int = 0):
"""
Return an enumerate object that yields (index, data) pairs where the index
reflects the actual position in the dataset, accounting for any previously
loaded state.

This is useful when resuming from a checkpoint, as the standard enumerate()
function always starts from 0, while this method starts from the correct
position based on the dataloader's internal state.

Args:
start (int): Value to start the enumeration from. This is added to the
internal batch position. Default is 0.

Returns:
An iterator yielding (index, data) tuples where index is the actual
batch position in the dataset.

Example:
>>> dataloader = StatefulDataLoader(dataset, batch_size=2)
>>> # Process some batches
>>> for i, batch in enumerate(dataloader):
>>> if i == 2:
>>> state = dataloader.state_dict()
>>> break
>>>
>>> # Create new dataloader and restore state
>>> dataloader2 = StatefulDataLoader(dataset, batch_size=2)
>>> dataloader2.load_state_dict(state)
>>>
>>> # Use custom enumerate method to get correct indices
>>> for i, batch in dataloader2.enumerate():
>>> print(f"Batch {i}: {batch}") # Will print "Batch 3: ..."
"""

for data in self:
current_batch = getattr(self._iterator, "_num_yielded", 1) - 1
yield current_batch + start, data
current_batch += 1


class _StatefulBaseDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader: StatefulDataLoader) -> None:
Expand Down
Loading