Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions docs/source-pytorch/common/evaluation_intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,121 @@ you can also pass in an :doc:`datamodules <../data/datamodule>` that have overri
# test (pass in datamodule)
trainer.test(datamodule=dm)


Test with Multiple DataLoaders
==============================

When you need to evaluate your model on multiple test datasets simultaneously (e.g., different domains, conditions, or
evaluation scenarios), PyTorch Lightning supports multiple test dataloaders out of the box.

To use multiple test dataloaders, simply return a list of dataloaders from your ``test_dataloader()`` method:

.. code-block:: python

class LitModel(L.LightningModule):
def test_dataloader(self):
return [
DataLoader(clean_test_dataset, batch_size=32),
DataLoader(noisy_test_dataset, batch_size=32),
DataLoader(adversarial_test_dataset, batch_size=32),
]

When using multiple test dataloaders, your ``test_step`` method **must** include a ``dataloader_idx`` parameter:

.. code-block:: python

def test_step(self, batch, batch_idx, dataloader_idx: int = 0):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)

# Use dataloader_idx to handle different test scenarios
return {'test_loss': loss}

Logging Metrics Per Dataloader
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Lightning provides automatic support for logging metrics per dataloader:

.. code-block:: python

def test_step(self, batch, batch_idx, dataloader_idx: int = 0):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
acc = (y_hat.argmax(dim=1) == y).float().mean()

# Lightning automatically adds "/dataloader_idx_X" suffix
self.log('test_loss', loss, add_dataloader_idx=True)
self.log('test_acc', acc, add_dataloader_idx=True)

return loss

This will create metrics like ``test_loss/dataloader_idx_0``, ``test_loss/dataloader_idx_1``, etc.

For more meaningful metric names, you can use custom naming where you need to make sure that individual names are
unique across dataloaders.

.. code-block:: python

def test_step(self, batch, batch_idx, dataloader_idx: int = 0):
# Define meaningful names for each dataloader
dataloader_names = {0: "clean", 1: "noisy", 2: "adversarial"}
dataset_name = dataloader_names.get(dataloader_idx, f"dataset_{dataloader_idx}")

# Log with custom names
self.log(f'test_loss_{dataset_name}', loss, add_dataloader_idx=False)
self.log(f'test_acc_{dataset_name}', acc, add_dataloader_idx=False)

Processing Entire Datasets Per Dataloader
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To perform calculations on the entire test dataset for each dataloader (e.g., computing overall metrics, creating
visualizations), accumulate results during ``test_step`` and process them in ``on_test_epoch_end``:

.. code-block:: python

class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
# Store outputs per dataloader
self.test_outputs = {}

def test_step(self, batch, batch_idx, dataloader_idx: int = 0):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)

# Initialize and store results
if dataloader_idx not in self.test_outputs:
self.test_outputs[dataloader_idx] = {'predictions': [], 'targets': []}
self.test_outputs[dataloader_idx]['predictions'].append(y_hat)
self.test_outputs[dataloader_idx]['targets'].append(y)
return loss

def on_test_epoch_end(self):
for dataloader_idx, outputs in self.test_outputs.items():
# Concatenate all predictions and targets for this dataloader
all_predictions = torch.cat(outputs['predictions'], dim=0)
all_targets = torch.cat(outputs['targets'], dim=0)

# Calculate metrics on the entire dataset, log and create visualizations
overall_accuracy = (all_predictions.argmax(dim=1) == all_targets).float().mean()
self.log(f'test_overall_acc_dataloader_{dataloader_idx}', overall_accuracy)
self._save_results(all_predictions, all_targets, dataloader_idx)

self.test_outputs.clear()

.. note::
When using multiple test dataloaders, ``trainer.test()`` returns a list of results, one for each dataloader:

.. code-block:: python

results = trainer.test(model)
print(f"Results from {len(results)} test dataloaders:")
for i, result in enumerate(results):
print(f"Dataloader {i}: {result}")

----------

**********
Expand Down