Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ logging-modules=logging


[FORMAT]

# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=

Expand Down Expand Up @@ -250,6 +249,8 @@ single-line-if-stmt=no


[BASIC]
# Allow redefinition of input builtins
allowed-redefined-builtins=input

# Naming hint for argument names
argument-name-hint=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$
Expand Down Expand Up @@ -401,7 +402,7 @@ max-returns=6
max-statements=50

# Minimum number of public methods for a class (see R0903).
min-public-methods=2
min-public-methods=0


[CLASSES]
Expand Down
56 changes: 12 additions & 44 deletions code_formatter.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,19 @@

#######################################

required_command="yapf unexpand"
code_directories="pina tests"
required_command="black"
code_directories=("pina" "tests")

#######################################

usage() {
echo
echo -e "\tUsage: $0 [files]"
echo
echo -e "\tIf not files are specified, script formats all ".py" files"
echo -e "\tin code directories ($code_directories); otherwise, formats"
echo -e "\tall given files"
echo
echo -e "\tRequired command: $required_command"
echo
exit 0
}


[[ $1 == "-h" ]] && usage

# Test for required program
for comm in $required_command; do
command -v $comm >/dev/null 2>&1 || {
echo "I require $comm but it's not installed. Aborting." >&2;
exit 1
}
done

# Find all python files in code directories
python_files=""
for dir in $code_directories; do
python_files="$python_files $(find $dir -name '*.py')"
done
[[ $# != 0 ]] && python_files=$@


# Here the important part: yapf format the files.
for file in $python_files; do
echo "Making beatiful $file..."
[[ ! -f $file ]] && echo "$file does not exist; $0 -h for more info" && exit

yapf --style='{
based_on_style: pep8,
indent_width: 4,
column_limit: 80
}' -i $file
done
if ! command -v $required_command >/dev/null 2>&1; then
echo "I require $required_command but it's not installed. Install dev dependencies."
echo "Aborting." >&2
exit 1
fi

# Run black formatter
for dir in "${code_directories[@]}"; do
python -m black --line-length 80 "$dir"
done
2 changes: 1 addition & 1 deletion pina/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,6 @@ def store_sample_domains(self):
samples = self.problem.discretised_domains[condition.domain]

self.data_collections[condition_name] = {
"input_points": samples,
"input": samples,
"equation": condition.equation,
}
37 changes: 33 additions & 4 deletions pina/condition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,41 @@
"""
Module for conditions.
"""

__all__ = [
"Condition",
"ConditionInterface",
"DomainEquationCondition",
"InputPointsEquationCondition",
"InputOutputPointsCondition",
"InputTargetCondition",
"TensorInputTensorTargetCondition",
"TensorInputGraphTargetCondition",
"GraphInputTensorTargetCondition",
"GraphInputGraphTargetCondition",
"InputEquationCondition",
"InputTensorEquationCondition",
"InputGraphEquationCondition",
"DataCondition",
"GraphDataCondition",
"TensorDataCondition",
]

from .condition_interface import ConditionInterface
from .condition import Condition
from .domain_equation_condition import DomainEquationCondition
from .input_equation_condition import InputPointsEquationCondition
from .input_output_condition import InputOutputPointsCondition
from .input_target_condition import (
InputTargetCondition,
TensorInputTensorTargetCondition,
TensorInputGraphTargetCondition,
GraphInputTensorTargetCondition,
GraphInputGraphTargetCondition,
)
from .input_equation_condition import (
InputEquationCondition,
InputTensorEquationCondition,
InputGraphEquationCondition,
)
from .data_condition import (
DataCondition,
GraphDataCondition,
TensorDataCondition,
)
88 changes: 62 additions & 26 deletions pina/condition/condition.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
"""Condition module."""

from .domain_equation_condition import DomainEquationCondition
from .input_equation_condition import InputPointsEquationCondition
from .input_output_condition import InputOutputPointsCondition
from .data_condition import DataConditionInterface
import warnings
from .data_condition import DataCondition
from .domain_equation_condition import DomainEquationCondition
from .input_equation_condition import InputEquationCondition
from .input_target_condition import InputTargetCondition
from ..utils import custom_warning_format

# Set the custom format for warnings
warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=DeprecationWarning)


def warning_function(new, old):
"""Handle the deprecation warning.

:param new: Object to use instead of the old one.
:type new: str
:param old: Object to deprecate.
:type old: str
"""
warnings.warn(
f"'{old}' is deprecated and will be removed "
f"in future versions. Please use '{new}' instead.",
DeprecationWarning,
)


class Condition:
"""
The class ``Condition`` is used to represent the constraints (physical
Expand Down Expand Up @@ -40,16 +55,32 @@ class Condition:

Example::

>>> TODO
>>> from pina import Condition
>>> condition = Condition(
... input=input,
... target=target
... )
>>> condition = Condition(
... domain=location,
... equation=equation
... )
>>> condition = Condition(
... input=input,
... equation=equation
... )
>>> condition = Condition(
... input=data,
... conditional_variables=conditional_variables
... )

"""

