diff --git a/deeplay/applications/autoencoders/__init__.py b/deeplay/applications/autoencoders/__init__.py index 191aac42..d4f255af 100644 --- a/deeplay/applications/autoencoders/__init__.py +++ b/deeplay/applications/autoencoders/__init__.py @@ -1,2 +1,3 @@ from .vae import VariationalAutoEncoder from .wae import WassersteinAutoEncoder +from .vgae import VariationalGraphAutoEncoder diff --git a/deeplay/applications/autoencoders/vgae.py b/deeplay/applications/autoencoders/vgae.py new file mode 100644 index 00000000..b0aeae85 --- /dev/null +++ b/deeplay/applications/autoencoders/vgae.py @@ -0,0 +1,160 @@ +from typing import Optional, Sequence, Callable, List + +from deeplay.components import ConvolutionalEncoder2d, ConvolutionalDecoder2d +from deeplay.applications import Application +from deeplay.external import External, Optimizer, Adam + +from deeplay import ( + DeeplayModule, + Layer, +) + +import torch +import torch.nn as nn + + +class VariationalGraphAutoEncoder(Application): + """ Variational Auto Encoder for Graphs + + Parameters + ---------- + encoder : nn.Module + decoder : nn.Module + hidden_features : int + Number of features of the hidden layers + latent_dim: int + Number of latent dimensions + alpha: float + Weighting for the node feature reconstruction loss. Defaults to 1 + beta: float + Weighting for the KL loss. Defaults to 1e-7 + gamma: float + Weighting for the edge feature reconstruction loss. Defaults to 1 + delta: float + Weighting for the MinCut loss. Defaults to 1 + reconstruction_loss: Reconstruction loss + Loss metric for the reconstruction of the node and edge features. Defaults to L1 (Mean absolute error). + optimizer: Optimizer + Optimizer to use for training. + """ + + encoder: torch.nn.Module + decoder: torch.nn.Module + hidden_features: int + latent_dim: int + alpha: float + beta: float + gamma: float + delta: float + reconstruction_loss: torch.nn.Module + optimizer: Optimizer + + def __init__( + self, + encoder: Optional[nn.Module] = None, + decoder: Optional[nn.Module] = None, + hidden_features: Optional[int] = 96, + latent_dim=int, + alpha: Optional[float] = 1, + beta: Optional[float] = 1e-7, + gamma: Optional[float] = 1, + delta: Optional[float] = 1, + reconstruction_loss: Optional[Callable] = nn.L1Loss(), + optimizer=None, + **kwargs, + ): + self.encoder = encoder + + self.fc_mu = Layer(nn.Linear, hidden_features, latent_dim) + self.fc_mu.set_input_map('x') + self.fc_mu.set_output_map('mu') + + self.fc_var = Layer(nn.Linear, hidden_features, latent_dim) + self.fc_var.set_input_map('x') + self.fc_var.set_output_map('log_var') + + self.fc_dec = Layer(nn.Linear, latent_dim, hidden_features) + self.fc_dec.set_input_map('z') + self.fc_dec.set_output_map('x') + + self.decoder = decoder + + self.reconstruction_loss = reconstruction_loss or nn.L1Loss() + self.latent_dim = latent_dim + self.alpha = alpha + self.beta = beta + self.gamma = gamma + self.delta = delta + + super().__init__(**kwargs) + + class Reparameterize(DeeplayModule): + def forward(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + self.reparameterize = Reparameterize() + self.reparameterize.set_input_map('mu', 'log_var') + self.reparameterize.set_output_map('z') + + self.optimizer = optimizer or Adam(lr=1e-3) + + @self.optimizer.params + def params(self): + return self.parameters() + + + def encode(self, x): + x = self.encoder(x) + x = self.fc_mu(x) + x = self.fc_var(x) + return x + + def decode(self, x): + x = self.fc_dec(x) + x = self.decoder(x) + return x + + def training_step(self, batch, batch_idx): + x, y = self.train_preprocess(batch) + node_features, edge_features = y + x = self(x) + node_features_hat = x['x'] + edge_features_hat = x['edge_attr'] + mu = x['mu'] + log_var = x['log_var'] + mincut_cut_loss = sum(value for key, value in x.items() if key.startswith('L_cut')) + mincut_ortho_loss = sum(value for key, value in x.items() if key.startswith('L_ortho')) + rec_loss_nodes, rec_loss_edges, KLD = self.compute_loss(node_features_hat, node_features, edge_features_hat, edge_features, mu, log_var) + + tot_loss = self.alpha * rec_loss_nodes + self.gamma * rec_loss_edges + self.beta * KLD + self.delta * (mincut_cut_loss + mincut_ortho_loss) + + loss = {"rec_loss_nodes": rec_loss_nodes, "rec_loss_edges": rec_loss_edges, "KL": KLD, + "MinCut cut loss": mincut_cut_loss, "MinCut orthogonality loss": mincut_ortho_loss, + "total_loss": tot_loss} + for name, v in loss.items(): + self.log( + f"train_{name}", + v, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return tot_loss + + def compute_loss(self, n_hat, n, e_hat, e, mu, log_var): + + rec_loss_nodes = self.reconstruction_loss(n_hat, n) + rec_loss_edges = self.reconstruction_loss(e_hat, e) + + KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) + + return rec_loss_nodes, rec_loss_edges, KLD + + def forward(self, x): + x = self.encode(x) + x = self.reparameterize(x) + x = self.decode(x) + return x diff --git a/deeplay/components/gnn/__init__.py b/deeplay/components/gnn/__init__.py index 1f0bd47c..f6b57a88 100644 --- a/deeplay/components/gnn/__init__.py +++ b/deeplay/components/gnn/__init__.py @@ -1,3 +1,5 @@ from .gcn import * from .mpn import * from .tpu import * +from .pooling import * +from .graphencdec import GraphEncoderBlock, GraphDecoderBlock, GraphEncoder, GraphDecoder \ No newline at end of file diff --git a/deeplay/components/gnn/gcn/__init__.py b/deeplay/components/gnn/gcn/__init__.py index 1271fefc..0d62ac12 100644 --- a/deeplay/components/gnn/gcn/__init__.py +++ b/deeplay/components/gnn/gcn/__init__.py @@ -1,2 +1,3 @@ from .gcn import GraphConvolutionalNeuralNetwork from .normalization import * +from .gcn_concat import GraphConvolutionalNeuralNetworkConcat \ No newline at end of file diff --git a/deeplay/components/gnn/gcn/gcn_concat.py b/deeplay/components/gnn/gcn/gcn_concat.py new file mode 100644 index 00000000..2b983213 --- /dev/null +++ b/deeplay/components/gnn/gcn/gcn_concat.py @@ -0,0 +1,184 @@ +from typing import List, Optional, Literal, Any, Sequence, Type, overload, Union + +from deeplay import DeeplayModule, Layer, LayerList + +from ..tpu import TransformPropagateUpdate +from deeplay.ops import Cat + +import torch +import torch.nn as nn + + +class GraphConvolutionalNeuralNetworkConcat(DeeplayModule): + in_features: Optional[int] + hidden_features: Sequence[Optional[int]] + out_features: int + blocks: LayerList[TransformPropagateUpdate] + + @property + def input(self): + """Return the input layer of the network. Equivalent to `.blocks[0]`.""" + return self.blocks[0] + + @property + def hidden(self): + """Return the hidden layers of the network. Equivalent to `.blocks[:-1]`""" + return self.blocks[:-1] + + @property + def output(self): + """Return the last layer of the network. Equivalent to `.blocks[-1]`.""" + return self.blocks[-1] + + @property + def transform(self) -> LayerList[Layer]: + """Return the layers of the network. Equivalent to `.blocks.layer`.""" + return self.blocks.transform + + @property + def propagate(self) -> LayerList[Layer]: + """Return the activations of the network. Equivalent to `.blocks.activation`.""" + return self.blocks.propagate + + @property + def update(self) -> LayerList[Layer]: + """Return the normalizations of the network. Equivalent to `.blocks.normalization`.""" + return self.blocks.update + + def __init__( + self, + in_features: int, + hidden_features: Sequence[int], + out_features: int, + out_activation: Union[Type[nn.Module], nn.Module, None] = None, + ): + super().__init__() + + self.in_features = in_features + self.hidden_features = hidden_features + self.out_features = out_features + + if in_features is None: + raise ValueError("in_features must be specified") + + if out_features is None: + raise ValueError("out_features must be specified") + + if in_features <= 0: + raise ValueError(f"in_features must be positive, got {in_features}") + + if out_features <= 0: + raise ValueError(f"out_features must be positive, got {out_features}") + + if any(h <= 0 for h in hidden_features): + raise ValueError( + f"all hidden_features must be positive, got {hidden_features}" + ) + + if out_activation is None: + out_activation = Layer(nn.Identity) + elif isinstance(out_activation, type) and issubclass(out_activation, nn.Module): + out_activation = Layer(out_activation) + + class Propagate(DeeplayModule): + def forward(self, x, A): + if A.is_sparse: + return torch.spmm(A, x) + elif (not A.is_sparse) & (A.size(0) == 2): + A = torch.sparse_coo_tensor( + A, + torch.ones(A.size(1)), + (x.size(0),) * 2, + device=A.device, + ) + return torch.spmm(A, x) + elif (not A.is_sparse) & len({A.size(0), A.size(1), x.size(0)}) == 1: + return A.type(x.dtype) @ x + else: + raise ValueError( + "Unsupported adjacency matrix format.", + "Ensure it is a pytorch sparse tensor, an edge index tensor, or a square dense tensor.", + "Consider updating the propagate layer to handle alternative formats.", + ) + + self.blocks = LayerList() + + for i, (c_in, c_out) in enumerate( + zip([in_features, *hidden_features], [*hidden_features, out_features]) + ): + transform = Layer(nn.Linear, c_in, c_out) + transform.set_input_map("x") + transform.set_output_map("x_prime") + + propagate = Layer(Propagate) + propagate.set_input_map("x_prime", "edge_index") + propagate.set_output_map("x_prime") + + update = Layer(nn.ReLU) if i < len(self.hidden_features) else out_activation + update.set_input_map("x_prime") + update.set_output_map("x_prime") + + block = TransformPropagateUpdate( + transform=transform, + propagate=propagate, + update=update, + order=["transform", "update", "propagate"] + ) + self.blocks.append(block) + + self.concat = Cat() + self.concat.set_input_map('x_prime', 'x') + self.concat.set_output_map('x') + + self.dense = Layer(nn.Linear, out_features*2, out_features) + self.dense.set_input_map('x') + self.dense.set_output_map('x') + + self.activate = Layer(nn.ReLU) + self.activate.set_input_map('x') + self.activate.set_output_map('x') + + def forward(self, x): + for block in self.blocks: + x = block(x) + + x = self.concat(x) + x = self.dense(x) + x = self.activate(x) + + return x + + @overload + def configure( + self, + /, + in_features: Optional[int] = None, + hidden_features: Optional[List[int]] = None, + out_features: Optional[int] = None, + out_activation: Union[Type[nn.Module], nn.Module, None] = None, + ) -> None: ... + + @overload + def configure( + self, + name: Literal["blocks"], + order: Optional[Sequence[str]] = None, + transform: Optional[Type[nn.Module]] = None, + propagate: Optional[Type[nn.Module]] = None, + update: Optional[Type[nn.Module]] = None, + **kwargs: Any, + ) -> None: ... + + @overload + def configure( + self, + name: Literal["blocks"], + index: Union[int, slice, List[Union[int, slice]]], + order: Optional[Sequence[str]] = None, + transform: Optional[Type[nn.Module]] = None, + propagate: Optional[Type[nn.Module]] = None, + update: Optional[Type[nn.Module]] = None, + **kwargs: Any, + ) -> None: ... + + configure = DeeplayModule.configure diff --git a/deeplay/components/gnn/gcn/normalization.py b/deeplay/components/gnn/gcn/normalization.py index 28607349..7d44226e 100644 --- a/deeplay/components/gnn/gcn/normalization.py +++ b/deeplay/components/gnn/gcn/normalization.py @@ -11,8 +11,11 @@ def add_self_loops(self, A, num_nodes): """ loop_index = torch.arange(num_nodes, device=A.device) loop_index = loop_index.unsqueeze(0).repeat(2, 1) - - A = torch.cat([A, loop_index], dim=1) + + if A.is_sparse: + A = torch.cat([A.indices(), loop_index], dim=1) + elif (not A.is_sparse) & (A.size(0) == 2): + A = torch.cat([A, loop_index], dim=1) return A diff --git a/deeplay/components/gnn/graphencdec.py b/deeplay/components/gnn/graphencdec.py new file mode 100644 index 00000000..e3c87a34 --- /dev/null +++ b/deeplay/components/gnn/graphencdec.py @@ -0,0 +1,601 @@ +from __future__ import annotations +from typing import Optional, Sequence, Type, Union +import warnings + +from deeplay import ( + DeeplayModule, + Layer, + LayerList, +) +from deeplay.components.gnn import MessagePassingNeuralNetworkGAUDI, GraphConvolutionalNeuralNetworkConcat +from deeplay.components.gnn.pooling import MinCutPooling, MinCutUpsampling +from deeplay.ops import Cat +from deeplay.components.gnn.pooling import GlobalGraphPooling, GlobalGraphUpsampling +# from deeplay.components.gnn.mpn import TransformOnlySenderNodes +from deeplay.components.mlp import MultiLayerPerceptron +from deeplay.ops import GetEdgeFeaturesNew +# from deeplay.components.gnn.mpn.propagation import Mean + +import torch.nn as nn + + +class GraphEncoder(DeeplayModule): + """ A Graph Encoder module that leverages multiple graph processing blocks to learn representations + from graph-structured data. This module supports graph convolution and pooling operations, enabling + effective encoding of graph information for downstream tasks. + + Parameters + ---------- + hidden_features: int + The number of hidden features in the hidden layers, both in the gcn and pooling, of the encoder. + num_blocks: int + The number of processing blocks in the encoder. + num_clusters: list[int] + The number of clusters the graph is pooled to in each processing block. + thresholds: list[float] + The threshold values for the connectivity in the clustering process. + + Configurables + ------------- + - hidden features (int): Number of features of the hidden layers. + - num_blocks (int): Number of processing blocks in the encoder. + - num_clusters list[int]: Number of clusters the graph is pooled to in each processing block. + - thresholds list[int]: The threshold values for the connectivity in the clustering process. + - poolings (template-like):A list of pooling layers to use. Default is using MinCut pooling for all layers, + except for the last, which is global pooling. + - save_intermediates (bool): Flag indicating whether to save intermediate adjacency matrices and other information, useful + when using it together with the GraphDecoder. Default is True. + + Constraints + ----------- + - input: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_nodes, node_features) + - edge_index: torch.Tensor of shape (2, num_edges) + - batch: torch.Tensor of shape (num_nodes) + - edge_attr: torch.Tensor of shape (num_edges, edge_features) + + Example + ---------- + >>> encoder = dl.GraphEncoder(hidden_features=96, num_blocks=3, num_clusters=[5, 3, 1], thresholds=[0.1, 0.2, None], save_intermediates=False).build() + >>> inp = {} + >>> inp["x"] = torch.randn(10, 16) + >>> inp['batch'] = torch.zeros(10, dtype=int) + >>> inp["edge_index"] = torch.randint(0, 10, (2, 20)) + >>> inp["edge_attr"] = torch.randn(20, 8) + >>> output = encoder(inp) + + + Return Values + ------------- + The forward method returns a mapping object with the updated node features, edge_index, edge_attributes, + and the cut and orthogonality losses from the MinCut pooling. + + """ + hidden_features: int + num_blocks: int + num_clusters: Sequence[int] + thresholds: Optional[Sequence[float]] + poolings: Optional[Sequence[nn.Module]] + save_intermediates: Optional[bool] + + def __init__( + self, + hidden_features: int, + num_blocks: int, + num_clusters: Sequence[int], + thresholds: Optional[Sequence[float]] = None, + poolings: Optional[Sequence[Union[Type[nn.Module], nn.Module]]] = None, + save_intermediates: Optional[bool] = True, + ): + super().__init__( + hidden_features = hidden_features, + num_blocks = num_blocks, + num_clusters = num_clusters, + thresholds = thresholds, + poolings = poolings, + save_intermediates = save_intermediates, + ) + + if not isinstance(hidden_features, int) or hidden_features <= 0: + raise ValueError(f"hidden_features must be a positive integer, got {hidden_features}") + + if poolings is None: + poolings = [MinCutPooling] * (num_blocks - 1) + [GlobalGraphPooling] + + assert len(poolings) == num_blocks, "Number of poolings should match num_blocks." + assert len(num_clusters) == num_blocks, "Lenght of number of clusters should match num_blocks." + + + self.message_passing = MessagePassingNeuralNetworkGAUDI( + hidden_features=[hidden_features], + out_features=hidden_features, + out_activation=nn.ReLU + ) + + # self.message_passing.transform = TransformOnlySenderNodes( + # combine=Cat(), + # layer=Layer(nn.LazyLinear, hidden_features), + # activation=nn.ReLU, + # ) + + # self.message_passing.transform.set_input_map("x", "edge_index", "input_edge_attr") + # self.message_passing.propagate = Mean() + # self.message_passing.propagate.set_input_map("x", "edge_index", "edge_attr") + + self.dense = Layer(nn.Linear, hidden_features, hidden_features) + self.dense.set_input_map('x') + self.dense.set_output_map('x') + + self.activate = Layer(nn.ReLU) + self.activate.set_input_map('x') + self.activate.set_output_map('x') + + + self.blocks = LayerList() + + for i in range(num_blocks): + pool = poolings[i] + + if save_intermediates == True: + edge_index_map = "edge_index" if i == 0 else f"edge_index_{i}" + select_output_map = f"s_{i}" + connect_output_map = f"edge_index_{i+1}" + batch_input_map = "batch" if i == 0 else f"batch_{i}" + batch_output_map = f"batch_{i+1}" + mincut_cut_loss_map = f"L_cut_{i}" + mincut_ortho_loss_map = f"L_ortho_{i}" + + block = GraphEncoderBlock( + in_features=hidden_features, + out_features=hidden_features, + num_clusters=num_clusters[i], + threshold=thresholds[i] if thresholds is not None else None, + pool=pool, + edge_index_map=edge_index_map, + select_output_map=select_output_map, + connect_output_map=connect_output_map, + batch_input_map=batch_input_map, + batch_output_map=batch_output_map, + mincut_cut_loss_map=mincut_cut_loss_map, + mincut_ortho_loss_map=mincut_ortho_loss_map, + ) + + else: + mincut_cut_loss_map = f"L_cut_{i}" + mincut_ortho_loss_map = f"L_ortho_{i}" + + block = GraphEncoderBlock( + in_features=hidden_features, + out_features=hidden_features, + num_clusters=num_clusters[i], + threshold=thresholds[i] if thresholds is not None else None, + pool=pool, + mincut_cut_loss_map=mincut_cut_loss_map, + mincut_ortho_loss_map=mincut_ortho_loss_map, + ) + + self.blocks.append(block) + + def forward(self, x): + x['input_edge_index'] = x['edge_index'] # Do this in a nicer way + x['input_edge_attr'] = x['edge_attr'] + x = self.message_passing(x) + x = self.dense(x) + x = self.activate(x) + for block in self.blocks: + x = block(x) + return x + + +class GraphDecoder(DeeplayModule): + """ + A Graph Decoder module that reconstructs graph structures from learned representations generated + by the GraphEncoder. This module aims to decode the latent graph features back into graph node + and edge attributes. + + Parameters + ---------- + hidden_features: int + The dimensionality of the hidden layers of the decoder. This should match the hidden + features from the corresponding GraphEncoder. + num_blocks: int + The number of processing blocks in the decoder. This should match the number of blocks + of the GraphEncoder. + output_node_dim: int + The dimensionality of the output node features. This should match the original dimensionallity + of the input node features of the GraphEncoder. + output_edge_dim: int + The dimensionality of the output edge features. This should match the original dimensionallity + of the input edge attributes of the GraphEncoder. + + Configurables + ------------- + - hidden features (int): Number of features of the hidden layers. + - num_blocks (int): Number of processing blocks in the decoder. + - output_node_dim (int): Number of dimensions of the output node features. + - output_edge_dim (int): Number of dimensions of the output edge attributes. + - upsamplings (template-like): A list of upsampling layers to use. Default is using MinCut upsampling + for all layers, except for the first, which is global upsampling. This should reflect the pooling + layers of the GraphEncoder. + + Constraints + ----------- + - input: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_nodes, node_features) + - edge_index: torch.Tensor of shape (2, num_edges) + - batch: torch.Tensor of shape (num_nodes) + + Example + ---------- + >>> encoder = dl.GraphEncoder(hidden_features=96, num_blocks=3, num_clusters=[20, 5, 1], thresholds=[0.1, 0.5, None], save_intermediates=False).build() + >>> inp = {} + >>> inp["x"] = torch.randn(10, 16) + >>> inp['batch'] = torch.zeros(10, dtype=int) + >>> inp["edge_index"] = torch.randint(0, 10, (2, 20)) + >>> inp["edge_attr"] = torch.randn(20, 8) + >>> encoder_output = encoder(inp) + >>> decoder = dl.GraphDecoder(hidden_features=96, num_blocks=3, output_node_dim=16, output_edge_dim=8).build() + >>> decoder_output = decoder(encoder_output) + + Return Values + ------------- + The forward method returns a mapping object with the reconstructed node features and edge attributes. + + """ + + hidden_features: int + num_blocks: int + output_node_dim: int + output_edge_dim: int + upsamplings: Optional[Sequence[nn.Module]] + + def __init__( + self, + hidden_features: int, + num_blocks: int, + output_node_dim: int, + output_edge_dim: int, + upsamplings: Optional[Sequence[Union[Type[nn.Module], nn.Module]]] = None, + ): + super().__init__( + hidden_features = hidden_features, + output_node_dim = output_node_dim, + output_edge_dim = output_edge_dim, + num_blocks = num_blocks, + upsamplings = upsamplings, + ) + + if not isinstance(hidden_features, int) or hidden_features <= 0: + raise ValueError(f"hidden_features must be a positive integer, got {hidden_features}") + + if upsamplings is None: + upsamplings = [GlobalGraphUpsampling] + [MinCutUpsampling] * (num_blocks - 1) + + assert len(upsamplings) == num_blocks, "Number of upsamplings should match num_blocks." + + self.blocks = LayerList() + + for i in range(num_blocks): + upsample = upsamplings[i] + edge_index_map = "edge_index" if i == num_blocks-1 else f"edge_index_{num_blocks-1-i}" + select_input_map = f"s_{num_blocks-1-i}" + connect_input_map = f"edge_index_{num_blocks-i}" + + block = GraphDecoderBlock( + in_features=hidden_features, + out_features=hidden_features, + upsample=upsample, + edge_index_map=edge_index_map, + select_input_map=select_input_map, + connect_input_map=connect_input_map, + ) + + self.blocks.append(block) + + self.dense = Layer(nn.Linear, hidden_features, hidden_features) + self.dense.set_input_map('x') + self.dense.set_output_map('x') + + self.activate = Layer(nn.ReLU) + self.activate.set_input_map('x') + self.activate.set_output_map('x') + + self.get_edge_attr = GetEdgeFeaturesNew() + self.get_edge_attr.set_input_map("x", "input_edge_index") + self.get_edge_attr.set_output_map("edge_attr_sender", "edge_attr_receiver") + + self.dense_sender = Layer(nn.Linear, hidden_features, hidden_features) + self.dense_sender.set_input_map('edge_attr_sender') + self.dense_sender.set_output_map('edge_attr_sender') + + self.activate_sender = Layer(nn.ReLU) + self.activate_sender.set_input_map('edge_attr_sender') + self.activate_sender.set_output_map('edge_attr_sender') + + self.dense_receiver = Layer(nn.Linear, hidden_features, hidden_features) + self.dense_receiver.set_input_map('edge_attr_receiver') + self.dense_receiver.set_output_map('edge_attr_receiver') + + self.activate_receiver = Layer(nn.ReLU) + self.activate_receiver.set_input_map('edge_attr_receiver') + self.activate_receiver.set_output_map('edge_attr_receiver') + + self.concat_edge_attr = Cat() + self.concat_edge_attr.set_input_map('edge_attr_sender', 'edge_attr_receiver') + self.concat_edge_attr.set_output_map('edge_attr') + + self.dense_edge_mlp_1 = Layer(nn.Linear, hidden_features * 2, hidden_features) + self.dense_edge_mlp_1.set_input_map('edge_attr') + self.dense_edge_mlp_1.set_output_map('edge_attr') + + self.activate_edge_mlp_1 = Layer(nn.ReLU) + self.activate_edge_mlp_1.set_input_map('edge_attr') + self.activate_edge_mlp_1.set_output_map('edge_attr') + + self.dense_edge_mlp_2 = Layer(nn.Linear, hidden_features, output_edge_dim) + self.dense_edge_mlp_2.set_input_map('edge_attr') + self.dense_edge_mlp_2.set_output_map('edge_attr') + + # # get the edge features: + # self.edge_mlp = MultiLayerPerceptron( + # in_features = hidden_features * 2, + # hidden_features = [hidden_features], + # out_features = output_edge_dim, + # out_activation = None, + # ) + # self.edge_mlp.set_input_map('edge_attr') + # self.edge_mlp.set_output_map('edge_attr') + + # get the node features: + self.node_mlp = MultiLayerPerceptron( + in_features = hidden_features, + hidden_features = [hidden_features, hidden_features], + out_features = output_node_dim, + out_activation = None, + ) + self.node_mlp.set_input_map('x') + self.node_mlp.set_output_map('x') + + def forward(self, x): + for block in self.blocks: + x = block(x) + + x = self.dense(x) + x = self.activate(x) + + x = self.get_edge_attr(x) + x = self.dense_sender(x) + x = self.activate_sender(x) + x = self.dense_receiver(x) + x = self.activate_receiver(x) + x = self.concat_edge_attr(x) + + x = self.dense_edge_mlp_1(x) + x = self.activate_edge_mlp_1(x) + + x = self.dense_edge_mlp_2(x) + + # x = self.edge_mlp(x) + + x = self.node_mlp(x) + return x + + +class GraphEncoderBlock(DeeplayModule): + """ + A Graph Encoder Block that processes graph data through a Graph Convolutional Neural Network (GCN) + and applies pooling operations to generate encoded representations of the graph structure. + This block is a fundamental component of the GraphEncoder, enabling hierarchical feature extraction. + + Parameters + ---------- + in_features: int + The number of input features for each node in the graph. + out_features: int + The number of output features for each node after processing. + + Configurables + ------------- + - in_features (int): The number of input features for each node in the graph. + - out_features (int): The number of output features for each node after processing. + - pool (template-like): The pooling operation to be used. Defaults to MinCutPooling. + - num_clusters (int): The number of clusters for MinCutPooling. Must be provided if using MinCutPooling. + - threshold (float): Threshold value for pooling operations. + - edge_index_map (str): The mapping for edge index inputs. Defaults to "edge_index". + - select_output_map (str): The mapping for the selection outputs from the pooling layer. Defaults to "s". + - connect_output_map (str): The mapping for connecting outputs to subsequent layers. Defaults to "edge_index_pool". + - batch_input_map (str): The mapping for batch input. Defaults to "batch". + - batch_output_map (str): The mapping for batch output. Defaults to "batch". + - mincut_cut_loss_map (str): The mapping for saving the mincut cut loss. Defaults to "L_cut". + - mincut_ortho_loss_map (str): The mapping for saving the mincut orthogonallity loss. Defaults to "L_ortho". + + Constraints + ----------- + - input: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_nodes, node_features) + - edge_index: torch.Tensor of shape (2, num_edges) + - batch: torch.Tensor of shape (num_nodes) + - edge_attr: torch.Tensor of shape (num_edges, edge_features) + + Example + ---------- + >>> block = dl.GraphEncoderBlock(in_features=16, out_features=16, num_clusters=5, threshold=0.1).build() + >>> inp = {} + >>> inp["x"] = torch.randn(10, 16) + >>> inp['batch'] = torch.zeros(10, dtype=int) + >>> inp["edge_index"] = torch.randint(0, 10, (2, 20)) + >>> inp["edge_attr"] = torch.randn(20, 8) + >>> output = block(inp) + """ + + in_features: Optional[int] + hidden_features: Sequence[Optional[int]] + out_features: int + pool: Optional[nn.Module] + num_clusters: Optional[int] + threshold: Optional[float] + edge_index_map: Optional[str] + select_output_map: Optional[str] + connect_output_map: Optional[str] + batch_input_map: Optional[str] + batch_output_map: Optional[str] + mincut_cut_loss_map: Optional[str] + mincut_ortho_loss_map: Optional[str] + + def __init__( + self, + in_features: int, + out_features: int, + pool: Optional[Union[Type[nn.Module], nn.Module, None]] = MinCutPooling, + num_clusters: Optional[int] = None, + threshold: Optional[float] = None, + edge_index_map: Optional[str] = "edge_index", + select_output_map: Optional[str] = "s", + connect_output_map: Optional[str] = "edge_index_pool", + batch_input_map: Optional[str] = "batch", + batch_output_map: Optional[str] = "batch", + mincut_cut_loss_map: Optional[str] = 'L_cut', + mincut_ortho_loss_map: Optional[str] = 'L_ortho', + ): + super().__init__( + in_features = in_features, + num_clusters = num_clusters, + threshold = threshold, + out_features = out_features, + pool = pool, + ) + self.edge_index_map = edge_index_map + self.connect_output_map = connect_output_map + + self.gcn = GraphConvolutionalNeuralNetworkConcat( + in_features=in_features, + hidden_features=[], + out_features=out_features, + out_activation=nn.ReLU, + ) + + self.gcn.propagate.set_input_map("x", edge_index_map) + + if pool == MinCutPooling: + if num_clusters is None: + raise ValueError("num_clusters must be provided for MinCutPooling") + + self.pool = pool(hidden_features=[out_features], num_clusters=num_clusters, threshold=threshold) + self.pool.mincut_loss.set_input_map(edge_index_map, select_output_map) + self.pool.mincut_loss.set_output_map(mincut_cut_loss_map, mincut_ortho_loss_map) + else: + self.pool = pool() + + self.pool.select.set_output_map(select_output_map) + + if hasattr(self.pool, 'reduce'): + self.pool.reduce.set_input_map('x', select_output_map) + if hasattr(self.pool, 'batch_compatible'): + self.pool.batch_compatible.set_input_map(select_output_map, batch_input_map) + self.pool.batch_compatible.set_output_map(select_output_map, batch_output_map) + if hasattr(self.pool, 'connect'): + self.pool.connect.set_input_map(edge_index_map, select_output_map) + self.pool.connect.set_output_map(connect_output_map) + if hasattr(self.pool, 'red_self_con') and self.pool.red_self_con is not None: + self.pool.red_self_con.set_input_map(connect_output_map) + self.pool.red_self_con.set_output_map(connect_output_map) + if hasattr(self.pool, 'apply_threshold') and self.pool.apply_threshold is not None: + self.pool.apply_threshold.set_input_map(connect_output_map) + self.pool.apply_threshold.set_output_map(connect_output_map) + if hasattr(self.pool, 'sparse'): + self.pool.sparse.set_input_map(connect_output_map) + self.pool.sparse.set_output_map(connect_output_map) + + def forward(self, x): + x = self.gcn(x) + x = self.pool(x) + return x + + +class GraphDecoderBlock(DeeplayModule): + """ + A Graph Decoder Block that upsamples a graph and applies a Graph Convolutional Neural Network (GCN). + This block is a fundamental component of the GraphDecoder, enabling the reconstruction of graph features + in a Graph Encoder Decoder model. + + Parameters + ---------- + in_features: int + The number of input features for each node in the graph. + out_features: int + The number of output features for each node after processing. + + Configurables + ------------- + - in_features (int): The number of input features for each node in the graph. + - out_features (int): The number of output features for each node after processing. + - upsample (template-like): The upsampling operation to be used. Defaults to MinCutUpsampling. + - edge_index_map (str): The mapping for edge index inputs. Defaults to "edge_index". + - select_input_map (str): The mapping for selection inputs for the upsampling layer. Defaults to "s". + - connect_input_map (str): The mapping for the connectivity for the upsampling layer. Defaults to "edge_index_pool". + - connect_output_map (str): The mapping for the connectivity outputs of the upsampling layer. Defaults to "-". + - batch_map (str): The mapping for batch inputs or outputs. Defaults to "batch". + + Example + ---------- + >>> encoderblock = dl.GraphEncoderBlock(in_features=16, out_features=16, num_clusters=5, threshold=0.2).build() + >>> inp = {} + >>> inp["x"] = torch.randn(10, 16) + >>> inp['batch'] = torch.zeros(10, dtype=int) + >>> inp["edge_index"] = torch.randint(0, 10, (2, 20)) + >>> inp["edge_attr"] = torch.randn(20, 8) + >>> encoderblock_output = encoderblock(inp) + >>> decoderblock = dl.GraphDecoderBlock(in_features=16, out_features=16).build() + >>> decoderblock_output = decoderblock(encoderblock_output) + + """ + in_features: int + out_features: int + upsample: Optional[nn.Module] + edge_index_map: Optional[str] + select_input_map: Optional[str] + connect_input_map: Optional[str] + connect_output_map: Optional[str] + batch_map: Optional[str] + + def __init__( + self, + in_features: int, + out_features: int, + upsample: Optional[Union[Type[nn.Module], nn.Module, None]] = MinCutUpsampling, + edge_index_map: Optional[str] = "edge_index", + select_input_map: Optional[str] = "s", + connect_input_map: Optional[str] = "edge_index_pool", + connect_output_map: Optional[str] = "-", + ): + super().__init__( + in_features = in_features, + out_features = out_features, + upsample = upsample, + edge_index_map=edge_index_map, + select_input_map=select_input_map, + connect_input_map=connect_input_map, + ) + + if upsample == MinCutUpsampling: + self.upsample = upsample() + self.upsample.upsample.set_input_map('x', connect_input_map, select_input_map) + self.upsample.upsample.set_output_map('x', connect_output_map) + + else: + self.upsample = upsample() + self.upsample.upsample.set_input_map('x', select_input_map) + + self.gcn = GraphConvolutionalNeuralNetworkConcat( + in_features=in_features, + hidden_features=[], + out_features=out_features, + out_activation=nn.ReLU, + ) + + self.gcn.propagate.set_input_map("x", edge_index_map) + + def forward(self, x): + x = self.upsample(x) + x = self.gcn(x) + return x \ No newline at end of file diff --git a/deeplay/components/gnn/mpn/__init__.py b/deeplay/components/gnn/mpn/__init__.py index be5bf416..bcee56ed 100644 --- a/deeplay/components/gnn/mpn/__init__.py +++ b/deeplay/components/gnn/mpn/__init__.py @@ -1,9 +1,11 @@ from .mpn import MessagePassingNeuralNetwork +from .mpn_gaudi import MessagePassingNeuralNetworkGAUDI from .rmpn import ResidualMessagePassingNeuralNetwork from .transformation import * from .propagation import Sum, WeightedSum, Mean, Max, Min, Prod from .update import * +from .get_edge_features import * from .cla import CombineLayerActivation from .ldw import LearnableDistancewWeighting diff --git a/deeplay/components/gnn/mpn/get_edge_features.py b/deeplay/components/gnn/mpn/get_edge_features.py new file mode 100644 index 00000000..05379ef8 --- /dev/null +++ b/deeplay/components/gnn/mpn/get_edge_features.py @@ -0,0 +1,15 @@ +from .cla import CombineLayerActivation + + +class GetEdgeFeatures(CombineLayerActivation): + """""" + + def get_forward_args(self, x): + """Get the node features of neighboring nodes for each edge. + - node features of sender nodes (x[edge_index[0]]) + - node features of receiver nodes (x[edge_index[1]]) + + edge_index denote the connectivity of the graph. + """ + x, edge_index = x + return x[edge_index[0]], x[edge_index[1]] diff --git a/deeplay/components/gnn/mpn/mpn_gaudi.py b/deeplay/components/gnn/mpn/mpn_gaudi.py new file mode 100644 index 00000000..6cdc9b02 --- /dev/null +++ b/deeplay/components/gnn/mpn/mpn_gaudi.py @@ -0,0 +1,150 @@ +from typing import List, Optional, Literal, Any, Sequence, Type, overload, Union + +from deeplay import DeeplayModule, Layer, LayerList +from deeplay.ops import Cat + +from ..tpu import TransformPropagateUpdate + +from .transformation import TransformOnlySenderNodes +from .propagation import Mean +from .update import Update + +import torch.nn as nn + + +class MessagePassingNeuralNetworkGAUDI(DeeplayModule): + hidden_features: Sequence[Optional[int]] + out_features: int + blocks: LayerList[TransformPropagateUpdate] + + @property + def input(self): + """Return the input layer of the network. Equivalent to `.blocks[0]`.""" + return self.blocks[0] + + @property + def hidden(self): + """Return the hidden layers of the network. Equivalent to `.blocks[:-1]`""" + return self.blocks[:-1] + + @property + def output(self): + """Return the last layer of the network. Equivalent to `.blocks[-1]`.""" + return self.blocks[-1] + + @property + def transform(self) -> LayerList[Layer]: + """Return the transform layers of the network. Equivalent to `.blocks.transform`.""" + return self.blocks.transform + + @property + def propagate(self) -> LayerList[Layer]: + """Return the propagate layers of the network. Equivalent to `.blocks.propagate`.""" + return self.blocks.propagate + + @property + def update(self) -> LayerList[Layer]: + """Return the update layers of the network. Equivalent to `.blocks.update`.""" + return self.blocks.update + + def __init__( + self, + hidden_features: Sequence[int], + out_features: int, + out_activation: Union[Type[nn.Module], nn.Module, None] = None, + ): + super().__init__() + + self.hidden_features = hidden_features + self.out_features = out_features + + if any(h <= 0 for h in hidden_features): + raise ValueError( + f"all hidden_channels must be positive, got {hidden_features}" + ) + + if out_features is None: + raise ValueError("out_features must be specified") + + if out_features <= 0: + raise ValueError( + f"Number of output features must be positive, got {out_features}" + ) + + if out_activation is None: + out_activation = Layer(nn.Identity) + elif isinstance(out_activation, type) and issubclass(out_activation, nn.Module): + out_activation = Layer(out_activation) + + self.blocks = LayerList() + for i, c_out in enumerate([*hidden_features, out_features]): + activation = ( + Layer(nn.ReLU) if i < len(hidden_features) - 1 else out_activation + ) + + transform = TransformOnlySenderNodes( + combine=Cat(), + layer=Layer(nn.LazyLinear, c_out), + activation=activation.new(), + ) + transform.set_input_map("x", "edge_index", "input_edge_attr") + transform.set_output_map("edge_attr") + + propagate = Mean() + propagate.set_input_map("x", "edge_index", "edge_attr") + propagate.set_output_map("aggregate") + + update = Update( + combine=Cat(), + layer=Layer(nn.LazyLinear, c_out), + activation=activation.new(), + ) + update.set_input_map("x", "aggregate") + update.set_output_map("x") + + block = TransformPropagateUpdate( + transform=transform, + propagate=propagate, + update=update, + ) + self.blocks.append(block) + + def forward(self, x): + for block in self.blocks: + x = block(x) + return x + + @overload + def configure( + self, + /, + in_features: Optional[int] = None, + hidden_features: Optional[List[int]] = None, + out_features: Optional[int] = None, + out_activation: Union[Type[nn.Module], nn.Module, None] = None, + ) -> None: ... + + @overload + def configure( + self, + name: Literal["blocks"], + order: Optional[Sequence[str]] = None, + transform: Optional[Type[nn.Module]] = None, + propagate: Optional[Type[nn.Module]] = None, + update: Optional[Type[nn.Module]] = None, + **kwargs: Any, + ) -> None: ... + + @overload + def configure( + self, + name: Literal["blocks"], + index: Union[int, slice, List[Union[int, slice]]], + order: Optional[Sequence[str]] = None, + transform: Optional[Type[nn.Module]] = None, + propagate: Optional[Type[nn.Module]] = None, + update: Optional[Type[nn.Module]] = None, + **kwargs: Any, + ) -> None: ... + + configure = DeeplayModule.configure diff --git a/deeplay/components/gnn/mpn/transformation.py b/deeplay/components/gnn/mpn/transformation.py index 7ee44740..7263915f 100644 --- a/deeplay/components/gnn/mpn/transformation.py +++ b/deeplay/components/gnn/mpn/transformation.py @@ -14,3 +14,16 @@ def get_forward_args(self, x): """ x, edge_index, edge_attr = x return x[edge_index[0]], x[edge_index[1]], edge_attr + +class TransformOnlySenderNodes(CombineLayerActivation): + """Transform module for MPN.""" + + def get_forward_args(self, x): + """Get the arguments for the Transform module. + An MPN Transform module takes the following arguments: + - node features of sender nodes (x[A[0]]) + - edge features (edgefeat) + A is the adjacency matrix of the graph. + """ + x, edge_index, edge_attr = x + return x[edge_index[0]], edge_attr diff --git a/deeplay/components/gnn/pooling/__init__.py b/deeplay/components/gnn/pooling/__init__.py new file mode 100644 index 00000000..96b025f7 --- /dev/null +++ b/deeplay/components/gnn/pooling/__init__.py @@ -0,0 +1,2 @@ +from .mincut import MinCutPooling, MinCutUpsampling +from .globalpool import GlobalGraphPooling, GlobalGraphUpsampling \ No newline at end of file diff --git a/deeplay/components/gnn/pooling/globalpool.py b/deeplay/components/gnn/pooling/globalpool.py new file mode 100644 index 00000000..55f06546 --- /dev/null +++ b/deeplay/components/gnn/pooling/globalpool.py @@ -0,0 +1,121 @@ +from typing import Optional + +import torch.nn as nn +import torch +from deeplay.module import DeeplayModule + +class GlobalGraphPooling(DeeplayModule): + """ + Pools all the nodes of the graph to a single cluster. + + (Inspired by MinCut-pooling ('Spectral Clustering with Graph Neural Networks for Graph Pooling'): + but with the assignment matrix S being deterministic (all nodes are pooled into one cluster)) + + Constraints + ----------- + - input: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_nodes, node_features) + + - output: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_clusters, node_features) + - s: torch.Tensor of shape (num_nodes, num_clusters) + + Examples + -------- + >>> global_pool = GlobalGraphPooling().build() + >>> inp = {} + >>> inp["x"] = torch.randn(3, 2) + >>> inp["batch"] = torch.zeros(3, dtype=int) + >>> inp["edge_index"] = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + >>> out = global_pool(inp) + """ + + def __init__( + self, + ): + super().__init__() + + class Select(DeeplayModule): + def forward(self, x): + return torch.ones((x.shape[0], 1)) + + class ClusterMatrixForBatch(DeeplayModule): + def forward(self, S, B): + unique_graphs = torch.unique(B) + num_graphs = len(unique_graphs) + + S_ = torch.zeros(S.shape[0] * num_graphs) + + row_indices = torch.arange(S.shape[0]) + col_indices = B + + S_[row_indices * num_graphs + col_indices] = S.view(-1) + B_ = torch.arange(num_graphs) + + return S_.reshape([S.shape[0], -1]), B_ + + + class Reduce(DeeplayModule): + def forward(self, x, s): + return torch.matmul(s.transpose(-2,-1), x) + + self.select = Select() + self.select.set_input_map('x') + self.select.set_output_map('s') + + self.batch_compatible = ClusterMatrixForBatch() + self.batch_compatible.set_input_map('s', 'batch') + self.batch_compatible.set_output_map('s', 'batch') + + self.reduce = Reduce() + self.reduce.set_input_map('x', 's') + self.reduce.set_output_map('x') + + def forward(self, x): + x = self.select(x) + x = self.batch_compatible(x) + x = self.reduce(x) + return (x) + + +class GlobalGraphUpsampling(DeeplayModule): + """ + Reverse of GlobalGraphPooling. + Only upsampling the node features. + + Constraints + ----------- + - input: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_clusters, node_features) + - s: torch.Tensor of shape (num_nodes, num_clusters) + + - output: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_nodes, node_features) + + Examples + -------- + >>> global_upsampling = GlobalGraphUpsampling() + >>> global_upsampling = global_upsampling.build() + + >>> inp = {} + >>> inp["x"] = torch.randn(1, 2) + >>> inp["s"] = torch.ones((3, 1)) + >>> out = global_upsampling(inp) + """ + + def __init__( + self, + ): + super().__init__() + + class Upsample(DeeplayModule): + def forward(self, x, s): + return torch.matmul(s, x) + + self.upsample = Upsample() + self.upsample.set_input_map('x', 's') + self.upsample.set_output_map('x') + + def forward(self, x): + x = self.upsample(x) + return x \ No newline at end of file diff --git a/deeplay/components/gnn/pooling/mincut.py b/deeplay/components/gnn/pooling/mincut.py new file mode 100644 index 00000000..ce307d4e --- /dev/null +++ b/deeplay/components/gnn/pooling/mincut.py @@ -0,0 +1,319 @@ +from typing import Sequence, Optional + +from deeplay import DeeplayModule + +from deeplay.components.mlp import MultiLayerPerceptron + +import torch +import torch.nn as nn + +class MinCutPooling(DeeplayModule): + """ + MinCut graph pooling as described in 'Spectral Clustering with Graph Neural Networks for Graph Pooling'. + + Parameters + ---------- + num_clusters: int + The number of clusters to which each graph is pooled. + hidden_features: Sequence[int] + The number of hidden features for the Multi-Layer Perceptron (MLP) used for selecting clusters for the pooling. + + Configurables + ------------- + - num_clusters (int): The number of clusters to which each graph is pooled. + - hidden_features (list[int]): The number of hidden features for the Multi-Layer Perceptron (MLP) used for selecting clusters for the pooling. + - reduce_self_connection (bool): Whether to reduce self-connections in the adjacency matrix. Defaults to True. + - threshold (float): A threshold value to apply to the adjacency matrix to binarize the edges. If None, no threshold is applied. Default is None. + + Constraints + ----------- + - input: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_nodes, node_features) + - edge_index: torch.Tensor of shape (2, num_edges) + - batch: torch.Tensor of shape (num_nodes) + + Example + ---------- + >>> MinCut = dl.components.gnn.pooling.MinCutPooling(hidden_features = [8], num_clusters = 5, reduce_self_connection = True, threshold = 0.25).build() + >>> inp = {} + >>> inp["x"] = torch.randn(10, 16) + >>> inp['batch'] = torch.zeros(10, dtype=int) + >>> inp["edge_index"] = torch.randint(0, 10, (2, 20)) + >>> output = MinCut(inp) + + """ + + num_clusters: int + hidden_features: Sequence[int] + reduce_self_connection: Optional[bool] + threshold: Optional[float] + + def __init__( + self, + num_clusters: int, + hidden_features: Sequence[int], + reduce_self_connection: Optional[bool] = True, + threshold: Optional[float] = None, + ): + super().__init__() + + self.num_clusters = num_clusters + self.reduce_self_connection = reduce_self_connection + self.threshold = threshold + + class ClusterMatrixForBatch(DeeplayModule): + def forward(self, S, B): + + unique_graphs = torch.unique(B) + num_graphs = len(unique_graphs) + + S_ = torch.zeros(S.shape[0] * S.shape[1] * num_graphs) + + row_indices = torch.arange(S.shape[0]).repeat_interleave(S.shape[1]) + col_indices = B.repeat_interleave(S.shape[1]) * S.shape[1] + torch.arange(S.shape[1]).repeat(S.shape[0]) + + S_[row_indices * (S.shape[1] * num_graphs) + col_indices] = S.view(-1) + + B_ = torch.arange(num_graphs).repeat_interleave(S.shape[1]) + + return S_.reshape([S.shape[0], -1]), B_ + + class Reduce(DeeplayModule): + def forward(self, x, s): + return torch.matmul(s.transpose(-2,-1), x) + + class Connect(DeeplayModule): + def forward(self, A, s): + if A.is_sparse: + return torch.spmm(s.transpose(-2,-1), torch.spmm(A, s)) + elif (not A.is_sparse) & (A.size(0) == 2): + A = torch.sparse_coo_tensor( + A, + torch.ones(A.size(1)), + (s.size(0),) * 2, + device=A.device, + ) + return torch.spmm(s.transpose(-2,-1), torch.spmm(A, s)) + elif (not A.is_sparse) & len({A.size(0), A.size(1), s.size(0)}) == 1: + return s.transpose(-2,-1) @ A.type(s.dtype) @ s + else: + raise ValueError( + "Unsupported adjacency matrix format.", + "Ensure it is a pytorch sparse tensor, an edge index tensor, or a square dense tensor.", + "Consider updating the propagate layer to handle alternative formats.", + ) + + class ReduceSelfConnection(DeeplayModule): + def __init__( + self, + eps: Optional[float] = 1e-15, + ): + super().__init__() + self.eps = eps + + def forward(self, A): + ind = torch.arange(A.shape[0]) + A[ind, ind] = 0 + D = torch.einsum('jk->j', A) + D_inv_sq = torch.pow(D, -0.5) + D_inv_sq = torch.where(torch.isinf(D_inv_sq), torch.tensor(0.0), D_inv_sq) + D_inv_sq = torch.diag(D_inv_sq) + + A = D_inv_sq @ A @ D_inv_sq + return A + + class MinCutLoss(DeeplayModule): + def __init__( + self, + eps: Optional[float] = 1e-15, + ): + super().__init__() + self.eps = eps + + def forward(self, A, S): + n_nodes = S.size(0) # number of nodes + n_clusters = S.size(1) # number of clusters in total (= number of clusters per graph * num graphs) + + if A.is_sparse: + degree = torch.sum(A, dim=0) + elif (not A.is_sparse) & (A.size(0) == 2): + A = torch.sparse_coo_tensor( + A, + torch.ones(A.size(1)), + (n_nodes,) * 2, + device=A.device, + ) + degree = torch.sum(A, dim=0) + elif (not A.is_sparse) & len({A.size(0), A.size(1)}) == 1: + degree = torch.sum(A, dim=0) + else: + raise ValueError( + "Unsupported adjacency matrix format.", + "Ensure it is a pytorch sparse tensor, an edge index tensor, or a square dense tensor.", + "Consider updating the propagate layer to handle alternative formats.", + ) + + eps = torch.sparse_coo_tensor( + indices=torch.arange(n_nodes).repeat(2, 1), + values=torch.zeros(n_nodes) + self.eps, + size=(n_nodes, n_nodes), + ) + + D = torch.eye(n_nodes) * degree + eps + + # cut loss: + L_cut = - torch.trace(torch.matmul(S.transpose(-2,-1), torch.matmul(A, S))) / (torch.trace(torch.matmul(S.transpose(-2,-1), torch.matmul(D, S)))) + + # orthogonality loss: + L_ortho = torch.linalg.norm( + (torch.matmul(S.transpose(-2,-1), S) / torch.linalg.norm(torch.matmul(S.transpose(-2,-1), S), ord = 'fro')) + - (torch.eye(n_clusters) / torch.sqrt(torch.tensor(n_clusters))), + ord = 'fro') + + + return L_cut, L_ortho + + class ApplyThreshold(DeeplayModule): + def __init__(self, threshold: float): + super().__init__() + self.threshold = threshold + + def forward(self, A): + return torch.where(A >= threshold, 1.0, 0.0) + + + class SparseEdgeIndex(DeeplayModule): + """ output edge index as a sparse tensor """ + def forward(self, A): + if A.is_sparse: + edge_index = A + return edge_index + else: + edge_index = A.to_sparse() + return edge_index + + + # select: S = MLP(X) + self.select = MultiLayerPerceptron( + in_features=None, + hidden_features=hidden_features, + out_features=num_clusters, + out_activation=nn.Softmax(dim=1)) + self.select.set_input_map("x") + self.select.set_output_map('s') + + # make S compatible with batches: + self.batch_compatible = ClusterMatrixForBatch() + self.batch_compatible.set_input_map("s", "batch") + self.batch_compatible.set_output_map("s", "batch") + + # mincut loss + self.mincut_loss = MinCutLoss() + self.mincut_loss.set_input_map('edge_index', 's') + self.mincut_loss.set_output_map('L_cut', 'L_ortho') + + # reduce: X' = S^T * X + self.reduce = Reduce() + self.reduce.set_input_map("x", 's') + self.reduce.set_output_map("x") + + # connect: A' = S^T * A * S + self.connect = Connect() + self.connect.set_input_map('edge_index', 's') + self.connect.set_output_map("edge_index") + + # reduce self connection + self.red_self_con = None + if reduce_self_connection: + self.red_self_con = ReduceSelfConnection(self.num_clusters) + self.red_self_con.set_input_map('edge_index') + self.red_self_con.set_output_map('edge_index') + + # apply threshold to A + self.apply_threshold = None + if threshold is not None: + self.apply_threshold = ApplyThreshold(self.threshold) + self.apply_threshold.set_input_map('edge_index') + self.apply_threshold.set_output_map('edge_index') + + # # make A sparse + self.sparse = SparseEdgeIndex() + self.sparse.set_input_map('edge_index') + self.sparse.set_output_map('edge_index') + + + def forward(self, x): + x = self.select(x) + x = self.batch_compatible(x) + x = self.mincut_loss(x) + x = self.reduce(x) + x = self.connect(x) + + if self.red_self_con is not None: + x = self.red_self_con(x) + + if self.apply_threshold is not None: + x = self.apply_threshold(x) + + x = self.sparse(x) + + return x + + +class MinCutUpsampling(DeeplayModule): + """ + Reverse of MinCutPooling as described in 'Spectral Clustering with Graph Neural Networks for Graph Pooling'. + + Constraints + ----------- + - input: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_clusters, node_features). + - edge_index_pool: torch.Tensor of shape (2, num_edges). + - batch: torch.Tensor of shape (num_clusters). + - s: torch.Tensor of shape (num_nodes, num_clusters) + + Example + ---------- + >>> mincut_upsample = MinCutUpsampling().build() + >>> inp = {} + >>> inp["x"] = torch.randn(2, 1) + >>> inp["batch"] = torch.zeros(2, dtype=int) + >>> inp['s'] = torch.tensor([[1.0, 0], [0, 1.0], [1.0, 0]]) + >>> inp["edge_index_pool"] = torch.tensor([[0, 1], [1, 0]]) + >>> out = mincut_upsample(inp) + + """ + + def __init__( + self, + ): + super().__init__() + + class Upsample(DeeplayModule): + def forward(self, x_pool, a_pool, s): + x = torch.matmul(s, x_pool) + + if a_pool.is_sparse: + a = torch.spmm(s, torch.spmm(a_pool, s.T)) + elif (not a_pool.is_sparse) & (a_pool.size(0) == 2): + a_pool = torch.sparse_coo_tensor( + a_pool, + torch.ones(a_pool.size(1)), + ((s.T).size(0),) * 2, + device=a_pool.device, + ) + a = torch.spmm(s, torch.spmm(a_pool, s.T)) + elif (not a_pool.is_sparse) & len({a_pool.size(0), a_pool.size(1), (s.T).size(0)}) == 1: + a = s @ a_pool.type(s.dtype) @ s.T + + return x, a + + self.upsample = Upsample() + self.upsample.set_input_map('x', 'edge_index_pool', 's') + self.upsample.set_output_map('x', 'edge_index') + + def forward(self, x): + x = self.upsample(x) + return x + + \ No newline at end of file diff --git a/deeplay/ops/__init__.py b/deeplay/ops/__init__.py index 7715eb2c..22b5e923 100644 --- a/deeplay/ops/__init__.py +++ b/deeplay/ops/__init__.py @@ -2,3 +2,4 @@ from .logs import FromLogs from .attention import * from .merge import * +from .get_edge_features import * diff --git a/deeplay/ops/get_edge_features.py b/deeplay/ops/get_edge_features.py new file mode 100644 index 00000000..a5b8e7aa --- /dev/null +++ b/deeplay/ops/get_edge_features.py @@ -0,0 +1,14 @@ + +from deeplay import DeeplayModule + +class GetEdgeFeaturesNew(DeeplayModule): + """""" + + def forward(self, x, edge_index): + """Get the node features of neighboring nodes for each edge. + - node features of sender nodes (x[edge_index[0]]) + - node features of receiver nodes (x[edge_index[1]]) + + edge_index denote the connectivity of the graph. + """ + return x[edge_index[0]], x[edge_index[1]] diff --git a/deeplay/tests/test_gnn.py b/deeplay/tests/test_gnn.py index 91862992..f7ef1eb7 100644 --- a/deeplay/tests/test_gnn.py +++ b/deeplay/tests/test_gnn.py @@ -22,6 +22,14 @@ Max, Layer, GlobalMeanPool, + GlobalGraphPooling, + GlobalGraphUpsampling, + MinCutPooling, + MinCutUpsampling, + GraphEncoder, + GraphDecoder, + GraphEncoderBlock, + GraphDecoderBlock, ) import itertools @@ -729,3 +737,119 @@ def test_gtoempm_defaults(self): out = model(inp) self.assertEqual(out.shape, (20, 1)) + + +class TestComponentPool(unittest.TestCase): + def test_global_pool(self): + global_pool = GlobalGraphPooling() + global_pool = global_pool.build() + + inp = {} + inp["x"] = torch.randn(3, 2) + inp["batch"] = torch.zeros(3, dtype=int) + inp["edge_index"] = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + out = global_pool(inp) + self.assertEqual(out["x"].shape, (1, 2)) + + def test_global_upsampling(self): + global_upsampling = GlobalGraphUpsampling() + global_upsampling = global_upsampling.build() + + inp = {} + inp["x"] = torch.randn(1, 2) + inp["s"] = torch.ones((3, 1)) + out = global_upsampling(inp) + self.assertEqual(out["x"].shape, (3, 2)) + + def test_mincut_pool(self): + mincut = MinCutPooling(num_clusters = 2, hidden_features = [5]) + mincut = mincut.build() + + inp = {} + inp["x"] = torch.randn(3, 1) + inp["batch"] = torch.zeros(3, dtype=int) + inp["edge_index"] = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + out = mincut(inp) + + self.assertEqual(out["x"].shape, (2, 1)) + self.assertEqual(out['edge_index'].shape, (2,2)) + self.assertEqual(out["s"].shape, (3, 2)) + self.assertTrue((torch.sum(out['s'], axis=1) - torch.tensor([1., 1., 1.])).sum() < 1e-5) + + def test_mincut_upsample(self): + mincut_upsample = MinCutUpsampling() + mincut_upsample = mincut_upsample.build() + + inp = {} + inp["x"] = torch.randn(2, 1) + inp["batch"] = torch.zeros(2, dtype=int) + inp['s'] = torch.tensor([[1.0, 0], [0, 1.0], [1.0, 0]]) + inp["edge_index_pool"] = torch.tensor([[0, 1], [1, 0]]) + out = mincut_upsample(inp) + + self.assertEqual(out["x"].shape, (3, 1)) + + +class TestComponentsGraphEncoderDecoder(unittest.TestCase): + def test_graph_encoder_block(self): + encoder_block = GraphEncoderBlock(in_features=1, out_features=4, num_clusters=2) + encoder_block = encoder_block.build() + + inp = {} + inp["x"] = torch.randn(3, 1) + inp["batch"] = torch.zeros(3, dtype=int) + inp["edge_index"] = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + out = encoder_block(inp) + + self.assertEqual(out["x"].shape, (2, 4)) + self.assertEqual(out["s"].shape, (3, 2)) + self.assertEqual(out["edge_index_pool"].shape, (2, 2)) + + def test_graph_decoder_block(self): + decoder_block = GraphDecoderBlock(in_features=1, out_features=4) + decoder_block = decoder_block.build() + + inp = {} + inp["x"] = torch.randn(2, 1) + inp["batch"] = torch.zeros(2, dtype=int) + inp["edge_index_pool"] = torch.tensor([[0, 1], [1, 0]]) + inp["edge_index"] = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + inp['s'] = torch.tensor([[1.0, 0], [0, 1.0], [1.0, 0]]) + out = decoder_block(inp) + + self.assertEqual(out["x"].shape, (3, 4)) + self.assertTrue(torch.all(inp["edge_index"] == out["edge_index"])) + + def test_graph_encoder(self): + graph_encoder = GraphEncoder(hidden_features=2, num_blocks=3, num_clusters=[3,2,1]) + graph_encoder = graph_encoder.build() + + self.assertEqual(len(graph_encoder.blocks), 3) + + inp = {} + inp["x"] = torch.randn(4, 2) + inp["batch"] = torch.zeros(4, dtype=int) + inp["edge_index"] = torch.tensor([[0, 1, 1, 2, 1, 3], [1, 0, 2, 1, 3, 1]]) + inp["edge_attr"] = torch.randn(6, 1) + out = graph_encoder(inp) + + self.assertEqual(out["x"].shape, (1, 2)) + self.assertEqual(out["s_1"].shape, (3, 2)) + + def test_graph_decoder(self): + graph_decoder = GraphDecoder(hidden_features=2, num_blocks=2, output_node_dim=2, output_edge_dim=1) + graph_decoder = graph_decoder.build() + + self.assertEqual(len(graph_decoder.blocks), 2) + + inp = {} + inp["x"] = torch.randn(1, 2) + inp["batch"] = torch.zeros(1, dtype=int) + inp["edge_index_1"] = torch.tensor([[0, 1], [1, 0]]) + inp["edge_index"] = torch.tensor([[0, 1, 1, 2, 1, 3], [1, 0, 2, 1, 3, 1]]) + inp['s_1'] = torch.ones((2,1)) + inp['s_0'] = torch.tensor([[1.0, 0], [0, 1.0], [0, 1.0], [1.0, 0]]) + out = graph_decoder(inp) + + self.assertEqual(out["x"].shape, (4, 2)) + self.assertEqual(out["edge_attr"].shape, (6, 1)) diff --git a/deeplay/tests/test_testing.ipynb b/deeplay/tests/test_testing.ipynb new file mode 100644 index 00000000..42156bbe --- /dev/null +++ b/deeplay/tests/test_testing.ipynb @@ -0,0 +1,1043 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "c:\\Users/xgrmir/Documents/Deeplay/deeplay\\deeplay\\__init__.py\n" + ] + } + ], + "source": [ + "import sys\n", + "\n", + "sys.path.insert(0, '/Users/xgrmir/Documents/Deeplay/deeplay')\n", + "\n", + "import deeplay as dl\n", + "print(dl.__file__)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from deeplay import (\n", + " GraphConvolutionalNeuralNetwork,\n", + " GraphToGlobalMPM,\n", + " GraphToNodeMPM,\n", + " GraphToEdgeMPM,\n", + " GraphToEdgeMAGIK,\n", + " MessagePassingNeuralNetwork,\n", + " ResidualMessagePassingNeuralNetwork,\n", + " MultiLayerPerceptron,\n", + " dense_laplacian_normalization,\n", + " Sum,\n", + " WeightedSum,\n", + " Mean,\n", + " Prod,\n", + " Min,\n", + " Max,\n", + " Layer,\n", + " GlobalMeanPool,\n", + " GlobalGraphPooling,\n", + " GlobalGraphUpsampling,\n", + " MinCutPooling,\n", + " MinCutUpsampling,\n", + " GraphEncoder,\n", + " GraphDecoder,\n", + " GraphEncoderBlock,\n", + " GraphDecoderBlock,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Global pool" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "global_pool = GlobalGraphPooling()\n", + "global_pool = global_pool.build()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "inp = {}\n", + "inp[\"x\"] = torch.randn(3, 2)\n", + "inp[\"batch\"] = torch.zeros(3, dtype=int)\n", + "inp[\"edge_index\"] = torch.tensor([[0, 1, 1, 2, 1], [1, 0, 2, 1, 0]])\n", + "out = global_pool(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'x': tensor([[ 1.6337, -1.7212]]),\n", + " 'batch': tensor([0]),\n", + " 'edge_index': tensor([[0, 1, 1, 2, 1],\n", + " [1, 0, 2, 1, 0]]),\n", + " 's': tensor([[1.],\n", + " [1.],\n", + " [1.]])}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 2])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['x'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3, 1])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['s'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "global_upsampling = GlobalGraphUpsampling()\n", + "global_upsampling = global_upsampling.build()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "inp = {}\n", + "inp[\"x\"] = torch.randn(1, 2)\n", + "inp[\"s\"] = torch.ones((3,1))\n", + "out = global_upsampling(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3, 2])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['x'].shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Mincut" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "mincut = MinCutPooling(num_clusters = 2, hidden_features = [5])\n", + "mincut = mincut.build()" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "MinCutPooling(\n", + " (select): MultiLayerPerceptron(\n", + " (blocks): LayerList(\n", + " (0): LinearBlock(\n", + " (layer): LazyLinear(in_features=0, out_features=5, bias=True)\n", + " (activation): ReLU()\n", + " )\n", + " (1): LinearBlock(\n", + " (layer): Linear(in_features=5, out_features=2, bias=True)\n", + " (activation): Softmax(dim=1)\n", + " )\n", + " )\n", + " )\n", + " (batch_compatible): ClusterMatrixForBatch()\n", + " (mincut_loss): MinCutLoss()\n", + " (reduce): Reduce()\n", + " (connect): Connect()\n", + " (red_self_con): ReduceSelfConnection()\n", + " (sparse): SparseEdgeIndex()\n", + ")" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mincut" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [], + "source": [ + "inp = {}\n", + "inp[\"x\"] = torch.randn(3, 1)\n", + "inp[\"batch\"] = torch.zeros(3, dtype=int)\n", + "inp[\"edge_index\"] = torch.tensor([[0, 1, 1, 2, 1], [1, 0, 2, 1, 0]])\n", + "out = mincut(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'x': tensor([[-1.4034],\n", + " [-1.7131]], grad_fn=),\n", + " 'batch': tensor([0, 0]),\n", + " 'edge_index': tensor(indices=tensor([[0, 1],\n", + " [1, 0]]),\n", + " values=tensor([0.1269, 0.1285]),\n", + " size=(2, 2), nnz=2, layout=torch.sparse_coo, grad_fn=),\n", + " 's': tensor([[0.4595, 0.5405],\n", + " [0.4439, 0.5561],\n", + " [0.4480, 0.5520]], grad_fn=),\n", + " 'L_cut': tensor(-1.0003, grad_fn=),\n", + " 'L_ortho': tensor(0.7652, grad_fn=)}" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3, 2])" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['s'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([1., 1., 1.], grad_fn=)" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.sum(out['s'], axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(indices=tensor([[0, 1],\n", + " [1, 0]]),\n", + " values=tensor([0.1209, 0.1388]),\n", + " size=(2, 2), nnz=2, layout=torch.sparse_coo, grad_fn=)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['edge_index']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Add test on different configurations with the layers in MinCut?" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "mincut_upsample = MinCutUpsampling()\n", + "mincut_upsample = mincut_upsample.build()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "inp = {}\n", + "inp[\"x\"] = torch.randn(2, 1)\n", + "inp[\"batch\"] = torch.zeros(2, dtype=int)\n", + "inp['s'] = torch.tensor([\n", + " [1.0, 0],\n", + " [0, 1.0],\n", + " [1.0, 0]\n", + "])\n", + "inp[\"edge_index_pool\"] = torch.tensor([[0, 1], [1, 0]])\n", + "out = mincut_upsample(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'x': tensor([[1.3595],\n", + " [0.2606],\n", + " [1.3595]]),\n", + " 'batch': tensor([0, 0]),\n", + " 's': tensor([[1., 0.],\n", + " [0., 1.],\n", + " [1., 0.]]),\n", + " 'edge_index_pool': tensor([[0, 1],\n", + " [1, 0]]),\n", + " 'edge_index': tensor([[0., 1., 0.],\n", + " [1., 0., 1.],\n", + " [0., 1., 0.]])}" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0., 1., 0.],\n", + " [1., 0., 1.],\n", + " [0., 1., 0.]])" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['edge_index']" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "a_pool = torch.tensor([[0, 1], [1, 0]])" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "s = torch.tensor([\n", + " [1, 0],\n", + " [0, 1],\n", + " [0.8, 0.2]\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "a_pool = torch.sparse_coo_tensor(\n", + " a_pool,\n", + " torch.ones(a_pool.size(1)),\n", + " ((s.T).size(0),) * 2,\n", + " device=a_pool.device,\n", + " )\n", + "a = torch.spmm(a_pool, s.T)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(indices=tensor([[0, 1],\n", + " [1, 0]]),\n", + " values=tensor([1., 1.]),\n", + " size=(2, 2), nnz=2, layout=torch.sparse_coo)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a_pool" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0.0000, 1.0000, 0.2000],\n", + " [1.0000, 0.0000, 0.8000]])" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1.0000, 0.0000, 0.8000],\n", + " [0.0000, 1.0000, 0.2000]])" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "s.T" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "mincut = MinCutPooling(num_clusters=2, hidden_features=[5], reduce_self_connection=False)\n", + "mincut = mincut.build()" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "inp = {}\n", + "inp[\"x\"] = torch.randn(3, 1)\n", + "inp[\"batch\"] = torch.zeros(3, dtype=int)\n", + "inp[\"edge_index\"] = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n", + "out = mincut(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'x': tensor([[0.9204],\n", + " [1.3934]], grad_fn=),\n", + " 'batch': tensor([0, 0]),\n", + " 'edge_index': tensor(indices=tensor([[0, 0, 1, 1],\n", + " [0, 1, 0, 1]]),\n", + " values=tensor([0.6270, 0.9567, 0.9567, 1.4596]),\n", + " size=(2, 2), nnz=4, layout=torch.sparse_coo, grad_fn=),\n", + " 's': tensor([[0.3918, 0.6082],\n", + " [0.3950, 0.6050],\n", + " [0.4019, 0.5981]], grad_fn=),\n", + " 'L_cut': tensor(-0.9999, grad_fn=),\n", + " 'L_ortho': tensor(0.7653, grad_fn=)}" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Graph Encoder and decoder" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### blocks" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "encoder_block = GraphEncoderBlock(in_features=1, out_features=4, num_clusters=2)\n", + "encoder_block = encoder_block.build()" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "GraphEncoderBlock(\n", + " (gcn): GraphConvolutionalNeuralNetwork(\n", + " (normalize): sparse_laplacian_normalization()\n", + " (blocks): LayerList(\n", + " (0): TransformPropagateUpdate(\n", + " (transform): Linear(in_features=1, out_features=4, bias=True)\n", + " (propagate): Propagate()\n", + " (update): ReLU()\n", + " )\n", + " )\n", + " )\n", + " (pool): MinCutPooling(\n", + " (select): MultiLayerPerceptron(\n", + " (blocks): LayerList(\n", + " (0): LinearBlock(\n", + " (layer): LazyLinear(in_features=0, out_features=4, bias=True)\n", + " (activation): ReLU()\n", + " )\n", + " (1): LinearBlock(\n", + " (layer): Linear(in_features=4, out_features=2, bias=True)\n", + " (activation): Softmax(dim=1)\n", + " )\n", + " )\n", + " )\n", + " (batch_compatible): ClusterMatrixForBatch()\n", + " (mincut_loss): MinCutLoss()\n", + " (reduce): Reduce()\n", + " (connect): Connect()\n", + " (red_self_con): ReduceSelfConnection()\n", + " (sparse): SparseEdgeIndex()\n", + " )\n", + ")" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "encoder_block" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "inp = {}\n", + "inp[\"x\"] = torch.randn(3, 1)\n", + "inp[\"batch\"] = torch.zeros(3, dtype=int)\n", + "inp[\"edge_index\"] = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n", + "out = encoder_block(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'x': tensor([[0.4139, 0.0000, 1.0564, 1.0303],\n", + " [0.1099, 0.0000, 0.2809, 0.2738]], grad_fn=),\n", + " 'batch': tensor([0, 0]),\n", + " 'edge_index': tensor([[0, 1, 1, 2],\n", + " [1, 0, 2, 1]]),\n", + " 'laplacian': tensor(indices=tensor([[0, 0, 1, 1, 1, 2, 2],\n", + " [0, 1, 0, 1, 2, 1, 2]]),\n", + " values=tensor([0.5000, 0.4082, 0.4082, 0.3333, 0.4082, 0.4082, 0.5000]),\n", + " size=(3, 3), nnz=7, layout=torch.sparse_coo),\n", + " 's': tensor([[0.7871, 0.2129],\n", + " [0.7905, 0.2095],\n", + " [0.7918, 0.2082]], grad_fn=),\n", + " 'L_cut': tensor(-1.0000, grad_fn=),\n", + " 'L_ortho': tensor(0.7654, grad_fn=),\n", + " 'edge_index_pool': tensor(indices=tensor([[0, 1],\n", + " [1, 0]]),\n", + " values=tensor([0.0838, 0.0838]),\n", + " size=(2, 2), nnz=2, layout=torch.sparse_coo, grad_fn=)}" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3, 2])" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['s'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "decoder_block = GraphDecoderBlock(in_features=1, out_features=4)\n", + "decoder_block = decoder_block.build()\n", + "\n", + "inp = {}\n", + "inp[\"x\"] = torch.randn(2, 1)\n", + "inp[\"batch\"] = torch.zeros(2, dtype=int)\n", + "inp[\"edge_index_pool\"] = torch.tensor([[0, 1], [1, 0]])\n", + "inp[\"edge_index\"] = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n", + "inp['s'] = torch.tensor([\n", + " [1.0, 0],\n", + " [0, 1.0],\n", + " [1.0, 0]\n", + "])\n", + "out = decoder_block(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'x': tensor([[0.0000, 0.0000, 0.4061, 0.0000],\n", + " [0.0000, 0.0000, 0.5079, 0.0000],\n", + " [0.0000, 0.0000, 0.4061, 0.0000]], grad_fn=),\n", + " 'batch': tensor([0, 0]),\n", + " 'edge_index_pool': tensor([[0, 1],\n", + " [1, 0]]),\n", + " 'edge_index': tensor([[0, 1, 1, 2],\n", + " [1, 0, 2, 1]]),\n", + " 's': tensor([[1., 0.],\n", + " [0., 1.],\n", + " [1., 0.]]),\n", + " '-': tensor([[0., 1., 0.],\n", + " [1., 0., 1.],\n", + " [0., 1., 0.]]),\n", + " 'laplacian': tensor(indices=tensor([[0, 0, 1, 1, 1, 2, 2],\n", + " [0, 1, 0, 1, 2, 1, 2]]),\n", + " values=tensor([0.5000, 0.4082, 0.4082, 0.3333, 0.4082, 0.4082, 0.5000]),\n", + " size=(3, 3), nnz=7, layout=torch.sparse_coo)}" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3, 4])" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['x'].shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Encoder and decoder" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "graph_encoder = GraphEncoder(hidden_features=2, num_blocks=2, num_clusters=[2,1])\n", + "graph_encoder = graph_encoder.build()" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "inp = {}\n", + "inp[\"x\"] = torch.randn(4, 2)\n", + "inp[\"batch\"] = torch.zeros(4, dtype=int)\n", + "inp[\"edge_index\"] = torch.tensor([[0, 1, 1, 2, 1, 3], [1, 0, 2, 1, 3, 1]])\n", + "inp[\"edge_attr\"] = torch.randn(6, 1)\n", + "out = graph_encoder(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(graph_encoder.blocks)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'x': tensor([[0.8827, 0.0000]], grad_fn=),\n", + " 'batch': tensor([0, 0, 0, 0]),\n", + " 'edge_index': tensor([[0, 1, 1, 2, 1, 3],\n", + " [1, 0, 2, 1, 3, 1]]),\n", + " 'edge_attr': tensor([[0.0046, 0.0058],\n", + " [0.5006, 0.0829],\n", + " [0.3284, 0.0937],\n", + " [0.1720, 0.0512],\n", + " [0.3950, 0.0706],\n", + " [0.5021, 0.0839]], grad_fn=),\n", + " 'aggregate': tensor([[0.5006, 0.0829],\n", + " [0.6788, 0.1409],\n", + " [0.3284, 0.0937],\n", + " [0.3950, 0.0706]], grad_fn=),\n", + " 'laplacian': tensor(indices=tensor([[0, 0, 1, 1],\n", + " [0, 1, 0, 1]]),\n", + " values=tensor([0.5000, 0.5000, 0.5000, 0.5000]),\n", + " size=(2, 2), nnz=4, layout=torch.sparse_coo),\n", + " 's_0': tensor([[0.4963, 0.5037],\n", + " [0.4963, 0.5037],\n", + " [0.4963, 0.5037],\n", + " [0.4963, 0.5037]], grad_fn=),\n", + " 'batch_1': tensor([0, 0]),\n", + " 'L_cut_0': tensor(-1., grad_fn=),\n", + " 'L_ortho_0': tensor(0.7654, grad_fn=),\n", + " 'edge_index_1': tensor(indices=tensor([[0, 1],\n", + " [1, 0]]),\n", + " values=tensor([0.1442, 0.1442]),\n", + " size=(2, 2), nnz=2, layout=torch.sparse_coo, grad_fn=),\n", + " 's_1': tensor([[1.],\n", + " [1.]]),\n", + " 'batch_2': tensor([0])}" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 2])" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['edge_index_1'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2\n" + ] + } + ], + "source": [ + "graph_decoder = GraphDecoder(hidden_features=2, num_blocks=2, output_node_dim=2, output_edge_dim=1)\n", + "graph_decoder = graph_decoder.build()\n", + "\n", + "print(len(graph_decoder.blocks))\n", + "\n", + "inp = {}\n", + "inp[\"x\"] = torch.randn(1, 2)\n", + "inp[\"batch\"] = torch.zeros(1, dtype=int)\n", + "inp[\"edge_index_1\"] = torch.tensor([[0, 1], [1, 0]])\n", + "inp[\"edge_index\"] = torch.tensor([[0, 1, 1, 2, 1, 3], [1, 0, 2, 1, 3, 1]])\n", + "inp['s_1'] = torch.ones((2,1))\n", + "inp['s_0'] = torch.tensor([[1.0, 0], [0, 1.0], [0, 1.0], [1.0, 0]])\n", + "out = graph_decoder(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0.5838],\n", + " [0.5838],\n", + " [0.5838],\n", + " [0.5838],\n", + " [0.5838],\n", + " [0.5838]], grad_fn=)" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['edge_attr']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}