Skip to content

Commit 12b7787

Browse files
dario-cosciaFilippoOlivo
authored andcommitted
Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
1 parent 4295738 commit 12b7787

File tree

77 files changed

+1167
-921
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+1167
-921
lines changed

pina/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
__all__ = [
22
"Trainer",
3-
"LabelTensor",
3+
"LabelTensor",
44
"Condition",
55
"PinaDataModule",
6-
'Graph',
6+
"Graph",
77
"SolverInterface",
8-
"MultiSolverInterface"
8+
"MultiSolverInterface",
99
]
1010

1111
from .label_tensor import LabelTensor
1212
from .graph import Graph
1313
from .solver import SolverInterface, MultiSolverInterface
1414
from .trainer import Trainer
1515
from .condition.condition import Condition
16-
from .data import PinaDataModule
16+
from .data import PinaDataModule

pina/adaptive_function/adaptive_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" Module for adaptive functions. """
1+
"""Module for adaptive functions."""
22

33
import torch
44
from ..utils import check_consistency

pina/adaptive_function/adaptive_function_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" Module for adaptive functions. """
1+
"""Module for adaptive functions."""
22

33
import torch
44

pina/adaptive_functions/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
warnings.formatwarning = custom_warning_format
99
warnings.filterwarnings("always", category=DeprecationWarning)
1010
warnings.warn(
11-
f"'pina.adaptive_functions' is deprecated and will be removed "
12-
f"in future versions. Please use 'pina.adaptive_function' instead.",
13-
DeprecationWarning)
11+
f"'pina.adaptive_functions' is deprecated and will be removed "
12+
f"in future versions. Please use 'pina.adaptive_function' instead.",
13+
DeprecationWarning,
14+
)

pina/callback/adaptive_refinement_callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def _compute_residual(self, trainer):
6767
# compute residual
6868
res_loss = {}
6969
tot_loss = []
70-
for location in self._sampling_locations: #TODO fix for new collector
70+
for location in self._sampling_locations: # TODO fix for new collector
7171
condition = solver.problem.conditions[location]
7272
pts = solver.problem.input_pts[location]
7373
# send points to correct device

pina/callback/processing_callback.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, metrics_to_track=None):
2626
super().__init__()
2727
self._collection = []
2828
# Default to tracking 'train_loss' and 'val_loss' if not specified
29-
self.metrics_to_track = metrics_to_track or ['train_loss', 'val_loss']
29+
self.metrics_to_track = metrics_to_track or ["train_loss", "val_loss"]
3030

3131
def on_train_epoch_end(self, trainer, pl_module):
3232
"""
@@ -40,7 +40,8 @@ def on_train_epoch_end(self, trainer, pl_module):
4040
if trainer.current_epoch > 0:
4141
# Append only the tracked metrics to avoid unnecessary data
4242
tracked_metrics = {
43-
k: v for k, v in trainer.logged_metrics.items()
43+
k: v
44+
for k, v in trainer.logged_metrics.items()
4445
if k in self.metrics_to_track
4546
}
4647
self._collection.append(copy.deepcopy(tracked_metrics))
@@ -57,16 +58,18 @@ def metrics(self):
5758
return {}
5859

5960
# Get intersection of keys across all collected dictionaries
60-
common_keys = set(self._collection[0]).intersection(*self._collection[1:])
61-
61+
common_keys = set(self._collection[0]).intersection(
62+
*self._collection[1:]
63+
)
64+
6265
# Stack the metric values for common keys and return
6366
return {
6467
k: torch.stack([dic[k] for dic in self._collection])
65-
for k in common_keys if k in self.metrics_to_track
68+
for k in common_keys
69+
if k in self.metrics_to_track
6670
}
6771

6872

69-
7073
class PINAProgressBar(TQDMProgressBar):
7174

7275
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
@@ -142,7 +145,8 @@ def on_fit_start(self, trainer, pl_module):
142145
for key in self._sorted_metrics:
143146
if (
144147
key not in trainer.solver.problem.conditions.keys()
145-
and key != "train" and key != "val"
148+
and key != "train"
149+
and key != "val"
146150
):
147151
raise KeyError(f"Key '{key}' is not present in the dictionary")
148152
# add the loss pedix

pina/callbacks/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
warnings.formatwarning = custom_warning_format
99
warnings.filterwarnings("always", category=DeprecationWarning)
1010
warnings.warn(
11-
f"'pina.callbacks' is deprecated and will be removed "
12-
f"in future versions. Please use 'pina.callback' instead.",
13-
DeprecationWarning)
11+
f"'pina.callbacks' is deprecated and will be removed "
12+
f"in future versions. Please use 'pina.callback' instead.",
13+
DeprecationWarning,
14+
)

pina/collector.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
# TODO
33
"""
4+
45
from .graph import Graph
56
from .utils import check_consistency
67

