|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +import itertools |
14 | 15 | import logging
|
15 | 16 | from unittest.mock import Mock
|
16 | 17 |
|
17 | 18 | import pytest
|
18 | 19 | import torch
|
| 20 | +from torch.utils.data import DataLoader |
19 | 21 |
|
20 | 22 | from lightning.pytorch import Trainer, seed_everything
|
21 | 23 | from lightning.pytorch.demos.boring_classes import BoringModel
|
@@ -206,3 +208,72 @@ def test_should_stop_early_stopping_conditions_met(
|
206 | 208 |
|
207 | 209 | assert (message in caplog.text) is raise_debug_msg
|
208 | 210 | assert trainer.fit_loop._can_stop_early is early_stop
|
| 211 | + |
| 212 | + |
| 213 | +@pytest.mark.parametrize("max_steps", [7, 20]) |
| 214 | +def test_tqdm_total_steps_with_iterator_no_length(tmp_path, max_steps): |
| 215 | + """Test trainer with infinite iterator (no __len__)""" |
| 216 | + |
| 217 | + batch_size = 4 |
| 218 | + model = BoringModel() |
| 219 | + |
| 220 | + # Infinite generator (no __len__) |
| 221 | + # NOTE: 32 for BoringModel |
| 222 | + infinite_iter = (torch.randn(batch_size, 32, dtype=torch.float32) for _ in itertools.count(0)) |
| 223 | + |
| 224 | + trainer = Trainer( |
| 225 | + default_root_dir=tmp_path, |
| 226 | + max_steps=max_steps, |
| 227 | + max_epochs=-1, |
| 228 | + limit_val_batches=0, |
| 229 | + enable_progress_bar=True, |
| 230 | + enable_model_summary=False, |
| 231 | + accelerator="cpu", |
| 232 | + ) |
| 233 | + |
| 234 | + # Override train_dataloader with infinite iterator |
| 235 | + model.train_dataloader = lambda: infinite_iter |
| 236 | + pbar = trainer.progress_bar_callback |
| 237 | + trainer.fit(model) |
| 238 | + |
| 239 | + # assert progress bar callback uses correct total steps |
| 240 | + assert pbar.train_progress_bar.total == max_steps |
| 241 | + |
| 242 | + |
| 243 | +@pytest.mark.parametrize("max_steps", [10, 15]) |
| 244 | +def test_progress_bar_steps(tmp_path, max_steps): |
| 245 | + batch_size = 4 |
| 246 | + |
| 247 | + model = BoringModel() |
| 248 | + # Create dataloader here, outside the model |
| 249 | + # NOTE: 32 for boring model |
| 250 | + x = torch.randn(100, 32) |
| 251 | + |
| 252 | + class SingleTensorDataset(torch.utils.data.IterableDataset): |
| 253 | + def __init__(self, data): |
| 254 | + super().__init__() |
| 255 | + self.data = data |
| 256 | + |
| 257 | + def __iter__(self): |
| 258 | + yield from self.data # yield just a tensor, not a tuple |
| 259 | + |
| 260 | + dataset = SingleTensorDataset(x) |
| 261 | + dataloader = DataLoader(dataset, batch_size=batch_size) |
| 262 | + |
| 263 | + # Patch model's train_dataloader method to return this dataloader |
| 264 | + model.train_dataloader = lambda: dataloader |
| 265 | + |
| 266 | + trainer = Trainer( |
| 267 | + default_root_dir=tmp_path, |
| 268 | + max_steps=max_steps, |
| 269 | + max_epochs=-1, |
| 270 | + limit_val_batches=0, |
| 271 | + enable_progress_bar=True, |
| 272 | + enable_model_summary=False, |
| 273 | + accelerator="cpu", |
| 274 | + ) |
| 275 | + pbar = trainer.progress_bar_callback |
| 276 | + trainer.fit(model) |
| 277 | + |
| 278 | + # assert progress bar callback uses correct total steps |
| 279 | + assert pbar.train_progress_bar.total == max_steps |
0 commit comments