diff --git a/docs/source-pytorch/common/evaluation_intermediate.rst b/docs/source-pytorch/common/evaluation_intermediate.rst index 35ae7ad3bac71..35107e1b75dc9 100644 --- a/docs/source-pytorch/common/evaluation_intermediate.rst +++ b/docs/source-pytorch/common/evaluation_intermediate.rst @@ -134,6 +134,121 @@ you can also pass in an :doc:`datamodules <../data/datamodule>` that have overri # test (pass in datamodule) trainer.test(datamodule=dm) + +Test with Multiple DataLoaders +============================== + +When you need to evaluate your model on multiple test datasets simultaneously (e.g., different domains, conditions, or +evaluation scenarios), PyTorch Lightning supports multiple test dataloaders out of the box. + +To use multiple test dataloaders, simply return a list of dataloaders from your ``test_dataloader()`` method: + +.. code-block:: python + + class LitModel(L.LightningModule): + def test_dataloader(self): + return [ + DataLoader(clean_test_dataset, batch_size=32), + DataLoader(noisy_test_dataset, batch_size=32), + DataLoader(adversarial_test_dataset, batch_size=32), + ] + +When using multiple test dataloaders, your ``test_step`` method **must** include a ``dataloader_idx`` parameter: + +.. code-block:: python + + def test_step(self, batch, batch_idx, dataloader_idx: int = 0): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + + # Use dataloader_idx to handle different test scenarios + return {'test_loss': loss} + +Logging Metrics Per Dataloader +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Lightning provides automatic support for logging metrics per dataloader: + +.. code-block:: python + + def test_step(self, batch, batch_idx, dataloader_idx: int = 0): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + acc = (y_hat.argmax(dim=1) == y).float().mean() + + # Lightning automatically adds "/dataloader_idx_X" suffix + self.log('test_loss', loss, add_dataloader_idx=True) + self.log('test_acc', acc, add_dataloader_idx=True) + + return loss + +This will create metrics like ``test_loss/dataloader_idx_0``, ``test_loss/dataloader_idx_1``, etc. + +For more meaningful metric names, you can use custom naming where you need to make sure that individual names are +unique across dataloaders. + +.. code-block:: python + + def test_step(self, batch, batch_idx, dataloader_idx: int = 0): + # Define meaningful names for each dataloader + dataloader_names = {0: "clean", 1: "noisy", 2: "adversarial"} + dataset_name = dataloader_names.get(dataloader_idx, f"dataset_{dataloader_idx}") + + # Log with custom names + self.log(f'test_loss_{dataset_name}', loss, add_dataloader_idx=False) + self.log(f'test_acc_{dataset_name}', acc, add_dataloader_idx=False) + +Processing Entire Datasets Per Dataloader +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To perform calculations on the entire test dataset for each dataloader (e.g., computing overall metrics, creating +visualizations), accumulate results during ``test_step`` and process them in ``on_test_epoch_end``: + +.. code-block:: python + + class LitModel(L.LightningModule): + def __init__(self): + super().__init__() + # Store outputs per dataloader + self.test_outputs = {} + + def test_step(self, batch, batch_idx, dataloader_idx: int = 0): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + + # Initialize and store results + if dataloader_idx not in self.test_outputs: + self.test_outputs[dataloader_idx] = {'predictions': [], 'targets': []} + self.test_outputs[dataloader_idx]['predictions'].append(y_hat) + self.test_outputs[dataloader_idx]['targets'].append(y) + return loss + + def on_test_epoch_end(self): + for dataloader_idx, outputs in self.test_outputs.items(): + # Concatenate all predictions and targets for this dataloader + all_predictions = torch.cat(outputs['predictions'], dim=0) + all_targets = torch.cat(outputs['targets'], dim=0) + + # Calculate metrics on the entire dataset, log and create visualizations + overall_accuracy = (all_predictions.argmax(dim=1) == all_targets).float().mean() + self.log(f'test_overall_acc_dataloader_{dataloader_idx}', overall_accuracy) + self._save_results(all_predictions, all_targets, dataloader_idx) + + self.test_outputs.clear() + +.. note:: + When using multiple test dataloaders, ``trainer.test()`` returns a list of results, one for each dataloader: + + .. code-block:: python + + results = trainer.test(model) + print(f"Results from {len(results)} test dataloaders:") + for i, result in enumerate(results): + print(f"Dataloader {i}: {result}") + ---------- **********