@@ -14,14 +15,12 @@ def __init__(self, problem):
1415
# those variables are used for the dataloading
1516
self._data_collections = {name: {} for name in self.problem.conditions}
1617
self.conditions_name = {
17-
i: name
18-
for i, name in enumerate(self.problem.conditions)
18+
i: name for i, name in enumerate(self.problem.conditions)
1919
}
2020

2121
# variables used to check that all conditions are sampled
2222
self._is_conditions_ready = {
23-
name: False
24-
for name in self.problem.conditions
23+
name: False for name in self.problem.conditions
2524
}
2625
self.full = False
2726

@@ -51,13 +50,16 @@ def store_fixed_data(self):
5150
for condition_name, condition in self.problem.conditions.items():
5251
# if the condition is not ready and domain is not attribute
5352
# of condition, we get and store the data
54-
if (not self._is_conditions_ready[condition_name]) and (not hasattr(
55-
condition, "domain")):
53+
if (not self._is_conditions_ready[condition_name]) and (
54+
not hasattr(condition, "domain")
55+
):
5656
# get data
5757
keys = condition.__slots__
5858
values = [getattr(condition, name) for name in keys]
59-
values = [value.data if isinstance(
60-
value, Graph) else value for value in values]
59+
values = [
60+
value.data if isinstance(value, Graph) else value
61+
for value in values
62+
]
6163
self.data_collections[condition_name] = dict(zip(keys, values))
6264
# condition now is ready
6365
self._is_conditions_ready[condition_name] = True
@@ -74,6 +76,6 @@ def store_sample_domains(self):
7476
samples = self.problem.discretised_domains[condition.domain]
7577

7678
self.data_collections[condition_name] = {
77-
'input_points': samples,
78-
'equation': condition.equation
79+
"input_points": samples,
80+
"equation": condition.equation,
7981
}

pina/condition/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
__all__ = [
2-
'Condition',
3-
'ConditionInterface',
4-
'DomainEquationCondition',
5-
'InputPointsEquationCondition',
6-
'InputOutputPointsCondition',
2+
"Condition",
3+
"ConditionInterface",
4+
"DomainEquationCondition",
5+
"InputPointsEquationCondition",
6+
"InputOutputPointsCondition",
77
]
88

99
from .condition_interface import ConditionInterface
1010
from .domain_equation_condition import DomainEquationCondition
1111
from .input_equation_condition import InputPointsEquationCondition
12-
from .input_output_condition import InputOutputPointsCondition
12+
from .input_output_condition import InputOutputPointsCondition

pina/condition/condition.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" Condition module. """
1+
"""Condition module."""
22

33
from .domain_equation_condition import DomainEquationCondition
44
from .input_equation_condition import InputPointsEquationCondition
@@ -11,6 +11,7 @@
1111
warnings.formatwarning = custom_warning_format
1212
warnings.filterwarnings("always", category=DeprecationWarning)
1313

14+
1415
class Condition:
1516
"""
1617
The class ``Condition`` is used to represent the constraints (physical
@@ -44,24 +45,30 @@ class Condition:
4445
"""
4546

4647
__slots__ = list(
47-
set(InputOutputPointsCondition.__slots__ +
48-
InputPointsEquationCondition.__slots__ +
49-
DomainEquationCondition.__slots__ +
50-
DataConditionInterface.__slots__))
48+
set(
49+
InputOutputPointsCondition.__slots__
50+
+ InputPointsEquationCondition.__slots__
51+
+ DomainEquationCondition.__slots__
52+
+ DataConditionInterface.__slots__
53+
)
54+
)
5155

5256
def __new__(cls, *args, **kwargs):
5357

5458
if len(args) != 0:
55-
raise ValueError("Condition takes only the following keyword "
56-
f"arguments: {Condition.__slots__}.")
59+
raise ValueError(
60+
"Condition takes only the following keyword "
61+
f"arguments: {Condition.__slots__}."
62+
)
5763

5864
# back-compatibility 0.1
59-
if 'location' in kwargs.keys():
60-
kwargs['domain'] = kwargs.pop('location')
65+
if "location" in kwargs.keys():
66+
kwargs["domain"] = kwargs.pop("location")
6167
warnings.warn(
62-
f"'location' is deprecated and will be removed "
63-
f"in future versions. Please use 'domain' instead.",
64-
DeprecationWarning)
68+
f"'location' is deprecated and will be removed "
69+
f"in future versions. Please use 'domain' instead.",
70+
DeprecationWarning,
71+
)
6572

6673
sorted_keys = sorted(kwargs.keys())
6774

0 commit comments

Comments
 (0)