Skip to content

Commit 85b9edc

Browse files
authored
fix tensor getitem in graph_dataset (#633)
1 parent 7469543 commit 85b9edc

File tree

5 files changed

+39
-51
lines changed

5 files changed

+39
-51
lines changed

pina/data/dataset.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -276,20 +276,6 @@ def _create_graph_batch(self, data):
276276
batch = LabelBatch.from_data_list(data)
277277
return batch
278278

279-
def _create_tensor_batch(self, data):
280-
"""
281-
Reshape properly ``data`` tensor to be processed handle by the graph
282-
based models.
283-
284-
:param data: torch.Tensor object of shape ``(N, ...)`` where ``N`` is
285-
the number of data objects.
286-
:type data: torch.Tensor | LabelTensor
287-
:return: Reshaped tensor object.
288-
:rtype: torch.Tensor | LabelTensor
289-
"""
290-
out = data.reshape(-1, *data.shape[2:])
291-
return out
292-
293279
def create_batch(self, data):
294280
"""
295281
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
@@ -324,7 +310,7 @@ def _retrive_data(self, data, idx_list):
324310
k: (
325311
self._create_graph_batch([v[i] for i in idx_list])
326312
if isinstance(v, list)
327-
else self._create_tensor_batch(v[idx_list])
313+
else v[idx_list]
328314
)
329315
for k, v in data.items()
330316
}

tests/test_data/test_graph_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_getitem(conditions_dict, max_conditions_lengths):
101101
[d["input"].x.shape == torch.Size((400, 10)) for d in data.values()]
102102
)
103103
assert all(
104-
[d["target"].shape == torch.Size((400, 10)) for d in data.values()]
104+
[d["target"].shape == torch.Size((20, 20, 10)) for d in data.values()]
105105
)
106106
assert all(
107107
[

tests/test_solver/test_ensemble_supervised_solver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
from torch._dynamo.eval_frame import OptimizedModule
44
from torch_geometric.nn import GCNConv
5+
from torch_geometric.utils import to_dense_batch
56
from pina import Condition, LabelTensor
67
from pina.condition import InputTargetCondition
78
from pina.problem import AbstractProblem
@@ -82,7 +83,7 @@ def forward(self, batch):
8283
y = self.conv(y, edge_index)
8384
y = self.activation(y)
8485
y = self.output(y)
85-
return y
86+
return to_dense_batch(y, batch.batch)[0]
8687

8788

8889
graph_models = [Models() for i in range(10)]

tests/test_solver/test_supervised_solver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
from torch._dynamo.eval_frame import OptimizedModule
44
from torch_geometric.nn import GCNConv
5+
from torch_geometric.utils import to_dense_batch
56
from pina import Condition, LabelTensor
67
from pina.condition import InputTargetCondition
78
from pina.problem import AbstractProblem
@@ -82,7 +83,7 @@ def forward(self, batch):
8283
y = self.conv(y, edge_index)
8384
y = self.activation(y)
8485
y = self.output(y)
85-
return y
86+
return to_dense_batch(y, batch.batch)[0]
8687

8788

8889
graph_model = Model()

tutorials/tutorial15/tutorial.ipynb

Lines changed: 33 additions & 33 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)