Skip to content

Commit 7f847d1

Browse files
committed
modelnet40 example pass!
1 parent 07e786c commit 7f847d1

File tree

5 files changed

+106
-10
lines changed

5 files changed

+106
-10
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def sample_patch():
2+
pass
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import numpy as np
2+
import scipy.io as scio
3+
import torch
4+
5+
6+
def load_ft(data_dir, feature_name='GVCNN'):
7+
data = scio.loadmat(data_dir)
8+
lbls = data['Y'].astype(np.long)
9+
if lbls.min() == 1:
10+
lbls = lbls - 1
11+
idx = data['indices'].item()
12+
13+
if feature_name == 'MVCNN':
14+
fts = data['X'][0].item().astype(np.float32)
15+
elif feature_name == 'GVCNN':
16+
fts = data['X'][1].item().astype(np.float32)
17+
else:
18+
print(f'wrong feature name{feature_name}!')
19+
raise IOError
20+
21+
idx_train = (idx == 1)
22+
idx_test = (idx == 0)
23+
return torch.tensor(fts), torch.tensor(lbls).squeeze(), \
24+
torch.tensor(idx_train).squeeze().bool(), \
25+
torch.tensor(idx_test).squeeze().bool()
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import os.path as osp
2+
3+
import torch
4+
import torch.nn.functional as F
5+
from data_helper import load_ft
6+
7+
from HyperG.hyedge import neighbor_distance
8+
from HyperG.hygraph import hyedge_concat
9+
from HyperG.models import HGNN
10+
from HyperG.utils.meter import trans_class_acc
11+
12+
# initialize parameters
13+
data_root = '/repository/HyperG_example/example_data/modelnet40/processed'
14+
result_root = '/repository/HyperG_example/tmp/modelnet40'
15+
16+
k_nearest = 10
17+
feature_dir = osp.join(data_root, 'ModelNet40_mvcnn_gvcnn.mat')
18+
19+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20+
21+
# load data
22+
gvcnn_ft, target, mask_train, mask_val = load_ft(feature_dir, feature_name='GVCNN')
23+
mvcnn_ft, _, _, _ = load_ft(feature_dir, feature_name='MVCNN')
24+
25+
# init H and X
26+
gvcnn_H = neighbor_distance(gvcnn_ft, k_nearest)
27+
mvcnn_H = neighbor_distance(mvcnn_ft, k_nearest)
28+
29+
ft = torch.cat([mvcnn_ft, gvcnn_ft], dim=1)
30+
H = hyedge_concat([mvcnn_H, gvcnn_H])
31+
32+
x_ch = ft.size(1)
33+
n_class = target.max().item() + 1
34+
model = HGNN(x_ch, n_class, hiddens=[128])
35+
36+
model, ft, H, target, mask_train, mask_val = model.to(device), ft.to(device), H.to(device), \
37+
target.to(device), mask_train.to(device), mask_val.to(device)
38+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
39+
40+
41+
def train():
42+
model.train()
43+
optimizer.zero_grad()
44+
pred = model(ft, H)
45+
F.nll_loss(pred[mask_train], target[mask_train]).backward()
46+
optimizer.step()
47+
48+
49+
def val():
50+
model.eval()
51+
pred = model(ft, H)
52+
53+
_train_acc = trans_class_acc(pred, target, mask_train)
54+
_val_acc = trans_class_acc(pred, target, mask_val)
55+
56+
return _train_acc, _val_acc
57+
58+
59+
if __name__ == '__main__':
60+
best_acc, best_iou = 0.0, 0.0
61+
for epoch in range(1, 101):
62+
train()
63+
train_acc, val_acc = val()
64+
if val_acc > best_acc:
65+
best_acc = val_acc
66+
print(f'Epoch: {epoch}, Train:{train_acc:.4f}, Val:{val_acc:.4f}, '
67+
f'Best Val acc:{best_acc:.4f}')

examples/segment/heart_mri/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
import os
21
import os.path as osp
32

43
import torch
54
import torch.nn.functional as F
65
from data_helper import preprocess, split_train_val
76

87
from HyperG.models import HGNN
8+
from HyperG.utils import check_dir
99
from HyperG.utils.meter import trans_class_acc, trans_iou_socre
1010
from HyperG.utils.visualization import trans_vis_pred_target
11-
from HyperG.utils import check_dir
1211

1312
# initialize parameters
1413
data_root = '/repository/HyperG_example/example_data/heart_mri/processed'

test/utils/data/mri/test_read.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55

66
@pytest.mark.skip(reason='unpleasure')
77
def test_read_mri():
8-
img_dir = '/repository/HyperG_example/example_data/heart_mri/processed/0001.mha'
9-
img = read_mri(img_dir)
10-
11-
# print(img.shape)
12-
# import matplotlib.pyplot as plt
13-
# img = img.squeeze().numpy()
14-
# plt.imshow(img, cmap=plt.cm.bone)
15-
# plt.show()
8+
for _img_name in ['0001', '0002', '0003', '0004', '0005']:
9+
img_dir = f'/repository/HyperG_example/example_data/heart_mri/processed/{_img_name}.mha'
10+
save_dir = f'/repository/HyperG_example/tmp/heart_mri/{_img_name}.jpg'
11+
img = read_mri(img_dir)
12+
13+
print(img.shape)
14+
import matplotlib.pyplot as plt
15+
img = img.squeeze().numpy()
16+
plt.imshow(img, cmap=plt.cm.bone)
17+
plt.imsave(save_dir, img, cmap=plt.cm.bone)
18+
# plt.show()

0 commit comments

Comments
 (0)