Skip to content

Commit aa8a9b4

Browse files
committed
Refactor datasets and implement LabelBatch
1 parent 7a97866 commit aa8a9b4

File tree

3 files changed

+156
-157
lines changed

3 files changed

+156
-157
lines changed

pina/data/dataset.py

Lines changed: 98 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22
This module provide basic data management functionalities
33
"""
44

5-
import functools
6-
import torch
7-
from torch.utils.data import Dataset
85
from abc import abstractmethod
9-
from torch_geometric.data import Batch, Data
10-
from pina import LabelTensor
6+
from torch.utils.data import Dataset
7+
from torch_geometric.data import Data
8+
from ..graph import Graph, LabelBatch
119

1210

1311
class PinaDatasetFactory:
@@ -19,38 +17,53 @@ class PinaDatasetFactory:
1917
"""
2018

2119
def __new__(cls, conditions_dict, **kwargs):
20+
# Check if conditions_dict is empty
2221
if len(conditions_dict) == 0:
2322
raise ValueError("No conditions provided")
24-
if all(
25-
[
26-
isinstance(v["input"], torch.Tensor)
27-
for v in conditions_dict.values()
28-
]
29-
):
30-
return PinaTensorDataset(conditions_dict, **kwargs)
31-
elif all(
32-
[isinstance(v["input"], list) for v in conditions_dict.values()]
33-
):
23+
24+
# Check is a Graph is present in the conditions
25+
is_graph = cls._is_graph_dataset(conditions_dict)
26+
if is_graph:
27+
# If a Graph is present, return a PinaGraphDataset
3428
return PinaGraphDataset(conditions_dict, **kwargs)
35-
raise ValueError(
36-
"Conditions must be either torch.Tensor or list of Data " "objects."
37-
)
29+
# If no Graph is present, return a PinaTensorDataset
30+
return PinaTensorDataset(conditions_dict, **kwargs)
31+
32+
@staticmethod
33+
def _is_graph_dataset(conditions_dict):
34+
for v in conditions_dict.values():
35+
for cond in v.values():
36+
if isinstance(cond, (Data, Graph, list)):
37+
return True
38+
return False
3839

3940

4041
class PinaDataset(Dataset):
4142
"""
4243
Abstract class for the PINA dataset
4344
"""
4445

45-
def __init__(self, conditions_dict, max_conditions_lengths):
46+
def __init__(
47+
self, conditions_dict, max_conditions_lengths, automatic_batching
48+
):
49+
# Store the conditions dictionary
4650
self.conditions_dict = conditions_dict
51+
# Store the maximum number of conditions to consider
4752
self.max_conditions_lengths = max_conditions_lengths
53+
# Store length of each condition
4854
self.conditions_length = {
4955
k: len(v["input"]) for k, v in self.conditions_dict.items()
5056
}
57+
# Store the maximum length of the dataset
5158
self.length = max(self.conditions_length.values())
59+
# Dynamically set the getitem function based on automatic batching
60+
if automatic_batching:
61+
self._getitem_func = self._getitem_int
62+
else:
63+
self._getitem_func = self._getitem_dummy
5264

5365
def _get_max_len(self):
66+
""""""
5467
max_len = 0
5568
for condition in self.conditions_dict.values():
5669
max_len = max(max_len, len(condition["input"]))
@@ -59,50 +72,66 @@ def _get_max_len(self):
5972
def __len__(self):
6073
return self.length
6174

62-
@abstractmethod
63-
def __getitem__(self, item):
64-
pass
65-
66-
67-
class PinaTensorDataset(PinaDataset):
68-
def __init__(
69-
self, conditions_dict, max_conditions_lengths, automatic_batching
70-
):
71-
super().__init__(conditions_dict, max_conditions_lengths)
75+
def __getitem__(self, idx):
76+
return self._getitem_func(idx)
7277

73-
if automatic_batching:
74-
self._getitem_func = self._getitem_int
75-
else:
76-
self._getitem_func = self._getitem_dummy
78+
def _getitem_dummy(self, idx):
79+
# If automatic batching is disabled, return the data at the given index
80+
return idx
7781

7882
def _getitem_int(self, idx):
83+
# If automatic batching is enabled, return the data at the given index
7984
return {
8085
k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()}
8186
for k, v in self.conditions_dict.items()
8287
}
8388

