diff --git a/pina/collector.py b/pina/collector.py index 381e9499c..93ea18254 100644 --- a/pina/collector.py +++ b/pina/collector.py @@ -1,3 +1,7 @@ +""" +# TODO +""" +from .graph import Graph from .utils import check_consistency @@ -52,6 +56,8 @@ def store_fixed_data(self): # get data keys = condition.__slots__ values = [getattr(condition, name) for name in keys] + values = [value.data if isinstance( + value, Graph) else value for value in values] self.data_collections[condition_name] = dict(zip(keys, values)) # condition now is ready self._is_conditions_ready[condition_name] = True diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 20b3c1c29..9ecfaa5ad 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -2,11 +2,11 @@ import warnings from lightning.pytorch import LightningDataModule import torch -from ..label_tensor import LabelTensor -from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \ - RandomSampler +from torch_geometric.data import Data +from torch.utils.data import DataLoader, SequentialSampler, RandomSampler from torch.utils.data.distributed import DistributedSampler -from .dataset import PinaDatasetFactory +from ..label_tensor import LabelTensor +from .dataset import PinaDatasetFactory, PinaTensorDataset from ..collector import Collector @@ -61,6 +61,10 @@ def __init__(self, max_conditions_lengths, dataset=None): max_conditions_lengths is None else ( self._collate_standard_dataloader) self.dataset = dataset + if isinstance(self.dataset, PinaTensorDataset): + self._collate = self._collate_tensor_dataset + else: + self._collate = self._collate_graph_dataset def _collate_custom_dataloader(self, batch): return self.dataset.fetch_from_idx_list(batch) @@ -73,7 +77,6 @@ def _collate_standard_dataloader(self, batch): if isinstance(batch, dict): return batch conditions_names = batch[0].keys() - # Condition names for condition_name in conditions_names: single_cond_dict = {} @@ -82,16 +85,28 @@ def _collate_standard_dataloader(self, batch): data_list = [batch[idx][condition_name][arg] for idx in range( min(len(batch), self.max_conditions_lengths[condition_name]))] - if isinstance(data_list[0], LabelTensor): - single_cond_dict[arg] = LabelTensor.stack(data_list) - elif isinstance(data_list[0], torch.Tensor): - single_cond_dict[arg] = torch.stack(data_list) - else: - raise NotImplementedError( - f"Data type {type(data_list[0])} not supported") + single_cond_dict[arg] = self._collate(data_list) + batch_dict[condition_name] = single_cond_dict return batch_dict + @staticmethod + def _collate_tensor_dataset(data_list): + if isinstance(data_list[0], LabelTensor): + return LabelTensor.stack(data_list) + if isinstance(data_list[0], torch.Tensor): + return torch.stack(data_list) + raise RuntimeError("Data must be Tensors or LabelTensor ") + + def _collate_graph_dataset(self, data_list): + if isinstance(data_list[0], LabelTensor): + return LabelTensor.cat(data_list) + if isinstance(data_list[0], torch.Tensor): + return torch.cat(data_list) + if isinstance(data_list[0], Data): + return self.dataset.create_graph_batch(data_list) + raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data") + def __call__(self, batch): return self.callable_function(batch) @@ -125,7 +140,7 @@ def __init__(self, batch_size=None, shuffle=True, repeat=False, - automatic_batching=False, + automatic_batching=None, num_workers=0, pin_memory=False, ): @@ -158,15 +173,35 @@ def __init__(self, logging.debug('Start initialization of Pina DataModule') logging.info('Start initialization of Pina DataModule') super().__init__() - self.automatic_batching = automatic_batching + + # Store fixed attributes self.batch_size = batch_size self.shuffle = shuffle self.repeat = repeat + self.automatic_batching = automatic_batching + if batch_size is None and num_workers != 0: + warnings.warn( + "Setting num_workers when batch_size is None has no effect on " + "the DataLoading process.") + self.num_workers = 0 + else: + self.num_workers = num_workers + if batch_size is None and pin_memory: + warnings.warn("Setting pin_memory to True has no effect when " + "batch_size is None.") + self.pin_memory = False + else: + self.pin_memory = pin_memory + + # Collect data + collector = Collector(problem) + collector.store_fixed_data() + collector.store_sample_domains() # Check if the splits are correct self._check_slit_sizes(train_size, test_size, val_size, predict_size) - # Begin Data splitting + # Split input data into subsets splits_dict = {} if train_size > 0: splits_dict['train'] = train_size @@ -188,19 +223,6 @@ def __init__(self, self.predict_dataset = None else: self.predict_dataloader = super().predict_dataloader - - collector = Collector(problem) - collector.store_fixed_data() - collector.store_sample_domains() - if batch_size is None and num_workers != 0: - warnings.warn( - "Setting num_workers when batch_size is None has no effect on " - "the DataLoading process.") - if batch_size is None and pin_memory: - warnings.warn("Setting pin_memory to True has no effect when " - "batch_size is None.") - self.num_workers = num_workers - self.pin_memory = pin_memory self.collector_splits = self._create_splits(collector, splits_dict) self.transfer_batch_to_device = self._transfer_batch_to_device @@ -316,10 +338,10 @@ def _create_dataloader(self, split, dataset): if self.batch_size is not None: sampler = PinaSampler(dataset, shuffle) if self.automatic_batching: - collate = Collator(self.find_max_conditions_lengths(split)) - + collate = Collator(self.find_max_conditions_lengths(split), + dataset=dataset) else: - collate = Collator(None, dataset) + collate = Collator(None, dataset=dataset) return DataLoader(dataset, self.batch_size, collate_fn=collate, sampler=sampler, num_workers=self.num_workers) diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 02450400b..2fecb9348 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -1,10 +1,12 @@ """ This module provide basic data management functionalities """ +import functools import torch from torch.utils.data import Dataset from abc import abstractmethod -from torch_geometric.data import Batch +from torch_geometric.data import Batch, Data +from pina import LabelTensor class PinaDatasetFactory: @@ -62,7 +64,7 @@ def __init__(self, conditions_dict, max_conditions_lengths, if automatic_batching: self._getitem_func = self._getitem_int else: - self._getitem_func = self._getitem_list + self._getitem_func = self._getitem_dummy def _getitem_int(self, idx): return { @@ -82,7 +84,7 @@ def fetch_from_idx_list(self, idx): return to_return_dict @staticmethod - def _getitem_list(idx): + def _getitem_dummy(idx): return idx def get_all_data(self): @@ -102,15 +104,56 @@ def input_points(self): } +class PinaBatch(Batch): + """ + Add extract function to torch_geometric Batch object + """ + def __init__(self): + + super().__init__(self) + + def extract(self, labels): + """ + Perform extraction of labels on node features (x) + + :param labels: Labels to extract + :type labels: list[str] | tuple[str] | str + :return: Batch object with extraction performed on x + :rtype: PinaBatch + """ + self.x = self.x.extract(labels) + return self + + class PinaGraphDataset(PinaDataset): def __init__(self, conditions_dict, max_conditions_lengths, automatic_batching): super().__init__(conditions_dict, max_conditions_lengths) + self.in_labels = {} + self.out_labels = None if automatic_batching: self._getitem_func = self._getitem_int else: - self._getitem_func = self._getitem_list + self._getitem_func = self._getitem_dummy + + ex_data = conditions_dict[list(conditions_dict.keys())[ + 0]]['input_points'][0] + for name, attr in ex_data.items(): + if isinstance(attr, LabelTensor): + self.in_labels[name] = attr.stored_labels + ex_data = conditions_dict[list(conditions_dict.keys())[ + 0]]['output_points'][0] + if isinstance(ex_data, LabelTensor): + self.out_labels = ex_data.labels + + self._create_graph_batch_from_list = self._labelise_batch( + self._base_create_graph_batch_from_list) if self.in_labels \ + else self._base_create_graph_batch_from_list + + self._create_output_batch = self._labelise_tensor( + self._base_create_output_batch) if self.out_labels is not None \ + else self._base_create_output_batch def fetch_from_idx_list(self, idx): to_return_dict = {} @@ -119,17 +162,24 @@ def fetch_from_idx_list(self, idx): condition_len = self.conditions_length[condition] if self.length > condition_len: cond_idx = [idx % condition_len for idx in cond_idx] - to_return_dict[condition] = {k: Batch.from_data_list([ - v[i] for i in cond_idx]) - if isinstance(v, list) - else v[ - cond_idx].reshape( - -1, *v[cond_idx].shape[2:]) - for k, v in data.items() - } + to_return_dict[condition] = { + k: self._create_graph_batch_from_list([v[i] for i in idx]) + if isinstance(v, list) + else self._create_output_batch(v[idx]) + for k, v in data.items() + } + return to_return_dict - def _getitem_list(self, idx): + def _base_create_graph_batch_from_list(self, data): + batch = PinaBatch.from_data_list(data) + return batch + + def _base_create_output_batch(self, data): + out = data.reshape(-1, *data.shape[2:]) + return out + + def _getitem_dummy(self, idx): return idx def _getitem_int(self, idx): @@ -144,3 +194,31 @@ def get_all_data(self): def __getitem__(self, idx): return self._getitem_func(idx) + + def _labelise_batch(self, func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + batch = func(*args, **kwargs) + for k, v in self.in_labels.items(): + tmp = batch[k] + tmp.labels = v + batch[k] = tmp + return batch + return wrapper + + def _labelise_tensor(self, func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + out = func(*args, **kwargs) + if isinstance(out, LabelTensor): + out.labels = self.out_labels + return out + return wrapper + + def create_graph_batch(self, data): + """ + # TODO + """ + if isinstance(data[0], Data): + return self._create_graph_batch_from_list(data) + return self._create_output_batch(data) diff --git a/pina/graph.py b/pina/graph.py index 959bd9cc0..ca92ab435 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -108,16 +108,14 @@ def __init__( x) # Perform the graph construction - self._build_graph_list(x, pos, edge_index, edge_attr, additional_params) + self._build_graph_list( + x, pos, edge_index, edge_attr, additional_params) def _build_graph_list(self, x, pos, edge_index, edge_attr, additional_params): for i, (x_, pos_, edge_index_) in enumerate(zip(x, pos, edge_index)): - if isinstance(x_, LabelTensor): - x_ = x_.tensor add_params_local = {k: v[i] for k, v in additional_params.items()} if edge_attr is not None: - self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_, edge_attr=edge_attr[i], **add_params_local)) @@ -127,7 +125,8 @@ def _build_graph_list(self, x, pos, edge_index, edge_attr, @staticmethod def _build_edge_attr(x, pos, edge_index): - distance = torch.abs(pos[edge_index[0]] - pos[edge_index[1]]) + distance = torch.abs(pos[edge_index[0]] - + pos[edge_index[1]]).as_subclass(torch.Tensor) return distance @staticmethod @@ -165,7 +164,8 @@ def _check_input_consistency(x, pos, edge_index=None): # If edge_index is a 3D tensor, we split it into a list of 2D tensors if edge_index is not None: if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3: - edge_index = [edge_index[i] for i in range(edge_index.shape[0])] + edge_index = [edge_index[i] + for i in range(edge_index.shape[0])] elif not (isinstance(edge_index, list) and all( t.ndim == 2 for t in edge_index)) and not ( isinstance(edge_index, @@ -219,7 +219,7 @@ def _check_and_build_edge_attr(self, edge_attr, build_edge_attr, data_len, if isinstance(edge_attr, list): if len(edge_attr) != data_len: raise TypeError("edge_attr must have the same length as x " - "and pos.") + "and pos.") return [edge_attr] * data_len if build_edge_attr: @@ -258,6 +258,8 @@ def _radius_graph(points, r): """ dist = torch.cdist(points, points, p=2) edge_index = torch.nonzero(dist <= r, as_tuple=False).t() + if isinstance(edge_index, LabelTensor): + edge_index = edge_index.tensor return edge_index @@ -293,4 +295,6 @@ def _knn_graph(points, k): row = torch.arange(points.size(0)).repeat_interleave(k) col = knn_indices.flatten() edge_index = torch.stack([row, col], dim=0) + if isinstance(edge_index, LabelTensor): + edge_index = edge_index.tensor return edge_index diff --git a/pina/trainer.py b/pina/trainer.py index 0d15e7699..eb8639e16 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -105,9 +105,9 @@ def __init__(self, # checking compilation and automatic batching if compile is None or sys.platform == "win32": compile = False - if automatic_batching is None: - automatic_batching = False + self.automatic_batching = automatic_batching if automatic_batching \ + is not None else False # set attributes self.compile = compile self.solver = solver @@ -115,7 +115,7 @@ def __init__(self, self._move_to_device() self.data_module = None self._create_datamodule(train_size, test_size, val_size, predict_size, - batch_size, automatic_batching, pin_memory, + batch_size, automatic_batching, pin_memory, num_workers) # logging diff --git a/tests/test_data/test_datamodule.py b/tests/test_data/test_datamodule.py index 866eebc69..f475c0498 100644 --- a/tests/test_data/test_datamodule.py +++ b/tests/test_data/test_datamodule.py @@ -13,10 +13,10 @@ input_tensor = torch.rand((100, 10)) output_tensor = torch.rand((100, 2)) -x = torch.rand((100, 50 , 10)) -pos = torch.rand((100, 50 , 2)) +x = torch.rand((100, 50, 10)) +pos = torch.rand((100, 50, 2)) input_graph = RadiusGraph(x, pos, r=.1, build_edge_attr=True) -output_graph = torch.rand((100, 50 , 10)) +output_graph = torch.rand((100, 50, 10)) @pytest.mark.parametrize( @@ -30,6 +30,7 @@ def test_constructor(input_, output_): problem = SupervisedProblem(input_=input_, output_=output_) PinaDataModule(problem) + @pytest.mark.parametrize( "input_, output_", [ @@ -46,14 +47,15 @@ def test_constructor(input_, output_): ) def test_setup_train(input_, output_, train_size, val_size, test_size): problem = SupervisedProblem(input_=input_, output_=output_) - dm = PinaDataModule(problem, train_size=train_size, val_size=val_size, test_size=test_size) + dm = PinaDataModule(problem, train_size=train_size, + val_size=val_size, test_size=test_size) dm.setup() assert hasattr(dm, "train_dataset") if isinstance(input_, torch.Tensor): assert isinstance(dm.train_dataset, PinaTensorDataset) else: assert isinstance(dm.train_dataset, PinaGraphDataset) - #assert len(dm.train_dataset) == int(len(input_) * train_size) + # assert len(dm.train_dataset) == int(len(input_) * train_size) if test_size > 0: assert hasattr(dm, "test_dataset") assert dm.test_dataset is None @@ -64,7 +66,8 @@ def test_setup_train(input_, output_, train_size, val_size, test_size): assert isinstance(dm.val_dataset, PinaTensorDataset) else: assert isinstance(dm.val_dataset, PinaGraphDataset) - #assert len(dm.val_dataset) == int(len(input_) * val_size) + # assert len(dm.val_dataset) == int(len(input_) * val_size) + @pytest.mark.parametrize( "input_, output_", @@ -82,7 +85,8 @@ def test_setup_train(input_, output_, train_size, val_size, test_size): ) def test_setup_test(input_, output_, train_size, val_size, test_size): problem = SupervisedProblem(input_=input_, output_=output_) - dm = PinaDataModule(problem, train_size=train_size, val_size=val_size, test_size=test_size) + dm = PinaDataModule(problem, train_size=train_size, + val_size=val_size, test_size=test_size) dm.setup(stage='test') if train_size > 0: assert hasattr(dm, "train_dataset") @@ -94,13 +98,14 @@ def test_setup_test(input_, output_, train_size, val_size, test_size): assert dm.val_dataset is None else: assert not hasattr(dm, "val_dataset") - + assert hasattr(dm, "test_dataset") if isinstance(input_, torch.Tensor): assert isinstance(dm.test_dataset, PinaTensorDataset) else: assert isinstance(dm.test_dataset, PinaGraphDataset) - #assert len(dm.test_dataset) == int(len(input_) * test_size) + # assert len(dm.test_dataset) == int(len(input_) * test_size) + @pytest.mark.parametrize( "input_, output_", @@ -112,7 +117,8 @@ def test_setup_test(input_, output_, train_size, val_size, test_size): def test_dummy_dataloader(input_, output_): problem = SupervisedProblem(input_=input_, output_=output_) solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) - trainer = Trainer(solver, batch_size=None, train_size=.7, val_size=.3, test_size=0.) + trainer = Trainer(solver, batch_size=None, train_size=.7, + val_size=.3, test_size=0.) dm = trainer.data_module dm.setup() dm.trainer = trainer @@ -140,6 +146,7 @@ def test_dummy_dataloader(input_, output_): assert isinstance(data[0][1]['input_points'], torch.Tensor) assert isinstance(data[0][1]['output_points'], torch.Tensor) + @pytest.mark.parametrize( "input_, output_", [ @@ -147,10 +154,17 @@ def test_dummy_dataloader(input_, output_): (input_graph, output_graph) ] ) -def test_dataloader(input_, output_): +@pytest.mark.parametrize( + "automatic_batching", + [ + True, False + ] +) +def test_dataloader(input_, output_, automatic_batching): problem = SupervisedProblem(input_=input_, output_=output_) solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) - trainer = Trainer(solver, batch_size=10, train_size=.7, val_size=.3, test_size=0.) + trainer = Trainer(solver, batch_size=10, train_size=.7, val_size=.3, + test_size=0., automatic_batching=automatic_batching) dm = trainer.data_module dm.setup() dm.trainer = trainer @@ -176,3 +190,67 @@ def test_dataloader(input_, output_): assert isinstance(data['data']['input_points'], torch.Tensor) assert isinstance(data['data']['output_points'], torch.Tensor) +from pina import LabelTensor + +input_tensor = LabelTensor(torch.rand((100, 3)), ['u', 'v', 'w']) +output_tensor = LabelTensor(torch.rand((100, 3)), ['u', 'v', 'w']) + +x = LabelTensor(torch.rand((100, 50, 3)), ['u', 'v', 'w']) +pos = LabelTensor(torch.rand((100, 50, 2)), ['x', 'y']) +input_graph = RadiusGraph(x, pos, r=.1, build_edge_attr=True) +output_graph = LabelTensor(torch.rand((100, 50, 3)), ['u', 'v', 'w']) + +@pytest.mark.parametrize( + "input_, output_", + [ + (input_tensor, output_tensor), + (input_graph, output_graph) + ] +) +@pytest.mark.parametrize( + "automatic_batching", + [ + True, False + ] +) +def test_dataloader_labels(input_, output_, automatic_batching): + problem = SupervisedProblem(input_=input_, output_=output_) + solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) + trainer = Trainer(solver, batch_size=10, train_size=.7, val_size=.3, + test_size=0., automatic_batching=automatic_batching) + dm = trainer.data_module + dm.setup() + dm.trainer = trainer + dataloader = dm.train_dataloader() + assert isinstance(dataloader, DataLoader) + assert len(dataloader) == 7 + data = next(iter(dataloader)) + assert isinstance(data, dict) + if isinstance(input_, RadiusGraph): + assert isinstance(data['data']['input_points'], Batch) + assert isinstance(data['data']['input_points'].x, LabelTensor) + assert data['data']['input_points'].x.labels == ['u', 'v', 'w'] + assert data['data']['input_points'].pos.labels == ['x', 'y'] + else: + assert isinstance(data['data']['input_points'], LabelTensor) + assert data['data']['input_points'].labels == ['u', 'v', 'w'] + assert isinstance(data['data']['output_points'], LabelTensor) + assert data['data']['output_points'].labels == ['u', 'v', 'w'] + + dataloader = dm.val_dataloader() + assert isinstance(dataloader, DataLoader) + assert len(dataloader) == 3 + data = next(iter(dataloader)) + assert isinstance(data, dict) + if isinstance(input_, RadiusGraph): + assert isinstance(data['data']['input_points'], Batch) + assert isinstance(data['data']['input_points'].x, LabelTensor) + assert data['data']['input_points'].x.labels == ['u', 'v', 'w'] + assert data['data']['input_points'].pos.labels == ['x', 'y'] + else: + assert isinstance(data['data']['input_points'], torch.Tensor) + assert isinstance(data['data']['input_points'], LabelTensor) + assert data['data']['input_points'].labels == ['u', 'v', 'w'] + assert isinstance(data['data']['output_points'], torch.Tensor) + assert data['data']['output_points'].labels == ['u', 'v', 'w'] +test_dataloader_labels(input_graph, output_graph, True) \ No newline at end of file