44
55import torch
66from torch_geometric .data import Data
7- from .condition_interface import ConditionInterface
87from ..label_tensor import LabelTensor
98from ..graph import Graph
10- from .. utils import check_consistency
9+ from .condition_interface import ConditionInterface
1110
1211
1312class InputTargetCondition (ConditionInterface ):
@@ -34,43 +33,41 @@ def __new__(cls, input, target):
3433 TensorInputGraphTargetCondition or GraphInputTensorTargetCondition
3534 or GraphInputGraphTargetCondition
3635 """
37- if (
38- cls == InputTargetCondition
39- and isinstance (input , (torch .Tensor , LabelTensor ))
40- and isinstance (target , (torch .Tensor , LabelTensor ))
36+ if cls != InputTargetCondition :
37+ return super ().__new__ (cls )
38+
39+ if isinstance (input , (torch .Tensor , LabelTensor )) and isinstance (
40+ target , (torch .Tensor , LabelTensor )
4141 ):
4242 subclass = TensorInputTensorTargetCondition
4343 return subclass .__new__ (subclass , input , target )
44-
45- elif (
46- cls == InputTargetCondition
47- and isinstance (input , (torch .Tensor , LabelTensor ))
48- and isinstance (target , (Graph , Data , list , tuple ))
44+ if isinstance (input , (torch .Tensor , LabelTensor )) and isinstance (
45+ target , (Graph , Data , list , tuple )
4946 ):
5047 cls ._check_graph_list_consistency (target )
5148 subclass = TensorInputGraphTargetCondition
5249 return subclass .__new__ (subclass , input , target )
5350
54- elif (
55- cls == InputTargetCondition
56- and isinstance (input , (Graph , Data , list , tuple ))
57- and isinstance (target , (torch .Tensor , LabelTensor ))
51+ if isinstance (input , (Graph , Data , list , tuple )) and isinstance (
52+ target , (torch .Tensor , LabelTensor )
5853 ):
5954 cls ._check_graph_list_consistency (input )
6055 subclass = GraphInputTensorTargetCondition
6156 return subclass .__new__ (subclass , input , target )
6257
63- elif (
64- cls == InputTargetCondition
65- and isinstance (input , (Graph , Data , list , tuple ))
66- and isinstance (target , (Graph , Data , list , tuple ))
58+ if isinstance (input , (Graph , Data , list , tuple )) and isinstance (
59+ target , (Graph , Data , list , tuple )
6760 ):
6861 cls ._check_graph_list_consistency (input )
6962 cls ._check_graph_list_consistency (target )
7063 subclass = GraphInputGraphTargetCondition
7164 return subclass .__new__ (subclass , input , target )
7265
73- return super ().__new__ (cls )
66+ raise ValueError (
67+ "Invalid input/target types. "
68+ "Please provide either Data, Graph, LabelTensor or torch.Tensor "
69+ "objects."
70+ )
7471
7572 def __init__ (self , input , target ):
7673 """
@@ -82,28 +79,16 @@ def __init__(self, input, target):
8279 :type target: torch.Tensor or Graph or Data
8380 """
8481 super ().__init__ ()
85- if isinstance (input , (list , tuple )) or isinstance (
86- target , (list , tuple )
87- ):
88- self ._check_input_target_len (input , target )
82+ self ._check_input_target_len (input , target )
8983 self .input = input
9084 self .target = target
9185
92- def __setattr__ (self , key , value ):
93- if key == "input" :
94- check_consistency (value , (torch .Tensor , LabelTensor , Data , Graph ))
95- InputTargetCondition .__dict__ [key ].__set__ (self , value )
96- elif key == "target" :
97- if value is not None :
98- check_consistency (
99- value , (torch .Tensor , LabelTensor , Data , Graph )
100- )
101- InputTargetCondition .__dict__ [key ].__set__ (self , value )
102- elif key in ("_problem" ):
103- super ().__setattr__ (key , value )
104-
10586 @staticmethod
10687 def _check_input_target_len (input , target ):
88+ if isinstance (input , (Graph , Data )) or isinstance (
89+ target , (Graph , Data )
90+ ):
91+ return
10792 if len (input ) != len (target ):
10893 raise ValueError (
10994 "The input and target lists must have the same length."
0 commit comments