Skip to content

Commit 7a97866

Browse files
committed
Add check consistency in InputTargetCondition
1 parent 0acdba2 commit 7a97866

File tree

5 files changed

+65
-17
lines changed

5 files changed

+65
-17
lines changed

pina/condition/data_condition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def __new__(cls, input, conditional_variables=None):
3535
:rtype: TensorDataCondition or GraphDataCondition
3636
"""
3737

38-
subclass = cls._get_subclass(input, conditional_variables)
39-
if subclass is not cls:
38+
if cls == DataCondition:
39+
subclass = cls._get_subclass(input, conditional_variables)
4040
return subclass.__new__(subclass, input, conditional_variables)
4141
return super().__new__(cls)
4242

pina/condition/input_equation_condition.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
InputEquationCondition class definition.
2+
Module to define InputEquationCondition class and its subclasses.
33
"""
44

55
import torch
@@ -31,8 +31,10 @@ def __new__(cls, input, equation):
3131
:return: InputEquationCondition subclass
3232
:rtype: InputTensorEquationCondition or InputGraphEquationCondition
3333
"""
34-
subclass = cls._get_subclass(input, equation)
35-
if subclass is not cls:
34+
check_consistency(equation, (EquationInterface))
35+
36+
if cls == InputEquationCondition:
37+
subclass = cls._get_subclass(input)
3638
return subclass.__new__(subclass, input, equation)
3739
return super().__new__(cls)
3840

@@ -50,8 +52,7 @@ def __init__(self, input, equation):
5052
self.equation = equation
5153

5254
@staticmethod
53-
def _get_subclass(input, equation):
54-
check_consistency(equation, (EquationInterface))
55+
def _get_subclass(input):
5556
is_tensor_input = isinstance(input, (LabelTensor, torch.Tensor))
5657
is_graph_input = isinstance(input, (Data, Graph)) or (
5758
isinstance(input, list)

pina/condition/input_target_condition.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
TODO docstring
2+
This module contains condition classes for supervised learning tasks.
33
"""
44

55
import torch
@@ -31,8 +31,9 @@ def __new__(cls, input, target):
3131
TensorInputGraphTargetCondition or GraphInputTensorTargetCondition
3232
or GraphInputGraphTargetCondition
3333
"""
34-
subclass = cls._get_subclass(input, target)
35-
if subclass is not cls:
34+
35+
if cls == InputTargetCondition:
36+
subclass = cls._get_subclass(input, target)
3637
return subclass.__new__(subclass, input, target)
3738
return super().__new__(cls)
3839

@@ -46,6 +47,8 @@ def __init__(self, input, target):
4647
:type target: torch.Tensor or Graph or Data
4748
"""
4849
super().__init__()
50+
if hasattr(self, "_check_input_target_consistency"):
51+
self._check_input_target_consistency(input, target)
4952
self.input = input
5053
self.target = target
5154

@@ -82,22 +85,65 @@ class TensorInputTensorTargetCondition(InputTargetCondition):
8285
InputTargetCondition subclass for torch.Tensor input and target data.
8386
"""
8487

88+
@staticmethod
89+
def _check_input_target_consistency(input, target):
90+
if len(input) != len(target):
91+
raise ValueError(
92+
"The input and target lists must have the same length."
93+
)
94+
8595

8696
class TensorInputGraphTargetCondition(InputTargetCondition):
8797
"""
8898
InputTargetCondition subclass for torch.Tensor input and Graph/Data target
8999
data.
90100
"""
91101

102+
@staticmethod
103+
def _check_input_target_consistency(input, target):
104+
if isinstance(target, (Graph, Data)):
105+
return
106+
if len(input) != len(target):
107+
raise ValueError(
108+
"The input and target lists must have the same length."
109+
)
110+
92111

93112
class GraphInputTensorTargetCondition(InputTargetCondition):
94113
"""
95114
InputTargetCondition subclass for Graph/Data input and torch.Tensor target
96115
data.
97116
"""
98117

118+
@staticmethod
119+
def _check_input_target_consistency(input, target):
120+
if isinstance(input, (Graph, Data)):
121+
return
122+
if len(input) != len(target):
123+
raise ValueError(
124+
"The input and target lists must have the same length."
125+
)
126+
99127

100128
class GraphInputGraphTargetCondition(InputTargetCondition):
101129
"""
102130
InputTargetCondition subclass for Graph/Data input and target data.
103131
"""
132+
133+
@staticmethod
134+
def _check_input_target_consistency(input, target):
135+
if isinstance(input, list) and isinstance(target, list):
136+
if len(input) != len(target):
137+
raise ValueError(
138+
"The input and target lists must have the same length."
139+
)
140+
return
141+
if isinstance(target, (Graph, Data)) and isinstance(
142+
input, (Graph, Data)
143+
):
144+
return
145+
raise ValueError(
146+
"Invalid input and target types. "
147+
"input and target must be either both lists or both Graph/Data "
148+
"objects."
149+
)

tests/test_condition.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
DomainEquationCondition,
1515
)
1616
from pina.condition import (
17-
DataCondition,
1817
TensorDataCondition,
1918
GraphDataCondition,
2019
)
@@ -24,10 +23,10 @@
2423

2524
example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]})
2625

27-
input_tensor = torch.tensor([[0, 0, 0]])
28-
target_tensor = torch.tensor([[1, 2]])
29-
input_lt = LabelTensor(torch.tensor([[0, 0, 0]]), ["x", "y", "z"])
30-
target_lt = LabelTensor(torch.tensor([[1, 2]]), ["a", "b"])
26+
input_tensor = torch.rand((10,3))
27+
target_tensor = torch.rand((10,2))
28+
input_lt = LabelTensor(torch.rand((10,3)), ["x", "y", "z"])
29+
target_lt = LabelTensor(torch.rand((10,2)), ["a", "b"])
3130

3231
x = torch.rand(10, 20, 2)
3332
pos = torch.rand(10, 20, 2)
@@ -69,7 +68,9 @@ def test_init_input_target():
6968
assert isinstance(cond, TensorInputGraphTargetCondition)
7069
cond = Condition(input=input_single_graph, target=target_lt)
7170
assert isinstance(cond, GraphInputTensorTargetCondition)
72-
cond = Condition(input=input_graph, target=target_single_graph)
71+
cond = Condition(input=input_graph, target=target_graph)
72+
assert isinstance(cond, GraphInputGraphTargetCondition)
73+
cond = Condition(input=input_single_graph, target=target_single_graph)
7374
assert isinstance(cond, GraphInputGraphTargetCondition)
7475

7576
with pytest.raises(ValueError):

tests/test_problem_zoo/test_supervised_problem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_constructor_graph():
2323
RadiusGraph(x=x_, pos=pos_, radius=0.2, edge_attr=True)
2424
for x_, pos_ in zip(x, pos)
2525
]
26-
output_ = torch.rand((100, 10))
26+
output_ = torch.rand((20, 100, 10))
2727
problem = SupervisedProblem(input_=input_, output_=output_)
2828
assert isinstance(problem, AbstractProblem)
2929
assert hasattr(problem, "conditions")

0 commit comments

Comments
 (0)