Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deeplay/applications/autoencoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .vae import VariationalAutoEncoder
from .wae import WassersteinAutoEncoder
from .vgae import VariationalGraphAutoEncoder
160 changes: 160 additions & 0 deletions deeplay/applications/autoencoders/vgae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from typing import Optional, Sequence, Callable, List

from deeplay.components import ConvolutionalEncoder2d, ConvolutionalDecoder2d
from deeplay.applications import Application
from deeplay.external import External, Optimizer, Adam

from deeplay import (
DeeplayModule,
Layer,
)

import torch
import torch.nn as nn


class VariationalGraphAutoEncoder(Application):
""" Variational Auto Encoder for Graphs

Parameters
----------
encoder : nn.Module
decoder : nn.Module
hidden_features : int
Number of features of the hidden layers
latent_dim: int
Number of latent dimensions
alpha: float
Weighting for the node feature reconstruction loss. Defaults to 1
beta: float
Weighting for the KL loss. Defaults to 1e-7
gamma: float
Weighting for the edge feature reconstruction loss. Defaults to 1
delta: float
Weighting for the MinCut loss. Defaults to 1
reconstruction_loss: Reconstruction loss
Loss metric for the reconstruction of the node and edge features. Defaults to L1 (Mean absolute error).
optimizer: Optimizer
Optimizer to use for training.
"""

encoder: torch.nn.Module
decoder: torch.nn.Module
hidden_features: int
latent_dim: int
alpha: float
beta: float
gamma: float
delta: float
reconstruction_loss: torch.nn.Module
optimizer: Optimizer

def __init__(
self,
encoder: Optional[nn.Module] = None,
decoder: Optional[nn.Module] = None,
hidden_features: Optional[int] = 96,
latent_dim=int,
alpha: Optional[float] = 1,
beta: Optional[float] = 1e-7,
gamma: Optional[float] = 1,
delta: Optional[float] = 1,
reconstruction_loss: Optional[Callable] = nn.L1Loss(),
optimizer=None,
**kwargs,
):
self.encoder = encoder

self.fc_mu = Layer(nn.Linear, hidden_features, latent_dim)
self.fc_mu.set_input_map('x')
self.fc_mu.set_output_map('mu')

self.fc_var = Layer(nn.Linear, hidden_features, latent_dim)
self.fc_var.set_input_map('x')
self.fc_var.set_output_map('log_var')

self.fc_dec = Layer(nn.Linear, latent_dim, hidden_features)
self.fc_dec.set_input_map('z')
self.fc_dec.set_output_map('x')

self.decoder = decoder

self.reconstruction_loss = reconstruction_loss or nn.L1Loss()
self.latent_dim = latent_dim
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.delta = delta

super().__init__(**kwargs)

