Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pina/collector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
# TODO
"""
from .graph import Graph
from .utils import check_consistency


Expand Down Expand Up @@ -52,6 +56,8 @@ def store_fixed_data(self):
# get data
keys = condition.__slots__
values = [getattr(condition, name) for name in keys]
values = [value.data if isinstance(
value, Graph) else value for value in values]
self.data_collections[condition_name] = dict(zip(keys, values))
# condition now is ready
self._is_conditions_ready[condition_name] = True
Expand Down
84 changes: 53 additions & 31 deletions pina/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import warnings
from lightning.pytorch import LightningDataModule
import torch
from ..label_tensor import LabelTensor
from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
RandomSampler
from torch_geometric.data import Data
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from .dataset import PinaDatasetFactory
from ..label_tensor import LabelTensor
from .dataset import PinaDatasetFactory, PinaTensorDataset
from ..collector import Collector


Expand Down Expand Up @@ -61,6 +61,10 @@ def __init__(self, max_conditions_lengths, dataset=None):
max_conditions_lengths is None else (
self._collate_standard_dataloader)
self.dataset = dataset
if isinstance(self.dataset, PinaTensorDataset):
self._collate = self._collate_tensor_dataset
else:
self._collate = self._collate_graph_dataset

def _collate_custom_dataloader(self, batch):
return self.dataset.fetch_from_idx_list(batch)
Expand All @@ -73,7 +77,6 @@ def _collate_standard_dataloader(self, batch):
if isinstance(batch, dict):
return batch
conditions_names = batch[0].keys()

# Condition names
for condition_name in conditions_names:
single_cond_dict = {}
Expand All @@ -82,16 +85,28 @@ def _collate_standard_dataloader(self, batch):
data_list = [batch[idx][condition_name][arg] for idx in range(
min(len(batch),
self.max_conditions_lengths[condition_name]))]
if isinstance(data_list[0], LabelTensor):
single_cond_dict[arg] = LabelTensor.stack(data_list)
elif isinstance(data_list[0], torch.Tensor):
single_cond_dict[arg] = torch.stack(data_list)
else:
raise NotImplementedError(
f"Data type {type(data_list[0])} not supported")
single_cond_dict[arg] = self._collate(data_list)

batch_dict[condition_name] = single_cond_dict
return batch_dict

@staticmethod
def _collate_tensor_dataset(data_list):
if isinstance(data_list[0], LabelTensor):
return LabelTensor.stack(data_list)
if isinstance(data_list[0], torch.Tensor):
return torch.stack(data_list)
raise RuntimeError("Data must be Tensors or LabelTensor ")

def _collate_graph_dataset(self, data_list):
if isinstance(data_list[0], LabelTensor):
return LabelTensor.cat(data_list)
if isinstance(data_list[0], torch.Tensor):
return torch.cat(data_list)
if isinstance(data_list[0], Data):
return self.dataset.create_graph_batch(data_list)
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")

def __call__(self, batch):
return self.callable_function(batch)

Expand Down Expand Up @@ -125,7 +140,7 @@ def __init__(self,
batch_size=None,
shuffle=True,
repeat=False,
automatic_batching=False,
automatic_batching=None,
num_workers=0,
pin_memory=False,
):
Expand Down Expand Up @@ -158,15 +173,35 @@ def __init__(self,
logging.debug('Start initialization of Pina DataModule')
logging.info('Start initialization of Pina DataModule')
super().__init__()
self.automatic_batching = automatic_batching

# Store fixed attributes
self.batch_size = batch_size
self.shuffle = shuffle
self.repeat = repeat
self.automatic_batching = automatic_batching
if batch_size is None and num_workers != 0:
warnings.warn(
"Setting num_workers when batch_size is None has no effect on "
"the DataLoading process.")
self.num_workers = 0
else:
self.num_workers = num_workers
if batch_size is None and pin_memory:
warnings.warn("Setting pin_memory to True has no effect when "
"batch_size is None.")
self.pin_memory = False
else:
self.pin_memory = pin_memory

# Collect data
collector = Collector(problem)
collector.store_fixed_data()
collector.store_sample_domains()

# Check if the splits are correct
self._check_slit_sizes(train_size, test_size, val_size, predict_size)

# Begin Data splitting
# Split input data into subsets
splits_dict = {}
if train_size > 0:
splits_dict['train'] = train_size
Expand All @@ -188,19 +223,6 @@ def __init__(self,
self.predict_dataset = None
else:
self.predict_dataloader = super().predict_dataloader

collector = Collector(problem)
collector.store_fixed_data()
collector.store_sample_domains()
if batch_size is None and num_workers != 0:
warnings.warn(
"Setting num_workers when batch_size is None has no effect on "
"the DataLoading process.")
if batch_size is None and pin_memory:
warnings.warn("Setting pin_memory to True has no effect when "
"batch_size is None.")
self.num_workers = num_workers
self.pin_memory = pin_memory
self.collector_splits = self._create_splits(collector, splits_dict)
self.transfer_batch_to_device = self._transfer_batch_to_device

Expand Down Expand Up @@ -316,10 +338,10 @@ def _create_dataloader(self, split, dataset):
if self.batch_size is not None:
sampler = PinaSampler(dataset, shuffle)
if self.automatic_batching:
collate = Collator(self.find_max_conditions_lengths(split))

collate = Collator(self.find_max_conditions_lengths(split),
dataset=dataset)
else:
collate = Collator(None, dataset)
collate = Collator(None, dataset=dataset)
return DataLoader(dataset, self.batch_size,
collate_fn=collate, sampler=sampler,
num_workers=self.num_workers)
Expand Down
104 changes: 91 additions & 13 deletions pina/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""
This module provide basic data management functionalities
"""
import functools
import torch
from torch.utils.data import Dataset
from abc import abstractmethod
from torch_geometric.data import Batch
from torch_geometric.data import Batch, Data
from pina import LabelTensor


