Skip to content

Commit fb523d7

Browse files
authored
Add Supervised Problem (#451)
* Add SuperviedProblem class in problem zoo
1 parent d9738f3 commit fb523d7

File tree

3 files changed

+73
-2
lines changed

3 files changed

+73
-2
lines changed

pina/problem/zoo/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
__all__ = [
2-
'Poisson2DSquareProblem'
2+
'Poisson2DSquareProblem',
3+
'SupervisedProblem'
4+
35
]
46

5-
from .poisson_2d_square import Poisson2DSquareProblem
7+
from .poisson_2d_square import Poisson2DSquareProblem
8+
from .supervised_problem import SupervisedProblem
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from pina.problem import AbstractProblem
2+
from pina import Condition
3+
from pina import Graph
4+
5+
class SupervisedProblem(AbstractProblem):
6+
"""
7+
A problem definition for supervised learning in PINA.
8+
9+
This class allows an easy and straightforward definition of a Supervised problem,
10+
based on a single condition of type `InputOutputPointsCondition`
11+
12+
:Example:
13+
>>> import torch
14+
>>> input_data = torch.rand((100, 10))
15+
>>> output_data = torch.rand((100, 10))
16+
>>> problem = SupervisedProblem(input_data, output_data)
17+
"""
18+
conditions = dict()
19+
output_variables = None
20+
21+
def __init__(self, input_, output_):
22+
"""
23+
Initialize the SupervisedProblem class
24+
25+
:param input_: Input data of the problem
26+
:type input_: torch.Tensor | Graph
27+
:param output_: Output data of the problem
28+
:type output_: torch.Tensor
29+
"""
30+
if isinstance(input_, Graph):
31+
input_ = input_.data
32+
self.conditions['data'] = Condition(
33+
input_points=input_,
34+
output_points = output_
35+
)
36+
super().__init__()
37+
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
from pina.problem import AbstractProblem
3+
from pina.condition import InputOutputPointsCondition
4+
from pina.problem.zoo.supervised_problem import SupervisedProblem
5+
from pina import RadiusGraph
6+
7+
def test_constructor():
8+
input_ = torch.rand((100,10))
9+
output_ = torch.rand((100,10))
10+
problem = SupervisedProblem(input_=input_, output_=output_)
11+
assert isinstance(problem, AbstractProblem)
12+
assert hasattr(problem, "conditions")
13+
assert isinstance(problem.conditions, dict)
14+
assert list(problem.conditions.keys()) == ['data']
15+
assert isinstance(problem.conditions['data'], InputOutputPointsCondition)
16+
17+
def test_constructor_graph():
18+
x = torch.rand((20,100,10))
19+
pos = torch.rand((20,100,2))
20+
input_ = RadiusGraph(
21+
x=x, pos=pos, r=.2, build_edge_attr=True
22+
)
23+
output_ = torch.rand((100,10))
24+
problem = SupervisedProblem(input_=input_, output_=output_)
25+
assert isinstance(problem, AbstractProblem)
26+
assert hasattr(problem, "conditions")
27+
assert isinstance(problem.conditions, dict)
28+
assert list(problem.conditions.keys()) == ['data']
29+
assert isinstance(problem.conditions['data'], InputOutputPointsCondition)
30+
assert isinstance(problem.conditions['data'].input_points, list)
31+
assert isinstance(problem.conditions['data'].output_points, torch.Tensor)

0 commit comments

Comments
 (0)