22This module provide basic data management functionalities
33"""
44
5- import functools
6- import torch
7- from torch .utils .data import Dataset
85from abc import abstractmethod
9- from torch_geometric .data import Batch , Data
10- from pina import LabelTensor
6+ from torch .utils .data import Dataset
7+ from torch_geometric .data import Data
8+ from ..graph import Graph , LabelBatch
119
1210
1311class PinaDatasetFactory :
@@ -19,38 +17,53 @@ class PinaDatasetFactory:
1917 """
2018
2119 def __new__ (cls , conditions_dict , ** kwargs ):
20+ # Check if conditions_dict is empty
2221 if len (conditions_dict ) == 0 :
2322 raise ValueError ("No conditions provided" )
24- if all (
25- [
26- isinstance (v ["input" ], torch .Tensor )
27- for v in conditions_dict .values ()
28- ]
29- ):
30- return PinaTensorDataset (conditions_dict , ** kwargs )
31- elif all (
32- [isinstance (v ["input" ], list ) for v in conditions_dict .values ()]
33- ):
23+
24+ # Check is a Graph is present in the conditions
25+ is_graph = cls ._is_graph_dataset (conditions_dict )
26+ if is_graph :
27+ # If a Graph is present, return a PinaGraphDataset
3428 return PinaGraphDataset (conditions_dict , ** kwargs )
35- raise ValueError (
36- "Conditions must be either torch.Tensor or list of Data " "objects."
37- )
29+ # If no Graph is present, return a PinaTensorDataset
30+ return PinaTensorDataset (conditions_dict , ** kwargs )
31+
32+ @staticmethod
33+ def _is_graph_dataset (conditions_dict ):
34+ for v in conditions_dict .values ():
35+ for cond in v .values ():
36+ if isinstance (cond , (Data , Graph , list )):
37+ return True
38+ return False
3839
3940
4041class PinaDataset (Dataset ):
4142 """
4243 Abstract class for the PINA dataset
4344 """
4445
45- def __init__ (self , conditions_dict , max_conditions_lengths ):
46+ def __init__ (
47+ self , conditions_dict , max_conditions_lengths , automatic_batching
48+ ):
49+ # Store the conditions dictionary
4650 self .conditions_dict = conditions_dict
51+ # Store the maximum number of conditions to consider
4752 self .max_conditions_lengths = max_conditions_lengths
53+ # Store length of each condition
4854 self .conditions_length = {
4955 k : len (v ["input" ]) for k , v in self .conditions_dict .items ()
5056 }
57+ # Store the maximum length of the dataset
5158 self .length = max (self .conditions_length .values ())
59+ # Dynamically set the getitem function based on automatic batching
60+ if automatic_batching :
61+ self ._getitem_func = self ._getitem_int
62+ else :
63+ self ._getitem_func = self ._getitem_dummy
5264
5365 def _get_max_len (self ):
66+ """"""
5467 max_len = 0
5568 for condition in self .conditions_dict .values ():
5669 max_len = max (max_len , len (condition ["input" ]))
@@ -59,50 +72,66 @@ def _get_max_len(self):
5972 def __len__ (self ):
6073 return self .length
6174
62- @abstractmethod
63- def __getitem__ (self , item ):
64- pass
65-
66-
67- class PinaTensorDataset (PinaDataset ):
68- def __init__ (
69- self , conditions_dict , max_conditions_lengths , automatic_batching
70- ):
71- super ().__init__ (conditions_dict , max_conditions_lengths )
75+ def __getitem__ (self , idx ):
76+ return self ._getitem_func (idx )
7277
73- if automatic_batching :
74- self ._getitem_func = self ._getitem_int
75- else :
76- self ._getitem_func = self ._getitem_dummy
78+ def _getitem_dummy (self , idx ):
79+ # If automatic batching is disabled, return the data at the given index
80+ return idx
7781
7882 def _getitem_int (self , idx ):
83+ # If automatic batching is enabled, return the data at the given index
7984 return {
8085 k : {k_data : v [k_data ][idx % len (v ["input" ])] for k_data in v .keys ()}
8186 for k , v in self .conditions_dict .items ()
8287 }
8388
89+ def get_all_data (self ):
90+ """
91+ Return all data in the dataset
92+
93+ :return: All data in the dataset
94+ :rtype: dict
95+ """
96+ index = list (range (len (self )))
97+ return self .fetch_from_idx_list (index )
98+
8499 def fetch_from_idx_list (self , idx ):
100+ """
101+ Return data from the dataset given a list of indices
102+
103+ :param idx: List of indices
104+ :type idx: list
105+ :return: Data from the dataset
106+ :rtype: dict
107+ """
85108 to_return_dict = {}
86109 for condition , data in self .conditions_dict .items ():
110+ # Get the indices for the current condition
87111 cond_idx = idx [: self .max_conditions_lengths [condition ]]
112+ # Get the length of the current condition
88113 condition_len = self .conditions_length [condition ]
114+ # If the length of the dataset is greater than the length of the
115+ # current condition, repeat the indices
89116 if self .length > condition_len :
90117 cond_idx = [idx % condition_len for idx in cond_idx ]
91- to_return_dict [condition ] = {
92- k : v [cond_idx ] for k , v in data .items ()
93- }
118+ # Retrieve the data from the current condition
119+ to_return_dict [condition ] = self ._retrive_data (data , cond_idx )
94120 return to_return_dict
95121
96- @staticmethod
97- def _getitem_dummy ( idx ):
98- return idx
122+ @abstractmethod
123+ def _retrive_data ( self , data , idx_list ):
124+ pass
99125
100- def get_all_data (self ):
101- index = [i for i in range (len (self ))]
102- return self .fetch_from_idx_list (index )
103126
104- def __getitem__ (self , idx ):
105- return self ._getitem_func (idx )
127+ class PinaTensorDataset (PinaDataset ):
128+ """
129+ Class for the PINA dataset with torch.Tensor data
130+ """
131+
132+ # Override _retrive_data method for torch.Tensor data
133+ def _retrive_data (self , data , idx_list ):
134+ return {k : v [idx_list ] for k , v in data .items ()}
106135
107136 @property
108137 def input (self ):
@@ -112,129 +141,42 @@ def input(self):
112141 return {k : v ["input" ] for k , v in self .conditions_dict .items ()}
113142
114143
115- class PinaBatch ( Batch ):
144+ class PinaGraphDataset ( PinaDataset ):
116145 """
117- Add extract function to torch_geometric Batch object
146+ Class for the PINA dataset with torch_geometric.data.Data data
118147 """
119148
120- def __init__ (self ):
121-
122- super ().__init__ (self )
123-
124- def extract (self , labels ):
125- """
126- Perform extraction of labels on node features (x)
127-
128- :param labels: Labels to extract
129- :type labels: list[str] | tuple[str] | str
130- :return: Batch object with extraction performed on x
131- :rtype: PinaBatch
132- """
133- self .x = self .x .extract (labels )
134- return self
135-
136-
137- class PinaGraphDataset (PinaDataset ):
138-
139- def __init__ (
140- self , conditions_dict , max_conditions_lengths , automatic_batching
141- ):
142- super ().__init__ (conditions_dict , max_conditions_lengths )
143- self .in_labels = {}
144- self .out_labels = None
145- if automatic_batching :
146- self ._getitem_func = self ._getitem_int
147- else :
148- self ._getitem_func = self ._getitem_dummy
149-
150- ex_data = conditions_dict [list (conditions_dict .keys ())[0 ]]["input" ][0 ]
151- for name , attr in ex_data .items ():
152- if isinstance (attr , LabelTensor ):
153- self .in_labels [name ] = attr .stored_labels
154- ex_data = conditions_dict [list (conditions_dict .keys ())[0 ]]["target" ][0 ]
155- if isinstance (ex_data , LabelTensor ):
156- self .out_labels = ex_data .labels
157-
158- self ._create_graph_batch_from_list = (
159- self ._labelise_batch (self ._base_create_graph_batch_from_list )
160- if self .in_labels
161- else self ._base_create_graph_batch_from_list
162- )
163-
164- self ._create_output_batch = (
165- self ._labelise_tensor (self ._base_create_output_batch )
166- if self .out_labels is not None
167- else self ._base_create_output_batch
168- )
169-
170- def fetch_from_idx_list (self , idx ):
171- to_return_dict = {}
172- for condition , data in self .conditions_dict .items ():
173- cond_idx = idx [: self .max_conditions_lengths [condition ]]
174- condition_len = self .conditions_length [condition ]
175- if self .length > condition_len :
176- cond_idx = [idx % condition_len for idx in cond_idx ]
177- to_return_dict [condition ] = {
178- k : (
179- self ._create_graph_batch_from_list ([v [i ] for i in idx ])
180- if isinstance (v , list )
181- else self ._create_output_batch (v [idx ])
182- )
183- for k , v in data .items ()
184- }
185-
186- return to_return_dict
187-
188- def _base_create_graph_batch_from_list (self , data ):
189- batch = PinaBatch .from_data_list (data )
149+ def _create_graph_batch_from_list (self , data ):
150+ batch = LabelBatch .from_data_list (data )
190151 return batch
191152
192- def _base_create_output_batch (self , data ):
153+ def _create_output_batch (self , data ):
193154 out = data .reshape (- 1 , * data .shape [2 :])
194155 return out
195156
196- def _getitem_dummy (self , idx ):
197- return idx
198-
199- def _getitem_int (self , idx ):
200- return {
201- k : {k_data : v [k_data ][idx % len (v ["input" ])] for k_data in v .keys ()}
202- for k , v in self .conditions_dict .items ()
203- }
204-
205- def get_all_data (self ):
206- index = [i for i in range (len (self ))]
207- return self .fetch_from_idx_list (index )
208-
209- def __getitem__ (self , idx ):
210- return self ._getitem_func (idx )
211-
212- def _labelise_batch (self , func ):
213- @functools .wraps (func )
214- def wrapper (* args , ** kwargs ):
215- batch = func (* args , ** kwargs )
216- for k , v in self .in_labels .items ():
217- tmp = batch [k ]
218- tmp .labels = v
219- batch [k ] = tmp
220- return batch
221-
222- return wrapper
223-
224- def _labelise_tensor (self , func ):
225- @functools .wraps (func )
226- def wrapper (* args , ** kwargs ):
227- out = func (* args , ** kwargs )
228- if isinstance (out , LabelTensor ):
229- out .labels = self .out_labels
230- return out
231-
232- return wrapper
233-
234157 def create_graph_batch (self , data ):
235158 """
236- # TODO
159+ Create a Batch object from a list of Data objects.
160+
161+ :param data: List of Data objects
162+ :type data: list
163+ :return: Batch object
164+ :rtype: Batch or PinaBatch
237165 """
238166 if isinstance (data [0 ], Data ):
239167 return self ._create_graph_batch_from_list (data )
240168 return self ._create_output_batch (data )
169+
170+ # Override _retrive_data method for graph handling
171+ def _retrive_data (self , data , idx_list ):
172+ # Return the data from the current condition
173+ # If the data is a list of Data objects, create a Batch object
174+ # If the data is a list of torch.Tensor objects, create a torch.Tensor
175+ return {
176+ k : (
177+ self ._create_graph_batch_from_list ([v [i ] for i in idx_list ])
178+ if isinstance (v , list )
179+ else self ._create_output_batch (v [idx_list ])
180+ )
181+ for k , v in data .items ()
182+ }
0 commit comments