diff --git a/pina/callback/adaptive_refinement_callback.py b/pina/callback/adaptive_refinement_callback.py index 951ee750a..9b2932f95 100644 --- a/pina/callback/adaptive_refinement_callback.py +++ b/pina/callback/adaptive_refinement_callback.py @@ -1,5 +1,6 @@ """PINA Callbacks Implementations""" +import importlib.metadata import torch from lightning.pytorch.callbacks import Callback from ..label_tensor import LabelTensor @@ -7,17 +8,17 @@ class R3Refinement(Callback): + """ + PINA Implementation of an R3 Refinement Callback. + """ def __init__(self, sample_every): """ - PINA Implementation of an R3 Refinement Callback. - This callback implements the R3 (Retain-Resample-Release) routine for sampling new points based on adaptive search. The algorithm incrementally accumulates collocation points in regions - of high PDE residuals, and releases those - with low residuals. Points are sampled uniformly in all regions - where sampling is needed. + of high PDE residuals, and releases those with low residuals. + Points are sampled uniformly in all regions where sampling is needed. .. seealso:: @@ -33,142 +34,148 @@ def __init__(self, sample_every): Example: >>> r3_callback = R3Refinement(sample_every=5) """ - super().__init__() - - # sample every - check_consistency(sample_every, int) - self._sample_every = sample_every - self._const_pts = None - - def _compute_residual(self, trainer): - """ - Computes the residuals for a PINN object. - - :return: the total loss, and pointwise loss. - :rtype: tuple - """ - - # extract the solver and device from trainer - solver = trainer.solver - device = trainer._accelerator_connector._accelerator_flag - precision = trainer.precision - if precision == "64-true": - precision = torch.float64 - elif precision == "32-true": - precision = torch.float32 - else: - raise RuntimeError( - "Currently R3Refinement is only implemented " - "for precision '32-true' and '64-true', set " - "Trainer precision to match one of the " - "available precisions." - ) - - # compute residual - res_loss = {} - tot_loss = [] - for location in self._sampling_locations: # TODO fix for new collector - condition = solver.problem.conditions[location] - pts = solver.problem.input_pts[location] - # send points to correct device - pts = pts.to(device=device, dtype=precision) - pts = pts.requires_grad_(True) - pts.retain_grad() - # PINN loss: equation evaluated only for sampling locations - target = condition.equation.residual(pts, solver.forward(pts)) - res_loss[location] = torch.abs(target).as_subclass(torch.Tensor) - tot_loss.append(torch.abs(target)) - - print(tot_loss) - - return torch.vstack(tot_loss), res_loss - - def _r3_routine(self, trainer): - """ - R3 refinement main routine. - - :param Trainer trainer: PINA Trainer. - """ - # compute residual (all device possible) - tot_loss, res_loss = self._compute_residual(trainer) - tot_loss = tot_loss.as_subclass(torch.Tensor) - - # !!!!!! From now everything is performed on CPU !!!!!! - - # average loss - avg = (tot_loss.mean()).to("cpu") - old_pts = {} # points to be retained - for location in self._sampling_locations: - pts = trainer._model.problem.input_pts[location] - labels = pts.labels - pts = pts.cpu().detach().as_subclass(torch.Tensor) - residuals = res_loss[location].cpu() - mask = (residuals > avg).flatten() - if any(mask): # append residuals greater than average - pts = (pts[mask]).as_subclass(LabelTensor) - pts.labels = labels - old_pts[location] = pts - numb_pts = self._const_pts[location] - len(old_pts[location]) - # sample new points - trainer._model.problem.discretise_domain( - numb_pts, "random", locations=[location] - ) - - else: # if no res greater than average, samples all uniformly - numb_pts = self._const_pts[location] - # sample new points - trainer._model.problem.discretise_domain( - numb_pts, "random", locations=[location] - ) - # adding previous population points - trainer._model.problem.add_points(old_pts) - - # update dataloader - trainer._create_or_update_loader() - - def on_train_start(self, trainer, _): - """ - Callback function called at the start of training. - - This method extracts the locations for sampling from the problem - conditions and calculates the total population. - - :param trainer: The trainer object managing the training process. - :type trainer: pytorch_lightning.Trainer - :param _: Placeholder argument (not used). - - :return: None - :rtype: None - """ - # extract locations for sampling - problem = trainer.solver.problem - locations = [] - for condition_name in problem.conditions: - condition = problem.conditions[condition_name] - if hasattr(condition, "location"): - locations.append(condition_name) - self._sampling_locations = locations - - # extract total population - const_pts = {} # for each location, store the # of pts to keep constant - for location in self._sampling_locations: - pts = trainer._model.problem.input_pts[location] - const_pts[location] = len(pts) - self._const_pts = const_pts - - def on_train_epoch_end(self, trainer, __): - """ - Callback function called at the end of each training epoch. - - This method triggers the R3 routine for refinement if the current - epoch is a multiple of `_sample_every`. - - :param trainer: The trainer object managing the training process. - :type trainer: pytorch_lightning.Trainer - :param __: Placeholder argument (not used). - - :return: None - :rtype: None - """ - if trainer.current_epoch % self._sample_every == 0: - self._r3_routine(trainer) + raise NotImplementedError( + "R3Refinement callback is being refactored in the pina " + f"{importlib.metadata.metadata('pina-mathlab')['Version']} " + "version. Please use version 0.1 if R3Refinement is required." + ) + + # super().__init__() + + # # sample every + # check_consistency(sample_every, int) + # self._sample_every = sample_every + # self._const_pts = None + + # def _compute_residual(self, trainer): + # """ + # Computes the residuals for a PINN object. + + # :return: the total loss, and pointwise loss. + # :rtype: tuple + # """ + + # # extract the solver and device from trainer + # solver = trainer.solver + # device = trainer._accelerator_connector._accelerator_flag + # precision = trainer.precision + # if precision == "64-true": + # precision = torch.float64 + # elif precision == "32-true": + # precision = torch.float32 + # else: + # raise RuntimeError( + # "Currently R3Refinement is only implemented " + # "for precision '32-true' and '64-true', set " + # "Trainer precision to match one of the " + # "available precisions." + # ) + + # # compute residual + # res_loss = {} + # tot_loss = [] + # for location in self._sampling_locations: + # condition = solver.problem.conditions[location] + # pts = solver.problem.input_pts[location] + # # send points to correct device + # pts = pts.to(device=device, dtype=precision) + # pts = pts.requires_grad_(True) + # pts.retain_grad() + # # PINN loss: equation evaluated only for sampling locations + # target = condition.equation.residual(pts, solver.forward(pts)) + # res_loss[location] = torch.abs(target).as_subclass(torch.Tensor) + # tot_loss.append(torch.abs(target)) + + # print(tot_loss) + + # return torch.vstack(tot_loss), res_loss + + # def _r3_routine(self, trainer): + # """ + # R3 refinement main routine. + + # :param Trainer trainer: PINA Trainer. + # """ + # # compute residual (all device possible) + # tot_loss, res_loss = self._compute_residual(trainer) + # tot_loss = tot_loss.as_subclass(torch.Tensor) + + # # !!!!!! From now everything is performed on CPU !!!!!! + + # # average loss + # avg = (tot_loss.mean()).to("cpu") + # old_pts = {} # points to be retained + # for location in self._sampling_locations: + # pts = trainer._model.problem.input_pts[location] + # labels = pts.labels + # pts = pts.cpu().detach().as_subclass(torch.Tensor) + # residuals = res_loss[location].cpu() + # mask = (residuals > avg).flatten() + # if any(mask): # append residuals greater than average + # pts = (pts[mask]).as_subclass(LabelTensor) + # pts.labels = labels + # old_pts[location] = pts + # numb_pts = self._const_pts[location] - len(old_pts[location]) + # # sample new points + # trainer._model.problem.discretise_domain( + # numb_pts, "random", locations=[location] + # ) + + # else: # if no res greater than average, samples all uniformly + # numb_pts = self._const_pts[location] + # # sample new points + # trainer._model.problem.discretise_domain( + # numb_pts, "random", locations=[location] + # ) + # # adding previous population points + # trainer._model.problem.add_points(old_pts) + + # # update dataloader + # trainer._create_or_update_loader() + + # def on_train_start(self, trainer, _): + # """ + # Callback function called at the start of training. + + # This method extracts the locations for sampling from the problem + # conditions and calculates the total population. + + # :param trainer: The trainer object managing the training process. + # :type trainer: pytorch_lightning.Trainer + # :param _: Placeholder argument (not used). + + # :return: None + # :rtype: None + # """ + # # extract locations for sampling + # problem = trainer.solver.problem + # locations = [] + # for condition_name in problem.conditions: + # condition = problem.conditions[condition_name] + # if hasattr(condition, "location"): + # locations.append(condition_name) + # self._sampling_locations = locations + + # # extract total population + # const_pts = {} # for each location, store the pts to keep constant + # for location in self._sampling_locations: + # pts = trainer._model.problem.input_pts[location] + # const_pts[location] = len(pts) + # self._const_pts = const_pts + + # def on_train_epoch_end(self, trainer, __): + # """ + # Callback function called at the end of each training epoch. + + # This method triggers the R3 routine for refinement if the current + # epoch is a multiple of `_sample_every`. + + # :param trainer: The trainer object managing the training process. + # :type trainer: pytorch_lightning.Trainer + # :param __: Placeholder argument (not used). + + # :return: None + # :rtype: None + # """ + # if trainer.current_epoch % self._sample_every == 0: + # self._r3_routine(trainer) diff --git a/pina/callback/linear_weight_update_callback.py b/pina/callback/linear_weight_update_callback.py index 02a8878f0..da5431bf7 100644 --- a/pina/callback/linear_weight_update_callback.py +++ b/pina/callback/linear_weight_update_callback.py @@ -37,12 +37,13 @@ def __init__( check_consistency(self.initial_value, (float, int), subclass=False) check_consistency(self.target_value, (float, int), subclass=False) - def on_train_start(self, trainer, solver): + def on_train_start(self, trainer, pl_module): """ Initialize the weight of the condition to the specified `initial_value`. - :param Trainer trainer: a pina:class:`Trainer` instance. - :param SolverInterface solver: a pina:class:`SolverInterface` instance. + :param Trainer trainer: A :class:`~pina.trainer.Trainer` instance. + :param SolverInterface pl_module: A + :class:`~pina.solver.solver.SolverInterface` instance. """ # Check that the target epoch is valid if not 0 < self.target_epoch <= trainer.max_epochs: @@ -52,7 +53,7 @@ def on_train_start(self, trainer, solver): ) # Check that the condition is a problem condition - if self.condition_name not in solver.problem.conditions: + if self.condition_name not in pl_module.problem.conditions: raise ValueError( f"`{self.condition_name}` must be a problem condition." ) @@ -66,20 +67,21 @@ def on_train_start(self, trainer, solver): ) # Check that the weighting schema is ScalarWeighting - if not isinstance(solver.weighting, ScalarWeighting): + if not isinstance(pl_module.weighting, ScalarWeighting): raise ValueError("The weighting schema must be ScalarWeighting.") # Initialize the weight of the condition - solver.weighting.weights[self.condition_name] = self.initial_value + pl_module.weighting.weights[self.condition_name] = self.initial_value - def on_train_epoch_start(self, trainer, solver): + def on_train_epoch_start(self, trainer, pl_module): """ Adjust at each epoch the weight of the condition. - :param Trainer trainer: a pina:class:`Trainer` instance. - :param SolverInterface solver: a pina:class:`SolverInterface` instance. + :param Trainer trainer: A :class:`~pina.trainer.Trainer` instance. + :param SolverInterface pl_module: A + :class:`~pina.solver.solver.SolverInterface` instance. """ if 0 < trainer.current_epoch <= self.target_epoch: - solver.weighting.weights[self.condition_name] += ( + pl_module.weighting.weights[self.condition_name] += ( self.target_value - self.initial_value ) / (self.target_epoch - 1) diff --git a/pina/callback/optimizer_callback.py b/pina/callback/optimizer_callback.py index c76bbfc82..6b77b3d9a 100644 --- a/pina/callback/optimizer_callback.py +++ b/pina/callback/optimizer_callback.py @@ -1,27 +1,27 @@ """PINA Callbacks Implementations""" from lightning.pytorch.callbacks import Callback -import torch +from ..optim import TorchOptimizer from ..utils import check_consistency -from pina.optim import TorchOptimizer class SwitchOptimizer(Callback): + """ + PINA Implementation of a Lightning Callback to switch optimizer during + training. + """ def __init__(self, new_optimizers, epoch_switch): """ - PINA Implementation of a Lightning Callback to switch optimizer during - training. - - This callback allows for switching between different optimizers during + This callback allows switching between different optimizers during training, enabling the exploration of multiple optimization strategies - without the need to stop training. + without interrupting the training process. :param new_optimizers: The model optimizers to switch to. Can be a - single :class:`torch.optim.Optimizer` or a list of them for multiple - model solver. + single :class:`torch.optim.Optimizer` instance or a list of them + for multiple model solver. :type new_optimizers: pina.optim.TorchOptimizer | list - :param epoch_switch: The epoch at which to switch to the new optimizer. + :param epoch_switch: The epoch at which the optimizer switch occurs. :type epoch_switch: int Example: @@ -46,7 +46,7 @@ def __init__(self, new_optimizers, epoch_switch): def on_train_epoch_start(self, trainer, __): """ - Callback function to switch optimizer at the start of each training epoch. + Switch the optimizer at the start of the specified training epoch. :param trainer: The trainer object managing the training process. :type trainer: pytorch_lightning.Trainer @@ -59,7 +59,7 @@ def on_train_epoch_start(self, trainer, __): optims = [] for idx, optim in enumerate(self._new_optimizers): - optim.hook(trainer.solver.models[idx].parameters()) - optims.append(optim.instance) + optim.hook(trainer.solver._pina_models[idx].parameters()) + optims.append(optim) - trainer.optimizers = optims + trainer.solver._pina_optimizers = optims diff --git a/pina/callback/processing_callback.py b/pina/callback/processing_callback.py index f3a13c18c..79c5b9c13 100644 --- a/pina/callback/processing_callback.py +++ b/pina/callback/processing_callback.py @@ -1,7 +1,7 @@ """PINA Callbacks Implementations""" -import torch import copy +import torch from lightning.pytorch.callbacks import Callback, TQDMProgressBar from lightning.pytorch.callbacks.progress.progress_bar import ( @@ -11,22 +11,37 @@ class MetricTracker(Callback): + """ + Lightning Callback for Metric Tracking. + """ def __init__(self, metrics_to_track=None): """ - Lightning Callback for Metric Tracking. - - Tracks specific metrics during the training process. + Tracks specified metrics during training. - :ivar _collection: A list to store collected metrics after each epoch. - - :param metrics_to_track: List of metrics to track. Defaults to train/val loss. - :type metrics_to_track: list, optional + :param metrics_to_track: List of metrics to track. + Defaults to train loss. + :type metrics_to_track: list[str], optional """ super().__init__() self._collection = [] - # Default to tracking 'train_loss' and 'val_loss' if not specified - self.metrics_to_track = metrics_to_track or ["train_loss", "val_loss"] + # Default to tracking 'train_loss' if not specified + self.metrics_to_track = metrics_to_track + + def setup(self, trainer, pl_module, stage): + """ + Called when fit, validate, test, predict, or tune begins. + + :param Trainer trainer: A :class:`~pina.trainer.Trainer` instance. + :param SolverInterface pl_module: A + :class:`~pina.solver.solver.SolverInterface` instance. + :param str stage: Either 'fit', 'test' or 'predict'. + """ + if self.metrics_to_track is None and trainer.batch_size is None: + self.metrics_to_track = ["train_loss"] + elif self.metrics_to_track is None: + self.metrics_to_track = ["train_loss_epoch"] + return super().setup(trainer, pl_module, stage) def on_train_epoch_end(self, trainer, pl_module): """ @@ -71,26 +86,28 @@ def metrics(self): class PINAProgressBar(TQDMProgressBar): + """ + PINA Implementation of a Lightning Callback for enriching the progress bar. + """ - BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]" + BAR_FORMAT = ( + "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, " + "{rate_noinv_fmt}{postfix}]" + ) def __init__(self, metrics="val", **kwargs): """ - PINA Implementation of a Lightning Callback for enriching the progress - bar. + This class enables the display of only relevant metrics during training. - This class provides functionality to display only relevant metrics - during the training process. - - :param metrics: Logged metrics to display during the training. It should - be a subset of the conditions keys defined in + :param metrics: Logged metrics to be shown during the training. + Must be a subset of the conditions keys defined in :obj:`pina.condition.Condition`. :type metrics: str | list(str) | tuple(str) :Keyword Arguments: - The additional keyword arguments specify the progress bar - and can be choosen from the `pytorch-lightning - TQDMProgressBar API `_ + The additional keyword arguments specify the progress bar and can be + choosen from the `pytorch-lightning TQDMProgressBar API + `_ Example: >>> pbar = PINAProgressBar(['mean']) @@ -105,9 +122,9 @@ def __init__(self, metrics="val", **kwargs): self._sorted_metrics = metrics def get_metrics(self, trainer, pl_module): - r"""Combines progress bar metrics collected from the trainer with + r"""Combine progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. - Implement this to override the items displayed in the progress bar. + Override this method to customize the items shown in the progress bar. The progress bar metrics are sorted according to ``metrics``. Here is an example of how to override the defaults: @@ -122,20 +139,20 @@ def get_metrics(self, trainer, model): :return: Dictionary with the items to be displayed in the progress bar. :rtype: tuple(dict) - """ standard_metrics = get_standard_metrics(trainer) pbar_metrics = trainer.progress_bar_metrics if pbar_metrics: pbar_metrics = { - key: pbar_metrics[key] for key in self._sorted_metrics + key: pbar_metrics[key] + for key in pbar_metrics + if key in self._sorted_metrics } return {**standard_metrics, **pbar_metrics} - def on_fit_start(self, trainer, pl_module): + def setup(self, trainer, pl_module, stage): """ - Check that the metrics defined in the initialization are available, - i.e. are correctly logged. + Check that the initialized metrics are available and correctly logged. :param trainer: The trainer object managing the training process. :type trainer: pytorch_lightning.Trainer @@ -150,7 +167,11 @@ def on_fit_start(self, trainer, pl_module): ): raise KeyError(f"Key '{key}' is not present in the dictionary") # add the loss pedix + if trainer.batch_size is not None: + pedix = "_loss_epoch" + else: + pedix = "_loss" self._sorted_metrics = [ - metric + "_loss" for metric in self._sorted_metrics + metric + pedix for metric in self._sorted_metrics ] - return super().on_fit_start(trainer, pl_module) + return super().setup(trainer, pl_module, stage) diff --git a/pina/trainer.py b/pina/trainer.py index 81abfbd17..3fb73bf07 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -64,8 +64,7 @@ def __init__( :Keyword Arguments: The additional keyword arguments specify the training setup and can be choosen from the `pytorch-lightning - Trainer API `_ + Trainer API `_ """ # check consistency for init types self._check_input_consistency( @@ -96,7 +95,6 @@ def __init__( # Setting default kwargs, overriding lightning defaults kwargs.setdefault("enable_progress_bar", True) - kwargs.setdefault("logger", None) super().__init__(**kwargs) @@ -127,9 +125,6 @@ def __init__( # logging self.logging_kwargs = { - "logger": bool( - kwargs["logger"] is not None or kwargs["logger"] is True - ), "sync_dist": bool( len(self._accelerator_connector._parallel_devices) > 1 ), diff --git a/tests/test_callback/test_metric_tracker.py b/tests/test_callback/test_metric_tracker.py index de14694e5..3e6fa4407 100644 --- a/tests/test_callback/test_metric_tracker.py +++ b/tests/test_callback/test_metric_tracker.py @@ -23,17 +23,18 @@ def test_metric_tracker_constructor(): MetricTracker() -# def test_metric_tracker_routine(): #TODO revert -# # make the trainer -# trainer = Trainer(solver=solver, -# callback=[ -# MetricTracker() -# ], -# accelerator='cpu', -# max_epochs=5) -# trainer.train() -# # get the tracked metrics -# metrics = trainer.callback[0].metrics -# # assert the logged metrics are correct -# logged_metrics = sorted(list(metrics.keys())) -# assert logged_metrics == ['train_loss_epoch', 'train_loss_step', 'val_loss'] +def test_metric_tracker_routine(): + # make the trainer + trainer = Trainer( + solver=solver, + callbacks=[MetricTracker()], + accelerator="cpu", + max_epochs=5, + log_every_n_steps=1, + ) + trainer.train() + # get the tracked metrics + metrics = trainer.callbacks[0].metrics + # assert the logged metrics are correct + logged_metrics = sorted(list(metrics.keys())) + assert logged_metrics == ["train_loss"] diff --git a/tests/test_callback/test_optimizer_callback.py b/tests/test_callback/test_optimizer_callback.py index 6250c7ace..785a9c3f4 100644 --- a/tests/test_callback/test_optimizer_callback.py +++ b/tests/test_callback/test_optimizer_callback.py @@ -21,19 +21,25 @@ # make the solver solver = PINN(problem=poisson_problem, model=model) -adam_optimizer = TorchOptimizer(torch.optim.Adam, lr=0.01) -lbfgs_optimizer = TorchOptimizer(torch.optim.LBFGS, lr=0.001) +adam = TorchOptimizer(torch.optim.Adam, lr=0.01) +lbfgs = TorchOptimizer(torch.optim.LBFGS, lr=0.001) def test_switch_optimizer_constructor(): - SwitchOptimizer(adam_optimizer, epoch_switch=10) - - -# def test_switch_optimizer_routine(): #TODO revert -# # make the trainer -# switch_opt_callback = SwitchOptimizer(lbfgs_optimizer, epoch_switch=3) -# trainer = Trainer(solver=solver, -# callback=[switch_opt_callback], -# accelerator='cpu', -# max_epochs=5) -# trainer.train() + SwitchOptimizer(adam, epoch_switch=10) + + +def test_switch_optimizer_routine(): + # check initial optimizer + solver.configure_optimizers() + assert solver.optimizer.instance.__class__ == torch.optim.Adam + # make the trainer + switch_opt_callback = SwitchOptimizer(lbfgs, epoch_switch=3) + trainer = Trainer( + solver=solver, + callbacks=[switch_opt_callback], + accelerator="cpu", + max_epochs=5, + ) + trainer.train() + assert solver.optimizer.instance.__class__ == torch.optim.LBFGS diff --git a/tests/test_callback/test_progress_bar.py b/tests/test_callback/test_progress_bar.py index cba623780..d77408c42 100644 --- a/tests/test_callback/test_progress_bar.py +++ b/tests/test_callback/test_progress_bar.py @@ -5,29 +5,32 @@ from pina.problem.zoo import Poisson2DSquareProblem as Poisson -# # make the problem -# poisson_problem = Poisson() -# boundaries = ['nil_g1', 'nil_g2', 'nil_g3', 'nil_g4'] -# n = 10 -# poisson_problem.discretise_domain(n, 'grid', locations=boundaries) -# poisson_problem.discretise_domain(n, 'grid', locations='laplace_D') -# model = FeedForward(len(poisson_problem.input_variables), -# len(poisson_problem.output_variables)) +# make the problem +poisson_problem = Poisson() +boundaries = ["g1", "g2", "g3", "g4"] +n = 10 +condition_names = list(poisson_problem.conditions.keys()) +poisson_problem.discretise_domain(n, "grid", domains=boundaries) +poisson_problem.discretise_domain(n, "grid", domains="D") +model = FeedForward( + len(poisson_problem.input_variables), len(poisson_problem.output_variables) +) -# # make the solver -# solver = PINN(problem=poisson_problem, model=model) +# make the solver +solver = PINN(problem=poisson_problem, model=model) -# def test_progress_bar_constructor(): -# PINAProgressBar(['mean']) +def test_progress_bar_constructor(): + PINAProgressBar() -# def test_progress_bar_routine(): -# # make the trainer -# trainer = Trainer(solver=solver, -# callback=[ -# PINAProgressBar(['mean', 'laplace_D']) -# ], -# accelerator='cpu', -# max_epochs=5) -# trainer.train() -# # TODO there should be a check that the correct metrics are displayed + +def test_progress_bar_routine(): + # make the trainer + trainer = Trainer( + solver=solver, + callbacks=[PINAProgressBar(["val", condition_names[0]])], + accelerator="cpu", + max_epochs=5, + ) + trainer.train() + # TODO there should be a check that the correct metrics are displayed