diff --git a/deeplay/applications/__init__.py b/deeplay/applications/__init__.py index 0818a8f0..0f20ce3a 100644 --- a/deeplay/applications/__init__.py +++ b/deeplay/applications/__init__.py @@ -3,6 +3,7 @@ from .regression import * from .detection import * from .autoencoders import * +from .clustering import * # from .classification import * # from .segmentation import ImageSegmentor diff --git a/deeplay/applications/clustering/__init__.py b/deeplay/applications/clustering/__init__.py new file mode 100644 index 00000000..a665019a --- /dev/null +++ b/deeplay/applications/clustering/__init__.py @@ -0,0 +1 @@ +from .miro import MIRO diff --git a/deeplay/applications/clustering/miro.py b/deeplay/applications/clustering/miro.py new file mode 100644 index 00000000..622c1d23 --- /dev/null +++ b/deeplay/applications/clustering/miro.py @@ -0,0 +1,270 @@ +"""MIRO: Multimodal Integration through Relational Optimization + +This module provides the MIRO framework for point cloud clustering, leveraging +advanced geometric deep learning techniques. MIRO transforms complex point +clouds into optimized representations, enabling more effective clustering +using traditional algorithms. + +Based on the original MIRO paper by Pineda et al. [1], this implementation offers +easy-to-use methods for training the MIRO model and performing geometric-aware +clustering. It integrates recurrent graph neural networks to refine point +cloud data and enhance clustering accuracy. + +[1] Pineda, Jesús, et al. "Spatial Clustering of Molecular Localizations with + Graph Neural Networks." arXiv preprint arXiv:2412.00173 (2024). +""" + +import numpy as np +import torch +import torch.nn as nn +from torch_geometric.data import Data +from sklearn.cluster import DBSCAN +from typing import Callable, Optional, List + +from deeplay.applications import Application +from deeplay.external import Adam, Optimizer +from deeplay.models import RecurrentMessagePassingModel + + +class MIRO(Application): + """Point cloud clustering using MIRO (Multimodal Integration through + Relational Optimization). + + MIRO is a geometric deep learning framework that enhances clustering + algorithms by transforming complex point clouds into an optimized structure + amenable to conventional clustering methods. MIRO employs recurrent graph + neural networks (rGNNs) to learn a transformation that squeezes localization + belonging to the same cluster toward a common center, resulting in a compact + representation of clusters within the point cloud. + + Parameters + ---------- + num_outputs : int + Dimensionality of the output features, representing a displacement + vector in Cartesian space for each node. This vector points toward + the center of each cluster. + connectivity_radius : float + Maximum distance between two nodes to consider them connected in the + graph. + model : nn.Module + A model implementing the forward method. It should return a list of + tensors of shape `(num_nodes, num_outputs)` representing the predicted + displacement vectors for each node at each recurrent iteration. If not + specified, a default model resembling the one from the original MIRO + paper is used. + nd_loss_weight : float + Weight for the auxiliary loss that enforces preservation of pairwise + distances between connected nodes. Default is 10. + loss : torch.nn.Module + Loss function for training. Default is `torch.nn.L1Loss`. + optimizer : Optimizer + Optimizer for training. Default is Adam with a learning rate of 1e-4. + + Returns + ------- + forward : method + Computes and returns the predicted displacement vectors for each node + in the input graph. The output is a list of tensors representing the + displacement vectors at each recurrent iteration. + + squeeze : method + Applies the predicted displacement vectors from the last recurrent + iteration (by default) to the input point cloud. This operation + optimizes the point cloud for clustering by aligning nodes closer to + their respective cluster centers. + + clustering : method + Groups nodes into clusters using the DBSCAN algorithm, based on the + predicted displacement vectors. Each node is assigned a cluster label, + where -1 indicates background noise. Returns an array of cluster labels + for the nodes. + + Example + ------- + >>> # Predicts displacement vectors for each node in a point cloud at each + >>> # recurrent iteration + >>> displacement_vectors = model(test_graph) + >>> print(type(displacement_vectors)) + + + >>> # Applies the predicted displacement vectors to the input point cloud + >>> squeezed = model.squeeze(test_graph) + >>> print(squeezed.shape) + (num_nodes, 2) + + >>> # Performs clustering using DBSCAN after MIRO squeezing + >>> eps = 0.3 # Maximum distance for cluster connection + >>> min_samples = 5 # Minimum points in a neighborhood for core points + >>> clusters = model.clustering(test_graph, eps, min_samples) + + >>> # Output cluster labels + >>> print(clusters) + array([ 0, 0, 1, 1, 1, -1, 2, 2, 2, ...]) + # Nodes in cluster 0, 1, 2, etc.; -1 are outliers + """ + + num_outputs: int + connectivity_radius: float + model: nn.Module + nd_loss_weight: float + loss: torch.nn.Module + metrics: list + optimizer: Optimizer + + def __init__( + self, + num_outputs: int = 2, + connectivity_radius: float = 1.0, + model: Optional[nn.Module] = None, + nd_loss_weight: float = 10, + loss: torch.nn.Module = torch.nn.L1Loss(), + optimizer=None, + **kwargs, + ): + + self.num_outputs = num_outputs + self.connectivity_radius = connectivity_radius + self.model = model or self._get_default_model() + self.nd_loss_weight = nd_loss_weight + + super().__init__(loss=loss, optimizer=optimizer or Adam(lr=1e-4), **kwargs) + + def _get_default_model(self): + rgnn = RecurrentMessagePassingModel( + hidden_features=256, out_features=self.num_outputs, num_iter=20 + ) + return rgnn + + def forward(self, x: Data) -> List[torch.Tensor]: + """Forward pass to compute predicted displacement vectors for each node. + + Parameters + ---------- + x : torch_geometric.data.Data + Input graph data. It is expected to have the attributes: + `x` (node features), `edge_index` (graph connectivity), + `edge_attr` (edge features), and `positions` (node spatial coordinates). + + Returns + ------- + list[torch.Tensor] + Predicted displacement vectors at each recurrent iteration. + """ + return self.model(x) + + def squeeze( + self, + x: Data, + from_iter: int = -1, + scaling: np.ndarray = np.array([1.0, 1.0]), + ) -> np.ndarray: + """Computes and applies the predicted displacement vectors to the + input point cloud. + + Parameters + ---------- + x : torch_geometric.data.Data + Input graph data. It is expected to have the attributes: + `x` (node features), `edge_index` (graph connectivity), + `edge_attr` (edge features), and `positions` (node spatial coordinates). + from_iter : int, optional + Index of the recurrent iteration to be used as displacement vectors. + Default is -1 (last iteration). + scaling : np.ndarray, optional + Scaling factors for each dimension. Default is [1.0, 1.0]. + + Returns + ------- + np.ndarray + Squeezed point cloud with optimized cluster alignment. + """ + predicted_displacements = self(x)[from_iter].detach().cpu().numpy() + positions = x.position.cpu().numpy() + squeezed_positions = ( + positions - predicted_displacements * self.connectivity_radius + ) + return squeezed_positions * scaling + + def clustering( + self, + x: Data, + eps: float, + min_samples: int, + from_iter: int = -1, + **kwargs, + ) -> np.ndarray: + """Perform clustering using DBSCAN after applying MIRO squeezing. + + Parameters + ---------- + x : torch_geometric.data.Data + Input graph data. + eps : float + The maximum distance between two samples for one to be considered + as in the neighborhood of the other. This is not a maximum bound + on the distances of points within a cluster. This is the most + important DBSCAN parameter to choose appropriately for your data set + and distance function. + min_samples : int + The number of samples (or total weight) in a neighborhood for a point + to be considered as a core point. This includes the point itself. + from_iter : int, optional + Index of the recurrent iteration to be used as displacement vectors. + Default is -1 (last iteration). + + Returns + ------- + np.ndarray + Cluster labels for each node. -1 indicates outliers. + """ + squeezed = self.squeeze(x, from_iter, **kwargs) + clusters = DBSCAN(eps=eps, min_samples=min_samples).fit(squeezed) + return clusters.labels_ + + def training_step(self, batch, batch_idx: int) -> torch.Tensor: + """Defines the training step for a single batch.""" + x, y = self.train_preprocess(batch) + y_hat = self(x) + loss = self.compute_loss(y_hat, y, x.edge_index, x.position) + + self.log( + "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True + ) + + self.log_metrics( + "train", y_hat, y, on_step=True, on_epoch=True, prog_bar=True, logger=True + ) + return loss + + def compute_loss( + self, + y_hat: List[torch.Tensor], + y: torch.Tensor, + edges: torch.Tensor, + position: torch.Tensor, + ) -> torch.Tensor: + """Computes the total loss for the model.""" + loss = 0 + for pred in y_hat: + loss += self.loss(pred, y) + self.nd_loss_weight * self.compute_nd_loss( + pred, y, edges, position + ) + return loss / len(y_hat) + + def compute_nd_loss( + self, + y_hat: torch.Tensor, + y: torch.Tensor, + edges: torch.Tensor, + position: torch.Tensor, + ) -> torch.Tensor: + """Computes auxiliary loss for pairwise distance preservation.""" + squeezed_gt = position - y * self.connectivity_radius + squeezed_gt_distances = torch.norm( + squeezed_gt[edges[0]] - squeezed_gt[edges[1]], dim=1 + ) + squeezed_pred = position - y_hat * self.connectivity_radius + squeezed_pred_distances = torch.norm( + squeezed_pred[edges[0]] - squeezed_pred[edges[1]], dim=1 + ) + return self.loss(squeezed_pred_distances, squeezed_gt_distances) diff --git a/deeplay/components/dict.py b/deeplay/components/dict.py index 125cd678..120ee435 100644 --- a/deeplay/components/dict.py +++ b/deeplay/components/dict.py @@ -1,4 +1,19 @@ -from typing import Dict, Any, Union, Tuple, overload +"""Utility Modules for Dictionary and Graph Data Operations + +This module contains utilities for operations involving dictionary-like +structures or PyTorch Geometric `Data` objects. These operations are useful in +geometric deep learning pipelines, where input data is often stored in a +structured format with various attributes. + +Classes: +- FromDict: Extracts specified keys from a dictionary or `Data` object. +- AddDict: Performs element-wise addition for specified keys in two + dictionaries or `Data` objects. +- CatDictElements: Concatenates specified elements within a dictionary or + `Data` object along a given dimension. +""" + +from typing import Dict, Any, Union, Tuple from deeplay import DeeplayModule @@ -7,6 +22,25 @@ class FromDict(DeeplayModule): + """Extract specified keys from a dictionary-like structure. + + Parameters + ---------- + keys : str + The keys to extract from the input dictionary. + + Returns + ------- + Any or Tuple[Any, ...] + The values corresponding to the specified keys. + + Example + ------- + >>> extractor = FromDict("key1", "key2").create() + >>> result = extractor({"key1": value1, "key2": value2}) + (value1, value2) + """ + def __init__(self, *keys: str): super().__init__() self.keys = keys @@ -23,8 +57,7 @@ def extra_repr(self) -> str: class AddDict(DeeplayModule): - """ - Element-wise addition of two dictionaries. + """Element-wise addition of two dictionaries. Parameters ---------- @@ -33,10 +66,19 @@ class AddDict(DeeplayModule): Constraints ----------- - - Both dictionaries 'x' (base) and 'y' (addition) must contain the same keys for the addition operation. + - Both dictionaries `x` (base) and `y` (addition) must contain the same + keys for the addition operation. + + - Input types: + - `x`: Dict[str, Any] or `torch_geometric.data.Data` + - `y`: Dict[str, Any] or `torch_geometric.data.Data` - - 'x': Dict[str, Any] or torch_geometric.data.Data. - - 'y': Dict[str, Any] or torch_geometric.data.Data. + Example + ------- + >>> adder = AddDict("key1", "key2").create() + >>> result = adder({"key1": value1, "key2": value2}, + {"key1": 1, "key2": 2}) + {"key1": value1 + 1, "key2": value2 + 2} """ def __init__(self, *keys: str): @@ -54,3 +96,42 @@ def forward( x.update({key: torch.add(x[key], y[key]) for key in self.keys}) return x + + +class CatDictElements(DeeplayModule): + """Concatenates specified elements within a dictionary-like structure along + a given dimension. + + Parameters + ---------- + keys : Tuple[tuple] + Specifies the keys to be concatenated as tuples. Each tuple contains + two keys: source and target. The source key is the key to be + concatenated with the target key. + dim : int, optional + Specifies the dimension along which the concatenation is performed. + Default is -1. + + Example + ------- + >>> concat = CatDictElements(("key1", "key2"), ("key3", "key4")).create() + >>> result = concat({"key1": tensor1, "key2": tensor2, + "key3": tensor3, "key4": tensor4}) + {"key2": torch.cat([tensor2, tensor1], dim=-1), + "key4": torch.cat([tensor4, tensor3], dim=-1)} + """ + + def __init__(self, *keys: Tuple[tuple], dim: int = -1): + super().__init__() + self.source, self.target = zip(*keys) + self.dim = dim + + def forward(self, x: Union[Dict[str, Any], Data]) -> Union[Dict[str, Any], Data]: + x = x.clone() if isinstance(x, Data) else x.copy() + x.update( + { + t: torch.cat([x[t], x[s]], dim=self.dim) + for t, s in zip(self.target, self.source) + } + ) + return x diff --git a/deeplay/components/gnn/__init__.py b/deeplay/components/gnn/__init__.py index 1f0bd47c..fd418303 100644 --- a/deeplay/components/gnn/__init__.py +++ b/deeplay/components/gnn/__init__.py @@ -1,3 +1,4 @@ from .gcn import * from .mpn import * from .tpu import * +from .rgb import * diff --git a/deeplay/components/gnn/rgb.py b/deeplay/components/gnn/rgb.py new file mode 100644 index 00000000..a898a63e --- /dev/null +++ b/deeplay/components/gnn/rgb.py @@ -0,0 +1,158 @@ +"""Recurrent Graph Block Module + +This module defines the `RecurrentGraphBlock` class, a component designed for +recurrent graph-based computations. It employs a recurrent structure to process +graph data iteratively using specified combine, layer, and head modules. +""" + +import torch +from torch_geometric.data import Data +from deeplay import DeeplayModule +from deeplay.components.dict import CatDictElements + + +class RecurrentGraphBlock(DeeplayModule): + """Recurrent graph processing block for iterative feature transformation. + + This module combines graph data features and hidden states, processes + them through a recurrent structure, and generates outputs using a + specified head module. It supports modular design for flexibility in + defining the combine, layer, and head operations. + + Parameters + ---------- + layer : DeeplayModule + Module that applies transformations to the graph data at each + iteration. + head : DeeplayModule + Module that processes the output from the layer and generates + final predictions. + hidden_features : int + The number of hidden features for the recurrent block. + num_iter : int + The number of recurrent iterations. + combine : DeeplayModule, optional + The module responsible for combining graph features with hidden + states. Default is `CatDictElements(("x", "hidden"))`. + + Returns + ------- + list + A list of outputs generated at each recurrent iteration. + + Raises + ------ + AttributeError + If the `combine` module does not contain `source` and `target` + attributes. Graph `combine` modules compatible with + `RecurrentGraphBlock` must have `source` and `target` attributes to + specify the keys to concatenate. A catalog of `combine` operations + is available in the `deeplay.components.dict` module. + + Example + ------- + >>> combine = CatDictElements(("x", "hidden")) + >>> layer = dl.MessagePassingNeuralNetwork([], 128) + >>> head = MultiLayerPerceptron(128, [], 128) + >>> block = RecurrentGraphBlock( + ... layer=layer, + ... head=head, + ... combine=combine, + ... hidden_features=128, + ... num_iter=10 + ... ) + + >>> # Set maps for input and output keys + >>> block.head.set_input_map("x") + >>> block.head.set_output_map("x") + >>> block = block.create() + + >>> # Create input graph data + >>> data = Data( + ... x=torch.randn(3, 128), + ... edge_index=torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]), + ... edge_attr=torch.randn(4, 128) + ... ) + >>> outputs = block(data) + >>> len(outputs) + 10 + """ + + combine: DeeplayModule + layer: DeeplayModule + head: DeeplayModule + + def __init__( + self, + layer: DeeplayModule, + head: DeeplayModule, + hidden_features: int, + num_iter: int, + combine: DeeplayModule = CatDictElements(("x", "hidden")), + ): + super().__init__() + self.combine = combine + self.layer = layer + self.head = head + self.hidden_features = hidden_features + self.num_iter = num_iter + + if not all(hasattr(self.combine, attr) for attr in ("source", "target")): + raise AttributeError( + "The 'combine' module must have 'source' and 'target' attributes. " + "These specify the keys to concatenate. Ensure the 'combine' " + "module is initialized with valid 'source' and 'target' keys. " + "Refer to the `CatDictElements` class in the `deeplay.components.dict` " + "module for guidance." + ) + self.hidden_variables_name = self.combine.target + + def initialize_hidden(self, x): + """Initialize hidden states for the graph nodes if not already provided + in the input data. + + Parameters + ---------- + x : Data or dict + The input graph data or dictionary-like structure. + + Returns + ------- + Data or dict + The input graph data with initialized hidden states. + """ + x = x.clone() if isinstance(x, Data) else x.copy() + for source, hidden_variable_name in zip( + self.combine.source, self.hidden_variables_name + ): + if hidden_variable_name not in x: + x.update( + { + hidden_variable_name: torch.zeros( + x[source].size(0), self.hidden_features + ).to(x[source].device) + } + ) + return x + + def forward(self, x): + """Forward pass to process the graph data through recurrent iterations. + + Parameters + ---------- + x : Data or dict + The input graph data or dictionary-like structure. + + Returns + ------- + list + A list of outputs generated at each recurrent iteration. + """ + x = self.initialize_hidden(x) + outputs = [] + for _ in range(self.num_iter): + x = self.combine(x) + x = self.layer(x) + outputs.append(self.head(x)) + + return outputs diff --git a/deeplay/models/gnn/__init__.py b/deeplay/models/gnn/__init__.py index 2facd6f0..47d1651b 100644 --- a/deeplay/models/gnn/__init__.py +++ b/deeplay/models/gnn/__init__.py @@ -1,4 +1,5 @@ from .mpm import MPM +from .rmpm import RecurrentMessagePassingModel from .gtogmpm import GraphToGlobalMPM, GlobalMeanPool from .gtonmpm import GraphToNodeMPM from .gtoempm import GraphToEdgeMPM diff --git a/deeplay/models/gnn/rmpm.py b/deeplay/models/gnn/rmpm.py new file mode 100644 index 00000000..1b675061 --- /dev/null +++ b/deeplay/models/gnn/rmpm.py @@ -0,0 +1,206 @@ +"""Recurrent Message Passing Neural Network (RMPN) Model + +This module defines the `RecurrentMessagePassingModel`, a ,neural network for +recurrent graph-based computations. It processes graph data iteratively using +message-passing mechanisms combined with recurrent structures. +""" + +from typing import Type, Union + +from deeplay import ( + DeeplayModule, + Parallel, + MultiLayerPerceptron, + MessagePassingNeuralNetwork, + Sequential, + RecurrentGraphBlock, + CatDictElements, +) + +import torch.nn as nn + + +class RecurrentMessagePassingModel(DeeplayModule): + """Recurrent Message Passing Neural Network (RMPN) model. + + RMPN processes graph data iteratively through a combination of an encoder + and a recurrent message passing layer. The encoder transforms input node + and edge features into a common hidden representation. The recurrent + message passing layer updates hidden node and edge representations by + concatenating them with the input features and processing them through a + message-passing neural network. Outputs for each iteration are collected + in a list and returned as the model's final output. + + Parameters + ---------- + hidden_features : int + Number of hidden units in the recurrent message passing layer. + out_features : int + Number of output features. + num_iter : int + Number of iterations of the recurrent message passing layer. + out_activation : template-like, optional + Activation function applied to the output. Default is `nn.Identity`. + + Raises + ------ + ValueError + If `out_features` or `hidden_features` are non-positive. + + Configurables + ------------- + - hidden_features (int): Number of hidden units in the recurrent message + passing layer. + - out_features (int): Number of output features. + - out_activation (template-like): Specification for the output activation + of the model. Default: nn.Identity. + - encoder (template-like): Specification for the encoder of the model. + Default: dl.Parallel consisting of two MLPs to process node and edge features. + - backbone (template-like): Specification for the backbone of the model. + Default: dl.RecurrentGraphBlock consisting of dl.MessagePassingNeuralNetwork and + a MLP head. + + Constraints + ----------- + - Input graph data must include: + - `x`: Node features of shape (num_nodes, node_in_features). + - `edge_index`: Edge connectivity of shape (2, num_edges). + - `edge_attr`: Edge features of shape (num_edges, edge_in_features). + - Optional attributes: + - `hidden_x`: Node hidden states of shape (num_nodes, hidden_features). + - `hidden_edge_attr`: Edge hidden states of shape (num_edges, hidden_features). + If not provided, they are initialized as zeros. + - Input can be provided as a dictionary or a `torch_geometric.data.Data` object. + + Returns + ------- + List[torch.Tensor] + List of tensors where each tensor corresponds to the output at an + iteration step, with shape (num_nodes, out_features). + + Example + ------- + >>> model = RecurrentMessagePassingModel(hidden_features=96, + out_features=2, + num_iter=3).create() + >>> graph_data = { + ... "x": torch.randn(10, 5), + ... "edge_index": torch.randint(0, 10, (2, 20)), + ... "edge_attr": torch.randn(20, 3), + ... } + >>> outputs = model(graph_data) + >>> print(len(outputs)) + 3 + >>> print(outputs[0].shape) + torch.Size([10, 2]) + """ + + hidden_features: int + out_features: int + num_iter: int + + def __init__( + self, + hidden_features: int, + out_features: int, + num_iter: int, + out_activation: Union[Type[nn.Module], nn.Module, None] = None, + ): + super().__init__() + + self.hidden_features = hidden_features + self.out_features = out_features + self.num_iter = num_iter + + if out_features <= 0: + raise ValueError(f"out_features must be positive, got {out_features}") + + if not isinstance(hidden_features, int): + raise ValueError( + f"hidden_features must be an integer, got {hidden_features}" + ) + + if hidden_features <= 0: + raise ValueError(f"hidden_features must be positive, got {hidden_features}") + + self.encoder = Parallel( + **{ + key: MultiLayerPerceptron( + in_features=None, + hidden_features=[], + out_features=hidden_features, + flatten_input=False, + ).set_input_map(key) + for key in ("x", "edge_attr") + } + ) + + combine = CatDictElements(("x", "hidden_x"), ("edge_attr", "hidden_edge_attr")) + backbone_layer = Sequential( + [ + Parallel( + **{ + key: MultiLayerPerceptron( + in_features=None, + hidden_features=[], + out_features=hidden_features, + flatten_input=False, + ).set_input_map(key) + for key in ("hidden_x", "hidden_edge_attr") + } + ), + MessagePassingNeuralNetwork([], hidden_features), + ] + ) + head = MultiLayerPerceptron( + hidden_features, + [], + out_features, + out_activation=out_activation, + flatten_input=False, + ) + + self.backbone = RecurrentGraphBlock( + combine=combine, + layer=backbone_layer, + head=head, + hidden_features=hidden_features, + num_iter=self.num_iter, + ) + + self.backbone.layer[1].transform.set_input_map( + "hidden_x", "edge_index", "hidden_edge_attr" + ) + self.backbone.layer[1].transform.set_output_map("hidden_edge_attr") + + self.backbone.layer[1].propagate.set_input_map( + "hidden_x", "edge_index", "hidden_edge_attr" + ) + self.backbone.layer[1].propagate.set_output_map("aggregate") + + update = MultiLayerPerceptron(None, [], hidden_features, flatten_input=False) + update.set_input_map("aggregate") + update.set_output_map("hidden_x") + self.backbone.layer[1].blocks[0].replace("update", update) + + self.backbone.layer[1][..., "activation"].configure(nn.ReLU) + + self.backbone.head.set_input_map("hidden_x") + self.backbone.head.set_output_map() + + def forward(self, x): + """Forward pass. + + Parameters + ---------- + x : dict or torch_geometric.data.Data + Input graph data containing node and edge features. + + Returns + ------- + list + A list of tensors representing the model's output at each iteration. + """ + x = self.encoder(x) + x = self.backbone(x) + return x diff --git a/deeplay/tests/test_dict.py b/deeplay/tests/test_dict.py index 4098e824..cb86d94b 100644 --- a/deeplay/tests/test_dict.py +++ b/deeplay/tests/test_dict.py @@ -5,6 +5,7 @@ Layer, LayerSkip, AddDict, + CatDictElements, Parallel, DeeplayModule, ) @@ -92,3 +93,14 @@ def test_add_with_base_dict(self): # Checks that the base dict is correctly passed self.assertEqual(inp.y, 3) self.assertEqual(len(out.y), 10) + + def test_cat_dict_elems(self): + inp = {} + inp["x"] = torch.Tensor([1]) + inp["y"] = torch.Tensor([1, 1]) + + block = CatDictElements(("x", "y")).create() + out = block(inp) + + self.assertEqual(out["x"].shape, torch.Size([1])) + self.assertEqual(out["y"].shape, torch.Size([3])) diff --git a/deeplay/tests/test_gnn.py b/deeplay/tests/test_gnn.py index 91862992..906c43d0 100644 --- a/deeplay/tests/test_gnn.py +++ b/deeplay/tests/test_gnn.py @@ -12,6 +12,8 @@ GraphToEdgeMAGIK, MessagePassingNeuralNetwork, ResidualMessagePassingNeuralNetwork, + RecurrentMessagePassingModel, + RecurrentGraphBlock, MultiLayerPerceptron, dense_laplacian_normalization, Sum, @@ -22,6 +24,7 @@ Max, Layer, GlobalMeanPool, + CatDictElements, ) import itertools @@ -729,3 +732,107 @@ def test_gtoempm_defaults(self): out = model(inp) self.assertEqual(out.shape, (20, 1)) + + +class TestModelRecurrentMPM(unittest.TestCase): + def test_recurrent_graph_block_defaults(self): + model = RecurrentGraphBlock( + combine=CatDictElements(("x", "hidden")), + layer=Layer(nn.Identity), + head=Layer(nn.Identity), + hidden_features=64, + num_iter=1, + ) + model = model.create() + + self.assertIsInstance(model.combine, CatDictElements) + self.assertIsInstance(model.layer, nn.Identity) + self.assertIsInstance(model.head, nn.Identity) + + self.assertEqual(model.hidden_features, 64) + self.assertEqual(model.num_iter, 1) + + # assess the case where hidden is provided + inp = {} + inp["x"] = torch.ones(10, 64) + out = model(inp) + + self.assertTrue(torch.all(out[0]["hidden"][:, :64] == torch.zeros(10, 64))) + + # assess the case where hidden is provided + inp["hidden"] = torch.ones(10, 64) * 2 + out = model(inp) + + self.assertTrue(torch.all(out[0]["hidden"][:, :64] == torch.ones(10, 64) * 2)) + + def test_RMPM_defaults(self): + model = RecurrentMessagePassingModel(96, 2, num_iter=10) + model = model.create() + + self.assertEqual(len(model.encoder[0].blocks), 1) + self.assertEqual(len(model.encoder[1].blocks), 1) + + self.assertEqual(model.encoder[0].blocks[0].layer.in_features, 0) + self.assertEqual(model.encoder[0].blocks[0].layer.out_features, 96) + self.assertEqual(model.encoder[1].blocks[0].layer.in_features, 0) + self.assertEqual(model.encoder[1].blocks[0].layer.out_features, 96) + + self.assertIsInstance(model.backbone, RecurrentGraphBlock) + self.assertIsInstance(model.backbone.combine, CatDictElements) + self.assertIsInstance(model.backbone.layer[0][0], MultiLayerPerceptron) + self.assertIsInstance(model.backbone.layer[0][1], MultiLayerPerceptron) + self.assertIsInstance(model.backbone.layer[1], MessagePassingNeuralNetwork) + self.assertIsInstance(model.backbone.head, MultiLayerPerceptron) + + self.assertEqual(model.backbone.head.in_features, 96) + self.assertEqual(model.backbone.head.out_features, 2) + + # check default mapping + self.assertEqual(model.encoder[0].input_args, ("x",)) + self.assertEqual(model.encoder[0].output_args.keys(), {"x"}) + self.assertEqual(model.encoder[1].input_args, ("edge_attr",)) + self.assertEqual(model.encoder[1].output_args.keys(), {"edge_attr"}) + + self.assertEqual(model.backbone.combine.source, ("x", "edge_attr")) + self.assertEqual( + model.backbone.combine.target, ("hidden_x", "hidden_edge_attr") + ) + + self.assertEqual(model.backbone.layer[0][0].input_args, ("hidden_x",)) + self.assertEqual(model.backbone.layer[0][0].output_args.keys(), {"hidden_x"}) + self.assertEqual(model.backbone.layer[0][1].input_args, ("hidden_edge_attr",)) + self.assertEqual( + model.backbone.layer[0][1].output_args.keys(), {"hidden_edge_attr"} + ) + + self.assertEqual( + model.backbone.layer[1].transform[0].input_args, + ("hidden_x", "edge_index", "hidden_edge_attr"), + ) + self.assertEqual( + model.backbone.layer[1].transform[0].output_args.keys(), + {"hidden_edge_attr"}, + ) + self.assertEqual( + model.backbone.layer[1].propagate[0].input_args, + ("hidden_x", "edge_index", "hidden_edge_attr"), + ) + self.assertEqual( + model.backbone.layer[1].propagate[0].output_args.keys(), {"aggregate"} + ) + + self.assertEqual(model.backbone.layer[1].update[0].input_args, ("aggregate",)) + self.assertEqual( + model.backbone.layer[1].update[0].output_args.keys(), {"hidden_x"} + ) + + inp = {} + inp["x"] = torch.randn(10, 5) + inp["edge_index"] = torch.randint(0, 10, (2, 20)) + inp["edge_attr"] = torch.randn(20, 3) + + out = model(inp) + + self.assertEqual(len(out), 10) + for o in out: + self.assertEqual(o.shape, (10, 2)) diff --git a/requirements.txt b/requirements.txt index aeda09ad..98732d26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,6 @@ torch-geometric kornia scipy scikit-image +scikit-learn rich dill \ No newline at end of file