From f8f64fd915057ac1f3ab05758c7156cf73727a13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Fri, 15 Nov 2024 12:44:59 +0100 Subject: [PATCH 01/11] Update dict.py --- deeplay/components/dict.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/deeplay/components/dict.py b/deeplay/components/dict.py index 125cd678..7bb32782 100644 --- a/deeplay/components/dict.py +++ b/deeplay/components/dict.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, Union, Tuple, overload +from typing import Dict, Any, Union, Tuple, Dict from deeplay import DeeplayModule @@ -54,3 +54,28 @@ def forward( x.update({key: torch.add(x[key], y[key]) for key in self.keys}) return x + + +class CatDictElements(DeeplayModule): + """ + Concatenates specified elements within a dictionary-like structure along a given dimension. + + Parameters: + - keys: Tuple of tuples, where each tuple contains the source and target keys to concatenate. + - dim: Dimension along which concatenation occurs. Default is -1. + """ + + def __init__(self, *keys: Tuple[tuple], dim: int = -1): + super().__init__() + self.source, self.target = zip(*keys) + self.dim = dim + + def forward(self, x: Union[Dict[str, Any], Data]) -> Union[Dict[str, Any], Data]: + x = x.clone() if isinstance(x, Data) else x.copy() + x.update( + { + t: torch.cat([x[t], x[s]], dim=self.dim) + for t, s in zip(self.target, self.source) + } + ) + return x From 3ee7cdacb99c798f7515c255783877ed57786df8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Fri, 15 Nov 2024 12:45:25 +0100 Subject: [PATCH 02/11] Create RecurrentGraphBlock --- deeplay/components/gnn/__init__.py | 1 + deeplay/components/gnn/rgb.py | 56 ++++++++++++++++++++++++++++++ deeplay/models/gnn/__init__.py | 1 + 3 files changed, 58 insertions(+) create mode 100644 deeplay/components/gnn/rgb.py diff --git a/deeplay/components/gnn/__init__.py b/deeplay/components/gnn/__init__.py index 1f0bd47c..fd418303 100644 --- a/deeplay/components/gnn/__init__.py +++ b/deeplay/components/gnn/__init__.py @@ -1,3 +1,4 @@ from .gcn import * from .mpn import * from .tpu import * +from .rgb import * diff --git a/deeplay/components/gnn/rgb.py b/deeplay/components/gnn/rgb.py new file mode 100644 index 00000000..2e628f71 --- /dev/null +++ b/deeplay/components/gnn/rgb.py @@ -0,0 +1,56 @@ +import torch +from deeplay import DeeplayModule +from deeplay.components.dict import CatDictElements + + +class RecurrentGraphBlock(DeeplayModule): + combine: DeeplayModule + layer: DeeplayModule + head: DeeplayModule + + def __init__( + self, + layer: DeeplayModule, + head: DeeplayModule, + hidden_features: int, + num_iter: int, + combine: DeeplayModule = CatDictElements(("x", "hidden")), + ): + super().__init__() + self.combine = combine + self.layer = layer + self.head = head + self.hidden_features = hidden_features + self.num_iter = num_iter + + if not all(hasattr(self.combine, attr) for attr in ("source", "target")): + raise AttributeError( + "The 'combine' module must have 'source' and 'target' attributes to specify " + "the keys to concatenate. Found None. Ensure that the 'combine' module is initialized " + "with valid 'source' and 'target' keys. Check CatDictElements for reference." + ) + self.hidden_variables_name = self.combine.target + + def initialize_hidden(self, x): + for source, hidden_variable_name in zip( + self.combine.source, self.hidden_variables_name + ): + if hidden_variable_name not in x: + x.update( + { + hidden_variable_name: torch.zeros( + x[source].size(0), self.hidden_features + ).to(x[source].device) + } + ) + return x + + def forward(self, x): + x = self.initialize_hidden(x) + outputs = [] + for _ in range(self.num_iter): + x = self.combine(x) + x = self.layer(x) + outputs.append(self.head(x)) + + return outputs diff --git a/deeplay/models/gnn/__init__.py b/deeplay/models/gnn/__init__.py index 2facd6f0..47d1651b 100644 --- a/deeplay/models/gnn/__init__.py +++ b/deeplay/models/gnn/__init__.py @@ -1,4 +1,5 @@ from .mpm import MPM +from .rmpm import RecurrentMessagePassingModel from .gtogmpm import GraphToGlobalMPM, GlobalMeanPool from .gtonmpm import GraphToNodeMPM from .gtoempm import GraphToEdgeMPM From e36721e5b3989cfc8a816d191f91eb6206244062 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Fri, 15 Nov 2024 12:45:38 +0100 Subject: [PATCH 03/11] Create rmpm.py --- deeplay/models/gnn/rmpm.py | 115 +++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 deeplay/models/gnn/rmpm.py diff --git a/deeplay/models/gnn/rmpm.py b/deeplay/models/gnn/rmpm.py new file mode 100644 index 00000000..31a27bbf --- /dev/null +++ b/deeplay/models/gnn/rmpm.py @@ -0,0 +1,115 @@ +from typing import Type, Union + +from deeplay import ( + DeeplayModule, + Parallel, + MultiLayerPerceptron, + MessagePassingNeuralNetwork, + Sequential, + RecurrentGraphBlock, + CatDictElements, +) + +import torch.nn as nn + + +class RecurrentMessagePassingModel(DeeplayModule): + """Recurrent Message Passing Neural Network (RMPN) model.""" + + hidden_features: int + out_features: int + num_iter: int + + def __init__( + self, + hidden_features: int, + out_features: int, + num_iter: int, + out_activation: Union[Type[nn.Module], nn.Module, None] = None, + ): + super().__init__() + + self.hidden_features = hidden_features + self.out_features = out_features + self.num_iter = num_iter + + if out_features <= 0: + raise ValueError(f"out_features must be positive, got {out_features}") + + if not isinstance(hidden_features, int): + raise ValueError( + f"hidden_features must be an integer, got {hidden_features}" + ) + + if hidden_features <= 0: + raise ValueError(f"hidden_features must be positive, got {hidden_features}") + + self.encoder = Parallel( + **{ + key: MultiLayerPerceptron( + in_features=None, + hidden_features=[], + out_features=hidden_features, + flatten_input=False, + ).set_input_map(key) + for key in ("x", "edge_attr") + } + ) + + combine = CatDictElements(("x", "hidden_x"), ("edge_attr", "hidden_edge_attr")) + backbone_layer = Sequential( + [ + Parallel( + **{ + key: MultiLayerPerceptron( + in_features=None, + hidden_features=[], + out_features=hidden_features, + flatten_input=False, + ).set_input_map(key) + for key in ("hidden_x", "hidden_edge_attr") + } + ), + MessagePassingNeuralNetwork([], hidden_features), + ] + ) + head = MultiLayerPerceptron( + hidden_features, + [], + out_features, + out_activation=out_activation, + flatten_input=False, + ) + + self.backbone = RecurrentGraphBlock( + combine=combine, + layer=backbone_layer, + head=head, + hidden_features=hidden_features, + num_iter=self.num_iter, + ) + + self.backbone.layer[1].transform.set_input_map( + "hidden_x", "edge_index", "hidden_edge_attr" + ) + self.backbone.layer[1].transform.set_output_map("hidden_edge_attr") + + self.backbone.layer[1].propagate.set_input_map( + "hidden_x", "edge_index", "hidden_edge_attr" + ) + self.backbone.layer[1].propagate.set_output_map("aggregate") + + update = MultiLayerPerceptron(None, [], hidden_features, flatten_input=False) + update.set_input_map("aggregate") + update.set_output_map("hidden_x") + self.backbone.layer[1].blocks[0].replace("update", update) + + self.backbone.layer[1][..., "activation"].configure(nn.ReLU) + + self.backbone.head.set_input_map("hidden_x") + self.backbone.head.set_output_map() + + def forward(self, x): + x = self.encoder(x) + x = self.backbone(x) + return x From 3083a3215ee426ea4884ccd6de75fcd19eef9b86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Fri, 15 Nov 2024 12:45:53 +0100 Subject: [PATCH 04/11] Create MIRO --- deeplay/applications/__init__.py | 1 + deeplay/applications/clustering/__init__.py | 1 + deeplay/applications/clustering/miro.py | 103 ++++++++++++++++++++ 3 files changed, 105 insertions(+) create mode 100644 deeplay/applications/clustering/__init__.py create mode 100644 deeplay/applications/clustering/miro.py diff --git a/deeplay/applications/__init__.py b/deeplay/applications/__init__.py index 0818a8f0..0f20ce3a 100644 --- a/deeplay/applications/__init__.py +++ b/deeplay/applications/__init__.py @@ -3,6 +3,7 @@ from .regression import * from .detection import * from .autoencoders import * +from .clustering import * # from .classification import * # from .segmentation import ImageSegmentor diff --git a/deeplay/applications/clustering/__init__.py b/deeplay/applications/clustering/__init__.py new file mode 100644 index 00000000..a665019a --- /dev/null +++ b/deeplay/applications/clustering/__init__.py @@ -0,0 +1 @@ +from .miro import MIRO diff --git a/deeplay/applications/clustering/miro.py b/deeplay/applications/clustering/miro.py new file mode 100644 index 00000000..665db61d --- /dev/null +++ b/deeplay/applications/clustering/miro.py @@ -0,0 +1,103 @@ +from typing import Callable, Optional + +import numpy as np +import torch +import torch.nn as nn + +from deeplay.models import RecurrentMessagePassingModel +from deeplay.applications import Application +from deeplay.external import Optimizer, Adam + +from sklearn.cluster import DBSCAN + + +class MIRO(Application): + num_outputs: int + connectivity_radius: float + model: nn.Module + nd_loss_weight: float + loss: torch.nn.Module + metrics: list + optimizer: Optimizer + + def __init__( + self, + num_outputs: int = 2, + connectivity_radius: float = 1.0, + model: Optional[nn.Module] = None, + nd_loss_weight: float = 10, + loss: torch.nn.Module = torch.nn.L1Loss(), + optimizer=None, + **kwargs, + ): + + self.num_outputs = num_outputs + self.connectivity_radius = connectivity_radius + self.model = model or self._get_default_model() + self.nd_loss_weight = nd_loss_weight + + super().__init__(loss=loss, optimizer=optimizer or Adam(lr=1e-4), **kwargs) + + def _get_default_model(self): + rgnn = RecurrentMessagePassingModel( + hidden_features=256, out_features=self.num_outputs, num_iter=20 + ) + return rgnn + + def forward(self, x): + return self.model(x) + + def compute_loss(self, y_hat, y, edges, position): + loss = 0 + for pred in y_hat: + loss += self.loss(pred, y) + self.nd_loss_weight * self.compute_nd_loss( + pred, y, edges, position + ) + return loss / len(y_hat) + + def compute_nd_loss(self, y_hat, y, edges, position): + compressed_gt = position - y * self.connectivity_radius + compressed_gt_distances = torch.norm( + compressed_gt[edges[0]] - compressed_gt[edges[1]], dim=1 + ) + compressed_pred = position - y_hat * self.connectivity_radius + compressed_pred_distances = torch.norm( + compressed_pred[edges[0]] - compressed_pred[edges[1]], dim=1 + ) + return self.loss(compressed_pred_distances, compressed_gt_distances) + + def clustering( + self, x, eps, min_samples, from_iter=-1, scaling=np.array([1.0, 1.0]) + ): + pred = self(x)[from_iter].detach().cpu().numpy() + + squeezed = x.position.cpu() - pred * self.connectivity_radius + squeezed = squeezed.numpy() * scaling + clusters = DBSCAN(eps=eps, min_samples=min_samples).fit(squeezed) + + return clusters.labels_ + + def training_step(self, batch, batch_idx): + x, y = self.train_preprocess(batch) + y_hat = self(x) + loss = self.compute_loss(y_hat, y, x.edge_index, x.position) + + self.log( + "train_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + self.log_metrics( + "train", + y_hat, + y, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return loss From cca00596614d92d4fb1a06a414e885c4f466f197 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Fri, 15 Nov 2024 15:28:26 +0100 Subject: [PATCH 05/11] Update miro.py --- deeplay/applications/clustering/miro.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/deeplay/applications/clustering/miro.py b/deeplay/applications/clustering/miro.py index 665db61d..6b1adfee 100644 --- a/deeplay/applications/clustering/miro.py +++ b/deeplay/applications/clustering/miro.py @@ -66,13 +66,19 @@ def compute_nd_loss(self, y_hat, y, edges, position): ) return self.loss(compressed_pred_distances, compressed_gt_distances) - def clustering( - self, x, eps, min_samples, from_iter=-1, scaling=np.array([1.0, 1.0]) - ): + def squeeze(self, x, from_iter=-1, scaling=np.array([1.0, 1.0])): pred = self(x)[from_iter].detach().cpu().numpy() + return (x.position.cpu() - pred * self.connectivity_radius).numpy() * scaling - squeezed = x.position.cpu() - pred * self.connectivity_radius - squeezed = squeezed.numpy() * scaling + def clustering( + self, + x, + eps, + min_samples, + from_iter=-1, + **kwargs, + ): + squeezed = self.squeeze(x, from_iter, **kwargs) clusters = DBSCAN(eps=eps, min_samples=min_samples).fit(squeezed) return clusters.labels_ From af7a4c726689de868ecf220c5238b36061b1e9c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Fri, 15 Nov 2024 16:12:30 +0100 Subject: [PATCH 06/11] Update requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index aeda09ad..98732d26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,6 @@ torch-geometric kornia scipy scikit-image +scikit-learn rich dill \ No newline at end of file From 36b9d5772ad839a0bce662d346b3f4107d968f50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Fri, 22 Nov 2024 17:36:15 +0100 Subject: [PATCH 07/11] Create documentation --- deeplay/applications/clustering/miro.py | 53 +++++++++++++++++++++++ deeplay/components/dict.py | 7 ++- deeplay/models/gnn/rmpm.py | 57 ++++++++++++++++++++++++- 3 files changed, 114 insertions(+), 3 deletions(-) diff --git a/deeplay/applications/clustering/miro.py b/deeplay/applications/clustering/miro.py index 6b1adfee..0b2827e4 100644 --- a/deeplay/applications/clustering/miro.py +++ b/deeplay/applications/clustering/miro.py @@ -12,6 +12,41 @@ class MIRO(Application): + """ + Point cloud clustering using MIRO (Multimodal Integration through Relational Optimization). + + Parameters + ---------- + num_outputs : int + Dimensionality of the output features, representing a displacement vector in Cartesian space for each node. This vector points toward the center of each cluster. + connectivity_radius : float + Maximum distance between two nodes to consider them connected in the graph. + model : nn.Module + A model implementing the forward method. It should return a tensor of shape `(num_nodes, num_outputs)` representing the predicted displacement vectors for each node, + or a list of tensors of the same shape for predictions at each recurrent iteration (default). If not specified, a default model resembling the one from the original MIRO paper is used. + nd_loss_weight : float + Weight for the auxiliary loss that enforces preservation of pairwise distances between connected nodes. + loss : torch.nn.Module + Loss function for training. Default is `torch.nn.L1Loss`. + optimizer : Optimizer + Optimizer for training. Default is Adam with a learning rate of 1e-4. + + Clustering + ------------------ + The clustering method `clustering` leverages the predicted displacement vectors to group nodes into clusters using the DBSCAN algorithm. The displacement vector points each node toward its corresponding cluster center, enabling robust identification of clusters in the point cloud. + + Example + -------- + >>> # Perform clustering + >>> eps = 0.3 # Maximum distance for cluster connection + >>> min_samples = 5 # Minimum points to form a cluster + >>> clusters = model.clustering(test_graph, eps, min_samples) + + >>> # Output cluster labels + >>> print(clusters) + array([ 0, 0, 1, 1, 1, -1, 2, 2, 2, ...]) # Nodes in cluster 0, 1, 2, etc.; -1 are outliers + """ + num_outputs: int connectivity_radius: float model: nn.Module @@ -78,6 +113,24 @@ def clustering( from_iter=-1, **kwargs, ): + """ + Perform clustering using the DBSCAN algorithm, with MIRO preprocessing + to optimize the input point cloud for effective clustering. + + Parameters + ---------- + x : torch_geometric.data.Data + Input graph data. + eps : float + The maximum distance between two samples for one to be considered + as in the neighborhood of the other. This is not a maximum bound + on the distances of points within a cluster. This is the most + important DBSCAN parameter to choose appropriately for your data set + and distance function. + min_samples : int + The number of samples (or total weight) in a neighborhood for a point + to be considered as a core point. This includes the point itself. + """ squeezed = self.squeeze(x, from_iter, **kwargs) clusters = DBSCAN(eps=eps, min_samples=min_samples).fit(squeezed) diff --git a/deeplay/components/dict.py b/deeplay/components/dict.py index 7bb32782..cb2d1d32 100644 --- a/deeplay/components/dict.py +++ b/deeplay/components/dict.py @@ -61,8 +61,11 @@ class CatDictElements(DeeplayModule): Concatenates specified elements within a dictionary-like structure along a given dimension. Parameters: - - keys: Tuple of tuples, where each tuple contains the source and target keys to concatenate. - - dim: Dimension along which concatenation occurs. Default is -1. + - keys: Tuple[tuple] + Specifies the keys to be concatenated as tuples. Each tuple contains two keys: source and target. + The source key is the key to be concatenated with the target key. + - dim: int + Specifies the dimension along which the concatenation is performed. """ def __init__(self, *keys: Tuple[tuple], dim: int = -1): diff --git a/deeplay/models/gnn/rmpm.py b/deeplay/models/gnn/rmpm.py index 31a27bbf..c64745af 100644 --- a/deeplay/models/gnn/rmpm.py +++ b/deeplay/models/gnn/rmpm.py @@ -14,7 +14,62 @@ class RecurrentMessagePassingModel(DeeplayModule): - """Recurrent Message Passing Neural Network (RMPN) model.""" + """Recurrent Message Passing Neural Network (RMPN) model. + + Parameters + ---------- + hidden_features: int + Number of hidden units in the recurrent message passing layer. + out_features: int + Number of output features. + num_iter: int + Number of iterations of the recurrent message passing layer. + out_activation: template-like + Specification for the output activation of the model. Default: nn.Identity. + + + Configurables + ------------- + - hidden_features (int): Number of hidden units in the recurrent message passing layer. + - out_features (int): Number of output features. + - out_activation (template-like): Specification for the output activation of the model. Default: nn.Identity. + - encoder (template-like): Specification for the encoder of the model. Default: dl.Parallel consisting of two MLPs to process node and edge features. + - backbone (template-like): Specification for the backbone of the model. Default: dl.RecurrentGraphBlock consisting of dl.MessagePassingNeuralNetwork and a MLP head. + + Constraints + ----------- + - input: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_nodes, node_in_features). + - edge_index: torch.Tensor of shape (2, num_edges). + - edge_attr: torch.Tensor of shape (num_edges, edge_in_features). + - hidden_x: (Optional) torch.Tensor of shape (num_nodes, hidden_features). + - hidden_edge_attr: (Optional) torch.Tensor of shape (num_edges, hidden_features). + + NOTE: node_in_features and edge_in_features are inferred from the input data. + + - output: List[torch.Tensor] where each tensor has shape (num_nodes, out_features). + + Examples + -------- + >>> # Define a RMPN model with 96 hidden features, 2 output features, and 3 iterations + >>> model = RecurrentMessagePassingModel(hidden_features=96, out_features=2, num_iter=3) + + >>> # Input graph data + >>> inp = { + >>> "x": torch.randn(10, 5), # Node features + >>> "edge_index": torch.randint(0, 10, (2, 20)), # Edge connectivity + >>> "edge_attr": torch.randn(20, 3), # Edge features + >>> } + + >>> # Model forward pass + >>> out = model(inp) + + >>> # Output shape + >>> print(len(out)) + 3 + >>> print(out[0].shape) + torch.Size([10, 2]) + """ hidden_features: int out_features: int From 474ae4ba9726c6ccffbd525f437356e0ade4fa20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Fri, 22 Nov 2024 17:36:25 +0100 Subject: [PATCH 08/11] Update rgb.py --- deeplay/components/gnn/rgb.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deeplay/components/gnn/rgb.py b/deeplay/components/gnn/rgb.py index 2e628f71..d25dbdf8 100644 --- a/deeplay/components/gnn/rgb.py +++ b/deeplay/components/gnn/rgb.py @@ -1,4 +1,5 @@ import torch +from torch_geometric.data import Data from deeplay import DeeplayModule from deeplay.components.dict import CatDictElements @@ -32,6 +33,7 @@ def __init__( self.hidden_variables_name = self.combine.target def initialize_hidden(self, x): + x = x.clone() if isinstance(x, Data) else x.copy() for source, hidden_variable_name in zip( self.combine.source, self.hidden_variables_name ): From 7a1b8dbf813b1677615f7cd5eed11f21d61699d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Fri, 22 Nov 2024 17:37:05 +0100 Subject: [PATCH 09/11] Create unittests --- deeplay/tests/test_dict.py | 12 +++++ deeplay/tests/test_gnn.py | 107 +++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+) diff --git a/deeplay/tests/test_dict.py b/deeplay/tests/test_dict.py index 4098e824..cb86d94b 100644 --- a/deeplay/tests/test_dict.py +++ b/deeplay/tests/test_dict.py @@ -5,6 +5,7 @@ Layer, LayerSkip, AddDict, + CatDictElements, Parallel, DeeplayModule, ) @@ -92,3 +93,14 @@ def test_add_with_base_dict(self): # Checks that the base dict is correctly passed self.assertEqual(inp.y, 3) self.assertEqual(len(out.y), 10) + + def test_cat_dict_elems(self): + inp = {} + inp["x"] = torch.Tensor([1]) + inp["y"] = torch.Tensor([1, 1]) + + block = CatDictElements(("x", "y")).create() + out = block(inp) + + self.assertEqual(out["x"].shape, torch.Size([1])) + self.assertEqual(out["y"].shape, torch.Size([3])) diff --git a/deeplay/tests/test_gnn.py b/deeplay/tests/test_gnn.py index 91862992..906c43d0 100644 --- a/deeplay/tests/test_gnn.py +++ b/deeplay/tests/test_gnn.py @@ -12,6 +12,8 @@ GraphToEdgeMAGIK, MessagePassingNeuralNetwork, ResidualMessagePassingNeuralNetwork, + RecurrentMessagePassingModel, + RecurrentGraphBlock, MultiLayerPerceptron, dense_laplacian_normalization, Sum, @@ -22,6 +24,7 @@ Max, Layer, GlobalMeanPool, + CatDictElements, ) import itertools @@ -729,3 +732,107 @@ def test_gtoempm_defaults(self): out = model(inp) self.assertEqual(out.shape, (20, 1)) + + +class TestModelRecurrentMPM(unittest.TestCase): + def test_recurrent_graph_block_defaults(self): + model = RecurrentGraphBlock( + combine=CatDictElements(("x", "hidden")), + layer=Layer(nn.Identity), + head=Layer(nn.Identity), + hidden_features=64, + num_iter=1, + ) + model = model.create() + + self.assertIsInstance(model.combine, CatDictElements) + self.assertIsInstance(model.layer, nn.Identity) + self.assertIsInstance(model.head, nn.Identity) + + self.assertEqual(model.hidden_features, 64) + self.assertEqual(model.num_iter, 1) + + # assess the case where hidden is provided + inp = {} + inp["x"] = torch.ones(10, 64) + out = model(inp) + + self.assertTrue(torch.all(out[0]["hidden"][:, :64] == torch.zeros(10, 64))) + + # assess the case where hidden is provided + inp["hidden"] = torch.ones(10, 64) * 2 + out = model(inp) + + self.assertTrue(torch.all(out[0]["hidden"][:, :64] == torch.ones(10, 64) * 2)) + + def test_RMPM_defaults(self): + model = RecurrentMessagePassingModel(96, 2, num_iter=10) + model = model.create() + + self.assertEqual(len(model.encoder[0].blocks), 1) + self.assertEqual(len(model.encoder[1].blocks), 1) + + self.assertEqual(model.encoder[0].blocks[0].layer.in_features, 0) + self.assertEqual(model.encoder[0].blocks[0].layer.out_features, 96) + self.assertEqual(model.encoder[1].blocks[0].layer.in_features, 0) + self.assertEqual(model.encoder[1].blocks[0].layer.out_features, 96) + + self.assertIsInstance(model.backbone, RecurrentGraphBlock) + self.assertIsInstance(model.backbone.combine, CatDictElements) + self.assertIsInstance(model.backbone.layer[0][0], MultiLayerPerceptron) + self.assertIsInstance(model.backbone.layer[0][1], MultiLayerPerceptron) + self.assertIsInstance(model.backbone.layer[1], MessagePassingNeuralNetwork) + self.assertIsInstance(model.backbone.head, MultiLayerPerceptron) + + self.assertEqual(model.backbone.head.in_features, 96) + self.assertEqual(model.backbone.head.out_features, 2) + + # check default mapping + self.assertEqual(model.encoder[0].input_args, ("x",)) + self.assertEqual(model.encoder[0].output_args.keys(), {"x"}) + self.assertEqual(model.encoder[1].input_args, ("edge_attr",)) + self.assertEqual(model.encoder[1].output_args.keys(), {"edge_attr"}) + + self.assertEqual(model.backbone.combine.source, ("x", "edge_attr")) + self.assertEqual( + model.backbone.combine.target, ("hidden_x", "hidden_edge_attr") + ) + + self.assertEqual(model.backbone.layer[0][0].input_args, ("hidden_x",)) + self.assertEqual(model.backbone.layer[0][0].output_args.keys(), {"hidden_x"}) + self.assertEqual(model.backbone.layer[0][1].input_args, ("hidden_edge_attr",)) + self.assertEqual( + model.backbone.layer[0][1].output_args.keys(), {"hidden_edge_attr"} + ) + + self.assertEqual( + model.backbone.layer[1].transform[0].input_args, + ("hidden_x", "edge_index", "hidden_edge_attr"), + ) + self.assertEqual( + model.backbone.layer[1].transform[0].output_args.keys(), + {"hidden_edge_attr"}, + ) + self.assertEqual( + model.backbone.layer[1].propagate[0].input_args, + ("hidden_x", "edge_index", "hidden_edge_attr"), + ) + self.assertEqual( + model.backbone.layer[1].propagate[0].output_args.keys(), {"aggregate"} + ) + + self.assertEqual(model.backbone.layer[1].update[0].input_args, ("aggregate",)) + self.assertEqual( + model.backbone.layer[1].update[0].output_args.keys(), {"hidden_x"} + ) + + inp = {} + inp["x"] = torch.randn(10, 5) + inp["edge_index"] = torch.randint(0, 10, (2, 20)) + inp["edge_attr"] = torch.randn(20, 3) + + out = model(inp) + + self.assertEqual(len(out), 10) + for o in out: + self.assertEqual(o.shape, (10, 2)) From 4d28297c6207b6f7ed6acedd75f6dd6097d2147c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Fri, 13 Dec 2024 14:40:20 +0100 Subject: [PATCH 10/11] update documentation and readability --- deeplay/applications/clustering/miro.py | 234 +++++++++++++++++------- deeplay/components/dict.py | 81 ++++++-- deeplay/components/gnn/rgb.py | 106 ++++++++++- deeplay/models/gnn/rmpm.py | 108 +++++++---- 4 files changed, 412 insertions(+), 117 deletions(-) diff --git a/deeplay/applications/clustering/miro.py b/deeplay/applications/clustering/miro.py index 0b2827e4..622c1d23 100644 --- a/deeplay/applications/clustering/miro.py +++ b/deeplay/applications/clustering/miro.py @@ -1,50 +1,106 @@ -from typing import Callable, Optional +"""MIRO: Multimodal Integration through Relational Optimization + +This module provides the MIRO framework for point cloud clustering, leveraging +advanced geometric deep learning techniques. MIRO transforms complex point +clouds into optimized representations, enabling more effective clustering +using traditional algorithms. + +Based on the original MIRO paper by Pineda et al. [1], this implementation offers +easy-to-use methods for training the MIRO model and performing geometric-aware +clustering. It integrates recurrent graph neural networks to refine point +cloud data and enhance clustering accuracy. + +[1] Pineda, Jesús, et al. "Spatial Clustering of Molecular Localizations with + Graph Neural Networks." arXiv preprint arXiv:2412.00173 (2024). +""" import numpy as np import torch import torch.nn as nn +from torch_geometric.data import Data +from sklearn.cluster import DBSCAN +from typing import Callable, Optional, List -from deeplay.models import RecurrentMessagePassingModel from deeplay.applications import Application -from deeplay.external import Optimizer, Adam - -from sklearn.cluster import DBSCAN +from deeplay.external import Adam, Optimizer +from deeplay.models import RecurrentMessagePassingModel class MIRO(Application): - """ - Point cloud clustering using MIRO (Multimodal Integration through Relational Optimization). + """Point cloud clustering using MIRO (Multimodal Integration through + Relational Optimization). + + MIRO is a geometric deep learning framework that enhances clustering + algorithms by transforming complex point clouds into an optimized structure + amenable to conventional clustering methods. MIRO employs recurrent graph + neural networks (rGNNs) to learn a transformation that squeezes localization + belonging to the same cluster toward a common center, resulting in a compact + representation of clusters within the point cloud. Parameters ---------- num_outputs : int - Dimensionality of the output features, representing a displacement vector in Cartesian space for each node. This vector points toward the center of each cluster. + Dimensionality of the output features, representing a displacement + vector in Cartesian space for each node. This vector points toward + the center of each cluster. connectivity_radius : float - Maximum distance between two nodes to consider them connected in the graph. + Maximum distance between two nodes to consider them connected in the + graph. model : nn.Module - A model implementing the forward method. It should return a tensor of shape `(num_nodes, num_outputs)` representing the predicted displacement vectors for each node, - or a list of tensors of the same shape for predictions at each recurrent iteration (default). If not specified, a default model resembling the one from the original MIRO paper is used. + A model implementing the forward method. It should return a list of + tensors of shape `(num_nodes, num_outputs)` representing the predicted + displacement vectors for each node at each recurrent iteration. If not + specified, a default model resembling the one from the original MIRO + paper is used. nd_loss_weight : float - Weight for the auxiliary loss that enforces preservation of pairwise distances between connected nodes. + Weight for the auxiliary loss that enforces preservation of pairwise + distances between connected nodes. Default is 10. loss : torch.nn.Module Loss function for training. Default is `torch.nn.L1Loss`. optimizer : Optimizer Optimizer for training. Default is Adam with a learning rate of 1e-4. - Clustering - ------------------ - The clustering method `clustering` leverages the predicted displacement vectors to group nodes into clusters using the DBSCAN algorithm. The displacement vector points each node toward its corresponding cluster center, enabling robust identification of clusters in the point cloud. + Returns + ------- + forward : method + Computes and returns the predicted displacement vectors for each node + in the input graph. The output is a list of tensors representing the + displacement vectors at each recurrent iteration. + + squeeze : method + Applies the predicted displacement vectors from the last recurrent + iteration (by default) to the input point cloud. This operation + optimizes the point cloud for clustering by aligning nodes closer to + their respective cluster centers. + + clustering : method + Groups nodes into clusters using the DBSCAN algorithm, based on the + predicted displacement vectors. Each node is assigned a cluster label, + where -1 indicates background noise. Returns an array of cluster labels + for the nodes. Example - -------- - >>> # Perform clustering + ------- + >>> # Predicts displacement vectors for each node in a point cloud at each + >>> # recurrent iteration + >>> displacement_vectors = model(test_graph) + >>> print(type(displacement_vectors)) + + + >>> # Applies the predicted displacement vectors to the input point cloud + >>> squeezed = model.squeeze(test_graph) + >>> print(squeezed.shape) + (num_nodes, 2) + + >>> # Performs clustering using DBSCAN after MIRO squeezing >>> eps = 0.3 # Maximum distance for cluster connection - >>> min_samples = 5 # Minimum points to form a cluster + >>> min_samples = 5 # Minimum points in a neighborhood for core points >>> clusters = model.clustering(test_graph, eps, min_samples) >>> # Output cluster labels >>> print(clusters) - array([ 0, 0, 1, 1, 1, -1, 2, 2, 2, ...]) # Nodes in cluster 0, 1, 2, etc.; -1 are outliers + array([ 0, 0, 1, 1, 1, -1, 2, 2, 2, ...]) + # Nodes in cluster 0, 1, 2, etc.; -1 are outliers """ num_outputs: int @@ -79,43 +135,65 @@ def _get_default_model(self): ) return rgnn - def forward(self, x): + def forward(self, x: Data) -> List[torch.Tensor]: + """Forward pass to compute predicted displacement vectors for each node. + + Parameters + ---------- + x : torch_geometric.data.Data + Input graph data. It is expected to have the attributes: + `x` (node features), `edge_index` (graph connectivity), + `edge_attr` (edge features), and `positions` (node spatial coordinates). + + Returns + ------- + list[torch.Tensor] + Predicted displacement vectors at each recurrent iteration. + """ return self.model(x) - def compute_loss(self, y_hat, y, edges, position): - loss = 0 - for pred in y_hat: - loss += self.loss(pred, y) + self.nd_loss_weight * self.compute_nd_loss( - pred, y, edges, position - ) - return loss / len(y_hat) + def squeeze( + self, + x: Data, + from_iter: int = -1, + scaling: np.ndarray = np.array([1.0, 1.0]), + ) -> np.ndarray: + """Computes and applies the predicted displacement vectors to the + input point cloud. - def compute_nd_loss(self, y_hat, y, edges, position): - compressed_gt = position - y * self.connectivity_radius - compressed_gt_distances = torch.norm( - compressed_gt[edges[0]] - compressed_gt[edges[1]], dim=1 - ) - compressed_pred = position - y_hat * self.connectivity_radius - compressed_pred_distances = torch.norm( - compressed_pred[edges[0]] - compressed_pred[edges[1]], dim=1 - ) - return self.loss(compressed_pred_distances, compressed_gt_distances) + Parameters + ---------- + x : torch_geometric.data.Data + Input graph data. It is expected to have the attributes: + `x` (node features), `edge_index` (graph connectivity), + `edge_attr` (edge features), and `positions` (node spatial coordinates). + from_iter : int, optional + Index of the recurrent iteration to be used as displacement vectors. + Default is -1 (last iteration). + scaling : np.ndarray, optional + Scaling factors for each dimension. Default is [1.0, 1.0]. - def squeeze(self, x, from_iter=-1, scaling=np.array([1.0, 1.0])): - pred = self(x)[from_iter].detach().cpu().numpy() - return (x.position.cpu() - pred * self.connectivity_radius).numpy() * scaling + Returns + ------- + np.ndarray + Squeezed point cloud with optimized cluster alignment. + """ + predicted_displacements = self(x)[from_iter].detach().cpu().numpy() + positions = x.position.cpu().numpy() + squeezed_positions = ( + positions - predicted_displacements * self.connectivity_radius + ) + return squeezed_positions * scaling def clustering( self, - x, - eps, - min_samples, - from_iter=-1, + x: Data, + eps: float, + min_samples: int, + from_iter: int = -1, **kwargs, - ): - """ - Perform clustering using the DBSCAN algorithm, with MIRO preprocessing - to optimize the input point cloud for effective clustering. + ) -> np.ndarray: + """Perform clustering using DBSCAN after applying MIRO squeezing. Parameters ---------- @@ -130,33 +208,63 @@ def clustering( min_samples : int The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. This includes the point itself. + from_iter : int, optional + Index of the recurrent iteration to be used as displacement vectors. + Default is -1 (last iteration). + + Returns + ------- + np.ndarray + Cluster labels for each node. -1 indicates outliers. """ squeezed = self.squeeze(x, from_iter, **kwargs) clusters = DBSCAN(eps=eps, min_samples=min_samples).fit(squeezed) - return clusters.labels_ - def training_step(self, batch, batch_idx): + def training_step(self, batch, batch_idx: int) -> torch.Tensor: + """Defines the training step for a single batch.""" x, y = self.train_preprocess(batch) y_hat = self(x) loss = self.compute_loss(y_hat, y, x.edge_index, x.position) self.log( - "train_loss", - loss, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, + "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True ) self.log_metrics( - "train", - y_hat, - y, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, + "train", y_hat, y, on_step=True, on_epoch=True, prog_bar=True, logger=True ) return loss + + def compute_loss( + self, + y_hat: List[torch.Tensor], + y: torch.Tensor, + edges: torch.Tensor, + position: torch.Tensor, + ) -> torch.Tensor: + """Computes the total loss for the model.""" + loss = 0 + for pred in y_hat: + loss += self.loss(pred, y) + self.nd_loss_weight * self.compute_nd_loss( + pred, y, edges, position + ) + return loss / len(y_hat) + + def compute_nd_loss( + self, + y_hat: torch.Tensor, + y: torch.Tensor, + edges: torch.Tensor, + position: torch.Tensor, + ) -> torch.Tensor: + """Computes auxiliary loss for pairwise distance preservation.""" + squeezed_gt = position - y * self.connectivity_radius + squeezed_gt_distances = torch.norm( + squeezed_gt[edges[0]] - squeezed_gt[edges[1]], dim=1 + ) + squeezed_pred = position - y_hat * self.connectivity_radius + squeezed_pred_distances = torch.norm( + squeezed_pred[edges[0]] - squeezed_pred[edges[1]], dim=1 + ) + return self.loss(squeezed_pred_distances, squeezed_gt_distances) diff --git a/deeplay/components/dict.py b/deeplay/components/dict.py index cb2d1d32..120ee435 100644 --- a/deeplay/components/dict.py +++ b/deeplay/components/dict.py @@ -1,4 +1,19 @@ -from typing import Dict, Any, Union, Tuple, Dict +"""Utility Modules for Dictionary and Graph Data Operations + +This module contains utilities for operations involving dictionary-like +structures or PyTorch Geometric `Data` objects. These operations are useful in +geometric deep learning pipelines, where input data is often stored in a +structured format with various attributes. + +Classes: +- FromDict: Extracts specified keys from a dictionary or `Data` object. +- AddDict: Performs element-wise addition for specified keys in two + dictionaries or `Data` objects. +- CatDictElements: Concatenates specified elements within a dictionary or + `Data` object along a given dimension. +""" + +from typing import Dict, Any, Union, Tuple from deeplay import DeeplayModule @@ -7,6 +22,25 @@ class FromDict(DeeplayModule): + """Extract specified keys from a dictionary-like structure. + + Parameters + ---------- + keys : str + The keys to extract from the input dictionary. + + Returns + ------- + Any or Tuple[Any, ...] + The values corresponding to the specified keys. + + Example + ------- + >>> extractor = FromDict("key1", "key2").create() + >>> result = extractor({"key1": value1, "key2": value2}) + (value1, value2) + """ + def __init__(self, *keys: str): super().__init__() self.keys = keys @@ -23,8 +57,7 @@ def extra_repr(self) -> str: class AddDict(DeeplayModule): - """ - Element-wise addition of two dictionaries. + """Element-wise addition of two dictionaries. Parameters ---------- @@ -33,10 +66,19 @@ class AddDict(DeeplayModule): Constraints ----------- - - Both dictionaries 'x' (base) and 'y' (addition) must contain the same keys for the addition operation. - - - 'x': Dict[str, Any] or torch_geometric.data.Data. - - 'y': Dict[str, Any] or torch_geometric.data.Data. + - Both dictionaries `x` (base) and `y` (addition) must contain the same + keys for the addition operation. + + - Input types: + - `x`: Dict[str, Any] or `torch_geometric.data.Data` + - `y`: Dict[str, Any] or `torch_geometric.data.Data` + + Example + ------- + >>> adder = AddDict("key1", "key2").create() + >>> result = adder({"key1": value1, "key2": value2}, + {"key1": 1, "key2": 2}) + {"key1": value1 + 1, "key2": value2 + 2} """ def __init__(self, *keys: str): @@ -57,15 +99,26 @@ def forward( class CatDictElements(DeeplayModule): - """ - Concatenates specified elements within a dictionary-like structure along a given dimension. + """Concatenates specified elements within a dictionary-like structure along + a given dimension. - Parameters: - - keys: Tuple[tuple] - Specifies the keys to be concatenated as tuples. Each tuple contains two keys: source and target. - The source key is the key to be concatenated with the target key. - - dim: int + Parameters + ---------- + keys : Tuple[tuple] + Specifies the keys to be concatenated as tuples. Each tuple contains + two keys: source and target. The source key is the key to be + concatenated with the target key. + dim : int, optional Specifies the dimension along which the concatenation is performed. + Default is -1. + + Example + ------- + >>> concat = CatDictElements(("key1", "key2"), ("key3", "key4")).create() + >>> result = concat({"key1": tensor1, "key2": tensor2, + "key3": tensor3, "key4": tensor4}) + {"key2": torch.cat([tensor2, tensor1], dim=-1), + "key4": torch.cat([tensor4, tensor3], dim=-1)} """ def __init__(self, *keys: Tuple[tuple], dim: int = -1): diff --git a/deeplay/components/gnn/rgb.py b/deeplay/components/gnn/rgb.py index d25dbdf8..a898a63e 100644 --- a/deeplay/components/gnn/rgb.py +++ b/deeplay/components/gnn/rgb.py @@ -1,3 +1,10 @@ +"""Recurrent Graph Block Module + +This module defines the `RecurrentGraphBlock` class, a component designed for +recurrent graph-based computations. It employs a recurrent structure to process +graph data iteratively using specified combine, layer, and head modules. +""" + import torch from torch_geometric.data import Data from deeplay import DeeplayModule @@ -5,6 +12,72 @@ class RecurrentGraphBlock(DeeplayModule): + """Recurrent graph processing block for iterative feature transformation. + + This module combines graph data features and hidden states, processes + them through a recurrent structure, and generates outputs using a + specified head module. It supports modular design for flexibility in + defining the combine, layer, and head operations. + + Parameters + ---------- + layer : DeeplayModule + Module that applies transformations to the graph data at each + iteration. + head : DeeplayModule + Module that processes the output from the layer and generates + final predictions. + hidden_features : int + The number of hidden features for the recurrent block. + num_iter : int + The number of recurrent iterations. + combine : DeeplayModule, optional + The module responsible for combining graph features with hidden + states. Default is `CatDictElements(("x", "hidden"))`. + + Returns + ------- + list + A list of outputs generated at each recurrent iteration. + + Raises + ------ + AttributeError + If the `combine` module does not contain `source` and `target` + attributes. Graph `combine` modules compatible with + `RecurrentGraphBlock` must have `source` and `target` attributes to + specify the keys to concatenate. A catalog of `combine` operations + is available in the `deeplay.components.dict` module. + + Example + ------- + >>> combine = CatDictElements(("x", "hidden")) + >>> layer = dl.MessagePassingNeuralNetwork([], 128) + >>> head = MultiLayerPerceptron(128, [], 128) + >>> block = RecurrentGraphBlock( + ... layer=layer, + ... head=head, + ... combine=combine, + ... hidden_features=128, + ... num_iter=10 + ... ) + + >>> # Set maps for input and output keys + >>> block.head.set_input_map("x") + >>> block.head.set_output_map("x") + >>> block = block.create() + + >>> # Create input graph data + >>> data = Data( + ... x=torch.randn(3, 128), + ... edge_index=torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]), + ... edge_attr=torch.randn(4, 128) + ... ) + >>> outputs = block(data) + >>> len(outputs) + 10 + """ + combine: DeeplayModule layer: DeeplayModule head: DeeplayModule @@ -26,13 +99,28 @@ def __init__( if not all(hasattr(self.combine, attr) for attr in ("source", "target")): raise AttributeError( - "The 'combine' module must have 'source' and 'target' attributes to specify " - "the keys to concatenate. Found None. Ensure that the 'combine' module is initialized " - "with valid 'source' and 'target' keys. Check CatDictElements for reference." + "The 'combine' module must have 'source' and 'target' attributes. " + "These specify the keys to concatenate. Ensure the 'combine' " + "module is initialized with valid 'source' and 'target' keys. " + "Refer to the `CatDictElements` class in the `deeplay.components.dict` " + "module for guidance." ) self.hidden_variables_name = self.combine.target def initialize_hidden(self, x): + """Initialize hidden states for the graph nodes if not already provided + in the input data. + + Parameters + ---------- + x : Data or dict + The input graph data or dictionary-like structure. + + Returns + ------- + Data or dict + The input graph data with initialized hidden states. + """ x = x.clone() if isinstance(x, Data) else x.copy() for source, hidden_variable_name in zip( self.combine.source, self.hidden_variables_name @@ -48,6 +136,18 @@ def initialize_hidden(self, x): return x def forward(self, x): + """Forward pass to process the graph data through recurrent iterations. + + Parameters + ---------- + x : Data or dict + The input graph data or dictionary-like structure. + + Returns + ------- + list + A list of outputs generated at each recurrent iteration. + """ x = self.initialize_hidden(x) outputs = [] for _ in range(self.num_iter): diff --git a/deeplay/models/gnn/rmpm.py b/deeplay/models/gnn/rmpm.py index c64745af..97ad9195 100644 --- a/deeplay/models/gnn/rmpm.py +++ b/deeplay/models/gnn/rmpm.py @@ -1,3 +1,10 @@ +"""Recurrent Message Passing Neural Network (RMPN) Model + +This module defines the `RecurrentMessagePassingModel`, a ,neural network for +recurrent graph-based computations. It processes graph data iteratively using +message-passing mechanisms combined with recurrent structures. +""" + from typing import Type, Union from deeplay import ( @@ -16,58 +23,73 @@ class RecurrentMessagePassingModel(DeeplayModule): """Recurrent Message Passing Neural Network (RMPN) model. + RMPN processes graph data iteratively through a combination of an encoder + and a recurrent message passing layer. The encoder transforms input node + and edge features into a common hidden representation. The recurrent + message passing layer updates hidden node and edge representations by + concatenating them with the input features and processing them through a + message-passing neural network. Outputs for each iteration are collected + in a list and returned as the model's final output. + Parameters ---------- - hidden_features: int + hidden_features : int Number of hidden units in the recurrent message passing layer. - out_features: int + out_features : int Number of output features. - num_iter: int + num_iter : int Number of iterations of the recurrent message passing layer. - out_activation: template-like - Specification for the output activation of the model. Default: nn.Identity. + out_activation : template-like, optional + Activation function applied to the output. Default is `nn.Identity`. + Raises + ------ + ValueError + If `out_features` or `hidden_features` are non-positive. Configurables ------------- - - hidden_features (int): Number of hidden units in the recurrent message passing layer. + - hidden_features (int): Number of hidden units in the recurrent message + passing layer. - out_features (int): Number of output features. - - out_activation (template-like): Specification for the output activation of the model. Default: nn.Identity. - - encoder (template-like): Specification for the encoder of the model. Default: dl.Parallel consisting of two MLPs to process node and edge features. - - backbone (template-like): Specification for the backbone of the model. Default: dl.RecurrentGraphBlock consisting of dl.MessagePassingNeuralNetwork and a MLP head. + - out_activation (template-like): Specification for the output activation + of the model. Default: nn.Identity. + - encoder (template-like): Specification for the encoder of the model. + Default: dl.Parallel consisting of two MLPs to process node and edge features. + - backbone (template-like): Specification for the backbone of the model. + Default: dl.RecurrentGraphBlock consisting of dl.MessagePassingNeuralNetwork and + a MLP head. Constraints ----------- - - input: Dict[str, Any] or torch-geometric Data object containing the following attributes: - - x: torch.Tensor of shape (num_nodes, node_in_features). - - edge_index: torch.Tensor of shape (2, num_edges). - - edge_attr: torch.Tensor of shape (num_edges, edge_in_features). - - hidden_x: (Optional) torch.Tensor of shape (num_nodes, hidden_features). - - hidden_edge_attr: (Optional) torch.Tensor of shape (num_edges, hidden_features). - - NOTE: node_in_features and edge_in_features are inferred from the input data. - - - output: List[torch.Tensor] where each tensor has shape (num_nodes, out_features). - - Examples - -------- - >>> # Define a RMPN model with 96 hidden features, 2 output features, and 3 iterations + - Input graph data must include: + - `x`: Node features of shape (num_nodes, node_in_features). + - `edge_index`: Edge connectivity of shape (2, num_edges). + - `edge_attr`: Edge features of shape (num_edges, edge_in_features). + - Optional attributes: + - `hidden_x`: Node hidden states of shape (num_nodes, hidden_features). + - `hidden_edge_attr`: Edge hidden states of shape (num_edges, hidden_features). + If not provided, they are initialized as zeros. + - Input can be provided as a dictionary or a `torch_geometric.data.Data` object. + + Returns + ------- + List[torch.Tensor] + List of tensors where each tensor corresponds to the output at an + iteration step, with shape (num_nodes, out_features). + + Example + ------- >>> model = RecurrentMessagePassingModel(hidden_features=96, out_features=2, num_iter=3) - - >>> # Input graph data - >>> inp = { - >>> "x": torch.randn(10, 5), # Node features - >>> "edge_index": torch.randint(0, 10, (2, 20)), # Edge connectivity - >>> "edge_attr": torch.randn(20, 3), # Edge features - >>> } - - >>> # Model forward pass - >>> out = model(inp) - - >>> # Output shape - >>> print(len(out)) + >>> graph_data = { + ... "x": torch.randn(10, 5), + ... "edge_index": torch.randint(0, 10, (2, 20)), + ... "edge_attr": torch.randn(20, 3), + ... } + >>> outputs = model(graph_data) + >>> print(len(outputs)) 3 - >>> print(out[0].shape) + >>> print(outputs[0].shape) torch.Size([10, 2]) """ @@ -165,6 +187,18 @@ def __init__( self.backbone.head.set_output_map() def forward(self, x): + """Forward pass. + + Parameters + ---------- + x : dict or torch_geometric.data.Data + Input graph data containing node and edge features. + + Returns + ------- + list + A list of tensors representing the model's output at each iteration. + """ x = self.encoder(x) x = self.backbone(x) return x From 22b2d7a2643a2d0337c6934f180e5249d215596b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Fri, 13 Dec 2024 14:50:56 +0100 Subject: [PATCH 11/11] Update rmpm.py --- deeplay/models/gnn/rmpm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deeplay/models/gnn/rmpm.py b/deeplay/models/gnn/rmpm.py index 97ad9195..1b675061 100644 --- a/deeplay/models/gnn/rmpm.py +++ b/deeplay/models/gnn/rmpm.py @@ -80,7 +80,9 @@ class RecurrentMessagePassingModel(DeeplayModule): Example ------- - >>> model = RecurrentMessagePassingModel(hidden_features=96, out_features=2, num_iter=3) + >>> model = RecurrentMessagePassingModel(hidden_features=96, + out_features=2, + num_iter=3).create() >>> graph_data = { ... "x": torch.randn(10, 5), ... "edge_index": torch.randint(0, 10, (2, 20)),