class Reparameterize(DeeplayModule):
def forward(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu

self.reparameterize = Reparameterize()
self.reparameterize.set_input_map('mu', 'log_var')
self.reparameterize.set_output_map('z')

self.optimizer = optimizer or Adam(lr=1e-3)

@self.optimizer.params
def params(self):
return self.parameters()


def encode(self, x):
x = self.encoder(x)
x = self.fc_mu(x)
x = self.fc_var(x)
return x

def decode(self, x):
x = self.fc_dec(x)
x = self.decoder(x)
return x

def training_step(self, batch, batch_idx):
x, y = self.train_preprocess(batch)
node_features, edge_features = y
x = self(x)
node_features_hat = x['x']
edge_features_hat = x['edge_attr']
mu = x['mu']
log_var = x['log_var']
mincut_cut_loss = sum(value for key, value in x.items() if key.startswith('L_cut'))
mincut_ortho_loss = sum(value for key, value in x.items() if key.startswith('L_ortho'))
rec_loss_nodes, rec_loss_edges, KLD = self.compute_loss(node_features_hat, node_features, edge_features_hat, edge_features, mu, log_var)

tot_loss = self.alpha * rec_loss_nodes + self.gamma * rec_loss_edges + self.beta * KLD + self.delta * (mincut_cut_loss + mincut_ortho_loss)

loss = {"rec_loss_nodes": rec_loss_nodes, "rec_loss_edges": rec_loss_edges, "KL": KLD,
"MinCut cut loss": mincut_cut_loss, "MinCut orthogonality loss": mincut_ortho_loss,
"total_loss": tot_loss}
for name, v in loss.items():
self.log(
f"train_{name}",
v,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
return tot_loss

def compute_loss(self, n_hat, n, e_hat, e, mu, log_var):

rec_loss_nodes = self.reconstruction_loss(n_hat, n)
rec_loss_edges = self.reconstruction_loss(e_hat, e)

KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

return rec_loss_nodes, rec_loss_edges, KLD

def forward(self, x):
x = self.encode(x)
x = self.reparameterize(x)
x = self.decode(x)
return x
2 changes: 2 additions & 0 deletions deeplay/components/gnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .gcn import *
from .mpn import *
from .tpu import *
from .pooling import *
from .graphencdec import GraphEncoderBlock, GraphDecoderBlock, GraphEncoder, GraphDecoder
1 change: 1 addition & 0 deletions deeplay/components/gnn/gcn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .gcn import GraphConvolutionalNeuralNetwork
from .normalization import *
from .gcn_concat import GraphConvolutionalNeuralNetworkConcat
184 changes: 184 additions & 0 deletions deeplay/components/gnn/gcn/gcn_concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from typing import List, Optional, Literal, Any, Sequence, Type, overload, Union

from deeplay import DeeplayModule, Layer, LayerList

from ..tpu import TransformPropagateUpdate
from deeplay.ops import Cat

import torch
import torch.nn as nn


class GraphConvolutionalNeuralNetworkConcat(DeeplayModule):
in_features: Optional[int]
hidden_features: Sequence[Optional[int]]
out_features: int
blocks: LayerList[TransformPropagateUpdate]

@property
def input(self):
"""Return the input layer of the network. Equivalent to `.blocks[0]`."""
return self.blocks[0]

@property
def hidden(self):
"""Return the hidden layers of the network. Equivalent to `.blocks[:-1]`"""
return self.blocks[:-1]

@property
def output(self):
"""Return the last layer of the network. Equivalent to `.blocks[-1]`."""
return self.blocks[-1]

@property
def transform(self) -> LayerList[Layer]:
"""Return the layers of the network. Equivalent to `.blocks.layer`."""
return self.blocks.transform

@property
def propagate(self) -> LayerList[Layer]:
"""Return the activations of the network. Equivalent to `.blocks.activation`."""
return self.blocks.propagate

@property
def update(self) -> LayerList[Layer]:
"""Return the normalizations of the network. Equivalent to `.blocks.normalization`."""
return self.blocks.update

def __init__(
self,
in_features: int,
hidden_features: Sequence[int],
out_features: int,
out_activation: Union[Type[nn.Module], nn.Module, None] = None,
):
super().__init__()

self.in_features = in_features
self.hidden_features = hidden_features
self.out_features = out_features

if in_features is None:
raise ValueError("in_features must be specified")

if out_features is None:
raise ValueError("out_features must be specified")

if in_features <= 0:
raise ValueError(f"in_features must be positive, got {in_features}")

if out_features <= 0:
raise ValueError(f"out_features must be positive, got {out_features}")

if any(h <= 0 for h in hidden_features):
raise ValueError(
f"all hidden_features must be positive, got {hidden_features}"
)

if out_activation is None:
out_activation = Layer(nn.Identity)
elif isinstance(out_activation, type) and issubclass(out_activation, nn.Module):
out_activation = Layer(out_activation)

class Propagate(DeeplayModule):
def forward(self, x, A):
if A.is_sparse:
return torch.spmm(A, x)
elif (not A.is_sparse) & (A.size(0) == 2):
A = torch.sparse_coo_tensor(
A,
torch.ones(A.size(1)),
(x.size(0),) * 2,
device=A.device,
)
return torch.spmm(A, x)
elif (not A.is_sparse) & len({A.size(0), A.size(1), x.size(0)}) == 1:
return A.type(x.dtype) @ x
else:
raise ValueError(
"Unsupported adjacency matrix format.",
"Ensure it is a pytorch sparse tensor, an edge index tensor, or a square dense tensor.",
"Consider updating the propagate layer to handle alternative formats.",
)

self.blocks = LayerList()

for i, (c_in, c_out) in enumerate(
zip([in_features, *hidden_features], [*hidden_features, out_features])
):
transform = Layer(nn.Linear, c_in, c_out)
transform.set_input_map("x")
transform.set_output_map("x_prime")

propagate = Layer(Propagate)
propagate.set_input_map("x_prime", "edge_index")
propagate.set_output_map("x_prime")

update = Layer(nn.ReLU) if i < len(self.hidden_features) else out_activation
update.set_input_map("x_prime")
update.set_output_map("x_prime")

block = TransformPropagateUpdate(
transform=transform,
propagate=propagate,
update=update,
order=["transform", "update", "propagate"]
)
self.blocks.append(block)

self.concat = Cat()
self.concat.set_input_map('x_prime', 'x')
self.concat.set_output_map('x')

self.dense = Layer(nn.Linear, out_features*2, out_features)
self.dense.set_input_map('x')
self.dense.set_output_map('x')

self.activate = Layer(nn.ReLU)
self.activate.set_input_map('x')
self.activate.set_output_map('x')

def forward(self, x):
for block in self.blocks:
x = block(x)

x = self.concat(x)
x = self.dense(x)
x = self.activate(x)

return x

@overload
def configure(
self,
/,
in_features: Optional[int] = None,
hidden_features: Optional[List[int]] = None,
out_features: Optional[int] = None,
out_activation: Union[Type[nn.Module], nn.Module, None] = None,
) -> None: ...

@overload
def configure(
self,
name: Literal["blocks"],
order: Optional[Sequence[str]] = None,
transform: Optional[Type[nn.Module]] = None,
propagate: Optional[Type[nn.Module]] = None,
update: Optional[Type[nn.Module]] = None,
**kwargs: Any,
) -> None: ...

@overload
def configure(
self,
name: Literal["blocks"],
index: Union[int, slice, List[Union[int, slice]]],
order: Optional[Sequence[str]] = None,
transform: Optional[Type[nn.Module]] = None,
propagate: Optional[Type[nn.Module]] = None,
update: Optional[Type[nn.Module]] = None,
**kwargs: Any,
) -> None: ...

configure = DeeplayModule.configure
7 changes: 5 additions & 2 deletions deeplay/components/gnn/gcn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ def add_self_loops(self, A, num_nodes):
"""
loop_index = torch.arange(num_nodes, device=A.device)
loop_index = loop_index.unsqueeze(0).repeat(2, 1)

A = torch.cat([A, loop_index], dim=1)

if A.is_sparse:
A = torch.cat([A.indices(), loop_index], dim=1)
elif (not A.is_sparse) & (A.size(0) == 2):
A = torch.cat([A, loop_index], dim=1)

return A

Expand Down
Loading
Loading