Skip to content

Commit 78e6482

Browse files
committed
Implement graph conditions and enhance Graph class
1 parent 1bbba45 commit 78e6482

17 files changed

+951
-416
lines changed

pina/collector.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,6 @@ def store_fixed_data(self):
5656
# get data
5757
keys = condition.__slots__
5858
values = [getattr(condition, name) for name in keys]
59-
values = [
60-
value.data if isinstance(value, Graph) else value
61-
for value in values
62-
]
6359
self.data_collections[condition_name] = dict(zip(keys, values))
6460
# condition now is ready
6561
self._is_conditions_ready[condition_name] = True

pina/condition/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@
44
"DomainEquationCondition",
55
"InputPointsEquationCondition",
66
"InputOutputPointsCondition",
7+
"GraphInputOutputCondition",
8+
"GraphDataCondition",
9+
"GraphInputEquationCondition",
710
]
811

912
from .condition_interface import ConditionInterface
1013
from .domain_equation_condition import DomainEquationCondition
1114
from .input_equation_condition import InputPointsEquationCondition
1215
from .input_output_condition import InputOutputPointsCondition
16+
from .graph_condition import GraphInputOutputCondition
17+
from .graph_condition import GraphDataCondition
18+
from .graph_condition import GraphInputEquationCondition

pina/condition/condition.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
from .input_equation_condition import InputPointsEquationCondition
55
from .input_output_condition import InputOutputPointsCondition
66
from .data_condition import DataConditionInterface
7+
from .graph_condition import (
8+
GraphInputOutputCondition,
9+
GraphInputEquationCondition,
10+
)
711
import warnings
812
from ..utils import custom_warning_format
913

@@ -82,5 +86,9 @@ def __new__(cls, *args, **kwargs):
8286
return DataConditionInterface(**kwargs)
8387
elif sorted_keys == DataConditionInterface.__slots__[0]:
8488
return DataConditionInterface(**kwargs)
89+
elif sorted_keys == sorted(GraphInputOutputCondition.__slots__):
90+
return GraphInputOutputCondition(**kwargs)
91+
elif sorted_keys == sorted(GraphInputEquationCondition.__slots__):
92+
return GraphInputEquationCondition(**kwargs)
8593
else:
8694
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")

pina/condition/condition_interface.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@
33

44
class ConditionInterface(metaclass=ABCMeta):
55

6-
condition_types = ["physics", "supervised", "unsupervised"]
7-
86
def __init__(self, *args, **kwargs):
9-
self._condition_type = None
107
self._problem = None
118

129
@property
@@ -16,19 +13,3 @@ def problem(self):
1613
@problem.setter
1714
def problem(self, value):
1815
self._problem = value
19-
20-
@property
21-
def condition_type(self):
22-
return self._condition_type
23-
24-
@condition_type.setter
25-
def condition_type(self, values):
26-
if not isinstance(values, (list, tuple)):
27-
values = [values]
28-
for value in values:
29-
if value not in ConditionInterface.condition_types:
30-
raise ValueError(
31-
"Unavailable type of condition, expected one of"
32-
f" {ConditionInterface.condition_types}."
33-
)
34-
self._condition_type = values

pina/condition/graph_condition.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from .condition_interface import ConditionInterface
2+
from ..graph import Graph
3+
from ..utils import check_consistency
4+
from torch_geometric.data import Data
5+
from ..equation.equation_interface import EquationInterface
6+
7+
8+
class GraphCondition(ConditionInterface):
9+
"""
10+
TODO
11+
"""
12+
13+
__slots__ = ["graph"]
14+
15+
def __new__(cls, graph):
16+
"""
17+
TODO : add docstring
18+
"""
19+
check_consistency(graph, (Graph, Data))
20+
graph = [graph] if isinstance(graph, Data) else graph
21+
22+
if all(g.y is not None for g in graph):
23+
return super().__new__(GraphInputOutputCondition)
24+
else:
25+
return super().__new__(GraphDataCondition)
26+
27+
def __init__(self, graph):
28+
29+
super().__init__()
30+
self.graph = graph
31+
32+
def __setattr__(self, key, value):
33+
if key == "graph":
34+
check_consistency(value, (Graph, Data))
35+
GraphCondition.__dict__[key].__set__(self, value)
36+
elif key in ("_problem", "_condition_type"):
37+
super().__setattr__(key, value)
38+
39+
40+
class GraphInputEquationCondition(ConditionInterface):
41+
42+
__slots__ = ["graph", "equation"]
43+
44+
def __init__(self, graph, equation):
45+
super().__init__()
46+
self.graph = graph
47+
self.equation = equation
48+
49+
def __setattr__(self, key, value):
50+
if key == "graph":
51+
check_consistency(value, (Graph, Data))
52+
GraphInputEquationCondition.__dict__[key].__set__(self, value)
53+
elif key == "equation":
54+
check_consistency(value, (EquationInterface))
55+
GraphInputEquationCondition.__dict__[key].__set__(self, value)
56+
elif key in ("_problem", "_condition_type"):
57+
super().__setattr__(key, value)
58+
59+
60+
# The split between GraphInputOutputCondition and GraphDataCondition
61+
# distinguishes different types of graph conditions passed to problems.
62+
# This separation simplifies consistency checks during problem creation.
63+
class GraphDataCondition(GraphCondition):
64+
pass
65+
66+
67+
class GraphInputOutputCondition(GraphCondition):
68+
pass

0 commit comments

Comments
 (0)