Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deeplay/applications/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .regression import *
from .detection import *
from .autoencoders import *
from .clustering import *

# from .classification import *
# from .segmentation import ImageSegmentor
Expand Down
1 change: 1 addition & 0 deletions deeplay/applications/clustering/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .miro import MIRO
109 changes: 109 additions & 0 deletions deeplay/applications/clustering/miro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Callable, Optional

import numpy as np
import torch
import torch.nn as nn

from deeplay.models import RecurrentMessagePassingModel
from deeplay.applications import Application
from deeplay.external import Optimizer, Adam

from sklearn.cluster import DBSCAN


class MIRO(Application):
num_outputs: int
connectivity_radius: float
model: nn.Module
nd_loss_weight: float
loss: torch.nn.Module
metrics: list
optimizer: Optimizer

def __init__(
self,
num_outputs: int = 2,
connectivity_radius: float = 1.0,
model: Optional[nn.Module] = None,
nd_loss_weight: float = 10,
loss: torch.nn.Module = torch.nn.L1Loss(),
optimizer=None,
**kwargs,
):

self.num_outputs = num_outputs
self.connectivity_radius = connectivity_radius
self.model = model or self._get_default_model()
self.nd_loss_weight = nd_loss_weight

super().__init__(loss=loss, optimizer=optimizer or Adam(lr=1e-4), **kwargs)

def _get_default_model(self):
rgnn = RecurrentMessagePassingModel(
hidden_features=256, out_features=self.num_outputs, num_iter=20
)
return rgnn

def forward(self, x):
return self.model(x)

def compute_loss(self, y_hat, y, edges, position):
loss = 0
for pred in y_hat:
loss += self.loss(pred, y) + self.nd_loss_weight * self.compute_nd_loss(
pred, y, edges, position
)
return loss / len(y_hat)

def compute_nd_loss(self, y_hat, y, edges, position):
compressed_gt = position - y * self.connectivity_radius
compressed_gt_distances = torch.norm(
compressed_gt[edges[0]] - compressed_gt[edges[1]], dim=1
)
compressed_pred = position - y_hat * self.connectivity_radius
compressed_pred_distances = torch.norm(
compressed_pred[edges[0]] - compressed_pred[edges[1]], dim=1
)
return self.loss(compressed_pred_distances, compressed_gt_distances)

def squeeze(self, x, from_iter=-1, scaling=np.array([1.0, 1.0])):
pred = self(x)[from_iter].detach().cpu().numpy()
return (x.position.cpu() - pred * self.connectivity_radius).numpy() * scaling

def clustering(
self,
x,
eps,
min_samples,
from_iter=-1,
**kwargs,
):
squeezed = self.squeeze(x, from_iter, **kwargs)
clusters = DBSCAN(eps=eps, min_samples=min_samples).fit(squeezed)

return clusters.labels_

def training_step(self, batch, batch_idx):
x, y = self.train_preprocess(batch)
y_hat = self(x)
loss = self.compute_loss(y_hat, y, x.edge_index, x.position)