class PinaDatasetFactory:
Expand Down Expand Up @@ -62,7 +64,7 @@ def __init__(self, conditions_dict, max_conditions_lengths,
if automatic_batching:
self._getitem_func = self._getitem_int
else:
self._getitem_func = self._getitem_list
self._getitem_func = self._getitem_dummy

def _getitem_int(self, idx):
return {
Expand All @@ -82,7 +84,7 @@ def fetch_from_idx_list(self, idx):
return to_return_dict

@staticmethod
def _getitem_list(idx):
def _getitem_dummy(idx):
return idx

def get_all_data(self):
Expand All @@ -102,15 +104,56 @@ def input_points(self):
}


class PinaBatch(Batch):
"""
Add extract function to torch_geometric Batch object
"""
def __init__(self):

super().__init__(self)

def extract(self, labels):
"""
Perform extraction of labels on node features (x)

:param labels: Labels to extract
:type labels: list[str] | tuple[str] | str
:return: Batch object with extraction performed on x
:rtype: PinaBatch
"""
self.x = self.x.extract(labels)
return self


class PinaGraphDataset(PinaDataset):

def __init__(self, conditions_dict, max_conditions_lengths,
automatic_batching):
super().__init__(conditions_dict, max_conditions_lengths)
self.in_labels = {}
self.out_labels = None
if automatic_batching:
self._getitem_func = self._getitem_int
else:
self._getitem_func = self._getitem_list
self._getitem_func = self._getitem_dummy

ex_data = conditions_dict[list(conditions_dict.keys())[
0]]['input_points'][0]
for name, attr in ex_data.items():
if isinstance(attr, LabelTensor):
self.in_labels[name] = attr.stored_labels
ex_data = conditions_dict[list(conditions_dict.keys())[
0]]['output_points'][0]
if isinstance(ex_data, LabelTensor):
self.out_labels = ex_data.labels

self._create_graph_batch_from_list = self._labelise_batch(
self._base_create_graph_batch_from_list) if self.in_labels \
else self._base_create_graph_batch_from_list

self._create_output_batch = self._labelise_tensor(
self._base_create_output_batch) if self.out_labels is not None \
else self._base_create_output_batch

def fetch_from_idx_list(self, idx):
to_return_dict = {}
Expand All @@ -119,17 +162,24 @@ def fetch_from_idx_list(self, idx):
condition_len = self.conditions_length[condition]
if self.length > condition_len:
cond_idx = [idx % condition_len for idx in cond_idx]
to_return_dict[condition] = {k: Batch.from_data_list([
v[i] for i in cond_idx])
if isinstance(v, list)
else v[
cond_idx].reshape(
-1, *v[cond_idx].shape[2:])
for k, v in data.items()
}
to_return_dict[condition] = {
k: self._create_graph_batch_from_list([v[i] for i in idx])
if isinstance(v, list)
else self._create_output_batch(v[idx])
for k, v in data.items()
}

return to_return_dict

def _getitem_list(self, idx):
def _base_create_graph_batch_from_list(self, data):
batch = PinaBatch.from_data_list(data)
return batch

def _base_create_output_batch(self, data):
out = data.reshape(-1, *data.shape[2:])
return out

def _getitem_dummy(self, idx):
return idx

def _getitem_int(self, idx):
Expand All @@ -144,3 +194,31 @@ def get_all_data(self):

def __getitem__(self, idx):
return self._getitem_func(idx)

def _labelise_batch(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
batch = func(*args, **kwargs)
for k, v in self.in_labels.items():
tmp = batch[k]
tmp.labels = v
batch[k] = tmp
return batch
return wrapper

def _labelise_tensor(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
out = func(*args, **kwargs)
if isinstance(out, LabelTensor):
out.labels = self.out_labels
return out
return wrapper

def create_graph_batch(self, data):
"""
# TODO
"""
if isinstance(data[0], Data):
return self._create_graph_batch_from_list(data)
return self._create_output_batch(data)
18 changes: 11 additions & 7 deletions pina/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,14 @@ def __init__(
x)

# Perform the graph construction
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
self._build_graph_list(
x, pos, edge_index, edge_attr, additional_params)

def _build_graph_list(self, x, pos, edge_index, edge_attr,
additional_params):
for i, (x_, pos_, edge_index_) in enumerate(zip(x, pos, edge_index)):
if isinstance(x_, LabelTensor):
x_ = x_.tensor
add_params_local = {k: v[i] for k, v in additional_params.items()}
if edge_attr is not None:

self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_,
edge_attr=edge_attr[i],
**add_params_local))
Expand All @@ -127,7 +125,8 @@ def _build_graph_list(self, x, pos, edge_index, edge_attr,

@staticmethod
def _build_edge_attr(x, pos, edge_index):
distance = torch.abs(pos[edge_index[0]] - pos[edge_index[1]])
distance = torch.abs(pos[edge_index[0]] -
pos[edge_index[1]]).as_subclass(torch.Tensor)
return distance

@staticmethod
Expand Down Expand Up @@ -165,7 +164,8 @@ def _check_input_consistency(x, pos, edge_index=None):
# If edge_index is a 3D tensor, we split it into a list of 2D tensors
if edge_index is not None:
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
edge_index = [edge_index[i]
for i in range(edge_index.shape[0])]
elif not (isinstance(edge_index, list) and all(
t.ndim == 2 for t in edge_index)) and not (
isinstance(edge_index,
Expand Down Expand Up @@ -219,7 +219,7 @@ def _check_and_build_edge_attr(self, edge_attr, build_edge_attr, data_len,
if isinstance(edge_attr, list):
if len(edge_attr) != data_len:
raise TypeError("edge_attr must have the same length as x "
"and pos.")
"and pos.")
return [edge_attr] * data_len

if build_edge_attr:
Expand Down Expand Up @@ -258,6 +258,8 @@ def _radius_graph(points, r):
"""
dist = torch.cdist(points, points, p=2)
edge_index = torch.nonzero(dist <= r, as_tuple=False).t()
if isinstance(edge_index, LabelTensor):
edge_index = edge_index.tensor
return edge_index


Expand Down Expand Up @@ -293,4 +295,6 @@ def _knn_graph(points, k):
row = torch.arange(points.size(0)).repeat_interleave(k)
col = knn_indices.flatten()
edge_index = torch.stack([row, col], dim=0)
if isinstance(edge_index, LabelTensor):
edge_index = edge_index.tensor
return edge_index
Loading
Loading