89+
def get_all_data(self):
90+
"""
91+
Return all data in the dataset
92+
93+
:return: All data in the dataset
94+
:rtype: dict
95+
"""
96+
index = list(range(len(self)))
97+
return self.fetch_from_idx_list(index)
98+
8499
def fetch_from_idx_list(self, idx):
100+
"""
101+
Return data from the dataset given a list of indices
102+
103+
:param idx: List of indices
104+
:type idx: list
105+
:return: Data from the dataset
106+
:rtype: dict
107+
"""
85108
to_return_dict = {}
86109
for condition, data in self.conditions_dict.items():
110+
# Get the indices for the current condition
87111
cond_idx = idx[: self.max_conditions_lengths[condition]]
112+
# Get the length of the current condition
88113
condition_len = self.conditions_length[condition]
114+
# If the length of the dataset is greater than the length of the
115+
# current condition, repeat the indices
89116
if self.length > condition_len:
90117
cond_idx = [idx % condition_len for idx in cond_idx]
91-
to_return_dict[condition] = {
92-
k: v[cond_idx] for k, v in data.items()
93-
}
118+
# Retrieve the data from the current condition
119+
to_return_dict[condition] = self._get_data_list_idx(data, cond_idx)
94120
return to_return_dict
95121

96-
@staticmethod
97-
def _getitem_dummy(idx):
98-
return idx
122+
@abstractmethod
123+
def _retrive_data(self, data, idx_list):
124+
pass
99125

100-
def get_all_data(self):
101-
index = [i for i in range(len(self))]
102-
return self.fetch_from_idx_list(index)
103126

104-
def __getitem__(self, idx):
105-
return self._getitem_func(idx)
127+
class PinaTensorDataset(PinaDataset):
128+
"""
129+
Class for the PINA dataset with torch.Tensor data
130+
"""
131+
132+
# Override _retrive_data method for torch.Tensor data
133+
def _retrive_data(self, data, idx_list):
134+
return {k: v[idx_list] for k, v in data.items()}
106135

107136
@property
108137
def input(self):
@@ -112,129 +141,42 @@ def input(self):
112141
return {k: v["input"] for k, v in self.conditions_dict.items()}
113142

114143

115-
class PinaBatch(Batch):
144+
class PinaGraphDataset(PinaDataset):
116145
"""
117-
Add extract function to torch_geometric Batch object
146+
Class for the PINA dataset with torch_geometric.data.Data data
118147
"""
119148

120-
def __init__(self):
121-
122-
super().__init__(self)
123-
124-
def extract(self, labels):
125-
"""
126-
Perform extraction of labels on node features (x)
127-
128-
:param labels: Labels to extract
129-
:type labels: list[str] | tuple[str] | str
130-
:return: Batch object with extraction performed on x
131-
:rtype: PinaBatch
132-
"""
133-
self.x = self.x.extract(labels)
134-
return self
135-
136-
137-
class PinaGraphDataset(PinaDataset):
138-
139-
def __init__(
140-
self, conditions_dict, max_conditions_lengths, automatic_batching
141-
):
142-
super().__init__(conditions_dict, max_conditions_lengths)
143-
self.in_labels = {}
144-
self.out_labels = None
145-
if automatic_batching:
146-
self._getitem_func = self._getitem_int
147-
else:
148-
self._getitem_func = self._getitem_dummy
149-
150-
ex_data = conditions_dict[list(conditions_dict.keys())[0]]["input"][0]
151-
for name, attr in ex_data.items():
152-
if isinstance(attr, LabelTensor):
153-
self.in_labels[name] = attr.stored_labels
154-
ex_data = conditions_dict[list(conditions_dict.keys())[0]]["target"][0]
155-
if isinstance(ex_data, LabelTensor):
156-
self.out_labels = ex_data.labels
157-
158-
self._create_graph_batch_from_list = (
159-
self._labelise_batch(self._base_create_graph_batch_from_list)
160-
if self.in_labels
161-
else self._base_create_graph_batch_from_list
162-
)
163-
164-
self._create_output_batch = (
165-
self._labelise_tensor(self._base_create_output_batch)
166-
if self.out_labels is not None
167-
else self._base_create_output_batch
168-
)
169-
170-
def fetch_from_idx_list(self, idx):
171-
to_return_dict = {}
172-
for condition, data in self.conditions_dict.items():
173-
cond_idx = idx[: self.max_conditions_lengths[condition]]
174-
condition_len = self.conditions_length[condition]
175-
if self.length > condition_len:
176-
cond_idx = [idx % condition_len for idx in cond_idx]
177-
to_return_dict[condition] = {
178-
k: (
179-
self._create_graph_batch_from_list([v[i] for i in idx])
180-
if isinstance(v, list)
181-
else self._create_output_batch(v[idx])
182-
)
183-
for k, v in data.items()
184-
}
185-
186-
return to_return_dict
187-
188-
def _base_create_graph_batch_from_list(self, data):
189-
batch = PinaBatch.from_data_list(data)
149+
def _create_graph_batch_from_list(self, data):
150+
batch = LabelBatch.from_data_list(data)
190151
return batch
191152

