Skip to content

Commit 8696366

Browse files
committed
Implement graph conditions
1 parent 8d1f387 commit 8696366

File tree

5 files changed

+169
-29
lines changed

5 files changed

+169
-29
lines changed

pina/condition/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@
44
"DomainEquationCondition",
55
"InputPointsEquationCondition",
66
"InputOutputPointsCondition",
7+
"GraphInputOutputCondition",
8+
"GraphDataCondition",
9+
"GraphInputEquationCondition",
710
]
811

912
from .condition_interface import ConditionInterface
1013
from .domain_equation_condition import DomainEquationCondition
1114
from .input_equation_condition import InputPointsEquationCondition
1215
from .input_output_condition import InputOutputPointsCondition
16+
from .graph_condition import GraphInputOutputCondition
17+
from .graph_condition import GraphDataCondition
18+
from .graph_condition import GraphInputEquationCondition

pina/condition/condition.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
from .input_equation_condition import InputPointsEquationCondition
55
from .input_output_condition import InputOutputPointsCondition
66
from .data_condition import DataConditionInterface
7+
from .graph_condition import (
8+
GraphInputOutputCondition,
9+
GraphInputEquationCondition,
10+
)
711
import warnings
812
from ..utils import custom_warning_format
913

@@ -82,5 +86,9 @@ def __new__(cls, *args, **kwargs):
8286
return DataConditionInterface(**kwargs)
8387
elif sorted_keys == DataConditionInterface.__slots__[0]:
8488
return DataConditionInterface(**kwargs)
89+
elif sorted_keys == sorted(GraphInputOutputCondition.__slots__):
90+
return GraphInputOutputCondition(**kwargs)
91+
elif sorted_keys == sorted(GraphInputEquationCondition.__slots__):
92+
return GraphInputEquationCondition(**kwargs)
8593
else:
8694
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")

pina/condition/condition_interface.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@
33

44
class ConditionInterface(metaclass=ABCMeta):
55

6-
condition_types = ["physics", "supervised", "unsupervised"]
7-
86
def __init__(self, *args, **kwargs):
9-
self._condition_type = None
107
self._problem = None
118

129
@property
@@ -15,20 +12,4 @@ def problem(self):
1512

1613
@problem.setter
1714
def problem(self, value):
18-
self._problem = value
19-
20-
@property
21-
def condition_type(self):
22-
return self._condition_type
23-
24-
@condition_type.setter
25-
def condition_type(self, values):
26-
if not isinstance(values, (list, tuple)):
27-
values = [values]
28-
for value in values:
29-
if value not in ConditionInterface.condition_types:
30-
raise ValueError(
31-
"Unavailable type of condition, expected one of"
32-
f" {ConditionInterface.condition_types}."
33-
)
34-
self._condition_type = values
15+
self._problem = value

pina/condition/graph_condition.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from .condition_interface import ConditionInterface
2+
from ..graph import Graph
3+
from ..utils import check_consistency
4+
from torch_geometric.data import Data
5+
from ..equation.equation_interface import EquationInterface
6+
7+
8+
class GraphCondition(ConditionInterface):
9+
"""
10+
TODO
11+
"""
12+
13+
__slots__ = ["graph"]
14+
15+
def __new__(cls, graph):
16+
"""
17+
TODO : add docstring
18+
"""
19+
check_consistency(graph, (Graph, Data))
20+
graph = [graph] if isinstance(graph, Data) else graph
21+
22+
if all(g.y is not None for g in graph):
23+
return super().__new__(GraphInputOutputCondition)
24+
else:
25+
return super().__new__(GraphDataCondition)
26+
27+
def __init__(self, graph):
28+
29+
super().__init__()
30+
self.graph = graph
31+
32+
def __setattr__(self, key, value):
33+
if key == "graph":
34+
check_consistency(value, (Graph, Data))
35+
GraphCondition.__dict__[key].__set__(self, value)
36+
elif key in ("_problem", "_condition_type"):
37+
super().__setattr__(key, value)
38+
39+
40+
class GraphInputEquationCondition(ConditionInterface):
41+
42+
__slots__ = ["graph", "equation"]
43+
44+
def __init__(self, graph, equation):
45+
super().__init__()
46+
self.graph = graph
47+
self.equation = equation
48+
49+
def __setattr__(self, key, value):
50+
if key == "graph":
51+
check_consistency(value, (Graph, Data))
52+
GraphInputEquationCondition.__dict__[key].__set__(self, value)
53+
elif key == "equation":
54+
check_consistency(value, (EquationInterface))
55+
GraphInputEquationCondition.__dict__[key].__set__(self, value)
56+
elif key in ("_problem", "_condition_type"):
57+
super().__setattr__(key, value)
58+
59+
60+
# The split between GraphInputOutputCondition and GraphDataCondition
61+
# distinguishes different types of graph conditions passed to problems.
62+
# This separation simplifies consistency checks during problem creation.
63+
class GraphDataCondition(GraphCondition):
64+
pass
65+
66+
67+
class GraphInputOutputCondition(GraphCondition):
68+
pass

