Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .github/workflows/black-formatter-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Black Formatter (PR Check)

on:
pull_request:
branches:
- "**" # Run on pull requests for all branches

jobs:
linter:
name: runner / black
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

- name: Run Black formatter (check mode)
uses: psf/black@stable
with:
options: "--check -l 80"
src: "./pina"
2 changes: 1 addition & 1 deletion .github/workflows/black-formatter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ jobs:
There appear to be some python formatting errors in ${{ github.sha }}. This pull request
uses the [psf/black](https://github.com/psf/black) formatter to fix these issues.
base: ${{ github.head_ref }} # Creates pull request onto pull request or commit branch
branch: actions/black
branch: actions/black
20 changes: 9 additions & 11 deletions pina/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
__all__ = [
"Trainer", "LabelTensor", "Plotter", "Condition",
"PinaDataModule", 'TorchOptimizer', 'Graph',
"RadiusGraph", "KNNGraph"
"Trainer",
"LabelTensor",
"Condition",
"PinaDataModule",
"Graph",
"SolverInterface",
"MultiSolverInterface",
]

from .meta import *
from .label_tensor import LabelTensor
from .solvers.solver import SolverInterface
from .graph import Graph
from .solver import SolverInterface, MultiSolverInterface
from .trainer import Trainer
from .plotter import Plotter
from .condition.condition import Condition

from .data import PinaDataModule

from .optim import TorchOptimizer
from .optim import TorchScheduler
from .graph import Graph, RadiusGraph, KNNGraph
31 changes: 31 additions & 0 deletions pina/adaptive_function/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
__all__ = [
"AdaptiveActivationFunctionInterface",
"AdaptiveReLU",
"AdaptiveSigmoid",
"AdaptiveTanh",
"AdaptiveSiLU",
"AdaptiveMish",
"AdaptiveELU",
"AdaptiveCELU",
"AdaptiveGELU",
"AdaptiveSoftmin",
"AdaptiveSoftmax",
"AdaptiveSIREN",
"AdaptiveExp",
]

from .adaptive_function import (
AdaptiveReLU,
AdaptiveSigmoid,
AdaptiveTanh,
AdaptiveSiLU,
AdaptiveMish,
AdaptiveELU,
AdaptiveCELU,
AdaptiveGELU,
AdaptiveSoftmin,
AdaptiveSoftmax,
AdaptiveSIREN,
AdaptiveExp,
)
from .adaptive_function_interface import AdaptiveActivationFunctionInterface
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
""" Module for adaptive functions. """
"""Module for adaptive functions."""

import torch
from ..utils import check_consistency
from .adaptive_func_interface import AdaptiveActivationFunctionInterface
from .adaptive_function_interface import AdaptiveActivationFunctionInterface


class AdaptiveReLU(AdaptiveActivationFunctionInterface):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" Module for adaptive functions. """
"""Module for adaptive functions."""

import torch

Expand All @@ -9,7 +9,7 @@
class AdaptiveActivationFunctionInterface(torch.nn.Module, metaclass=ABCMeta):
r"""
The
:class:`~pina.adaptive_functions.adaptive_func_interface.AdaptiveActivationFunctionInterface`
:class:`~pina.adaptive_function.adaptive_func_interface.AdaptiveActivationFunctionInterface`
class makes a :class:`torch.nn.Module` activation function into an adaptive
trainable activation function. If one wants to create an adpative activation
function, this class must be use as base class.
Expand Down
41 changes: 12 additions & 29 deletions pina/adaptive_functions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,14 @@
__all__ = [
"AdaptiveActivationFunctionInterface",
"AdaptiveReLU",
"AdaptiveSigmoid",
"AdaptiveTanh",
"AdaptiveSiLU",
"AdaptiveMish",
"AdaptiveELU",
"AdaptiveCELU",
"AdaptiveGELU",
"AdaptiveSoftmin",
"AdaptiveSoftmax",
"AdaptiveSIREN",
"AdaptiveExp",
]
import warnings

from .adaptive_func import (
AdaptiveReLU,
AdaptiveSigmoid,
AdaptiveTanh,
AdaptiveSiLU,
AdaptiveMish,
AdaptiveELU,
AdaptiveCELU,
AdaptiveGELU,
AdaptiveSoftmin,
AdaptiveSoftmax,
AdaptiveSIREN,
AdaptiveExp,
from ..adaptive_function import *
from ..utils import custom_warning_format

# back-compatibility 0.1
# Set the custom format for warnings
warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=DeprecationWarning)
warnings.warn(
f"'pina.adaptive_functions' is deprecated and will be removed "
f"in future versions. Please use 'pina.adaptive_function' instead.",
DeprecationWarning,
)
from .adaptive_func_interface import AdaptiveActivationFunctionInterface
10 changes: 10 additions & 0 deletions pina/callback/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
__all__ = [
"SwitchOptimizer",
"R3Refinement",
"MetricTracker",
"PINAProgressBar",
]

from .optimizer_callback import SwitchOptimizer
from .adaptive_refinement_callback import R3Refinement
from .processing_callback import MetricTracker, PINAProgressBar
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _compute_residual(self, trainer):
# compute residual
res_loss = {}
tot_loss = []
for location in self._sampling_locations: #TODO fix for new collector
for location in self._sampling_locations: # TODO fix for new collector
condition = solver.problem.conditions[location]
pts = solver.problem.input_pts[location]
# send points to correct device
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, new_optimizers, epoch_switch):

:param new_optimizers: The model optimizers to switch to. Can be a
single :class:`torch.optim.Optimizer` or a list of them for multiple
model solvers.
model solver.
:type new_optimizers: pina.optim.TorchOptimizer | list
:param epoch_switch: The epoch at which to switch to the new optimizer.
:type epoch_switch: int
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, metrics_to_track=None):
super().__init__()
self._collection = []
# Default to tracking 'train_loss' and 'val_loss' if not specified
self.metrics_to_track = metrics_to_track or ['train_loss', 'val_loss']
self.metrics_to_track = metrics_to_track or ["train_loss", "val_loss"]

def on_train_epoch_end(self, trainer, pl_module):
"""
Expand All @@ -40,7 +40,8 @@ def on_train_epoch_end(self, trainer, pl_module):
if trainer.current_epoch > 0:
# Append only the tracked metrics to avoid unnecessary data
tracked_metrics = {
k: v for k, v in trainer.logged_metrics.items()
k: v
for k, v in trainer.logged_metrics.items()
if k in self.metrics_to_track
}
self._collection.append(copy.deepcopy(tracked_metrics))
Expand All @@ -57,16 +58,18 @@ def metrics(self):
return {}

# Get intersection of keys across all collected dictionaries
common_keys = set(self._collection[0]).intersection(*self._collection[1:])

common_keys = set(self._collection[0]).intersection(
*self._collection[1:]
)

# Stack the metric values for common keys and return
return {
k: torch.stack([dic[k] for dic in self._collection])
for k in common_keys if k in self.metrics_to_track
for k in common_keys
if k in self.metrics_to_track
}



class PINAProgressBar(TQDMProgressBar):

BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
Expand All @@ -87,7 +90,7 @@ def __init__(self, metrics="val", **kwargs):
:Keyword Arguments:
The additional keyword arguments specify the progress bar
and can be choosen from the `pytorch-lightning
TQDMProgressBar API <https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/callbacks/progress/tqdm_progress.html#TQDMProgressBar>`_
TQDMProgressBar API <https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/callback/progress/tqdm_progress.html#TQDMProgressBar>`_

Example:
>>> pbar = PINAProgressBar(['mean'])
Expand Down Expand Up @@ -142,7 +145,8 @@ def on_fit_start(self, trainer, pl_module):
for key in self._sorted_metrics:
if (
key not in trainer.solver.problem.conditions.keys()
and key != "train" and key != "val"
and key != "train"
and key != "val"
):
raise KeyError(f"Key '{key}' is not present in the dictionary")
# add the loss pedix
Expand Down
22 changes: 13 additions & 9 deletions pina/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
__all__ = [
"SwitchOptimizer",
"R3Refinement",
"MetricTracker",
"PINAProgressBar",
]
import warnings

from .optimizer_callbacks import SwitchOptimizer
from .adaptive_refinment_callbacks import R3Refinement
from .processing_callbacks import MetricTracker, PINAProgressBar
from ..callback import *
from ..utils import custom_warning_format

# back-compatibility 0.1
# Set the custom format for warnings
warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=DeprecationWarning)
warnings.warn(
f"'pina.callbacks' is deprecated and will be removed "
f"in future versions. Please use 'pina.callback' instead.",
DeprecationWarning,
)
22 changes: 12 additions & 10 deletions pina/collector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
# TODO
"""

from .graph import Graph
from .utils import check_consistency

Expand All @@ -14,14 +15,12 @@ def __init__(self, problem):
# those variables are used for the dataloading
self._data_collections = {name: {} for name in self.problem.conditions}
self.conditions_name = {
i: name
for i, name in enumerate(self.problem.conditions)
i: name for i, name in enumerate(self.problem.conditions)
}

# variables used to check that all conditions are sampled
self._is_conditions_ready = {
name: False
for name in self.problem.conditions
name: False for name in self.problem.conditions
}
self.full = False

Expand Down Expand Up @@ -51,13 +50,16 @@ def store_fixed_data(self):
for condition_name, condition in self.problem.conditions.items():
# if the condition is not ready and domain is not attribute
# of condition, we get and store the data
if (not self._is_conditions_ready[condition_name]) and (not hasattr(
condition, "domain")):
if (not self._is_conditions_ready[condition_name]) and (
not hasattr(condition, "domain")
):
# get data
keys = condition.__slots__
values = [getattr(condition, name) for name in keys]
values = [value.data if isinstance(
value, Graph) else value for value in values]
values = [
value.data if isinstance(value, Graph) else value
for value in values
]
self.data_collections[condition_name] = dict(zip(keys, values))
# condition now is ready
self._is_conditions_ready[condition_name] = True
Expand All @@ -74,6 +76,6 @@ def store_sample_domains(self):
samples = self.problem.discretised_domains[condition.domain]

self.data_collections[condition_name] = {
'input_points': samples,
'equation': condition.equation
"input_points": samples,
"equation": condition.equation,
}
12 changes: 6 additions & 6 deletions pina/condition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
__all__ = [
'Condition',
'ConditionInterface',
'DomainEquationCondition',
'InputPointsEquationCondition',
'InputOutputPointsCondition',
"Condition",
"ConditionInterface",
"DomainEquationCondition",
"InputPointsEquationCondition",
"InputOutputPointsCondition",
]

from .condition_interface import ConditionInterface
from .domain_equation_condition import DomainEquationCondition
from .input_equation_condition import InputPointsEquationCondition
from .input_output_condition import InputOutputPointsCondition
from .input_output_condition import InputOutputPointsCondition
Loading
Loading