__slots__ = list(
set(
InputOutputPointsCondition.__slots__
+ InputPointsEquationCondition.__slots__
InputTargetCondition.__slots__
+ InputEquationCondition.__slots__
+ DomainEquationCondition.__slots__
+ DataConditionInterface.__slots__
+ DataCondition.__slots__
)
)

Expand All @@ -62,25 +93,30 @@ def __new__(cls, *args, **kwargs):
)

# back-compatibility 0.1
if "location" in kwargs.keys():
keys = list(kwargs.keys())
if "location" in keys:
kwargs["domain"] = kwargs.pop("location")
warnings.warn(
f"'location' is deprecated and will be removed "
f"in future versions. Please use 'domain' instead.",
DeprecationWarning,
)
warning_function(new="domain", old="location")

sorted_keys = sorted(kwargs.keys())
if "input_points" in keys:
kwargs["input"] = kwargs.pop("input_points")
warning_function(new="input", old="input_points")

if sorted_keys == sorted(InputOutputPointsCondition.__slots__):
return InputOutputPointsCondition(**kwargs)
elif sorted_keys == sorted(InputPointsEquationCondition.__slots__):
return InputPointsEquationCondition(**kwargs)
elif sorted_keys == sorted(DomainEquationCondition.__slots__):
if "output_points" in keys:
kwargs["target"] = kwargs.pop("output_points")
warning_function(new="target", old="output_points")

sorted_keys = sorted(kwargs.keys())
if sorted_keys == sorted(InputTargetCondition.__slots__):
return InputTargetCondition(**kwargs)
if sorted_keys == sorted(InputEquationCondition.__slots__):
return InputEquationCondition(**kwargs)
if sorted_keys == sorted(DomainEquationCondition.__slots__):
return DomainEquationCondition(**kwargs)
elif sorted_keys == sorted(DataConditionInterface.__slots__):
return DataConditionInterface(**kwargs)
elif sorted_keys == DataConditionInterface.__slots__[0]:
return DataConditionInterface(**kwargs)
else:
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
if (
sorted_keys == sorted(DataCondition.__slots__)
or sorted_keys[0] == DataCondition.__slots__[0]
):
return DataCondition(**kwargs)

raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
84 changes: 67 additions & 17 deletions pina/condition/condition_interface.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,84 @@
"""
Module that defines the ConditionInterface class.
"""

from abc import ABCMeta
from torch_geometric.data import Data
from ..label_tensor import LabelTensor
from ..graph import Graph


class ConditionInterface(metaclass=ABCMeta):
"""
Abstract class which defines a common interface for all the conditions.
"""

condition_types = ["physics", "supervised", "unsupervised"]

def __init__(self, *args, **kwargs):
self._condition_type = None
def __init__(self):
self._problem = None

@property
def problem(self):
"""
Return the problem to which the condition is associated.

:return: Problem to which the condition is associated
:rtype: pina.problem.AbstractProblem
"""
return self._problem

@problem.setter
def problem(self, value):
self._problem = value

@property
def condition_type(self):
return self._condition_type

@condition_type.setter
def condition_type(self, values):
if not isinstance(values, (list, tuple)):
values = [values]
for value in values:
if value not in ConditionInterface.condition_types:
@staticmethod
def _check_graph_list_consistency(data_list):

# If the data is a Graph or Data object, return (do not need to check
# anything)
if isinstance(data_list, (Graph, Data)):
return

# check all elements in the list are of the same type
if not all(isinstance(i, (Graph, Data)) for i in data_list):
raise ValueError(
"Invalid input types. "
"Please provide either Data or Graph objects."
)
data = data_list[0]
# Store the keys of the first element in the list
keys = sorted(list(data.keys()))

# Store the type of each tensor inside first element Data/Graph object
data_types = {name: tensor.__class__ for name, tensor in data.items()}

# Store the labels of each LabelTensor inside first element Data/Graph
# object
labels = {
name: tensor.labels
for name, tensor in data.items()
if isinstance(tensor, LabelTensor)
}

# Iterate over the list of Data/Graph objects
for data in data_list[1:]:
# Check if the keys of the current element are the same as the first
# element
if sorted(list(data.keys())) != keys:
raise ValueError(
"Unavailable type of condition, expected one of"
f" {ConditionInterface.condition_types}."
"All elements in the list must have the same keys."
)
self._condition_type = values
for name, tensor in data.items():
# Check if the type of each tensor inside the current element
# is the same as the first element
if tensor.__class__ is not data_types[name]:
raise ValueError(
f"Data {name} must be a {data_types[name]}, got "
f"{tensor.__class__}"
)
# If the tensor is a LabelTensor, check if the labels are the
# same as the first element
if isinstance(tensor, LabelTensor):
if tensor.labels != labels[name]:
raise ValueError(
"LabelTensor must have the same labels"
)
Loading
Loading