diff --git a/test/stateful_dataloader/test_dataloader.py b/test/stateful_dataloader/test_dataloader.py index 17abd0dc5..4d9368811 100644 --- a/test/stateful_dataloader/test_dataloader.py +++ b/test/stateful_dataloader/test_dataloader.py @@ -3145,5 +3145,209 @@ def test_out_of_order_iterable_ds(self): instantiate_device_type_tests(TestDataLoaderDeviceType, globals()) +@unittest.skipIf( + TEST_WITH_TSAN, + "Fails with TSAN with the following error: starting new threads after multi-threaded " + "fork is not supported. Dying (set die_after_fork=0 to override)", +) +class TestStatefulDataLoaderEnumerate(TestCase): + def setUp(self): + super().setUp() + self.data = torch.arange(20) + self.dataset = TensorDataset(self.data) + + def test_custom_enumerate_basic(self): + """Test that custom enumerate works correctly without state restoration.""" + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=False) + + # Test custom enumerate produces correct indices + custom_results = list(dataloader.enumerate()) + builtin_results = list(enumerate(dataloader)) + + # Both should produce the same results when no state is loaded + self.assertEqual(len(custom_results), len(builtin_results)) + for (custom_idx, custom_data), (builtin_idx, builtin_data) in zip(custom_results, builtin_results): + self.assertEqual(custom_idx, builtin_idx) + self.assertTrue(torch.equal(custom_data[0], builtin_data[0])) + + def test_custom_enumerate_with_start_parameter(self): + """Test that custom enumerate works correctly with start parameter.""" + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=False) + + start_value = 100 + results = list(dataloader.enumerate(start=start_value)) + + expected_indices = list(range(start_value, start_value + len(dataloader))) + actual_indices = [idx for idx, _ in results] + + self.assertEqual(actual_indices, expected_indices) + + def test_custom_enumerate_with_state_restoration(self): + """Test that custom enumerate correctly handles state restoration.""" + # Create initial dataloader and process some batches + dataloader1 = DataLoader(self.dataset, batch_size=2, shuffle=False) + + # Process first 3 batches (indices 0, 1, 2) and save state + processed_count = 0 + for i, (batch,) in enumerate(dataloader1): + processed_count += 1 + if i == 2: # After processing batches 0, 1, 2 + state = dataloader1.state_dict() + break + + self.assertEqual(processed_count, 3) + + # Create new dataloader and restore state + dataloader2 = DataLoader(self.dataset, batch_size=2, shuffle=False) + dataloader2.load_state_dict(state) + + # Use custom enumerate to continue + remaining_results = list(dataloader2.enumerate()) + + # Should start from index 3 (since we processed 0, 1, 2) + expected_start_index = 3 + expected_indices = list(range(expected_start_index, len(dataloader1))) + actual_indices = [idx for idx, _ in remaining_results] + + self.assertEqual(actual_indices, expected_indices) + + # Verify data correctness + expected_data_start = 6 # batch 3 should contain [6, 7] + first_batch_data = remaining_results[0][1][0] + self.assertTrue(torch.equal(first_batch_data, torch.tensor([expected_data_start, expected_data_start + 1]))) + + def test_custom_enumerate_vs_builtin_after_restoration(self): + """Test that demonstrates the difference between custom and builtin enumerate after state restoration.""" + # Create initial dataloader and process some batches + dataloader1 = DataLoader(self.dataset, batch_size=2, shuffle=False) + + # Process first 2 batches and save state + for i, batch in enumerate(dataloader1): + if i == 1: # After processing batches 0, 1 + state = dataloader1.state_dict() + break + + # Test builtin enumerate (demonstrates the problem) + dataloader2 = DataLoader(self.dataset, batch_size=2, shuffle=False) + dataloader2.load_state_dict(state) + builtin_results = list(enumerate(dataloader2)) + builtin_indices = [idx for idx, _ in builtin_results] + + # Test custom enumerate (shows the fix) + dataloader3 = DataLoader(self.dataset, batch_size=2, shuffle=False) + dataloader3.load_state_dict(state) + custom_results = list(dataloader3.enumerate()) + custom_indices = [idx for idx, _ in custom_results] + + # Builtin enumerate should start from 0 (the problem) + self.assertEqual(builtin_indices, [0, 1, 2, 3, 4, 5, 6, 7]) + + # Custom enumerate should start from 2 (the fix) + self.assertEqual(custom_indices, [2, 3, 4, 5, 6, 7, 8, 9]) + + # Data should be the same for both + for (_, builtin_data), (_, custom_data) in zip(builtin_results, custom_results): + self.assertTrue(torch.equal(builtin_data[0], custom_data[0])) + + def test_custom_enumerate_with_multiprocessing(self): + """Test that custom enumerate works correctly with multiprocessing.""" + # Test with 2 workers + dataloader1 = DataLoader(self.dataset, batch_size=2, shuffle=False, num_workers=2) + + # Process some batches and save state + for i, batch in enumerate(dataloader1): + if i == 2: + state = dataloader1.state_dict() + break + + # Restore state and use custom enumerate + dataloader2 = DataLoader(self.dataset, batch_size=2, shuffle=False, num_workers=2) + dataloader2.load_state_dict(state) + results = list(dataloader2.enumerate()) + + # Should start from the correct index + expected_start_index = 3 + actual_indices = [idx for idx, _ in results] + self.assertEqual(actual_indices[0], expected_start_index) + + def test_custom_enumerate_empty_after_restoration(self): + """Test custom enumerate when no data remains after state restoration.""" + # Use a small dataset + small_data = torch.arange(4) + small_dataset = TensorDataset(small_data) + dataloader1 = DataLoader(small_dataset, batch_size=2, shuffle=False) + + # Process all but the last batch + state = None + batches_processed = 0 + for i, batch in enumerate(dataloader1): + batches_processed += 1 + if i == 0: # After first batch, only one batch remains + state = dataloader1.state_dict() + + # Restore state and process remaining data + dataloader2 = DataLoader(small_dataset, batch_size=2, shuffle=False) + dataloader2.load_state_dict(state) + remaining_results = list(dataloader2.enumerate()) + + # Should have exactly one batch remaining with correct index + self.assertEqual(len(remaining_results), 1) + self.assertEqual(remaining_results[0][0], 1) # Should be index 1 + + def test_custom_enumerate_single_batch(self): + """Test custom enumerate with single batch scenarios.""" + # Create dataset with exactly one batch + single_batch_data = torch.arange(4) + single_batch_dataset = TensorDataset(single_batch_data) + dataloader = DataLoader(single_batch_dataset, batch_size=4, shuffle=False) + + # Should produce one result with index 0 + results = list(dataloader.enumerate()) + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], 0) + + # Test with start parameter + results_with_start = list(dataloader.enumerate(start=50)) + self.assertEqual(len(results_with_start), 1) + self.assertEqual(results_with_start[0][0], 50) + + def test_custom_enumerate_iterable_dataset(self): + """Test custom enumerate with IterableDataset.""" + + class SimpleIterableDataset(IterableDataset): + def __init__(self, data): + self.data = data + + def __iter__(self): + return iter(self.data) + + def __len__(self): + return len(self.data) + + iterable_dataset = SimpleIterableDataset(list(range(10))) + dataloader = DataLoader(iterable_dataset, batch_size=2, shuffle=False) + + # Test basic custom enumerate + results = list(dataloader.enumerate()) + expected_indices = list(range(5)) # 10 items / 2 batch_size = 5 batches + actual_indices = [idx for idx, _ in results] + + self.assertEqual(actual_indices, expected_indices) + + def test_custom_enumerate_consistency(self): + """Test that multiple calls to custom enumerate produce consistent results.""" + dataloader = DataLoader(self.dataset, batch_size=3, shuffle=False) + + # Call enumerate multiple times + results1 = list(dataloader.enumerate()) + results2 = list(dataloader.enumerate(start=0)) + + # Results should be identical + self.assertEqual(len(results1), len(results2)) + for (idx1, data1), (idx2, data2) in zip(results1, results2): + self.assertEqual(idx1, idx2) + self.assertTrue(torch.equal(data1[0], data2[0])) + + if __name__ == "__main__": run_tests() diff --git a/torchdata/stateful_dataloader/stateful_dataloader.py b/torchdata/stateful_dataloader/stateful_dataloader.py index 1ffeec298..208b17033 100644 --- a/torchdata/stateful_dataloader/stateful_dataloader.py +++ b/torchdata/stateful_dataloader/stateful_dataloader.py @@ -425,6 +425,46 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: return self.next_iter_state = state_dict + def enumerate(self, start: int = 0): + """ + Return an enumerate object that yields (index, data) pairs where the index + reflects the actual position in the dataset, accounting for any previously + loaded state. + + This is useful when resuming from a checkpoint, as the standard enumerate() + function always starts from 0, while this method starts from the correct + position based on the dataloader's internal state. + + Args: + start (int): Value to start the enumeration from. This is added to the + internal batch position. Default is 0. + + Returns: + An iterator yielding (index, data) tuples where index is the actual + batch position in the dataset. + + Example: + >>> dataloader = StatefulDataLoader(dataset, batch_size=2) + >>> # Process some batches + >>> for i, batch in enumerate(dataloader): + >>> if i == 2: + >>> state = dataloader.state_dict() + >>> break + >>> + >>> # Create new dataloader and restore state + >>> dataloader2 = StatefulDataLoader(dataset, batch_size=2) + >>> dataloader2.load_state_dict(state) + >>> + >>> # Use custom enumerate method to get correct indices + >>> for i, batch in dataloader2.enumerate(): + >>> print(f"Batch {i}: {batch}") # Will print "Batch 3: ..." + """ + + for data in self: + current_batch = getattr(self._iterator, "_num_yielded", 1) - 1 + yield current_batch + start, data + current_batch += 1 + class _StatefulBaseDataLoaderIter(_BaseDataLoaderIter): def __init__(self, loader: StatefulDataLoader) -> None: