Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions pina/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ def store_fixed_data(self):
# get data
keys = condition.__slots__
values = [getattr(condition, name) for name in keys]
values = [
value.data if isinstance(value, Graph) else value
for value in values
]
self.data_collections[condition_name] = dict(zip(keys, values))
# condition now is ready
self._is_conditions_ready[condition_name] = True
Expand Down
6 changes: 6 additions & 0 deletions pina/condition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@
"DomainEquationCondition",
"InputPointsEquationCondition",
"InputOutputPointsCondition",
"GraphInputOutputCondition",
"GraphDataCondition",
"GraphInputEquationCondition",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this really useful? I think this is useful only if you want to differentiate wrt graph tensors (node features, or others). Is it correct?

]

from .condition_interface import ConditionInterface
from .domain_equation_condition import DomainEquationCondition
from .input_equation_condition import InputPointsEquationCondition
from .input_output_condition import InputOutputPointsCondition
from .graph_condition import GraphInputOutputCondition
from .graph_condition import GraphDataCondition
from .graph_condition import GraphInputEquationCondition
8 changes: 8 additions & 0 deletions pina/condition/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from .input_equation_condition import InputPointsEquationCondition
from .input_output_condition import InputOutputPointsCondition
from .data_condition import DataConditionInterface
from .graph_condition import (
GraphInputOutputCondition,
GraphInputEquationCondition,
)
import warnings
from ..utils import custom_warning_format

Expand Down Expand Up @@ -82,5 +86,9 @@ def __new__(cls, *args, **kwargs):
return DataConditionInterface(**kwargs)
elif sorted_keys == DataConditionInterface.__slots__[0]:
return DataConditionInterface(**kwargs)
elif sorted_keys == sorted(GraphInputOutputCondition.__slots__):
return GraphInputOutputCondition(**kwargs)
elif sorted_keys == sorted(GraphInputEquationCondition.__slots__):
return GraphInputEquationCondition(**kwargs)
else:
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
19 changes: 0 additions & 19 deletions pina/condition/condition_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@

class ConditionInterface(metaclass=ABCMeta):

condition_types = ["physics", "supervised", "unsupervised"]

def __init__(self, *args, **kwargs):
self._condition_type = None
self._problem = None

@property
Expand All @@ -16,19 +13,3 @@ def problem(self):
@problem.setter
def problem(self, value):
self._problem = value

@property
def condition_type(self):
return self._condition_type

@condition_type.setter
def condition_type(self, values):
if not isinstance(values, (list, tuple)):
values = [values]
for value in values:
if value not in ConditionInterface.condition_types:
raise ValueError(
"Unavailable type of condition, expected one of"
f" {ConditionInterface.condition_types}."
)
self._condition_type = values
68 changes: 68 additions & 0 deletions pina/condition/graph_condition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from .condition_interface import ConditionInterface
from ..graph import Graph
from ..utils import check_consistency
from torch_geometric.data import Data
from ..equation.equation_interface import EquationInterface


class GraphCondition(ConditionInterface):
"""
TODO
"""

__slots__ = ["graph"]

def __new__(cls, graph):
"""
TODO : add docstring
"""
check_consistency(graph, (Graph, Data))
graph = [graph] if isinstance(graph, Data) else graph

if all(g.y is not None for g in graph):
return super().__new__(GraphInputOutputCondition)
else:
return super().__new__(GraphDataCondition)

def __init__(self, graph):

super().__init__()
self.graph = graph

def __setattr__(self, key, value):
if key == "graph":
check_consistency(value, (Graph, Data))
GraphCondition.__dict__[key].__set__(self, value)
elif key in ("_problem", "_condition_type"):
super().__setattr__(key, value)


class GraphInputEquationCondition(ConditionInterface):

__slots__ = ["graph", "equation"]

def __init__(self, graph, equation):
super().__init__()
self.graph = graph
self.equation = equation

def __setattr__(self, key, value):
if key == "graph":
check_consistency(value, (Graph, Data))
GraphInputEquationCondition.__dict__[key].__set__(self, value)
elif key == "equation":
check_consistency(value, (EquationInterface))
GraphInputEquationCondition.__dict__[key].__set__(self, value)
elif key in ("_problem", "_condition_type"):
super().__setattr__(key, value)


# The split between GraphInputOutputCondition and GraphDataCondition
# distinguishes different types of graph conditions passed to problems.
# This separation simplifies consistency checks during problem creation.
class GraphDataCondition(GraphCondition):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the difference between graphDatacondition and GraphCondition?

pass


class GraphInputOutputCondition(GraphCondition):
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not implemented?

Loading
Loading