diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 9962ebebb..55c9df1b1 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -254,12 +254,13 @@ Callbacks .. toctree:: :titlesonly: - Processing callback - Optimizer callback - Switch Scheduler - R3 Refinment callback - Refinment Interface callback - Normalizer callback + Switch Optimizer + Switch Scheduler + Normalizer Data + PINA Progress Bar + Metric Tracker + Refinement Interface + R3 Refinement Losses and Weightings --------------------- diff --git a/docs/source/_rst/callback/optimizer_callback.rst b/docs/source/_rst/callback/optim/switch_optimizer.rst similarity index 54% rename from docs/source/_rst/callback/optimizer_callback.rst rename to docs/source/_rst/callback/optim/switch_optimizer.rst index 0afdc2669..635e79a18 100644 --- a/docs/source/_rst/callback/optimizer_callback.rst +++ b/docs/source/_rst/callback/optim/switch_optimizer.rst @@ -1,7 +1,7 @@ -Optimizer callbacks +Switch Optimizer ===================== -.. currentmodule:: pina.callback.optimizer_callback +.. currentmodule:: pina.callback.optim.switch_optimizer .. autoclass:: SwitchOptimizer :members: :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/callback/switch_scheduler.rst b/docs/source/_rst/callback/optim/switch_scheduler.rst similarity index 65% rename from docs/source/_rst/callback/switch_scheduler.rst rename to docs/source/_rst/callback/optim/switch_scheduler.rst index 0e69ef0fb..3176904da 100644 --- a/docs/source/_rst/callback/switch_scheduler.rst +++ b/docs/source/_rst/callback/optim/switch_scheduler.rst @@ -1,7 +1,7 @@ Switch Scheduler ===================== -.. currentmodule:: pina.callback.switch_scheduler +.. currentmodule:: pina.callback.optim.switch_scheduler .. autoclass:: SwitchScheduler :members: :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/callback/processing/metric_tracker.rst b/docs/source/_rst/callback/processing/metric_tracker.rst new file mode 100644 index 000000000..f21cc7730 --- /dev/null +++ b/docs/source/_rst/callback/processing/metric_tracker.rst @@ -0,0 +1,7 @@ +Metric Tracker +================== +.. currentmodule:: pina.callback.processing.metric_tracker + +.. autoclass:: MetricTracker + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/callback/normalizer_data_callback.rst b/docs/source/_rst/callback/processing/normalizer_data_callback.rst similarity index 53% rename from docs/source/_rst/callback/normalizer_data_callback.rst rename to docs/source/_rst/callback/processing/normalizer_data_callback.rst index 6f59f7aee..a44f0c402 100644 --- a/docs/source/_rst/callback/normalizer_data_callback.rst +++ b/docs/source/_rst/callback/processing/normalizer_data_callback.rst @@ -1,7 +1,7 @@ -Normalizer callbacks +Normalizer Data ======================= -.. currentmodule:: pina.callback.normalizer_data_callback +.. currentmodule:: pina.callback.processing.normalizer_data_callback .. autoclass:: NormalizerDataCallback :members: :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/callback/processing/pina_progress_bar.rst b/docs/source/_rst/callback/processing/pina_progress_bar.rst new file mode 100644 index 000000000..1d42ad120 --- /dev/null +++ b/docs/source/_rst/callback/processing/pina_progress_bar.rst @@ -0,0 +1,7 @@ +PINA Progress Bar +================== +.. currentmodule:: pina.callback.processing.pina_progress_bar + +.. autoclass:: PINAProgressBar + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/callback/processing_callback.rst b/docs/source/_rst/callback/processing_callback.rst deleted file mode 100644 index a06bb8b17..000000000 --- a/docs/source/_rst/callback/processing_callback.rst +++ /dev/null @@ -1,11 +0,0 @@ -Processing callbacks -======================= - -.. currentmodule:: pina.callback.processing_callback -.. autoclass:: MetricTracker - :members: - :show-inheritance: - -.. autoclass:: PINAProgressBar - :members: - :show-inheritance: \ No newline at end of file diff --git a/pina/callback/__init__.py b/pina/callback/__init__.py index f2057257e..92da661cb 100644 --- a/pina/callback/__init__.py +++ b/pina/callback/__init__.py @@ -3,14 +3,15 @@ __all__ = [ "SwitchOptimizer", "SwitchScheduler", - "MetricTracker", + "NormalizerDataCallback", "PINAProgressBar", + "MetricTracker", "R3Refinement", - "NormalizerDataCallback", ] -from .optimizer_callback import SwitchOptimizer -from .processing_callback import MetricTracker, PINAProgressBar +from .optim.switch_optimizer import SwitchOptimizer +from .optim.switch_scheduler import SwitchScheduler +from .processing.normalizer_data_callback import NormalizerDataCallback +from .processing.pina_progress_bar import PINAProgressBar +from .processing.metric_tracker import MetricTracker from .refinement import R3Refinement -from .normalizer_data_callback import NormalizerDataCallback -from .switch_scheduler import SwitchScheduler diff --git a/pina/callback/optimizer_callback.py b/pina/callback/optim/switch_optimizer.py similarity index 97% rename from pina/callback/optimizer_callback.py rename to pina/callback/optim/switch_optimizer.py index 1b518406b..3072b7c2e 100644 --- a/pina/callback/optimizer_callback.py +++ b/pina/callback/optim/switch_optimizer.py @@ -1,8 +1,8 @@ """Module for the SwitchOptimizer callback.""" from lightning.pytorch.callbacks import Callback -from ..optim import TorchOptimizer -from ..utils import check_consistency +from ...optim import TorchOptimizer +from ...utils import check_consistency class SwitchOptimizer(Callback): diff --git a/pina/callback/switch_scheduler.py b/pina/callback/optim/switch_scheduler.py similarity index 96% rename from pina/callback/switch_scheduler.py rename to pina/callback/optim/switch_scheduler.py index 22ae8bd08..3641f4ee4 100644 --- a/pina/callback/switch_scheduler.py +++ b/pina/callback/optim/switch_scheduler.py @@ -1,8 +1,8 @@ """Module for the SwitchScheduler callback.""" from lightning.pytorch.callbacks import Callback -from ..optim import TorchScheduler -from ..utils import check_consistency, check_positive_integer +from ...optim import TorchScheduler +from ...utils import check_consistency, check_positive_integer class SwitchScheduler(Callback): diff --git a/pina/callback/processing/metric_tracker.py b/pina/callback/processing/metric_tracker.py new file mode 100644 index 000000000..9b1dc9d4a --- /dev/null +++ b/pina/callback/processing/metric_tracker.py @@ -0,0 +1,80 @@ +"""Module for the Metric Tracker.""" + +import copy +import torch +from lightning.pytorch.callbacks import Callback + + +class MetricTracker(Callback): + """ + Lightning Callback for Metric Tracking. + """ + + def __init__(self, metrics_to_track=None): + """ + Tracks specified metrics during training. + + :param metrics_to_track: List of metrics to track. + Defaults to train loss. + :type metrics_to_track: list[str], optional + """ + super().__init__() + self._collection = [] + # Default to tracking 'train_loss' if not specified + self.metrics_to_track = metrics_to_track + + def setup(self, trainer, pl_module, stage): + """ + Called when fit, validate, test, predict, or tune begins. + + :param Trainer trainer: A :class:`~pina.trainer.Trainer` instance. + :param SolverInterface pl_module: A + :class:`~pina.solver.solver.SolverInterface` instance. + :param str stage: Either 'fit', 'test' or 'predict'. + """ + if self.metrics_to_track is None and trainer.batch_size is None: + self.metrics_to_track = ["train_loss"] + elif self.metrics_to_track is None: + self.metrics_to_track = ["train_loss_epoch"] + return super().setup(trainer, pl_module, stage) + + def on_train_epoch_end(self, trainer, pl_module): + """ + Collect and track metrics at the end of each training epoch. + + :param trainer: The trainer object managing the training process. + :type trainer: pytorch_lightning.Trainer + :param pl_module: The model being trained (not used here). + """ + # Track metrics after the first epoch onwards + if trainer.current_epoch > 0: + # Append only the tracked metrics to avoid unnecessary data + tracked_metrics = { + k: v + for k, v in trainer.logged_metrics.items() + if k in self.metrics_to_track + } + self._collection.append(copy.deepcopy(tracked_metrics)) + + @property + def metrics(self): + """ + Aggregate collected metrics over all epochs. + + :return: A dictionary containing aggregated metric values. + :rtype: dict + """ + if not self._collection: + return {} + + # Get intersection of keys across all collected dictionaries + common_keys = set(self._collection[0]).intersection( + *self._collection[1:] + ) + + # Stack the metric values for common keys and return + return { + k: torch.stack([dic[k] for dic in self._collection]) + for k in common_keys + if k in self.metrics_to_track + } diff --git a/pina/callback/normalizer_data_callback.py b/pina/callback/processing/normalizer_data_callback.py similarity index 97% rename from pina/callback/normalizer_data_callback.py rename to pina/callback/processing/normalizer_data_callback.py index ef957b9ef..4d85a7d9a 100644 --- a/pina/callback/normalizer_data_callback.py +++ b/pina/callback/processing/normalizer_data_callback.py @@ -2,10 +2,10 @@ import torch from lightning.pytorch import Callback -from ..label_tensor import LabelTensor -from ..utils import check_consistency, is_function -from ..condition import InputTargetCondition -from ..data.dataset import PinaGraphDataset +from ...label_tensor import LabelTensor +from ...utils import check_consistency, is_function +from ...condition import InputTargetCondition +from ...data.dataset import PinaGraphDataset class NormalizerDataCallback(Callback): diff --git a/pina/callback/processing_callback.py b/pina/callback/processing/pina_progress_bar.py similarity index 57% rename from pina/callback/processing_callback.py rename to pina/callback/processing/pina_progress_bar.py index 244c7266d..4c322a5e8 100644 --- a/pina/callback/processing_callback.py +++ b/pina/callback/processing/pina_progress_bar.py @@ -1,90 +1,12 @@ """Module for the Processing Callbacks.""" -import copy -import torch - -from lightning.pytorch.callbacks import Callback, TQDMProgressBar +from lightning.pytorch.callbacks import TQDMProgressBar from lightning.pytorch.callbacks.progress.progress_bar import ( get_standard_metrics, ) from pina.utils import check_consistency -class MetricTracker(Callback): - """ - Lightning Callback for Metric Tracking. - """ - - def __init__(self, metrics_to_track=None): - """ - Tracks specified metrics during training. - - :param metrics_to_track: List of metrics to track. - Defaults to train loss. - :type metrics_to_track: list[str], optional - """ - super().__init__() - self._collection = [] - # Default to tracking 'train_loss' if not specified - self.metrics_to_track = metrics_to_track - - def setup(self, trainer, pl_module, stage): - """ - Called when fit, validate, test, predict, or tune begins. - - :param Trainer trainer: A :class:`~pina.trainer.Trainer` instance. - :param SolverInterface pl_module: A - :class:`~pina.solver.solver.SolverInterface` instance. - :param str stage: Either 'fit', 'test' or 'predict'. - """ - if self.metrics_to_track is None and trainer.batch_size is None: - self.metrics_to_track = ["train_loss"] - elif self.metrics_to_track is None: - self.metrics_to_track = ["train_loss_epoch"] - return super().setup(trainer, pl_module, stage) - - def on_train_epoch_end(self, trainer, pl_module): - """ - Collect and track metrics at the end of each training epoch. - - :param trainer: The trainer object managing the training process. - :type trainer: pytorch_lightning.Trainer - :param pl_module: The model being trained (not used here). - """ - # Track metrics after the first epoch onwards - if trainer.current_epoch > 0: - # Append only the tracked metrics to avoid unnecessary data - tracked_metrics = { - k: v - for k, v in trainer.logged_metrics.items() - if k in self.metrics_to_track - } - self._collection.append(copy.deepcopy(tracked_metrics)) - - @property - def metrics(self): - """ - Aggregate collected metrics over all epochs. - - :return: A dictionary containing aggregated metric values. - :rtype: dict - """ - if not self._collection: - return {} - - # Get intersection of keys across all collected dictionaries - common_keys = set(self._collection[0]).intersection( - *self._collection[1:] - ) - - # Stack the metric values for common keys and return - return { - k: torch.stack([dic[k] for dic in self._collection]) - for k in common_keys - if k in self.metrics_to_track - } - - class PINAProgressBar(TQDMProgressBar): """ PINA Implementation of a Lightning Callback for enriching the progress bar. diff --git a/tests/test_callback/test_progress_bar.py b/tests/test_callback/test_pina_progress_bar.py similarity index 94% rename from tests/test_callback/test_progress_bar.py rename to tests/test_callback/test_pina_progress_bar.py index d77408c42..1013e0c2e 100644 --- a/tests/test_callback/test_progress_bar.py +++ b/tests/test_callback/test_pina_progress_bar.py @@ -1,7 +1,7 @@ from pina.solver import PINN from pina.trainer import Trainer from pina.model import FeedForward -from pina.callback.processing_callback import PINAProgressBar +from pina.callback import PINAProgressBar from pina.problem.zoo import Poisson2DSquareProblem as Poisson diff --git a/tests/test_callback/test_adaptive_refinement_callback.py b/tests/test_callback/test_r3_refinement.py similarity index 97% rename from tests/test_callback/test_adaptive_refinement_callback.py rename to tests/test_callback/test_r3_refinement.py index 7866c7f7b..9f167bb06 100644 --- a/tests/test_callback/test_adaptive_refinement_callback.py +++ b/tests/test_callback/test_r3_refinement.py @@ -1,12 +1,10 @@ import pytest - from torch.nn import MSELoss - from pina.solver import PINN from pina.trainer import Trainer from pina.model import FeedForward from pina.problem.zoo import Poisson2DSquareProblem as Poisson -from pina.callback.refinement import R3Refinement +from pina.callback import R3Refinement # make the problem diff --git a/tests/test_callback/test_optimizer_callback.py b/tests/test_callback/test_switch_optimizer.py similarity index 100% rename from tests/test_callback/test_optimizer_callback.py rename to tests/test_callback/test_switch_optimizer.py