|
3 | 3 | """ |
4 | 4 |
|
5 | 5 | import torch |
6 | | - |
| 6 | +from torch_geometric.data import Data |
7 | 7 | from .condition_interface import ConditionInterface |
8 | 8 | from ..label_tensor import LabelTensor |
9 | 9 | from ..graph import Graph |
10 | | -from torch_geometric.data import Data |
11 | 10 |
|
12 | 11 |
|
13 | 12 | class InputTargetCondition(ConditionInterface): |
@@ -66,56 +65,39 @@ def _get_subclass(input, target): |
66 | 65 |
|
67 | 66 | if is_tensor_input and is_tensor_target: |
68 | 67 | return TensorInputTensorTargetCondition |
69 | | - elif is_tensor_input and is_graph_target: |
| 68 | + if is_tensor_input and is_graph_target: |
70 | 69 | return TensorInputGraphTargetCondition |
71 | | - elif is_graph_input and is_tensor_target: |
| 70 | + if is_graph_input and is_tensor_target: |
72 | 71 | return GraphInputTensorTargetCondition |
73 | | - elif is_graph_input and is_graph_target: |
| 72 | + if is_graph_input and is_graph_target: |
74 | 73 | return GraphInputGraphTargetCondition |
75 | | - else: |
76 | | - raise ValueError( |
77 | | - "Invalid input and target types. " |
78 | | - "Please provide either torch.Tensor or Graph objects." |
79 | | - ) |
80 | | - |
81 | | - def __init__(self, input, target): |
82 | | - """ |
83 | | - TODO : add docstring |
84 | | - """ |
85 | | - super().__init__() |
86 | | - self.input = input |
87 | | - self.target = target |
| 74 | + raise ValueError( |
| 75 | + "Invalid input and target types. " |
| 76 | + "Please provide either torch.Tensor or Graph objects." |
| 77 | + ) |
88 | 78 |
|
89 | 79 |
|
90 | 80 | class TensorInputTensorTargetCondition(InputTargetCondition): |
91 | 81 | """ |
92 | 82 | InputTargetCondition subclass for torch.Tensor input and target data. |
93 | 83 | """ |
94 | 84 |
|
95 | | - pass |
96 | | - |
97 | 85 |
|
98 | 86 | class TensorInputGraphTargetCondition(InputTargetCondition): |
99 | 87 | """ |
100 | 88 | InputTargetCondition subclass for torch.Tensor input and Graph/Data target |
101 | 89 | data. |
102 | 90 | """ |
103 | 91 |
|
104 | | - pass |
105 | | - |
106 | 92 |
|
107 | 93 | class GraphInputTensorTargetCondition(InputTargetCondition): |
108 | 94 | """ |
109 | 95 | InputTargetCondition subclass for Graph/Data input and torch.Tensor target |
110 | 96 | data. |
111 | 97 | """ |
112 | 98 |
|
113 | | - pass |
114 | | - |
115 | 99 |
|
116 | 100 | class GraphInputGraphTargetCondition(InputTargetCondition): |
117 | 101 | """ |
118 | 102 | InputTargetCondition subclass for Graph/Data input and target data. |
119 | 103 | """ |
120 | | - |
121 | | - pass |
0 commit comments