Skip to content

Commit 07e786c

Browse files
committed
breast pathology pass!
1 parent 32072f3 commit 07e786c

File tree

23 files changed

+480
-40
lines changed

23 files changed

+480
-40
lines changed

HyperG/conv/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
from .hyconv import HyConv
1+
from .hyconv import HyConv
2+
3+
__all__ = ['HyConv']

HyperG/conv/hyconv.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def gen_node_ft(self, x: torch.Tensor, H: torch.Tensor):
4949
return x
5050

5151
def forward(self, x: torch.Tensor, H: torch.Tensor, hyedge_weight=None):
52+
assert len(x.shape) == 2, 'the input of HyperConv should be N x C'
5253
# feature transform
5354
x = x.matmul(self.theta)
5455

HyperG/hyedge/gather_neighbor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,14 @@ def neighbor_distance(x: torch.Tensor, k_nearest, dis_metric=pairwise_euclidean_
7676
:param k_nearest:
7777
:return:
7878
"""
79-
assert len(x.shape) == 2
79+
80+
assert len(x.shape) == 2, 'should be a tensor with (N x C) or (B x C x M x N)'
8081

8182
# N x C
8283
node_num = x.size(0)
8384
dis_matrix = dis_metric(x)
8485
_, nn_idx = torch.topk(dis_matrix, k_nearest, dim=1, largest=False)
85-
hyedge_idx = torch.arange(node_num).unsqueeze(0).repeat(k_nearest, 1).transpose(1, 0).reshape(-1)
86+
hyedge_idx = torch.arange(node_num).to(x.device).unsqueeze(0).repeat(k_nearest, 1).transpose(1, 0).reshape(-1)
8687
H = torch.stack([nn_idx.reshape(-1), hyedge_idx])
8788
return H
8889

HyperG/models/BaseCNNs.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import torch.nn as nn
2+
import torchvision
3+
4+
5+
class ResNetFeature(nn.Module):
6+
7+
def __init__(self, depth=34, pretrained=True):
8+
super().__init__()
9+
assert depth in [18, 34, 50, 101, 152]
10+
11+
if depth == 18:
12+
base_model = torchvision.models.resnet18(pretrained=pretrained)
13+
self.len_feature = 512
14+
self.features = nn.Sequential(*list(base_model.children())[:-2])
15+
elif depth == 34:
16+
base_model = torchvision.models.resnet34(pretrained=pretrained)
17+
self.len_feature = 512
18+
self.features = nn.Sequential(*list(base_model.children())[:-2])
19+
elif depth == 50:
20+
base_model = torchvision.models.resnet50(pretrained=pretrained)
21+
self.len_feature = 2048
22+
self.features = nn.Sequential(*list(base_model.children())[:-2])
23+
elif depth == 101:
24+
base_model = torchvision.models.resnet101(pretrained=pretrained)
25+
self.len_feature = 2048
26+
self.features = nn.Sequential(*list(base_model.children())[:-2])
27+
elif depth == 152:
28+
base_model = torchvision.models.resnet152(pretrained=pretrained)
29+
self.len_feature = 2048
30+
self.features = nn.Sequential(*list(base_model.children())[:-2])
31+
else:
32+
raise NotImplementedError(f'ResNet-{depth} is not implemented!')
33+
34+
def forward(self, x):
35+
x = self.features(x)
36+
37+
# Attention! No reshape!
38+
return x
39+
40+
41+
class ResNetClassifier(nn.Module):
42+
43+
def __init__(self, n_class, len_feature):
44+
super().__init__()
45+
self.len_feature = len_feature
46+
self.classifier = nn.Linear(self.len_feature, n_class)
47+
48+
def forward(self, x):
49+
# -> batch_size x C x N
50+
x = x.view(x.size(0), x.size(1), -1)
51+
52+
# -> batch_size x C
53+
x = x.mean(dim=-1)
54+
55+
x = self.classifier(x)
56+
return x

HyperG/models/CNN_HGNN.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import torch.nn as nn
2+
import torch.nn.functional as F
3+
4+
from HyperG.conv import HyConv
5+
from HyperG.hyedge import neighbor_distance
6+
from HyperG.models import ResNetFeature, ResNetClassifier
7+
8+
9+
class ResNet_HGNN(nn.Module):
10+
def __init__(self, n_class, depth, k_nearest, hiddens=[512], dropout=0.5, pretrained=True):
11+
super().__init__()
12+
self.dropout = dropout
13+
self.k_nearest = k_nearest
14+
self.ft_layers = ResNetFeature(depth=depth, pretrained=pretrained)
15+
16+
# hypergraph convolution for feature refine
17+
self.hyconvs = []
18+
dim_in = self.ft_layers.len_feature
19+
for h in hiddens:
20+
dim_out = h
21+
self.hyconvs.append(HyConv(dim_in, dim_out))
22+
dim_in = dim_out
23+
self.hyconvs = nn.ModuleList(self.hyconvs)
24+
25+
self.cls_layers = ResNetClassifier(n_class=n_class, len_feature=dim_in)
26+
27+
def forward(self, x):
28+
x = self.ft_layers(x)
29+
30+
assert x.size(0) == 1, 'when construct hypergraph, only support batch size = 1!'
31+
x = x.view(x.size(1), x.size(2) * x.size(3))
32+
# -> N x C
33+
x = x.permute(1, 0)
34+
H = neighbor_distance(x, k_nearest=self.k_nearest)
35+
# Hypergraph Convs
36+
for hyconv in self.hyconvs:
37+
x = hyconv(x, H)
38+
x = F.leaky_relu(x, inplace=True)
39+
x = F.dropout(x, self.dropout)
40+
# N x C -> 1 x C x N
41+
x = x.permute(1, 0).unsqueeze(0)
42+
43+
x = self.cls_layers(x)
44+
45+
return x

HyperG/models/CNNs.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch.nn as nn
2+
from HyperG.models import ResNetFeature, ResNetClassifier
3+
4+
5+
class ResNet(nn.Module):
6+
def __init__(self, n_class, depth=34, pretrained=True):
7+
super().__init__()
8+
9+
self.ft_layers = ResNetFeature(depth=depth, pretrained=pretrained)
10+
self.cls_layers = ResNetClassifier(n_class=n_class, len_feature=self.ft_layers.len_feature)
11+
12+
def forward(self, x):
13+
x = self.ft_layers(x)
14+
x = self.cls_layers(x)
15+
16+
return x

HyperG/models/HGNN.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66

77
class HGNN(nn.Module):
8-
def __init__(self, in_ch, n_class, hidens=[16], dropout=0.5) -> None:
8+
def __init__(self, in_ch, n_class, hiddens=[16], dropout=0.5) -> None:
99
super().__init__()
1010
self.dropout = dropout
1111
_in = in_ch
1212
self.hyconvs = []
13-
for _h in hidens:
13+
for _h in hiddens:
1414
_out = _h
1515
self.hyconvs.append(HyConv(_in, _out))
1616
_in = _out
@@ -23,4 +23,4 @@ def forward(self, x, H, hyedge_weight=None):
2323
x = F.leaky_relu(x, inplace=True)
2424
x = F.dropout(x, self.dropout)
2525
x = self.last_hyconv(x, H)
26-
return x
26+
return F.log_softmax(x, dim=1)

HyperG/models/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,6 @@
1-
from .HGNN import HGNN
1+
from .BaseCNNs import ResNetFeature, ResNetClassifier
2+
from .CNN_HGNN import ResNet_HGNN
3+
from .HGNN import HGNN
4+
from .CNNs import ResNet
5+
6+
__all__ = ['ResNetFeature', 'ResNetClassifier', 'HGNN', 'ResNet_HGNN', 'ResNet']

HyperG/utils/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import os
2+
3+
4+
def check_dir(_dir, make=True):
5+
if os.path.exists(_dir):
6+
return True
7+
else:
8+
if make:
9+
os.makedirs(_dir)
10+
return True
11+
return False

HyperG/utils/data/pathology/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)