Skip to content

Commit 32072f3

Browse files
committed
heart mri example test pass!
1 parent 3d9d1fb commit 32072f3

File tree

11 files changed

+54
-32
lines changed

11 files changed

+54
-32
lines changed

.idea/HyperG_package.iml

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/deployment.xml

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/misc.xml

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

HyperG/conv/hyconv.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,30 +22,30 @@ def reset_parameters(self):
2222

2323
def gen_hyedge_ft(self, x: torch.Tensor, H: torch.Tensor, hyedge_weight=None):
2424
ft_dim = x.size(1)
25-
node_idx, edge_idx = H
25+
node_idx, hyedge_idx = H
2626
hyedge_num = count_hyedge(H)
2727

2828
# a vector to normalize hyperedge feature
2929
hyedge_norm = 1.0 / degree_hyedge(H).float()
3030
if hyedge_weight is not None:
3131
hyedge_norm *= hyedge_weight
32-
hyedge_norm = hyedge_norm[edge_idx]
32+
hyedge_norm = hyedge_norm[hyedge_idx]
3333

3434
x = x[node_idx] * hyedge_norm.unsqueeze(1)
35-
x = torch.zeros(hyedge_num, ft_dim).scatter_add(0, edge_idx.unsqueeze(1).repeat(1, ft_dim), x)
35+
x = torch.zeros(hyedge_num, ft_dim).to(x.device).scatter_add(0, hyedge_idx.unsqueeze(1).repeat(1, ft_dim), x)
3636
return x
3737

3838
def gen_node_ft(self, x: torch.Tensor, H: torch.Tensor):
3939
ft_dim = x.size(1)
40-
node_idx, edge_idx = H
40+
node_idx, hyedge_idx = H
4141
node_num = count_node(H)
4242

4343
# a vector to normalize node feature
4444
node_norm = 1.0 / degree_node(H).float()
45-
node_norm = node_norm[edge_idx]
45+
node_norm = node_norm[node_idx]
4646

47-
x = x[edge_idx] * node_norm.unsqueeze(1)
48-
x = torch.zeros(node_num, ft_dim).scatter_add(0, node_idx.unsqueeze(1).repeat(1, ft_dim), x)
47+
x = x[hyedge_idx] * node_norm.unsqueeze(1)
48+
x = torch.zeros(node_num, ft_dim).to(x.device).scatter_add(0, node_idx.unsqueeze(1).repeat(1, ft_dim), x)
4949
return x
5050

5151
def forward(self, x: torch.Tensor, H: torch.Tensor, hyedge_weight=None):

HyperG/hyedge/utils/degree.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@
66
def degree_node(H):
77
node_idx, edge_idx = H
88
node_num = count_node(H)
9-
return torch.zeros(node_num).scatter_add(0, node_idx, torch.ones_like(node_idx).float()).long()
9+
src = torch.ones_like(node_idx).float().to(H.device)
10+
out = torch.zeros(node_num).to(H.device)
11+
return out.scatter_add(0, node_idx, src).long()
12+
# return torch.zeros(node_num).scatter_add(0, node_idx, torch.ones_like(node_idx).float()).long()
1013

1114

12-
def degree_hyedge(H):
15+
def degree_hyedge(H: torch.Tensor):
1316
node_idx, hyedge_idx = H
1417
edge_num = count_hyedge(H)
15-
return torch.zeros(edge_num).scatter_add(0, hyedge_idx, torch.ones_like(hyedge_idx).float()).long()
18+
src = torch.ones_like(hyedge_idx).float().to(H.device)
19+
out = torch.zeros(edge_num).to(H.device)
20+
return out.scatter_add(0, hyedge_idx, src).long()

HyperG/hygraph/fusion.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,22 @@
22

33
import torch
44

5-
from HyperG.hyedge import count_hyedge, count_node
5+
from HyperG.hyedge import count_hyedge, count_node, contiguous_hyedge_idx
66

77

88
def hyedge_concat(Hs: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], same_node=True):
99
node_num = 0
1010
hyedge_num = 0
1111
Hs_new = []
1212
for H in Hs:
13+
_H = H.clone()
1314
if not same_node:
14-
H[0, :] += node_num
15-
H[1, :] += hyedge_num
15+
_H[0, :] += node_num
16+
_H[1, :] += hyedge_num
1617

