Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,13 @@ Callbacks
.. toctree::
:titlesonly:

Processing callback <callback/processing_callback.rst>
Optimizer callback <callback/optimizer_callback.rst>
Switch Scheduler <callback/switch_scheduler.rst>
R3 Refinment callback <callback/refinement/r3_refinement.rst>
Refinment Interface callback <callback/refinement/refinement_interface.rst>
Normalizer callback <callback/normalizer_data_callback.rst>
Switch Optimizer <callback/optim/switch_optimizer.rst>
Switch Scheduler <callback/optim/switch_scheduler.rst>
Normalizer Data <callback/processing/normalizer_data_callback.rst>
PINA Progress Bar <callback/processing/pina_progress_bar.rst>
Metric Tracker <callback/processing/metric_tracker.rst>
Refinement Interface <callback/refinement/refinement_interface.rst>
R3 Refinement <callback/refinement/r3_refinement.rst>

Losses and Weightings
---------------------
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Optimizer callbacks
Switch Optimizer
=====================

.. currentmodule:: pina.callback.optimizer_callback
.. currentmodule:: pina.callback.optim.switch_optimizer
.. autoclass:: SwitchOptimizer
:members:
:show-inheritance:
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Switch Scheduler
=====================

.. currentmodule:: pina.callback.switch_scheduler
.. currentmodule:: pina.callback.optim.switch_scheduler
.. autoclass:: SwitchScheduler
:members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/_rst/callback/processing/metric_tracker.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Metric Tracker
==================
.. currentmodule:: pina.callback.processing.metric_tracker

.. autoclass:: MetricTracker
:members:
:show-inheritance:
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Normalizer callbacks
Normalizer Data
=======================

.. currentmodule:: pina.callback.normalizer_data_callback
.. currentmodule:: pina.callback.processing.normalizer_data_callback
.. autoclass:: NormalizerDataCallback
:members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/_rst/callback/processing/pina_progress_bar.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
PINA Progress Bar
==================
.. currentmodule:: pina.callback.processing.pina_progress_bar

.. autoclass:: PINAProgressBar
:members:
:show-inheritance:
11 changes: 0 additions & 11 deletions docs/source/_rst/callback/processing_callback.rst

This file was deleted.

13 changes: 7 additions & 6 deletions pina/callback/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
__all__ = [
"SwitchOptimizer",
"SwitchScheduler",
"MetricTracker",
"NormalizerDataCallback",
"PINAProgressBar",
"MetricTracker",
"R3Refinement",
"NormalizerDataCallback",
]

from .optimizer_callback import SwitchOptimizer
from .processing_callback import MetricTracker, PINAProgressBar
from .optim.switch_optimizer import SwitchOptimizer
from .optim.switch_scheduler import SwitchScheduler
from .processing.normalizer_data_callback import NormalizerDataCallback
from .processing.pina_progress_bar import PINAProgressBar
from .processing.metric_tracker import MetricTracker
from .refinement import R3Refinement
from .normalizer_data_callback import NormalizerDataCallback
from .switch_scheduler import SwitchScheduler
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Module for the SwitchOptimizer callback."""

from lightning.pytorch.callbacks import Callback
from ..optim import TorchOptimizer
from ..utils import check_consistency
from ...optim import TorchOptimizer
from ...utils import check_consistency


class SwitchOptimizer(Callback):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Module for the SwitchScheduler callback."""

from lightning.pytorch.callbacks import Callback
from ..optim import TorchScheduler
from ..utils import check_consistency, check_positive_integer
from ...optim import TorchScheduler
from ...utils import check_consistency, check_positive_integer


class SwitchScheduler(Callback):
Expand Down
80 changes: 80 additions & 0 deletions pina/callback/processing/metric_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Module for the Metric Tracker."""

import copy
import torch
from lightning.pytorch.callbacks import Callback


class MetricTracker(Callback):
"""
Lightning Callback for Metric Tracking.
"""

def __init__(self, metrics_to_track=None):
"""
Tracks specified metrics during training.

:param metrics_to_track: List of metrics to track.
Defaults to train loss.
:type metrics_to_track: list[str], optional
"""
super().__init__()
self._collection = []
# Default to tracking 'train_loss' if not specified
self.metrics_to_track = metrics_to_track

def setup(self, trainer, pl_module, stage):
"""
Called when fit, validate, test, predict, or tune begins.

:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
:param SolverInterface pl_module: A
:class:`~pina.solver.solver.SolverInterface` instance.
:param str stage: Either 'fit', 'test' or 'predict'.
"""
if self.metrics_to_track is None and trainer.batch_size is None:
self.metrics_to_track = ["train_loss"]
elif self.metrics_to_track is None:
self.metrics_to_track = ["train_loss_epoch"]
return super().setup(trainer, pl_module, stage)

def on_train_epoch_end(self, trainer, pl_module):
"""
Collect and track metrics at the end of each training epoch.

:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:param pl_module: The model being trained (not used here).
"""
# Track metrics after the first epoch onwards
if trainer.current_epoch > 0:
# Append only the tracked metrics to avoid unnecessary data
tracked_metrics = {
k: v
for k, v in trainer.logged_metrics.items()
if k in self.metrics_to_track
}
self._collection.append(copy.deepcopy(tracked_metrics))

@property
def metrics(self):
"""
Aggregate collected metrics over all epochs.

:return: A dictionary containing aggregated metric values.
:rtype: dict
"""
if not self._collection:
return {}

# Get intersection of keys across all collected dictionaries
common_keys = set(self._collection[0]).intersection(
*self._collection[1:]
)

# Stack the metric values for common keys and return
return {
k: torch.stack([dic[k] for dic in self._collection])
for k in common_keys
if k in self.metrics_to_track
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import torch
from lightning.pytorch import Callback
from ..label_tensor import LabelTensor
from ..utils import check_consistency, is_function
from ..condition import InputTargetCondition
from ..data.dataset import PinaGraphDataset
from ...label_tensor import LabelTensor
from ...utils import check_consistency, is_function
from ...condition import InputTargetCondition
from ...data.dataset import PinaGraphDataset


class NormalizerDataCallback(Callback):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,90 +1,12 @@
"""Module for the Processing Callbacks."""

import copy
import torch

from lightning.pytorch.callbacks import Callback, TQDMProgressBar
from lightning.pytorch.callbacks import TQDMProgressBar
from lightning.pytorch.callbacks.progress.progress_bar import (
get_standard_metrics,
)
from pina.utils import check_consistency


class MetricTracker(Callback):
"""
Lightning Callback for Metric Tracking.
"""

def __init__(self, metrics_to_track=None):
"""
Tracks specified metrics during training.

:param metrics_to_track: List of metrics to track.
Defaults to train loss.
:type metrics_to_track: list[str], optional
"""
super().__init__()
self._collection = []
# Default to tracking 'train_loss' if not specified
self.metrics_to_track = metrics_to_track

def setup(self, trainer, pl_module, stage):
"""
Called when fit, validate, test, predict, or tune begins.

:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
:param SolverInterface pl_module: A
:class:`~pina.solver.solver.SolverInterface` instance.
:param str stage: Either 'fit', 'test' or 'predict'.
"""
if self.metrics_to_track is None and trainer.batch_size is None:
self.metrics_to_track = ["train_loss"]
elif self.metrics_to_track is None:
self.metrics_to_track = ["train_loss_epoch"]
return super().setup(trainer, pl_module, stage)

def on_train_epoch_end(self, trainer, pl_module):
"""
Collect and track metrics at the end of each training epoch.

:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:param pl_module: The model being trained (not used here).
"""
# Track metrics after the first epoch onwards
if trainer.current_epoch > 0:
# Append only the tracked metrics to avoid unnecessary data
tracked_metrics = {
k: v
for k, v in trainer.logged_metrics.items()
if k in self.metrics_to_track
}
self._collection.append(copy.deepcopy(tracked_metrics))

@property
def metrics(self):
"""
Aggregate collected metrics over all epochs.

:return: A dictionary containing aggregated metric values.
:rtype: dict
"""
if not self._collection:
return {}

# Get intersection of keys across all collected dictionaries
common_keys = set(self._collection[0]).intersection(
*self._collection[1:]
)

# Stack the metric values for common keys and return
return {
k: torch.stack([dic[k] for dic in self._collection])
for k in common_keys
if k in self.metrics_to_track
}


class PINAProgressBar(TQDMProgressBar):
"""
PINA Implementation of a Lightning Callback for enriching the progress bar.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pina.solver import PINN
from pina.trainer import Trainer
from pina.model import FeedForward
from pina.callback.processing_callback import PINAProgressBar
from pina.callback import PINAProgressBar
from pina.problem.zoo import Poisson2DSquareProblem as Poisson


Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import pytest

from torch.nn import MSELoss

from pina.solver import PINN
from pina.trainer import Trainer
from pina.model import FeedForward
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
from pina.callback.refinement import R3Refinement
from pina.callback import R3Refinement


# make the problem
Expand Down