tests/test_condition.py

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,37 @@
33

44
from pina import LabelTensor, Condition
55
from pina.domain import CartesianDomain
6+
from pina.condition import (
7+
GraphInputOutputCondition,
8+
GraphInputEquationCondition,
9+
)
610
from pina.equation.equation_factory import FixedValue
11+
from pina.graph import RadiusGraph
12+
from torch_geometric.data import Data
13+
from pina.operator import laplacian
14+
from pina.equation.equation import Equation
715

8-
example_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
9-
example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
10-
example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b'])
16+
example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]})
17+
example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ["x", "y", "z"])
18+
example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ["a", "b"])
1119

1220

1321
def test_init_inputoutput():
1422
Condition(input_points=example_input_pts, output_points=example_output_pts)
1523
with pytest.raises(ValueError):
1624
Condition(example_input_pts, example_output_pts)
1725
with pytest.raises(ValueError):
18-
Condition(input_points=3., output_points='example')
26+
Condition(input_points=3.0, output_points="example")
1927
with pytest.raises(ValueError):
2028
Condition(input_points=example_domain, output_points=example_domain)
2129

2230

23-
test_init_inputoutput()
24-
25-
2631
def test_init_domainfunc():
2732
Condition(domain=example_domain, equation=FixedValue(0.0))
2833
with pytest.raises(ValueError):
2934
Condition(example_domain, FixedValue(0.0))
3035
with pytest.raises(ValueError):
31-
Condition(domain=3., equation='example')
36+
Condition(domain=3.0, equation="example")
3237
with pytest.raises(ValueError):
3338
Condition(domain=example_input_pts, equation=example_output_pts)
3439

@@ -38,6 +43,78 @@ def test_init_inputfunc():
3843
with pytest.raises(ValueError):
3944
Condition(example_domain, FixedValue(0.0))
4045
with pytest.raises(ValueError):
41-
Condition(input_points=3., equation='example')
46+
Condition(input_points=3.0, equation="example")
4247
with pytest.raises(ValueError):
4348
Condition(input_points=example_domain, equation=example_output_pts)
49+
50+
51+
def test_graph_io_condition():
52+
x = torch.rand(10, 10, 4)
53+
pos = torch.rand(10, 10, 2)
54+
y = torch.rand(10, 10, 2)
55+
graph = [
56+
RadiusGraph(x=x_, pos=pos_, radius=0.1, build_edge_attr=True, y=y_)
57+
for x_, pos_, y_ in zip(x, pos, y)
58+
]
59+
condition = Condition(graph=graph)
60+
assert isinstance(condition, GraphInputOutputCondition)
61+
assert isinstance(condition.graph, list)
62+
63+
x = x[0]
64+
pos = pos[0]
65+
y = y[0]
66+
edge_index = graph[0].edge_index
67+
graph = Data(x=x, pos=pos, edge_index=edge_index, y=y)
68+
condition = Condition(graph=graph)
69+
assert isinstance(condition, GraphInputOutputCondition)
70+
assert isinstance(condition.graph, Data)
71+
72+
73+
def laplace_equation(input_, output_):
74+
"""
75+
Implementation of the laplace equation.
76+
"""
77+
force_term = torch.sin(input_.extract(["x"]) * torch.pi) * torch.sin(
78+
input_.extract(["y"]) * torch.pi
79+
)
80+
delta_u = laplacian(output_.extract(["u"]), input_)
81+
return delta_u - force_term
82+
83+
84+
def test_graph_eq_condition():
85+
def laplace(input_, output_):
86+
"""
87+
Implementation of the laplace equation.
88+
"""
89+
force_term = torch.sin(input_.extract(["x"]) * torch.pi) * torch.sin(
90+
input_.extract(["y"]) * torch.pi
91+
)
92+
delta_u = laplacian(output_.extract(["u"]), input_)
93+
return delta_u - force_term
94+
95+
x = torch.rand(10, 10, 4)
96+
pos = torch.rand(10, 10, 2)
97+
graph = [
98+
RadiusGraph(
99+
x=x_,
100+
pos=pos_,
101+
radius=0.1,
102+
build_edge_attr=True,
103+
)
104+
for x_, pos_, in zip(
105+
x,
106+
pos,
107+
)
108+
]
109+
laplace_equation = Equation(laplace)
110+
condition = Condition(graph=graph, equation=laplace_equation)
111+
assert isinstance(condition, GraphInputEquationCondition)
112+
assert isinstance(condition.graph, list)
113+
114+
x = x[0]
115+
pos = pos[0]
116+
edge_index = graph[0].edge_index
117+
graph = Data(x=x, pos=pos, edge_index=edge_index)
118+
condition = Condition(graph=graph, equation=laplace_equation)
119+
assert isinstance(condition, GraphInputEquationCondition)
120+
assert isinstance(condition.graph, Data)

0 commit comments

Comments
 (0)