diff --git a/pina/collector.py b/pina/collector.py index c8e816039..09e3df130 100644 --- a/pina/collector.py +++ b/pina/collector.py @@ -56,10 +56,6 @@ def store_fixed_data(self): # get data keys = condition.__slots__ values = [getattr(condition, name) for name in keys] - values = [ - value.data if isinstance(value, Graph) else value - for value in values - ] self.data_collections[condition_name] = dict(zip(keys, values)) # condition now is ready self._is_conditions_ready[condition_name] = True diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 3893e34cf..f75bde065 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -4,9 +4,15 @@ "DomainEquationCondition", "InputPointsEquationCondition", "InputOutputPointsCondition", + "GraphInputOutputCondition", + "GraphDataCondition", + "GraphInputEquationCondition", ] from .condition_interface import ConditionInterface from .domain_equation_condition import DomainEquationCondition from .input_equation_condition import InputPointsEquationCondition from .input_output_condition import InputOutputPointsCondition +from .graph_condition import GraphInputOutputCondition +from .graph_condition import GraphDataCondition +from .graph_condition import GraphInputEquationCondition diff --git a/pina/condition/condition.py b/pina/condition/condition.py index e01db1f0f..ae912f1e6 100644 --- a/pina/condition/condition.py +++ b/pina/condition/condition.py @@ -4,6 +4,10 @@ from .input_equation_condition import InputPointsEquationCondition from .input_output_condition import InputOutputPointsCondition from .data_condition import DataConditionInterface +from .graph_condition import ( + GraphInputOutputCondition, + GraphInputEquationCondition, +) import warnings from ..utils import custom_warning_format @@ -82,5 +86,9 @@ def __new__(cls, *args, **kwargs): return DataConditionInterface(**kwargs) elif sorted_keys == DataConditionInterface.__slots__[0]: return DataConditionInterface(**kwargs) + elif sorted_keys == sorted(GraphInputOutputCondition.__slots__): + return GraphInputOutputCondition(**kwargs) + elif sorted_keys == sorted(GraphInputEquationCondition.__slots__): + return GraphInputEquationCondition(**kwargs) else: raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.") diff --git a/pina/condition/condition_interface.py b/pina/condition/condition_interface.py index a9d62fd45..929bea06c 100644 --- a/pina/condition/condition_interface.py +++ b/pina/condition/condition_interface.py @@ -3,10 +3,7 @@ class ConditionInterface(metaclass=ABCMeta): - condition_types = ["physics", "supervised", "unsupervised"] - def __init__(self, *args, **kwargs): - self._condition_type = None self._problem = None @property @@ -16,19 +13,3 @@ def problem(self): @problem.setter def problem(self, value): self._problem = value - - @property - def condition_type(self): - return self._condition_type - - @condition_type.setter - def condition_type(self, values): - if not isinstance(values, (list, tuple)): - values = [values] - for value in values: - if value not in ConditionInterface.condition_types: - raise ValueError( - "Unavailable type of condition, expected one of" - f" {ConditionInterface.condition_types}." - ) - self._condition_type = values diff --git a/pina/condition/graph_condition.py b/pina/condition/graph_condition.py new file mode 100644 index 000000000..0b623b12f --- /dev/null +++ b/pina/condition/graph_condition.py @@ -0,0 +1,68 @@ +from .condition_interface import ConditionInterface +from ..graph import Graph +from ..utils import check_consistency +from torch_geometric.data import Data +from ..equation.equation_interface import EquationInterface + + +class GraphCondition(ConditionInterface): + """ + TODO + """ + + __slots__ = ["graph"] + + def __new__(cls, graph): + """ + TODO : add docstring + """ + check_consistency(graph, (Graph, Data)) + graph = [graph] if isinstance(graph, Data) else graph + + if all(g.y is not None for g in graph): + return super().__new__(GraphInputOutputCondition) + else: + return super().__new__(GraphDataCondition) + + def __init__(self, graph): + + super().__init__() + self.graph = graph + + def __setattr__(self, key, value): + if key == "graph": + check_consistency(value, (Graph, Data)) + GraphCondition.__dict__[key].__set__(self, value) + elif key in ("_problem", "_condition_type"): + super().__setattr__(key, value) + + +class GraphInputEquationCondition(ConditionInterface): + + __slots__ = ["graph", "equation"] + + def __init__(self, graph, equation): + super().__init__() + self.graph = graph + self.equation = equation + + def __setattr__(self, key, value): + if key == "graph": + check_consistency(value, (Graph, Data)) + GraphInputEquationCondition.__dict__[key].__set__(self, value) + elif key == "equation": + check_consistency(value, (EquationInterface)) + GraphInputEquationCondition.__dict__[key].__set__(self, value) + elif key in ("_problem", "_condition_type"): + super().__setattr__(key, value) + + +# The split between GraphInputOutputCondition and GraphDataCondition +# distinguishes different types of graph conditions passed to problems. +# This separation simplifies consistency checks during problem creation. +class GraphDataCondition(GraphCondition): + pass + + +class GraphInputOutputCondition(GraphCondition): + pass diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 288e00e2e..260688ec5 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -1,8 +1,12 @@ +""" +This module contains the PinaDataModule class, which is used to manage the +datasets and dataloaders in the PINA. +""" + import logging import warnings from lightning.pytorch import LightningDataModule import torch -from torch_geometric.data import Data from torch.utils.data import DataLoader, SequentialSampler, RandomSampler from torch.utils.data.distributed import DistributedSampler from ..label_tensor import LabelTensor @@ -11,7 +15,7 @@ class DummyDataloader: - """ " + """ Dummy dataloader used when batch size is None. It callects all the data in self.dataset and returns it when it is called a single batch. """ @@ -58,7 +62,18 @@ def __next__(self): class Collator: + """ + Class used to collate retrieved data from the dataset. + """ + def __init__(self, max_conditions_lengths, dataset=None): + """ + Initialize the object, setting the right function to collate the data + based on the input dataset. + + :param dict max_conditions_lengths: maximum length of each condition + :param PinaDataset dataset: The dataset object to be processed. + """ self.max_conditions_lengths = max_conditions_lengths self.callable_function = ( self._collate_custom_dataloader @@ -84,58 +99,46 @@ def _collate_standard_dataloader(self, batch): conditions_names = batch[0].keys() # Condition names for condition_name in conditions_names: - single_cond_dict = {} condition_args = batch[0][condition_name].keys() - for arg in condition_args: - data_list = [ - batch[idx][condition_name][arg] - for idx in range( - min( - len(batch), - self.max_conditions_lengths[condition_name], - ) - ) - ] - single_cond_dict[arg] = self._collate(data_list) - - batch_dict[condition_name] = single_cond_dict + batch_dict[condition_name] = self._collate( + condition_args, condition_name, batch + ) return batch_dict - @staticmethod - def _collate_tensor_dataset(data_list): - if isinstance(data_list[0], LabelTensor): - return LabelTensor.stack(data_list) - if isinstance(data_list[0], torch.Tensor): - return torch.stack(data_list) - raise RuntimeError("Data must be Tensors or LabelTensor ") - - def _collate_graph_dataset(self, data_list): - if isinstance(data_list[0], LabelTensor): - return LabelTensor.cat(data_list) - if isinstance(data_list[0], torch.Tensor): - return torch.cat(data_list) - if isinstance(data_list[0], Data): - return self.dataset.create_graph_batch(data_list) - raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data") - - def __call__(self, batch): - return self.callable_function(batch) + def _collate_tensor_dataset(self, condition_args, condition_name, batch): + to_return_dict = {} + for arg in condition_args: + data_list = [ + batch[idx][condition_name][arg] + for idx in range( + min(len(batch), self.max_conditions_lengths[condition_name]) + ) + ] + if isinstance(data_list[0], LabelTensor): + data = LabelTensor.stack(data_list) + elif isinstance(data_list[0], torch.Tensor): + data = torch.stack(data_list) + else: + raise ValueError( + f"Data type {type(data_list[0])} not supported" + ) + to_return_dict[arg] = data + return to_return_dict -class PinaSampler: - def __new__(cls, dataset, shuffle): + def _collate_graph_dataset(self, condition_args, condition_name, batch): + data_list = [ + batch[idx][condition_name] + for idx in range( + min(len(batch), self.max_conditions_lengths[condition_name]) + ) + ] + return self.dataset.divide_batch( + batch=self.dataset.create_graph_batch(data_list) + ) - if ( - torch.distributed.is_available() - and torch.distributed.is_initialized() - ): - sampler = DistributedSampler(dataset, shuffle=shuffle) - else: - if shuffle: - sampler = RandomSampler(dataset) - else: - sampler = SequentialSampler(dataset) - return sampler + def __call__(self, batch): + return self.callable_function(batch) class PinaDataModule(LightningDataModule): @@ -169,9 +172,11 @@ def __init__( :type test_size: float :param val_size: Fraction or number of elements in the validation split. :type val_size: float - :param predict_size: Fraction or number of elements in the prediction split. + :param predict_size: Fraction or number of elements in the prediction + split. :type predict_size: float - :param batch_size: Batch size used for training. If None, the entire dataset is used per batch. + :param batch_size: Batch size used for training. If None, the entire + dataset is used per batch. :type batch_size: int or None :param shuffle: Whether to shuffle the dataset before splitting. :type shuffle: bool @@ -179,9 +184,11 @@ def __init__( :type repeat: bool :param automatic_batching: Whether to enable automatic batching. :type automatic_batching: bool - :param num_workers: Number of worker threads for data loading. Default 0 (serial loading) + :param num_workers: Number of worker threads for data loading. Default + 0 (serial loading) :type num_workers: int - :param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False) + :param pin_memory: Whether to use pinned memory for faster data + transfer to GPU. (Default False) :type pin_memory: bool """ logging.debug("Start initialization of Pina DataModule") @@ -251,7 +258,7 @@ def setup(self, stage=None): if stage == "fit" or stage is None: self.train_dataset = PinaDatasetFactory( self.collector_splits["train"], - max_conditions_lengths=self.find_max_conditions_lengths( + max_conditions_lengths=self._find_max_conditions_lengths( "train" ), automatic_batching=self.automatic_batching, @@ -259,7 +266,7 @@ def setup(self, stage=None): if "val" in self.collector_splits.keys(): self.val_dataset = PinaDatasetFactory( self.collector_splits["val"], - max_conditions_lengths=self.find_max_conditions_lengths( + max_conditions_lengths=self._find_max_conditions_lengths( "val" ), automatic_batching=self.automatic_batching, @@ -267,13 +274,15 @@ def setup(self, stage=None): elif stage == "test": self.test_dataset = PinaDatasetFactory( self.collector_splits["test"], - max_conditions_lengths=self.find_max_conditions_lengths("test"), + max_conditions_lengths=self._find_max_conditions_lengths( + "test" + ), automatic_batching=self.automatic_batching, ) elif stage == "predict": self.predict_dataset = PinaDatasetFactory( self.collector_splits["predict"], - max_conditions_lengths=self.find_max_conditions_lengths( + max_conditions_lengths=self._find_max_conditions_lengths( "predict" ), automatic_batching=self.automatic_batching, @@ -285,7 +294,7 @@ def setup(self, stage=None): @staticmethod def _split_condition(condition_dict, splits_dict): - len_condition = len(condition_dict["input_points"]) + len_condition = len(list(condition_dict.values())[0]) lengths = [ int(len_condition * length) for length in splits_dict.values() @@ -308,7 +317,7 @@ def _split_condition(condition_dict, splits_dict): if k != "equation" # Equations are NEVER dataloaded } - if offset + stage_len >= len_condition: + if offset + stage_len > len_condition: offset = len_condition - 1 continue offset += stage_len @@ -343,7 +352,7 @@ def _apply_shuffle(condition_dict, len_data): condition_name, condition_dict, ) in collector.data_collections.items(): - len_data = len(condition_dict["input_points"]) + len_data = len(list(condition_dict.values())[0]) if self.shuffle: _apply_shuffle(condition_dict, len_data) for key, data in self._split_condition( @@ -355,20 +364,22 @@ def _apply_shuffle(condition_dict, len_data): def _create_dataloader(self, split, dataset): shuffle = self.shuffle if split == "train" else False # Suppress the warning about num_workers. - # In many cases, especially for PINNs, serial data loading can outperform parallel data loading. + # In many cases, especially for PINNs, serial data loading can + # outperform parallel data loading. warnings.filterwarnings( "ignore", message=( - r"The '(train|val|test)_dataloader' does not have many workers which may be a bottleneck." + r"The '(train|val|test)_dataloader' does not have many workers " + "which may be a bottleneck." ), module="lightning.pytorch.trainer.connectors.data_connector", ) # Use custom batching (good if batch size is large) if self.batch_size is not None: - sampler = PinaSampler(dataset, shuffle) + sampler = self.sampler(dataset, shuffle) if self.automatic_batching: collate = Collator( - self.find_max_conditions_lengths(split), dataset=dataset + self._find_max_conditions_lengths(split), dataset=dataset ) else: collate = Collator(None, dataset=dataset) @@ -386,16 +397,16 @@ def _create_dataloader(self, split, dataset): self.transfer_batch_to_device = self._transfer_batch_to_device_dummy return dataloader - def find_max_conditions_lengths(self, split): + def _find_max_conditions_lengths(self, split): max_conditions_lengths = {} for k, v in self.collector_splits[split].items(): if self.batch_size is None: - max_conditions_lengths[k] = len(v["input_points"]) + max_conditions_lengths[k] = len(list(v.values())[0]) elif self.repeat: max_conditions_lengths[k] = self.batch_size else: max_conditions_lengths[k] = min( - len(v["input_points"]), self.batch_size + len(list(v.values())[0]), self.batch_size ) return max_conditions_lengths @@ -457,7 +468,10 @@ def _check_slit_sizes(train_size, test_size, val_size, predict_size): @property def input_points(self): """ - # TODO + Return the input points of the datasets + + :return: The input points of the datasets + :rtype dict """ to_return = {} if hasattr(self, "train_dataset") and self.train_dataset is not None: @@ -467,3 +481,20 @@ def input_points(self): if hasattr(self, "test_dataset") and self.test_dataset is not None: to_return = self.test_dataset.input_points return to_return + + @staticmethod + def sampler(dataset, shuffle): + """ + # TODO + """ + if ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + ): + sampler = DistributedSampler(dataset, shuffle=shuffle) + else: + if shuffle: + sampler = RandomSampler(dataset) + else: + sampler = SequentialSampler(dataset) + return sampler diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 3944ef436..7deaa5452 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -3,10 +3,9 @@ """ import functools -import torch -from torch.utils.data import Dataset from abc import abstractmethod -from torch_geometric.data import Batch, Data +from torch.utils.data import Dataset +from torch_geometric.data import Batch from pina import LabelTensor @@ -21,22 +20,15 @@ class PinaDatasetFactory: def __new__(cls, conditions_dict, **kwargs): if len(conditions_dict) == 0: raise ValueError("No conditions provided") + print(conditions_dict) + if all("graph" in list(v.keys()) for v in conditions_dict.values()): + return PinaGraphDataset(conditions_dict, **kwargs) if all( - [ - isinstance(v["input_points"], torch.Tensor) - for v in conditions_dict.values() - ] + "input_points" in list(v.keys()) for v in conditions_dict.values() ): return PinaTensorDataset(conditions_dict, **kwargs) - elif all( - [ - isinstance(v["input_points"], list) - for v in conditions_dict.values() - ] - ): - return PinaGraphDataset(conditions_dict, **kwargs) raise ValueError( - "Conditions must be either torch.Tensor or list of Data " "objects." + "Conditions must be either torch.Tensor or list of Data objects." ) @@ -48,10 +40,6 @@ class PinaDataset(Dataset): def __init__(self, conditions_dict, max_conditions_lengths): self.conditions_dict = conditions_dict self.max_conditions_lengths = max_conditions_lengths - self.conditions_length = { - k: len(v["input_points"]) for k, v in self.conditions_dict.items() - } - self.length = max(self.conditions_length.values()) def _get_max_len(self): max_len = 0 @@ -66,13 +54,41 @@ def __len__(self): def __getitem__(self, item): pass + def get_all_data(self): + """ + Get all the data from the dataset + + :return: dictionary with the data for each condition + :rtype: dict + """ + index = list(range(len(self))) + return self.fetch_from_idx_list(index) + class PinaTensorDataset(PinaDataset): + """ + Dataset class for torch.Tensor conditions + """ + def __init__( self, conditions_dict, max_conditions_lengths, automatic_batching ): + """ + Initialize the dataset, assign the conditions and maximum lengths + for each condition. Moreover, it sets the right function to get + the data from the dataset. + + :param dict conditions_dict: dictionary with conditions + :param dict max_conditions_lengths: maximum length of each condition + :param bool automatic_batching: if True, the dataset will return + a single condition for each index, otherwise it will return the + index itself + """ super().__init__(conditions_dict, max_conditions_lengths) - + self.conditions_length = { + k: len(v["input_points"]) for k, v in self.conditions_dict.items() + } + self.length = max(self.conditions_length.values()) if automatic_batching: self._getitem_func = self._getitem_int else: @@ -88,12 +104,26 @@ def _getitem_int(self, idx): } def fetch_from_idx_list(self, idx): + """ + Retrive data from the dataset given a list of indexes. + + :param list idx: list of indexes + :return: dictionary with the data for each condition + :rtype: dict + """ + # Set empty dictionary to_return_dict = {} + # Loop over conditions for condition, data in self.conditions_dict.items(): + # Get the indexes for the current condition cond_idx = idx[: self.max_conditions_lengths[condition]] + # Get the length of the current condition condition_len = self.conditions_length[condition] + # If the length of the dataset is greater than the length of the + # condition, we need to take the modulo of the index if self.length > condition_len: cond_idx = [idx % condition_len for idx in cond_idx] + # Store the data for the current condition to_return_dict[condition] = { k: v[cond_idx] for k, v in data.items() } @@ -103,10 +133,6 @@ def fetch_from_idx_list(self, idx): def _getitem_dummy(idx): return idx - def get_all_data(self): - index = [i for i in range(len(self))] - return self.fetch_from_idx_list(index) - def __getitem__(self, idx): return self._getitem_func(idx) @@ -118,34 +144,31 @@ def input_points(self): return {k: v["input_points"] for k, v in self.conditions_dict.items()} -class PinaBatch(Batch): +class PinaGraphDataset(PinaDataset): """ - Add extract function to torch_geometric Batch object + Dataset class for torch_geometric.data.Data and Graph conditions """ - def __init__(self): - - super().__init__(self) - - def extract(self, labels): - """ - Perform extraction of labels on node features (x) - - :param labels: Labels to extract - :type labels: list[str] | tuple[str] | str - :return: Batch object with extraction performed on x - :rtype: PinaBatch - """ - self.x = self.x.extract(labels) - return self - - -class PinaGraphDataset(PinaDataset): - def __init__( self, conditions_dict, max_conditions_lengths, automatic_batching ): + """ + Initialize the dataset, assign the conditions and maximum lengths + for each condition. Moreover, it sets the right function to get + the data from the dataset. + + :param dict conditions_dict: dictionary with conditions + :param dict max_conditions_lengths: maximum length of each condition + :param bool automatic_batching: if True, the dataset will return + a single condition for each index, otherwise it will return the + index itself + """ super().__init__(conditions_dict, max_conditions_lengths) + self.conditions_length = { + k: len(v["graph"]) for k, v in self.conditions_dict.items() + } + self.length = max(self.conditions_length.values()) + self.in_labels = {} self.out_labels = None if automatic_batching: @@ -153,72 +176,63 @@ def __init__( else: self._getitem_func = self._getitem_dummy - ex_data = conditions_dict[list(conditions_dict.keys())[0]][ - "input_points" - ][0] + ex_data = conditions_dict[list(conditions_dict.keys())[0]]["graph"][0] + for name, attr in ex_data.items(): if isinstance(attr, LabelTensor): self.in_labels[name] = attr.stored_labels - ex_data = conditions_dict[list(conditions_dict.keys())[0]][ - "output_points" - ][0] - if isinstance(ex_data, LabelTensor): - self.out_labels = ex_data.labels self._create_graph_batch_from_list = ( self._labelise_batch(self._base_create_graph_batch_from_list) if self.in_labels else self._base_create_graph_batch_from_list ) - - self._create_output_batch = ( - self._labelise_tensor(self._base_create_output_batch) - if self.out_labels is not None - else self._base_create_output_batch - ) + if hasattr(ex_data, "y"): + self.divide_batch = self._extract_output(self._divide_batch) + else: + self.divide_batch = self._divide_batch def fetch_from_idx_list(self, idx): + """ + Retrive data from the dataset given a list of indexes. + + :param list idx: list of indexes + :return: dictionary with the data for each condition + :rtype dict + """ to_return_dict = {} for condition, data in self.conditions_dict.items(): cond_idx = idx[: self.max_conditions_lengths[condition]] condition_len = self.conditions_length[condition] if self.length > condition_len: cond_idx = [idx % condition_len for idx in cond_idx] - to_return_dict[condition] = { - k: ( - self._create_graph_batch_from_list([v[i] for i in idx]) - if isinstance(v, list) - else self._create_output_batch(v[idx]) - ) - for k, v in data.items() - } + batch = self._create_graph_batch_from_list( + [data["graph"][i] for i in idx] + ) + to_return_dict[condition] = self.divide_batch(batch=batch) + return to_return_dict + def _divide_batch(self, batch): + """ + Divide the batch into input and output points + """ + to_return_dict = {} + to_return_dict["input_points"] = batch return to_return_dict def _base_create_graph_batch_from_list(self, data): - batch = PinaBatch.from_data_list(data) + batch = Batch.from_data_list(data) return batch - def _base_create_output_batch(self, data): - out = data.reshape(-1, *data.shape[2:]) - return out - def _getitem_dummy(self, idx): return idx def _getitem_int(self, idx): return { - k: { - k_data: v[k_data][idx % len(v["input_points"])] - for k_data in v.keys() - } + k: v["graph"][idx % len(v["graph"])] for k, v in self.conditions_dict.items() } - def get_all_data(self): - index = [i for i in range(len(self))] - return self.fetch_from_idx_list(index) - def __getitem__(self, idx): return self._getitem_func(idx) @@ -246,8 +260,43 @@ def wrapper(*args, **kwargs): def create_graph_batch(self, data): """ - # TODO + Create a graph batch from a list of Data objects. This method is + to be called from the outside. + + :param list data: list of Data or Graph objects + :return: Batch object + :rtype: torch_geometric.data.Batch """ - if isinstance(data[0], Data): - return self._create_graph_batch_from_list(data) - return self._create_output_batch(data) + return self._create_graph_batch_from_list(data) + + @staticmethod + def _extract_output(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + out = func(*args, **kwargs) + batch = kwargs["batch"] + # Copying y into ouput_points + out["output_points"] = batch.y + # Deleting y from batch + batch.y = None + # Store new batch withou y + out["input_points"] = batch + return out + + return wrapper + + @staticmethod + def _extract_cond_vars(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + out = func(*args, **kwargs) + batch = kwargs["batch"] + # Copying conditional_vars into conditional_vars dict item + out["conditional_variables"] = batch.conditional_vars + # Deleting conditional_vars from batch + batch.conditional_vars = None + # Store new batch withou conditional_vars + out["input_points"] = batch + return out + + return wrapper diff --git a/pina/graph.py b/pina/graph.py index 77e426e1d..11bc8412b 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -95,9 +95,10 @@ def _check_pos_consistency(pos): :param torch.Tensor pos: The position tensor. """ if pos is not None: - check_consistency(pos, (torch.Tensor, LabelTensor)) - if pos.ndim != 2: - raise ValueError("pos must be a 2D tensor.") + return + check_consistency(pos, (torch.Tensor, LabelTensor)) + if pos.ndim != 2: + raise ValueError("pos must be a 2D tensor.") @staticmethod def _check_edge_index_consistency(edge_index): @@ -120,16 +121,17 @@ def _check_edge_attr_consistency(edge_attr, edge_index): :param torch.Tensor edge_index: The edge index tensor. """ if edge_attr is not None: - check_consistency(edge_attr, (torch.Tensor, LabelTensor)) - if edge_attr.ndim != 2: - raise ValueError("edge_attr must be a 2D tensor.") - if edge_attr.size(0) != edge_index.size(1): - raise ValueError( - "edge_attr must have shape " - "[num_edges, num_edge_features], expected " - f"num_edges {edge_index.size(1)} " - f"got {edge_attr.size(0)}." - ) + return + check_consistency(edge_attr, (torch.Tensor, LabelTensor)) + if edge_attr.ndim != 2: + raise ValueError("edge_attr must be a 2D tensor.") + if edge_attr.size(0) != edge_index.size(1): + raise ValueError( + "edge_attr must have shape " + "[num_edges, num_edge_features], expected " + f"num_edges {edge_index.size(1)} " + f"got {edge_attr.size(0)}." + ) @staticmethod def _check_x_consistency(x, pos=None): @@ -139,15 +141,16 @@ def _check_x_consistency(x, pos=None): :param torch.Tensor pos: The position tensor. """ if x is not None: - check_consistency(x, (torch.Tensor, LabelTensor)) - if x.ndim != 2: - raise ValueError("x must be a 2D tensor.") - if pos is not None: - if x.size(0) != pos.size(0): - raise ValueError("Inconsistent number of nodes.") - if pos is not None: - if x.size(0) != pos.size(0): - raise ValueError("Inconsistent number of nodes.") + return + check_consistency(x, (torch.Tensor, LabelTensor)) + if x.ndim != 2: + raise ValueError("x must be a 2D tensor.") + if pos is not None: + if x.size(0) != pos.size(0): + raise ValueError("Inconsistent number of nodes.") + if pos is not None: + if x.size(0) != pos.size(0): + raise ValueError("Inconsistent number of nodes.") @staticmethod def _preprocess_edge_index(edge_index, undirected): @@ -162,6 +165,18 @@ def _preprocess_edge_index(edge_index, undirected): edge_index = to_undirected(edge_index) return edge_index + def extract(self, labels): + """ + Perform extraction of labels on node features (x) + + :param labels: Labels to extract + :type labels: list[str] | tuple[str] | str + :return: Batch object with extraction performed on x + :rtype: PinaBatch + """ + self.x = self.x.extract(labels) + return self + class GraphBuilder: """ @@ -198,28 +213,32 @@ def __new__( :return: A Graph instance constructed using the provided information. :rtype: Graph """ - edge_attr = cls._create_edge_attr( - pos, edge_index, edge_attr, custom_edge_func or cls._build_edge_attr - ) - return Graph( + tmp = Graph( x=x, - edge_index=edge_index, - edge_attr=edge_attr, pos=pos, + edge_index=edge_index, **kwargs, ) + edge_attr = cls._create_edge_attr( + tmp, edge_attr, custom_edge_func or cls._build_edge_attr + ) + tmp.edge_attr = edge_attr + return tmp + @staticmethod - def _create_edge_attr(pos, edge_index, edge_attr, func): + def _create_edge_attr(graph, edge_attr, func): check_consistency(edge_attr, bool) - if edge_attr: - if is_function(func): - return func(pos, edge_index) - raise ValueError("custom_edge_func must be a function.") - return None + if not edge_attr: + return None + if is_function(func): + return func(graph) + raise ValueError("custom_edge_func must be a function.") @staticmethod - def _build_edge_attr(pos, edge_index): + def _build_edge_attr(graph): + pos = graph.pos + edge_index = graph.edge_index return ( (pos[edge_index[0]] - pos[edge_index[1]]) .abs() diff --git a/pina/problem/zoo/supervised_problem.py b/pina/problem/zoo/supervised_problem.py index ef0406225..6529f4059 100644 --- a/pina/problem/zoo/supervised_problem.py +++ b/pina/problem/zoo/supervised_problem.py @@ -1,14 +1,27 @@ +import torch from pina.problem import AbstractProblem from pina import Condition from pina import Graph +from pina import LabelTensor -class SupervisedProblem(AbstractProblem): +class SupervisedProblem: + + def __new__(cls, *args, **kwargs): + + if sorted(list(kwargs.keys())) == sorted(["input_", "output_"]): + return SupervisedTensorProblem(**kwargs) + elif sorted(list(kwargs.keys())) == sorted(["graph_"]): + return SupervisedGraphProblem(**kwargs) + raise RuntimeError("Invalid arguments for SupervisedProblem") + + +class SupervisedTensorProblem(AbstractProblem): """ A problem definition for supervised learning in PINA. - This class allows an easy and straightforward definition of a Supervised problem, - based on a single condition of type `InputOutputPointsCondition` + This class allows an easy and straightforward definition of a Supervised + problem, based on a single condition of type `InputOutputPointsCondition` :Example: >>> import torch @@ -29,9 +42,54 @@ def __init__(self, input_, output_): :param output_: Output data of the problem :type output_: torch.Tensor """ - if isinstance(input_, Graph): - input_ = input_.data + if not isinstance(input_, (torch.Tensor, LabelTensor)): + raise ValueError( + "The input data must be a torch.Tensor or a LabelTensor" + ) + if not isinstance(output_, (torch.Tensor, LabelTensor)): + raise ValueError( + "The output data must be a torch.Tensor or a LabelTensor" + ) + if isinstance(output_, LabelTensor): + self.output_variables = output_.labels + self.conditions["data"] = Condition( input_points=input_, output_points=output_ ) super().__init__() + + +class SupervisedGraphProblem(AbstractProblem): + """ + A problem definition for supervised learning in PINA. + + This class allows an easy and straightforward definition of a Supervised problem, + based on a single condition of type `InputOutputPointsCondition` + + :Example: + >>> import torch + >>> from pina.graph import RadiusGraph + >>> x = torch.rand((10, 100, 10)) + >>> pos = torch.rand((10, 100, 2)) + >>> y = torch.rand((10, 100, 2)) + >>> input_data = RadiusGraph(x=x, pos=pos, r=.2, y=y) + >>> problem = SupervisedProblem(graph_=input_data) + """ + + conditions = dict() + output_variables = None + + def __init__(self, graph_): + """ + Initialize the SupervisedProblem class + + :param graph_: Input data of the problem + :type graph_: Graph + """ + if not isinstance(graph_, list) or not all( + isinstance(g, Graph) for g in graph_ + ): + raise ValueError("The input data must be a Graph") + + self.conditions["data"] = Condition(graph=graph_) + super().__init__() diff --git a/pina/solver/supervised.py b/pina/solver/supervised.py index 56771b84e..fb4ad27e1 100644 --- a/pina/solver/supervised.py +++ b/pina/solver/supervised.py @@ -5,7 +5,7 @@ from .solver import SingleSolverInterface from ..utils import check_consistency from ..loss.loss_interface import LossInterface -from ..condition import InputOutputPointsCondition +from ..condition import InputOutputPointsCondition, GraphInputOutputCondition class SupervisedSolver(SingleSolverInterface): @@ -37,7 +37,10 @@ class SupervisedSolver(SingleSolverInterface): multiple (discretised) input functions. """ - accepted_conditions_types = InputOutputPointsCondition + accepted_conditions_types = ( + InputOutputPointsCondition, + GraphInputOutputCondition, + ) def __init__( self, diff --git a/tests/test_condition.py b/tests/test_condition.py index f5842b978..95b3b7dff 100644 --- a/tests/test_condition.py +++ b/tests/test_condition.py @@ -3,11 +3,19 @@ from pina import LabelTensor, Condition from pina.domain import CartesianDomain +from pina.condition import ( + GraphInputOutputCondition, + GraphInputEquationCondition, +) from pina.equation.equation_factory import FixedValue +from pina.graph import RadiusGraph +from torch_geometric.data import Data +from pina.operator import laplacian +from pina.equation.equation import Equation -example_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) -example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ['x', 'y', 'z']) -example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b']) +example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) +example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ["x", "y", "z"]) +example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ["a", "b"]) def test_init_inputoutput(): @@ -15,20 +23,17 @@ def test_init_inputoutput(): with pytest.raises(ValueError): Condition(example_input_pts, example_output_pts) with pytest.raises(ValueError): - Condition(input_points=3., output_points='example') + Condition(input_points=3.0, output_points="example") with pytest.raises(ValueError): Condition(input_points=example_domain, output_points=example_domain) -test_init_inputoutput() - - def test_init_domainfunc(): Condition(domain=example_domain, equation=FixedValue(0.0)) with pytest.raises(ValueError): Condition(example_domain, FixedValue(0.0)) with pytest.raises(ValueError): - Condition(domain=3., equation='example') + Condition(domain=3.0, equation="example") with pytest.raises(ValueError): Condition(domain=example_input_pts, equation=example_output_pts) @@ -38,6 +43,78 @@ def test_init_inputfunc(): with pytest.raises(ValueError): Condition(example_domain, FixedValue(0.0)) with pytest.raises(ValueError): - Condition(input_points=3., equation='example') + Condition(input_points=3.0, equation="example") with pytest.raises(ValueError): Condition(input_points=example_domain, equation=example_output_pts) + + +def test_graph_io_condition(): + x = torch.rand(10, 10, 4) + pos = torch.rand(10, 10, 2) + y = torch.rand(10, 10, 2) + graph = [ + RadiusGraph(x=x_, pos=pos_, radius=0.1, build_edge_attr=True, y=y_) + for x_, pos_, y_ in zip(x, pos, y) + ] + condition = Condition(graph=graph) + assert isinstance(condition, GraphInputOutputCondition) + assert isinstance(condition.graph, list) + + x = x[0] + pos = pos[0] + y = y[0] + edge_index = graph[0].edge_index + graph = Data(x=x, pos=pos, edge_index=edge_index, y=y) + condition = Condition(graph=graph) + assert isinstance(condition, GraphInputOutputCondition) + assert isinstance(condition.graph, Data) + + +def laplace_equation(input_, output_): + """ + Implementation of the laplace equation. + """ + force_term = torch.sin(input_.extract(["x"]) * torch.pi) * torch.sin( + input_.extract(["y"]) * torch.pi + ) + delta_u = laplacian(output_.extract(["u"]), input_) + return delta_u - force_term + + +def test_graph_eq_condition(): + def laplace(input_, output_): + """ + Implementation of the laplace equation. + """ + force_term = torch.sin(input_.extract(["x"]) * torch.pi) * torch.sin( + input_.extract(["y"]) * torch.pi + ) + delta_u = laplacian(output_.extract(["u"]), input_) + return delta_u - force_term + + x = torch.rand(10, 10, 4) + pos = torch.rand(10, 10, 2) + graph = [ + RadiusGraph( + x=x_, + pos=pos_, + radius=0.1, + build_edge_attr=True, + ) + for x_, pos_, in zip( + x, + pos, + ) + ] + laplace_equation = Equation(laplace) + condition = Condition(graph=graph, equation=laplace_equation) + assert isinstance(condition, GraphInputEquationCondition) + assert isinstance(condition.graph, list) + + x = x[0] + pos = pos[0] + edge_index = graph[0].edge_index + graph = Data(x=x, pos=pos, edge_index=edge_index) + condition = Condition(graph=graph, equation=laplace_equation) + assert isinstance(condition, GraphInputEquationCondition) + assert isinstance(condition.graph, Data) diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index 2d7de9dc4..703ad42f4 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -15,30 +15,37 @@ x = torch.rand((100, 50, 10)) pos = torch.rand((100, 50, 2)) +output_graph = torch.rand((100, 50, 10)) input_graph = [ - RadiusGraph(x=x_, pos=pos_, radius=0.2) for x_, pos_, in zip(x, pos) + RadiusGraph(x=x_, pos=pos_, radius=0.2, y=y_) + for x_, pos_, y_ in zip(x, pos, output_graph) ] -output_graph = torch.rand((100, 50, 10)) @pytest.mark.parametrize( "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], + [(input_tensor, output_tensor), (input_graph, None)], ) def test_constructor(input_, output_): - problem = SupervisedProblem(input_=input_, output_=output_) + if output_ is None: + problem = SupervisedProblem(graph_=input_) + else: + problem = SupervisedProblem(input_=input_, output_=output_) PinaDataModule(problem) @pytest.mark.parametrize( "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], + [(input_tensor, output_tensor), (input_graph, None)], ) @pytest.mark.parametrize( "train_size, val_size, test_size", [(0.7, 0.2, 0.1), (0.7, 0.3, 0)] ) def test_setup_train(input_, output_, train_size, val_size, test_size): - problem = SupervisedProblem(input_=input_, output_=output_) + if output_ is None: + problem = SupervisedProblem(graph_=input_) + else: + problem = SupervisedProblem(input_=input_, output_=output_) dm = PinaDataModule( problem, train_size=train_size, val_size=val_size, test_size=test_size ) @@ -64,13 +71,16 @@ def test_setup_train(input_, output_, train_size, val_size, test_size): @pytest.mark.parametrize( "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], + [(input_tensor, output_tensor), (input_graph, None)], ) @pytest.mark.parametrize( "train_size, val_size, test_size", [(0.7, 0.2, 0.1), (0.0, 0.0, 1.0)] ) def test_setup_test(input_, output_, train_size, val_size, test_size): - problem = SupervisedProblem(input_=input_, output_=output_) + if output_ is None: + problem = SupervisedProblem(graph_=input_) + else: + problem = SupervisedProblem(input_=input_, output_=output_) dm = PinaDataModule( problem, train_size=train_size, val_size=val_size, test_size=test_size ) @@ -96,10 +106,13 @@ def test_setup_test(input_, output_, train_size, val_size, test_size): @pytest.mark.parametrize( "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], + [(input_tensor, output_tensor), (input_graph, None)], ) def test_dummy_dataloader(input_, output_): - problem = SupervisedProblem(input_=input_, output_=output_) + if output_ is None: + problem = SupervisedProblem(graph_=input_) + else: + problem = SupervisedProblem(input_=input_, output_=output_) solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) trainer = Trainer( solver, batch_size=None, train_size=0.7, val_size=0.3, test_size=0.0 @@ -134,11 +147,14 @@ def test_dummy_dataloader(input_, output_): @pytest.mark.parametrize( "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], + [(input_tensor, output_tensor), (input_graph, None)], ) @pytest.mark.parametrize("automatic_batching", [True, False]) def test_dataloader(input_, output_, automatic_batching): - problem = SupervisedProblem(input_=input_, output_=output_) + if output_ is None: + problem = SupervisedProblem(graph_=input_) + else: + problem = SupervisedProblem(input_=input_, output_=output_) solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) trainer = Trainer( solver, @@ -181,19 +197,21 @@ def test_dataloader(input_, output_, automatic_batching): x = LabelTensor(torch.rand((100, 50, 3)), ["u", "v", "w"]) pos = LabelTensor(torch.rand((100, 50, 2)), ["x", "y"]) -input_graph = [ - RadiusGraph(x=x[i], pos=pos[i], radius=0.1) for i in range(len(x)) -] output_graph = LabelTensor(torch.rand((100, 50, 3)), ["u", "v", "w"]) +input_graph = [RadiusGraph(x=x[i], pos=pos[i], radius=0.1, y=output_graph[i]) for i in range(len(x))] + @pytest.mark.parametrize( "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], + [(input_tensor, output_tensor), (input_graph, None)], ) @pytest.mark.parametrize("automatic_batching", [True, False]) def test_dataloader_labels(input_, output_, automatic_batching): - problem = SupervisedProblem(input_=input_, output_=output_) + if output_ is None: + problem = SupervisedProblem(graph_=input_) + else: + problem = SupervisedProblem(input_=input_, output_=output_) solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) trainer = Trainer( solver, diff --git a/tests/test_data/test_graph_dataset.py b/tests/test_data/test_graph_dataset.py index 4acb19ebd..bc950699c 100644 --- a/tests/test_data/test_graph_dataset.py +++ b/tests/test_data/test_graph_dataset.py @@ -6,26 +6,25 @@ x = torch.rand((100, 20, 10)) pos = torch.rand((100, 20, 2)) +output_ = torch.rand((100, 20, 10)) input_ = [ - KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) - for x_, pos_ in zip(x, pos) + KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True, y=y_) + for x_, pos_, y_ in zip(x, pos, output_) ] -output_ = torch.rand((100, 20, 10)) x_2 = torch.rand((50, 20, 10)) pos_2 = torch.rand((50, 20, 2)) +output_2_ = torch.rand((50, 20, 10)) input_2_ = [ - KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) - for x_, pos_ in zip(x_2, pos_2) + KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True, y=y_) + for x_, pos_, y_ in zip(x_2, pos_2, output_2_) ] -output_2_ = torch.rand((50, 20, 10)) # Problem with a single condition conditions_dict_single = { "data": { - "input_points": input_, - "output_points": output_, + "graph": input_, } } max_conditions_lengths_single = {"data": 100} @@ -33,12 +32,10 @@ # Problem with multiple conditions conditions_dict_single_multi = { "data_1": { - "input_points": input_, - "output_points": output_, + "graph": input_, }, "data_2": { - "input_points": input_2_, - "output_points": output_2_, + "graph": input_2_, }, } @@ -77,30 +74,27 @@ def test_getitem(conditions_dict, max_conditions_lengths): ) data = dataset[50] assert isinstance(data, dict) - assert all([isinstance(d["input_points"], Data) for d in data.values()]) - assert all( - [isinstance(d["output_points"], torch.Tensor) for d in data.values()] - ) + assert all([isinstance(d, Data) for d in data.values()]) assert all( [ - d["input_points"].x.shape == torch.Size((20, 10)) + d.x.shape == torch.Size((20, 10)) for d in data.values() ] ) assert all( [ - d["output_points"].shape == torch.Size((20, 10)) + d.y.shape == torch.Size((20, 10)) for d in data.values() ] ) assert all( [ - d["input_points"].edge_index.shape == torch.Size((2, 60)) + d.edge_index.shape == torch.Size((2, 60)) for d in data.values() ] ) assert all( - [d["input_points"].edge_attr.shape[0] == 60 for d in data.values()] + [d.edge_attr.shape[0] == 60 for d in data.values()] ) data = dataset.fetch_from_idx_list([i for i in range(20)]) diff --git a/tests/test_graph.py b/tests/test_graph.py index bf053a89f..4b083a375 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -5,7 +5,9 @@ from torch_geometric.data import Data -def build_edge_attr(pos, edge_index): +def build_edge_attr(graph): + pos = graph.pos + edge_index = graph.edge_index return torch.cat([pos[edge_index[0]], pos[edge_index[1]]], dim=-1) diff --git a/tests/test_problem_zoo/test_supervised_problem.py b/tests/test_problem_zoo/test_supervised_problem.py index 06241fa91..635b4a72f 100644 --- a/tests/test_problem_zoo/test_supervised_problem.py +++ b/tests/test_problem_zoo/test_supervised_problem.py @@ -1,6 +1,6 @@ import torch from pina.problem import AbstractProblem -from pina.condition import InputOutputPointsCondition +from pina.condition import InputOutputPointsCondition, GraphInputOutputCondition from pina.problem.zoo.supervised_problem import SupervisedProblem from pina.graph import RadiusGraph @@ -19,16 +19,16 @@ def test_constructor(): def test_constructor_graph(): x = torch.rand((20, 100, 10)) pos = torch.rand((20, 100, 2)) + output_ = torch.rand((20, 100, 10)) input_ = [ - RadiusGraph(x=x_, pos=pos_, radius=0.2, edge_attr=True) - for x_, pos_ in zip(x, pos) + RadiusGraph(x=x_, pos=pos_, radius=0.2, edge_attr=True, y=y_) + for x_, pos_, y_ in zip(x, pos, output_) ] - output_ = torch.rand((100, 10)) - problem = SupervisedProblem(input_=input_, output_=output_) + + problem = SupervisedProblem(graph_=input_) assert isinstance(problem, AbstractProblem) assert hasattr(problem, "conditions") assert isinstance(problem.conditions, dict) assert list(problem.conditions.keys()) == ["data"] - assert isinstance(problem.conditions["data"], InputOutputPointsCondition) - assert isinstance(problem.conditions["data"].input_points, list) - assert isinstance(problem.conditions["data"].output_points, torch.Tensor) + assert isinstance(problem.conditions["data"], GraphInputOutputCondition) + assert isinstance(problem.conditions["data"].graph, list) diff --git a/tests/test_solver/test_reduced_order_model_solver.py b/tests/test_solver/test_reduced_order_model_solver.py index a53719255..2b3d87a6a 100644 --- a/tests/test_solver/test_reduced_order_model_solver.py +++ b/tests/test_solver/test_reduced_order_model_solver.py @@ -3,7 +3,7 @@ from pina import Condition, LabelTensor from pina.problem import AbstractProblem -from pina.condition import InputOutputPointsCondition +from pina.condition import InputOutputPointsCondition, GraphInputOutputCondition from pina.solver import ReducedOrderModelSolver from pina.trainer import Trainer from pina.model import FeedForward @@ -12,22 +12,23 @@ class LabelTensorProblem(AbstractProblem): - input_variables = ['u_0', 'u_1'] - output_variables = ['u'] + input_variables = ["u_0", "u_1"] + output_variables = ["u"] conditions = { - 'data': Condition( - input_points=LabelTensor(torch.randn(20, 2), ['u_0', 'u_1']), - output_points=LabelTensor(torch.randn(20, 1), ['u'])), + "data": Condition( + input_points=LabelTensor(torch.randn(20, 2), ["u_0", "u_1"]), + output_points=LabelTensor(torch.randn(20, 1), ["u"]), + ), } class TensorProblem(AbstractProblem): - input_variables = ['u_0', 'u_1'] - output_variables = ['u'] + input_variables = ["u_0", "u_1"] + output_variables = ["u"] conditions = { - 'data': Condition( - input_points=torch.randn(20, 2), - output_points=torch.randn(20, 1)) + "data": Condition( + input_points=torch.randn(20, 2), output_points=torch.randn(20, 1) + ) } @@ -35,23 +36,27 @@ class AE(torch.nn.Module): def __init__(self, input_dimensions, rank): super().__init__() self.encode = FeedForward( - input_dimensions, rank, layers=[input_dimensions//4]) + input_dimensions, rank, layers=[input_dimensions // 4] + ) self.decode = FeedForward( - rank, input_dimensions, layers=[input_dimensions//4]) + rank, input_dimensions, layers=[input_dimensions // 4] + ) class AE_missing_encode(torch.nn.Module): def __init__(self, input_dimensions, rank): super().__init__() self.encode = FeedForward( - input_dimensions, rank, layers=[input_dimensions//4]) + input_dimensions, rank, layers=[input_dimensions // 4] + ) class AE_missing_decode(torch.nn.Module): def __init__(self, input_dimensions, rank): super().__init__() self.decode = FeedForward( - rank, input_dimensions, layers=[input_dimensions//4]) + rank, input_dimensions, layers=[input_dimensions // 4] + ) rank = 10 @@ -62,26 +67,41 @@ def __init__(self, input_dimensions, rank): def test_constructor(): problem = TensorProblem() - ReducedOrderModelSolver(problem=problem, - interpolation_network=interpolation_net, - reduction_network=reduction_net) - ReducedOrderModelSolver(problem=LabelTensorProblem(), - reduction_network=reduction_net, - interpolation_network=interpolation_net) - assert ReducedOrderModelSolver.accepted_conditions_types == InputOutputPointsCondition + ReducedOrderModelSolver( + problem=problem, + interpolation_network=interpolation_net, + reduction_network=reduction_net, + ) + ReducedOrderModelSolver( + problem=LabelTensorProblem(), + reduction_network=reduction_net, + interpolation_network=interpolation_net, + ) + assert ReducedOrderModelSolver.accepted_conditions_types == ( + InputOutputPointsCondition, + GraphInputOutputCondition, + ) with pytest.raises(SyntaxError): - ReducedOrderModelSolver(problem=problem, - reduction_network=AE_missing_encode( - len(problem.output_variables), rank), - interpolation_network=interpolation_net) - ReducedOrderModelSolver(problem=problem, - reduction_network=AE_missing_decode( - len(problem.output_variables), rank), - interpolation_network=interpolation_net) + ReducedOrderModelSolver( + problem=problem, + reduction_network=AE_missing_encode( + len(problem.output_variables), rank + ), + interpolation_network=interpolation_net, + ) + ReducedOrderModelSolver( + problem=problem, + reduction_network=AE_missing_decode( + len(problem.output_variables), rank + ), + interpolation_network=interpolation_net, + ) with pytest.raises(ValueError): - ReducedOrderModelSolver(problem=Poisson2DSquareProblem(), - reduction_network=reduction_net, - interpolation_network=interpolation_net) + ReducedOrderModelSolver( + problem=Poisson2DSquareProblem(), + reduction_network=reduction_net, + interpolation_network=interpolation_net, + ) @pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) @@ -89,99 +109,122 @@ def test_constructor(): @pytest.mark.parametrize("compile", [True, False]) def test_solver_train(use_lt, batch_size, compile): problem = LabelTensorProblem() if use_lt else TensorProblem() - solver = ReducedOrderModelSolver(problem=problem, - reduction_network=reduction_net, - interpolation_network=interpolation_net, use_lt=use_lt) - trainer = Trainer(solver=solver, - max_epochs=2, - accelerator='cpu', - batch_size=batch_size, - train_size=1., - test_size=0., - val_size=0., - compile=compile) + solver = ReducedOrderModelSolver( + problem=problem, + reduction_network=reduction_net, + interpolation_network=interpolation_net, + use_lt=use_lt, + ) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=1.0, + test_size=0.0, + val_size=0.0, + compile=compile, + ) trainer.train() if trainer.compile: for v in solver.model.values(): - assert (isinstance(v, OptimizedModule)) + assert isinstance(v, OptimizedModule) @pytest.mark.parametrize("use_lt", [True, False]) @pytest.mark.parametrize("compile", [True, False]) def test_solver_validation(use_lt, compile): problem = LabelTensorProblem() if use_lt else TensorProblem() - solver = ReducedOrderModelSolver(problem=problem, - reduction_network=reduction_net, - interpolation_network=interpolation_net, use_lt=use_lt) - trainer = Trainer(solver=solver, - max_epochs=2, - accelerator='cpu', - batch_size=None, - train_size=0.9, - val_size=0.1, - test_size=0., - compile=compile) + solver = ReducedOrderModelSolver( + problem=problem, + reduction_network=reduction_net, + interpolation_network=interpolation_net, + use_lt=use_lt, + ) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=None, + train_size=0.9, + val_size=0.1, + test_size=0.0, + compile=compile, + ) trainer.train() if trainer.compile: for v in solver.model.values(): - assert (isinstance(v, OptimizedModule)) + assert isinstance(v, OptimizedModule) @pytest.mark.parametrize("use_lt", [True, False]) @pytest.mark.parametrize("compile", [True, False]) def test_solver_test(use_lt, compile): problem = LabelTensorProblem() if use_lt else TensorProblem() - solver = ReducedOrderModelSolver(problem=problem, - reduction_network=reduction_net, - interpolation_network=interpolation_net, use_lt=use_lt) - trainer = Trainer(solver=solver, - max_epochs=2, - accelerator='cpu', - batch_size=None, - train_size=0.8, - val_size=0.1, - test_size=0.1, - compile=compile) + solver = ReducedOrderModelSolver( + problem=problem, + reduction_network=reduction_net, + interpolation_network=interpolation_net, + use_lt=use_lt, + ) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=None, + train_size=0.8, + val_size=0.1, + test_size=0.1, + compile=compile, + ) trainer.train() if trainer.compile: for v in solver.model.values(): - assert (isinstance(v, OptimizedModule)) + assert isinstance(v, OptimizedModule) def test_train_load_restore(): dir = "tests/test_solver/tmp/" problem = LabelTensorProblem() - solver = ReducedOrderModelSolver(problem=problem, - - reduction_network=reduction_net, - interpolation_network=interpolation_net) - trainer = Trainer(solver=solver, - max_epochs=5, - accelerator='cpu', - batch_size=None, - train_size=0.9, - test_size=0.1, - val_size=0., - default_root_dir=dir) + solver = ReducedOrderModelSolver( + problem=problem, + reduction_network=reduction_net, + interpolation_network=interpolation_net, + ) + trainer = Trainer( + solver=solver, + max_epochs=5, + accelerator="cpu", + batch_size=None, + train_size=0.9, + test_size=0.1, + val_size=0.0, + default_root_dir=dir, + ) trainer.train() # restore - ntrainer = Trainer(solver=solver, - max_epochs=5, - accelerator='cpu',) + ntrainer = Trainer( + solver=solver, + max_epochs=5, + accelerator="cpu", + ) ntrainer.train( - ckpt_path=f'{dir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt') + ckpt_path=f"{dir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt" + ) # loading new_solver = ReducedOrderModelSolver.load_from_checkpoint( - f'{dir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt', + f"{dir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt", problem=problem, reduction_network=reduction_net, - interpolation_network=interpolation_net) + interpolation_network=interpolation_net, + ) test_pts = LabelTensor(torch.rand(20, 2), problem.input_variables) assert new_solver.forward(test_pts).shape == (20, 1) assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape torch.testing.assert_close( - new_solver.forward(test_pts), - solver.forward(test_pts)) + new_solver.forward(test_pts), solver.forward(test_pts) + ) # rm directories import shutil - shutil.rmtree('tests/test_solver/tmp') + + shutil.rmtree("tests/test_solver/tmp") diff --git a/tests/test_solver/test_supervised_solver.py b/tests/test_solver/test_supervised_solver.py index 2b8d60878..ddb44b251 100644 --- a/tests/test_solver/test_supervised_solver.py +++ b/tests/test_solver/test_supervised_solver.py @@ -1,42 +1,109 @@ import torch import pytest from pina import Condition, LabelTensor -from pina.condition import InputOutputPointsCondition +from pina.condition import InputOutputPointsCondition, GraphInputOutputCondition from pina.problem import AbstractProblem from pina.solver import SupervisedSolver from pina.model import FeedForward from pina.trainer import Trainer +from pina.graph import KNNGraph from torch._dynamo.eval_frame import OptimizedModule +from torch_geometric.nn import GCNConv class LabelTensorProblem(AbstractProblem): - input_variables = ['u_0', 'u_1'] - output_variables = ['u'] + input_variables = ["u_0", "u_1"] + output_variables = ["u"] conditions = { - 'data': Condition( - input_points=LabelTensor(torch.randn(20, 2), ['u_0', 'u_1']), - output_points=LabelTensor(torch.randn(20, 1), ['u'])), + "data": Condition( + input_points=LabelTensor(torch.randn(20, 2), ["u_0", "u_1"]), + output_points=LabelTensor(torch.randn(20, 1), ["u"]), + ), } class TensorProblem(AbstractProblem): - input_variables = ['u_0', 'u_1'] - output_variables = ['u'] + input_variables = ["u_0", "u_1"] + output_variables = ["u"] conditions = { - 'data': Condition( - input_points=torch.randn(20, 2), - output_points=torch.randn(20, 1)) + "data": Condition( + input_points=torch.randn(20, 2), output_points=torch.randn(20, 1) + ) + } + + +x = torch.rand((100, 20, 5)) +pos = torch.rand((100, 20, 2)) +output_ = torch.rand((100, 20, 1)) +input_ = [ + KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True, y=y_) + for x_, pos_, y_ in zip(x, pos, output_) +] + + +class GraphProblem(AbstractProblem): + output_variables = None + conditions = { + "data": Condition( + graph=input_, + ) + } + + +x = LabelTensor(torch.rand((100, 20, 5)), ["a", "b", "c", "d", "e"]) +pos = LabelTensor(torch.rand((100, 20, 2)), ["x", "y"]) +output_ = LabelTensor(torch.rand((100, 20, 1)), ["u"]) +input_ = [ + KNNGraph(x=x[i], pos=pos[i], neighbours=3, edge_attr=True, y=output_[i]) + for i in range(len(x)) +] + + +class GraphProblemLT(AbstractProblem): + output_variables = ["u"] + input_variables = ["a", "b", "c", "d", "e"] + conditions = { + "data": Condition( + graph=input_, + ) } model = FeedForward(2, 1) +class Model(torch.nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.lift = torch.nn.Linear(5, 10) + self.activation = torch.nn.Tanh() + self.output = torch.nn.Linear(10, 1) + + self.conv = GCNConv(10, 10) + + def forward(self, batch): + + x = batch.x + edge_index = batch.edge_index + for _ in range(1): + y = self.lift(x) + y = self.activation(y) + y = self.conv(y, edge_index) + y = self.activation(y) + y = self.output(y) + return y + + +graph_model = Model() + + def test_constructor(): SupervisedSolver(problem=TensorProblem(), model=model) SupervisedSolver(problem=LabelTensorProblem(), model=model) assert SupervisedSolver.accepted_conditions_types == ( - InputOutputPointsCondition + InputOutputPointsCondition, + GraphInputOutputCondition, ) @@ -46,18 +113,38 @@ def test_constructor(): def test_solver_train(use_lt, batch_size, compile): problem = LabelTensorProblem() if use_lt else TensorProblem() solver = SupervisedSolver(problem=problem, model=model, use_lt=use_lt) - trainer = Trainer(solver=solver, - max_epochs=2, - accelerator='cpu', - batch_size=batch_size, - train_size=1., - test_size=0., - val_size=0., - compile=compile) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=1.0, + test_size=0.0, + val_size=0.0, + compile=compile, + ) trainer.train() if trainer.compile: - assert (isinstance(solver.model, OptimizedModule)) + assert isinstance(solver.model, OptimizedModule) + + +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +@pytest.mark.parametrize("use_lt", [True, False]) +def test_solver_train_graph(batch_size, use_lt): + problem = GraphProblemLT() if use_lt else GraphProblem() + solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=1.0, + test_size=0.0, + val_size=0.0, + ) + + trainer.train() @pytest.mark.parametrize("use_lt", [True, False]) @@ -65,17 +152,37 @@ def test_solver_train(use_lt, batch_size, compile): def test_solver_validation(use_lt, compile): problem = LabelTensorProblem() if use_lt else TensorProblem() solver = SupervisedSolver(problem=problem, model=model, use_lt=use_lt) - trainer = Trainer(solver=solver, - max_epochs=2, - accelerator='cpu', - batch_size=None, - train_size=0.9, - val_size=0.1, - test_size=0., - compile=compile) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=None, + train_size=0.9, + val_size=0.1, + test_size=0.0, + compile=compile, + ) trainer.train() if trainer.compile: - assert (isinstance(solver.model, OptimizedModule)) + assert isinstance(solver.model, OptimizedModule) + + +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +@pytest.mark.parametrize("use_lt", [True, False]) +def test_solver_val_graph(batch_size, use_lt): + problem = GraphProblemLT() if use_lt else GraphProblem() + solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=0.9, + val_size=0.1, + test_size=0.0, + ) + + trainer.train() @pytest.mark.parametrize("use_lt", [True, False]) @@ -83,51 +190,126 @@ def test_solver_validation(use_lt, compile): def test_solver_test(use_lt, compile): problem = LabelTensorProblem() if use_lt else TensorProblem() solver = SupervisedSolver(problem=problem, model=model, use_lt=use_lt) - trainer = Trainer(solver=solver, - max_epochs=2, - accelerator='cpu', - batch_size=None, - train_size=0.8, - val_size=0.1, - test_size=0.1, - compile=compile) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=None, + train_size=0.8, + val_size=0.1, + test_size=0.1, + compile=compile, + ) trainer.test() if trainer.compile: - assert (isinstance(solver.model, OptimizedModule)) + assert isinstance(solver.model, OptimizedModule) + + +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +@pytest.mark.parametrize("use_lt", [True, False]) +def test_solver_test_graph(batch_size, use_lt): + problem = GraphProblemLT() if use_lt else GraphProblem() + solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=0.8, + val_size=0.1, + test_size=0.1, + ) + + trainer.test() def test_train_load_restore(): dir = "tests/test_solver/tmp/" problem = LabelTensorProblem() solver = SupervisedSolver(problem=problem, model=model) - trainer = Trainer(solver=solver, - max_epochs=5, - accelerator='cpu', - batch_size=None, - train_size=0.9, - test_size=0.1, - val_size=0., - default_root_dir=dir) + trainer = Trainer( + solver=solver, + max_epochs=5, + accelerator="cpu", + batch_size=None, + train_size=0.9, + test_size=0.1, + val_size=0.0, + default_root_dir=dir, + ) trainer.train() # restore - new_trainer = Trainer(solver=solver, max_epochs=5, accelerator='cpu') + new_trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") new_trainer.train( - ckpt_path=f'{dir}/lightning_logs/version_0/checkpoints/' + - 'epoch=4-step=5.ckpt') + ckpt_path=f"{dir}/lightning_logs/version_0/checkpoints/" + + "epoch=4-step=5.ckpt" + ) # loading new_solver = SupervisedSolver.load_from_checkpoint( - f'{dir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt', - problem=problem, model=model) + f"{dir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt", + problem=problem, + model=model, + ) test_pts = LabelTensor(torch.rand(20, 2), problem.input_variables) assert new_solver.forward(test_pts).shape == (20, 1) assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape torch.testing.assert_close( - new_solver.forward(test_pts), - solver.forward(test_pts)) + new_solver.forward(test_pts), solver.forward(test_pts) + ) + + # rm directories + import shutil + + shutil.rmtree("tests/test_solver/tmp") + + +def test_train_load_restore_graph(): + dir = "tests/test_solver/tmp/" + problem = GraphProblemLT() + solver = SupervisedSolver(problem=problem, model=graph_model) + trainer = Trainer( + solver=solver, + max_epochs=5, + accelerator="cpu", + batch_size=None, + train_size=0.9, + test_size=0.1, + val_size=0.0, + default_root_dir=dir, + ) + trainer.train() + + # restore + new_trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") + new_trainer.train( + ckpt_path=f"{dir}/lightning_logs/version_0/checkpoints/" + + "epoch=4-step=5.ckpt" + ) + + # loading + new_solver = SupervisedSolver.load_from_checkpoint( + f"{dir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt", + problem=problem, + model=graph_model, + ) + + test_pts = KNNGraph( + x=LabelTensor(torch.rand(20, 5), ["a", "b", "c", "d", "e"]), + pos=LabelTensor(torch.rand(20, 2), ["x", "y"]), + neighbours=3, + edge_attr=True, + ) + + assert new_solver.forward(test_pts).shape == (20, 1) + assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape + torch.testing.assert_close( + new_solver.forward(test_pts), solver.forward(test_pts) + ) # rm directories import shutil - shutil.rmtree('tests/test_solver/tmp') + + shutil.rmtree("tests/test_solver/tmp")