192-
def _base_create_output_batch(self, data):
153+
def _create_output_batch(self, data):
193154
out = data.reshape(-1, *data.shape[2:])
194155
return out
195156

196-
def _getitem_dummy(self, idx):
197-
return idx
198-
199-
def _getitem_int(self, idx):
200-
return {
201-
k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()}
202-
for k, v in self.conditions_dict.items()
203-
}
204-
205-
def get_all_data(self):
206-
index = [i for i in range(len(self))]
207-
return self.fetch_from_idx_list(index)
208-
209-
def __getitem__(self, idx):
210-
return self._getitem_func(idx)
211-
212-
def _labelise_batch(self, func):
213-
@functools.wraps(func)
214-
def wrapper(*args, **kwargs):
215-
batch = func(*args, **kwargs)
216-
for k, v in self.in_labels.items():
217-
tmp = batch[k]
218-
tmp.labels = v
219-
batch[k] = tmp
220-
return batch
221-
222-
return wrapper
223-
224-
def _labelise_tensor(self, func):
225-
@functools.wraps(func)
226-
def wrapper(*args, **kwargs):
227-
out = func(*args, **kwargs)
228-
if isinstance(out, LabelTensor):
229-
out.labels = self.out_labels
230-
return out
231-
232-
return wrapper
233-
234157
def create_graph_batch(self, data):
235158
"""
236-
# TODO
159+
Create a Batch object from a list of Data objects.
160+
161+
:param data: List of Data objects
162+
:type data: list
163+
:return: Batch object
164+
:rtype: Batch or PinaBatch
237165
"""
238166
if isinstance(data[0], Data):
239167
return self._create_graph_batch_from_list(data)
240168
return self._create_output_batch(data)
169+
170+
# Override _retrive_data method for graph handling
171+
def _retrive_data(self, data, idx_list):
172+
# Return the data from the current condition
173+
# If the data is a list of Data objects, create a Batch object
174+
# If the data is a list of torch.Tensor objects, create a torch.Tensor
175+
return {
176+
k: (
177+
self._create_graph_batch_from_list([v[i] for i in idx_list])
178+
if isinstance(v, list)
179+
else self._create_output_batch(v[idx_list])
180+
)
181+
for k, v in data.items()
182+
}

pina/graph.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
import torch
6-
from torch_geometric.data import Data
6+
from torch_geometric.data import Data, Batch
77
from torch_geometric.utils import to_undirected
88
from . import LabelTensor
99
from .utils import check_consistency, is_function
@@ -162,6 +162,18 @@ def _preprocess_edge_index(edge_index, undirected):
162162
edge_index = to_undirected(edge_index)
163163
return edge_index
164164

165+
def extract(self, labels):
166+
"""
167+
Perform extraction of labels on node features (x)
168+
169+
:param labels: Labels to extract
170+
:type labels: list[str] | tuple[str] | str
171+
:return: Batch object with extraction performed on x
172+
:rtype: PinaBatch
173+
"""
174+
self.x = self.x.extract(labels)
175+
return self
176+
165177

166178
class GraphBuilder:
167179
"""
@@ -317,3 +329,31 @@ def compute_knn_graph(points, k):
317329
row = torch.arange(points.size(0)).repeat_interleave(k)
318330
col = knn_indices.flatten()
319331
return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)
332+
333+
334+
class LabelBatch(Batch):
335+
"""
336+
Add extract function to torch_geometric Batch object
337+
"""
338+
339+
@classmethod
340+
def from_data_list(cls, data_list):
341+
"""
342+
Create a Batch object from a list of Data objects.
343+
"""
344+
# Store the labels of Data/Graph objects (all data have the same labels)
345+
# If the data do not contain labels, labels is an empty dictionary,
346+
# therefore the labels are not stored
347+
labels = {
348+
k: v.labels
349+
for k, v in data_list[0].items()
350+
if isinstance(v, LabelTensor)
351+
}
352+
353+
# Create a Batch object from the list of Data objects
354+
batch = super().from_data_list(data_list)
355+
356+
# Put the labels back in the Batch object
357+
for k, v in labels.items():
358+
batch[k].labels = v
359+
return batch

0 commit comments

Comments
 (0)