17-
Hs_new.append(H)
18+
Hs_new.append(_H)
1819

1920
hyedge_num += count_hyedge(H)
2021
node_num += count_node(H)
21-
22-
return torch.cat(Hs_new, dim=1)
22+
Hs_new = torch.cat(Hs_new, dim=1)
23+
return contiguous_hyedge_idx(Hs_new)

examples/segment/heart_mri/__init__.py

Whitespace-only changes.

examples/segment/heart_mri/data_helper.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def normalize(x):
1616

1717

1818
def split_train_val(root, ratio=0.8, save_dir=None, resplit=False):
19-
if not resplit and save_dir is not None:
20-
with open(save_dir, 'r') as f:
19+
if not resplit and save_dir is not None and osp.exists(save_dir):
20+
with open(save_dir, 'rb') as f:
2121
result = pickle.load(f)
2222
return result
2323

@@ -44,7 +44,7 @@ def split_train_val(root, ratio=0.8, save_dir=None, resplit=False):
4444
save_folder = osp.split(save_dir)[0]
4545
if not osp.exists(save_folder):
4646
os.makedirs(save_folder)
47-
with open(save_dir, 'w') as f:
47+
with open(save_dir, 'wb') as f:
4848
pickle.dump(result, f)
4949

5050
return result
@@ -70,12 +70,12 @@ def preprocess(data_list, patch_size, k_nearest):
7070
lbl.append(_lbl)
7171
mask_train.extend([0] * _node_num)
7272

73-
x, lbl = torch.cat(x, dim=0), torch.cat(lbl, dim=0)
73+
x, lbl = torch.cat(x, dim=0), torch.cat(lbl, dim=0).long()
7474
x = normalize(x)
7575

7676
H_grid = hyedge_concat(H_grid, same_node=False)
77-
mask_train = torch.tensor(mask_train)
78-
mask_val = 1 - mask_train
77+
mask_train = torch.tensor(mask_train).bool()
78+
mask_val = ~mask_train
7979

8080
H_global = neighbor_distance(x, k_nearest)
8181

examples/segment/heart_mri/train.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
import torch
22
import torch.nn.functional as F
3+
from data_helper import preprocess, split_train_val
34

45
from HyperG.models import HGNN
5-
from .data_helper import preprocess, split_train_val
66

77
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
88

99
k_nearest = 7
1010
patch_size = (5, 5)
11-
data_list = split_train_val('/repository/HyperG_example/example_data/heart_mri/processed',
12-
save_dir='/repository/HyperG_example/tmp/heart_mri', ratio=0.8)
11+
data_list = split_train_val('/repository/HyperG_example/example_data/heart_mri/processed', ratio=0.8,
12+
save_dir='/repository/HyperG_example/tmp/heart_mri/split.pkl', resplit=True)
1313
x, H, lbl, mask_train, mask_val = preprocess(data_list, patch_size, k_nearest)
1414

1515
x_ch = x.size(1)
16-
n_class = lbl.max() + 1
16+
n_class = lbl.max().item() + 1
1717
model = HGNN(x_ch, n_class, hidens=[16])
1818

1919
model, x, H, lbl, mask_train, mask_val = model.to(device), x.to(device), H.to(device), \
@@ -44,9 +44,9 @@ def val():
4444

4545
if __name__ == '__main__':
4646
best_acc = 0.0
47-
for epoch in range(1, 51):
47+
for epoch in range(1, 21):
4848
train()
4949
train_acc, val_acc = val()
5050
if val_acc > best_acc:
5151
best_acc = val_acc
52-
print(f'Epoch: {epoch}, Train:{train_acc:.4f}, Val:{val_acc:.4f}, Best Val:{best_acc}:.4f')
52+
print(f'Epoch: {epoch}, Train:{train_acc:.4f}, Val:{val_acc:.4f}, Best Val:{best_acc:.4f}')

test/models/test_HGNN.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
23
from HyperG.models import HGNN
34

45

0 commit comments

Comments
 (0)