self.log(
"train_loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)

self.log_metrics(
"train",
y_hat,
y,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss
27 changes: 26 additions & 1 deletion deeplay/components/dict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Any, Union, Tuple, overload
from typing import Dict, Any, Union, Tuple, Dict

from deeplay import DeeplayModule

Expand Down Expand Up @@ -54,3 +54,28 @@ def forward(

x.update({key: torch.add(x[key], y[key]) for key in self.keys})
return x


class CatDictElements(DeeplayModule):
"""
Concatenates specified elements within a dictionary-like structure along a given dimension.

Parameters:
- keys: Tuple of tuples, where each tuple contains the source and target keys to concatenate.
- dim: Dimension along which concatenation occurs. Default is -1.
"""

def __init__(self, *keys: Tuple[tuple], dim: int = -1):
super().__init__()
self.source, self.target = zip(*keys)
self.dim = dim

def forward(self, x: Union[Dict[str, Any], Data]) -> Union[Dict[str, Any], Data]:
x = x.clone() if isinstance(x, Data) else x.copy()
x.update(
{
t: torch.cat([x[t], x[s]], dim=self.dim)
for t, s in zip(self.target, self.source)
}
)
return x
1 change: 1 addition & 0 deletions deeplay/components/gnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .gcn import *
from .mpn import *
from .tpu import *
from .rgb import *
56 changes: 56 additions & 0 deletions deeplay/components/gnn/rgb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
from deeplay import DeeplayModule
from deeplay.components.dict import CatDictElements


class RecurrentGraphBlock(DeeplayModule):
combine: DeeplayModule
layer: DeeplayModule
head: DeeplayModule

def __init__(
self,
layer: DeeplayModule,
head: DeeplayModule,
hidden_features: int,
num_iter: int,
combine: DeeplayModule = CatDictElements(("x", "hidden")),
):
super().__init__()
self.combine = combine
self.layer = layer
self.head = head
self.hidden_features = hidden_features
self.num_iter = num_iter

if not all(hasattr(self.combine, attr) for attr in ("source", "target")):
raise AttributeError(
"The 'combine' module must have 'source' and 'target' attributes to specify "
"the keys to concatenate. Found None. Ensure that the 'combine' module is initialized "
"with valid 'source' and 'target' keys. Check CatDictElements for reference."
)
self.hidden_variables_name = self.combine.target

def initialize_hidden(self, x):
for source, hidden_variable_name in zip(
self.combine.source, self.hidden_variables_name
):
if hidden_variable_name not in x:
x.update(
{
hidden_variable_name: torch.zeros(
x[source].size(0), self.hidden_features
).to(x[source].device)
}
)
return x

def forward(self, x):
x = self.initialize_hidden(x)
outputs = []
for _ in range(self.num_iter):
x = self.combine(x)
x = self.layer(x)
outputs.append(self.head(x))

return outputs
1 change: 1 addition & 0 deletions deeplay/models/gnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .mpm import MPM
from .rmpm import RecurrentMessagePassingModel
from .gtogmpm import GraphToGlobalMPM, GlobalMeanPool
from .gtonmpm import GraphToNodeMPM
from .gtoempm import GraphToEdgeMPM
Expand Down
115 changes: 115 additions & 0 deletions deeplay/models/gnn/rmpm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from typing import Type, Union

from deeplay import (
DeeplayModule,
Parallel,
MultiLayerPerceptron,
MessagePassingNeuralNetwork,
Sequential,
RecurrentGraphBlock,
CatDictElements,
)

import torch.nn as nn


class RecurrentMessagePassingModel(DeeplayModule):
"""Recurrent Message Passing Neural Network (RMPN) model."""

hidden_features: int
out_features: int
num_iter: int

def __init__(
self,
hidden_features: int,
out_features: int,
num_iter: int,
out_activation: Union[Type[nn.Module], nn.Module, None] = None,
):
super().__init__()

self.hidden_features = hidden_features
self.out_features = out_features
self.num_iter = num_iter

if out_features <= 0:
raise ValueError(f"out_features must be positive, got {out_features}")

if not isinstance(hidden_features, int):
raise ValueError(
f"hidden_features must be an integer, got {hidden_features}"
)

if hidden_features <= 0:
raise ValueError(f"hidden_features must be positive, got {hidden_features}")

self.encoder = Parallel(
**{
key: MultiLayerPerceptron(
in_features=None,
hidden_features=[],
out_features=hidden_features,
flatten_input=False,
).set_input_map(key)
for key in ("x", "edge_attr")
}
)

combine = CatDictElements(("x", "hidden_x"), ("edge_attr", "hidden_edge_attr"))
backbone_layer = Sequential(
[
Parallel(
**{
key: MultiLayerPerceptron(
in_features=None,
hidden_features=[],
out_features=hidden_features,
flatten_input=False,
).set_input_map(key)
for key in ("hidden_x", "hidden_edge_attr")
}
),
MessagePassingNeuralNetwork([], hidden_features),
]
)
head = MultiLayerPerceptron(
hidden_features,
[],
out_features,
out_activation=out_activation,
flatten_input=False,
)

self.backbone = RecurrentGraphBlock(
combine=combine,
layer=backbone_layer,
head=head,
hidden_features=hidden_features,
num_iter=self.num_iter,
)

self.backbone.layer[1].transform.set_input_map(
"hidden_x", "edge_index", "hidden_edge_attr"
)
self.backbone.layer[1].transform.set_output_map("hidden_edge_attr")

self.backbone.layer[1].propagate.set_input_map(
"hidden_x", "edge_index", "hidden_edge_attr"
)
self.backbone.layer[1].propagate.set_output_map("aggregate")

update = MultiLayerPerceptron(None, [], hidden_features, flatten_input=False)
update.set_input_map("aggregate")
update.set_output_map("hidden_x")
self.backbone.layer[1].blocks[0].replace("update", update)

self.backbone.layer[1][..., "activation"].configure(nn.ReLU)

self.backbone.head.set_input_map("hidden_x")
self.backbone.head.set_output_map()

def forward(self, x):
x = self.encoder(x)
x = self.backbone(x)
return x
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ torch-geometric
kornia
scipy
scikit-image
scikit-learn
rich
dill