Skip to content

Commit 1e659b0

Browse files
committed
finish survival prediction, test pass!
1 parent 797039a commit 1e659b0

File tree

12 files changed

+229
-114
lines changed

12 files changed

+229
-114
lines changed

HyperG/hyedge/gather_neighbor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def neighbor_distance(x: torch.Tensor, k_nearest, dis_metric=pairwise_euclidean_
7777
:return:
7878
"""
7979

80-
assert len(x.shape) == 2, 'should be a tensor with (N x C) or (B x C x M x N)'
80+
assert len(x.shape) == 2, 'should be a tensor with dimension (N x C)'
8181

8282
# N x C
8383
node_num = x.size(0)

HyperG/models/BaseCNNs.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
class ResNetFeature(nn.Module):
66

7-
def __init__(self, depth=34, pretrained=True):
7+
def __init__(self, depth=34, pooling=False, pretrained=True):
88
super().__init__()
99
assert depth in [18, 34, 50, 101, 152]
10+
self.pooling = pooling
1011

1112
if depth == 18:
1213
base_model = torchvision.models.resnet18(pretrained=pretrained)
@@ -34,8 +35,15 @@ def __init__(self, depth=34, pretrained=True):
3435
def forward(self, x):
3536
x = self.features(x)
3637

37-
# Attention! No reshape!
38-
return x
38+
if self.pooling:
39+
# -> batch_size x C x N
40+
x = x.view(x.size(0), x.size(1), -1)
41+
# -> batch_size x C
42+
x = x.mean(dim=-1)
43+
return x
44+
else:
45+
# Attention! No reshape!
46+
return x
3947

4048

4149
class ResNetClassifier(nn.Module):

HyperG/utils/data/pathology/overview_patches.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@ def draw_patches_on_slide(slide_dir, patch_coors, mini_frac=32):
1313
mini_size = np.ceil(np.array(slide.level_dimensions[0]) / mini_frac).astype(np.int)
1414
mini_level = get_just_gt_level(slide, mini_size)
1515

16-
img = slide.read_region((0, 0), mini_level, slide.level_dimensions[mini_level])
16+
img = slide.read_region((0, 0), mini_level, slide.level_dimensions[mini_level]).convert('RGB')
1717
img = img.resize(mini_size)
1818

1919
sampled_mask = gather_sampled_patches(patch_coors, mini_size, mini_frac)
2020
sampled_patches_img = fuse_img_mask(np.asarray(img), sampled_mask)
2121

2222
img.close()
23-
return Image.fromarray(sampled_patches_img)
23+
return sampled_patches_img
2424

2525

26-
def gather_sampled_patches(patch_coors, mini_size, mini_frac):
26+
def gather_sampled_patches(patch_coors, mini_size, mini_frac) -> np.array:
2727
# generate sampled area mask
2828
sampled_mask = np.zeros((mini_size[1], mini_size[0]), np.uint8)
2929
for _coor in patch_coors:
@@ -36,8 +36,8 @@ def gather_sampled_patches(patch_coors, mini_size, mini_frac):
3636
return sampled_mask
3737

3838

39-
def fuse_img_mask(img: np.array, mask: np.array, alpha=0.7):
40-
assert img.shape == mask.shape
39+
def fuse_img_mask(img: np.array, mask: np.array, alpha=0.7) -> Image:
40+
assert img.shape[:2] == mask.shape
4141
img = img.copy()
4242
if (mask != 0).any():
4343
img[mask != 0] = alpha * img[mask != 0] + \

HyperG/utils/data/pathology/sample_patches.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66
import openslide
7-
import tqdm
7+
from tqdm import tqdm
88
from scipy import ndimage
99
from skimage.filters import threshold_otsu
1010
from skimage.morphology import dilation, star
@@ -44,7 +44,8 @@ def sample_patch_coors(slide_dir, num_sample=2000, patch_size=256):
4444
# half of the center
4545
th_num = int(np.ceil((mini_patch_size * 3 / 4 * mini_patch_size * 3 / 4)))
4646

47-
for row, col in tqdm(row_col):
47+
pbar = tqdm(total=num_sample)
48+
for row, col in row_col:
4849
if cnt >= num_sample:
4950
break
5051
mini_patch = bg_mask[row:row + mini_patch_size, col: col + mini_patch_size]
@@ -55,6 +56,8 @@ def sample_patch_coors(slide_dir, num_sample=2000, patch_size=256):
5556
# continue
5657
patch_coors.append(origin)
5758
cnt += 1
59+
pbar.update(1)
60+
pbar.close()
5861

5962
return patch_coors
6063

HyperG/utils/meter/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .class_error import ClassErrorMeter
21
from .transductive import trans_class_acc, trans_iou_socre
2+
from .inductive import CIndexMeter
33

4-
__all__ = ['trans_class_acc', 'trans_iou_socre', 'ClassErrorMeter']
4+
__all__ = ['trans_class_acc', 'trans_iou_socre', 'CIndexMeter']

HyperG/utils/meter/class_error.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

HyperG/utils/meter/inductive.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import numpy as np
2+
import torch
3+
4+
5+
class CIndexMeter:
6+
def __init__(self):
7+
super(CIndexMeter, self).__init__()
8+
self.reset()
9+
10+
def reset(self):
11+
self.output = np.array([])
12+
self.target = np.array([])
13+
14+
def add(self, output: torch.tensor, target: torch.tensor):
15+
output = output.cpu().detach().squeeze().numpy()[np.newaxis]
16+
target = target.cpu().detach().squeeze().numpy()[np.newaxis]
17+
18+
assert output.ndim == target.ndim, 'target and output do not match'
19+
assert output.ndim == 1
20+
21+
self.output = np.hstack([self.output, output])
22+
self.target = np.hstack([self.target, target])
23+
24+
def value(self):
25+
output = self.output[np.newaxis]
26+
target = self.target[np.newaxis]
27+
28+
num_sample = output.shape[-1]
29+
num_hit = (~((output.T > output) ^ (target.T > target))).sum()
30+
31+
return float(num_hit - num_sample) / float(num_sample * num_sample - num_sample)

examples/regression/survival_prediction/data_helper.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch
99
from extract_patch_feature import extract_ft
1010
from torch.utils.data import Dataset, DataLoader
11-
from torch.utils.data.dataset import T_co
1211

1312
from HyperG.utils.data import split_id
1413
from HyperG.utils.data.pathology import sample_patch_coors, draw_patches_on_slide
@@ -28,15 +27,19 @@ def split_train_val(data_root, ratio=0.8, save_split_dir=None, resplit=False):
2827
survival_time_max = 0
2928
for full_dir in all_list:
3029
_id = get_id(full_dir)
30+
all_dict[_id] = {}
31+
st = int(lbls[_id])
3132
all_dict[_id]['img_dir'] = full_dir
32-
all_dict[_id]['survival_time'] = lbls[_id]
33+
all_dict[_id]['survival_time'] = st
3334
survival_time_max = survival_time_max \
34-
if survival_time_max > lbls[_id] else lbls[_id]
35+
if survival_time_max > st else st
3536

3637
id_list = list(all_dict.keys())
3738
train_list, val_list = split_id(id_list, ratio)
3839

39-
result = {'survival_time_max': survival_time_max}
40+
result = {'survival_time_max': survival_time_max,
41+
'train': {},
42+
'val': {}}
4043
for _id in train_list:
4144
result['train'][_id] = all_dict[_id]
4245
for _id in val_list:
@@ -52,40 +55,52 @@ def split_train_val(data_root, ratio=0.8, save_split_dir=None, resplit=False):
5255
return result
5356

5457

58+
# def tmp_get_split(data_root):
59+
# def tmp_get_id(_dir):
60+
# _num = int(osp.splitext(osp.split(_dir)[1])[0].split('_')[1])
61+
# return f'TCGA-GBM-{_num}'
62+
#
63+
# result = {'train': {}, 'val': {}}
64+
# for phase in ['train', 'val']:
65+
# glob.glob(osp.join(data_root, phase, '*.npy'))
66+
67+
5568
def preprocess(data_dict, patch_ft_dir, patch_coors_dir, num_sample=2000,
5669
patch_size=256, sampled_vis=None, mini_frac=32):
5770
# check if each slide patch feature exists
5871
all_dir_list = []
5972
for phase in ['train', 'val']:
60-
for _dir in data_dict[phase]:
61-
all_dir_list.append(_dir['img_dir'])
73+
for _id in data_dict[phase].keys():
74+
all_dir_list.append(data_dict[phase][_id]['img_dir'])
6275
to_do_list = check_patch_ft(all_dir_list, patch_ft_dir)
6376

6477
if to_do_list is not None:
6578
for _idx, _dir in enumerate(to_do_list):
66-
print(f'processing {_idx + 1}/{len(to_do_list)}...')
79+
print(f'{_idx + 1}/{len(to_do_list)}: processing slide {_dir}...')
80+
81+
print(f'sampling patch...')
6782
_id = get_id(_dir)
6883
_patch_coors = sample_patch_coors(_dir, num_sample=2000, patch_size=256)
6984

7085
# save sampled patch coordinates
71-
with open(osp.join(patch_coors_dir, f'{_id}_coors.pkl')) as fp:
86+
with open(osp.join(patch_coors_dir, f'{_id}_coors.pkl'), 'wb') as fp:
7287
pickle.dump(_patch_coors, fp)
7388

7489
# visualize sampled patches on slide
7590
if sampled_vis is not None:
91+
_vis_img_dir = osp.join(sampled_vis, f'{_id}_sampled_patches.jpg')
92+
print(f'saving sampled patch_slide visualization {_vis_img_dir}...')
7693
_vis_img = draw_patches_on_slide(_dir, _patch_coors, mini_frac=32)
77-
with open(osp.join(sampled_vis, f'{_id}_sampled_patches.jpg')) as fp:
94+
with open(_vis_img_dir, 'w') as fp:
7895
_vis_img.save(fp)
7996

80-
# extract patch feature for each slide
81-
for _dir in all_dir_list:
82-
_id = get_id(_dir)
83-
_patch_coors = None
84-
fts = extract_ft(_dir, _patch_coors)
85-
np.save(osp.join(patch_ft_dir, f'{_id}_fts.npy'), fts.cpu().numpy())
97+
# extract patch feature for each slide
98+
print(f'extracting feature...')
99+
fts = extract_ft(_dir, _patch_coors, depth=34, batch_size=512)
100+
np.save(osp.join(patch_ft_dir, f'{_id}_fts.npy'), fts.cpu().numpy())
86101

87102

88-
def get_dataloader(data_dict, patch_ft_dir):
103+
def get_dataloaders(data_dict, patch_ft_dir):
89104
all_ft_list = glob.glob(osp.join(patch_ft_dir, '*_fts.npy'))
90105

91106
ft_dict = {}
@@ -98,7 +113,8 @@ def get_dataloader(data_dict, patch_ft_dir):
98113
shuffle=True, num_workers=4)
99114
for phase in ['train', 'val']}
100115
dataset_size = {phase: len(SP_datasets[phase]) for phase in ['train', 'val']}
101-
return SP_dataloaders, dataset_size
116+
len_ft = SP_datasets['train'][0][0].size(1)
117+
return SP_dataloaders, dataset_size, len_ft
102118

103119

104120
class SlidePatch(Dataset):
@@ -113,7 +129,7 @@ def __init__(self, data_dict: dict, ft_dict, survival_time_max):
113129
def __getitem__(self, idx: int):
114130
id = self.id_list[idx]
115131
fts = torch.tensor(np.load(self.ft_dict[id])).float()
116-
st = torch.tensor(self.data_dict[id]['survival_time_max']).float()
132+
st = torch.tensor(self.data_dict[id]['survival_time']).float()
117133
return fts, st / self.st_max
118134

119135
def __len__(self) -> int:
@@ -122,7 +138,7 @@ def __len__(self) -> int:
122138

123139
def check_patch_ft(dir_list, patch_ft_dir):
124140
to_do_list = []
125-
done_list = glob.glob(osp.join(patch_ft_dir, '*_ft.npy'))
141+
done_list = glob.glob(osp.join(patch_ft_dir, '*_fts.npy'))
126142
done_list = [get_id(_dir) for _dir in done_list]
127143
for _dir in dir_list:
128144
id = get_id(_dir)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import openslide
2+
import torch
3+
from tqdm import tqdm
4+
from torch.utils.data import Dataset, DataLoader
5+
from torchvision import transforms
6+
7+
from HyperG.models import ResNetFeature
8+
9+
10+
def extract_ft(slide_dir: str, patch_coors, depth=34, batch_size=16):
11+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
12+
slide = openslide.open_slide(slide_dir)
13+
14+
model_ft = ResNetFeature(depth=depth, pooling=True, pretrained=True)
15+
model_ft = model_ft.to(device)
16+
model_ft.eval()
17+
18+
dataset = Patches(slide, patch_coors)
19+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
20+
21+
fts = []
22+
with tqdm(total=len(dataset)) as pbar:
23+
for _patches in dataloader:
24+
_patches = _patches.to(device)
25+
with torch.no_grad():
26+
_fts = model_ft(_patches)
27+
fts.append(_fts)
28+
pbar.update(_patches.size(0))
29+
30+
fts = torch.cat(fts, dim=0)
31+
assert fts.size(0) == len(patch_coors)
32+
return fts
33+
34+
35+
class Patches(Dataset):
36+
37+
def __init__(self, slide: openslide, patch_coors) -> None:
38+
super().__init__()
39+
self.slide = slide
40+
self.patch_coors = patch_coors
41+
self.transform = transforms.Compose([
42+
transforms.Resize(224),
43+
transforms.ToTensor(),
44+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
45+
std=[0.229, 0.224, 0.225])
46+
])
47+
48+
def __getitem__(self, idx: int):
49+
coor = self.patch_coors[idx]
50+
img = self.slide.read_region((coor[0], coor[1]), 0, (coor[2], coor[3])).convert('RGB')
51+
return self.transform(img)
52+
53+
def __len__(self) -> int:
54+
return len(self.patch_coors)

0 commit comments

Comments
 (0)