Skip to content

Commit 797039a

Browse files
committed
finish survival prediction data_helper!
1 parent 2430b65 commit 797039a

File tree

1 file changed

+42
-14
lines changed

1 file changed

+42
-14
lines changed

examples/regression/survival_prediction/data_helper.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import pickle
66

77
import numpy as np
8+
import torch
89
from extract_patch_feature import extract_ft
9-
from torch.utils.data import Dataset
10+
from torch.utils.data import Dataset, DataLoader
1011
from torch.utils.data.dataset import T_co
1112

1213
from HyperG.utils.data import split_id
@@ -24,18 +25,23 @@ def split_train_val(data_root, ratio=0.8, save_split_dir=None, resplit=False):
2425
lbls = json.load(fp)
2526

2627
all_dict = {}
28+
survival_time_max = 0
2729
for full_dir in all_list:
28-
id = get_id(full_dir)
29-
all_dict[id]['img_dir'] = full_dir
30-
all_dict[id]['survival_time'] = lbls[id]
30+
_id = get_id(full_dir)
31+
all_dict[_id]['img_dir'] = full_dir
32+
all_dict[_id]['survival_time'] = lbls[_id]
33+
survival_time_max = survival_time_max \
34+
if survival_time_max > lbls[_id] else lbls[_id]
3135

3236
id_list = list(all_dict.keys())
3337
train_list, val_list = split_id(id_list, ratio)
3438

35-
train_list = [all_dict[_id] for _id in train_list]
36-
val_list = [all_dict[_id] for _id in val_list]
39+
result = {'survival_time_max': survival_time_max}
40+
for _id in train_list:
41+
result['train'][_id] = all_dict[_id]
42+
for _id in val_list:
43+
result['val'][_id] = all_dict[_id]
3744

38-
result = {'train': train_list, 'val': val_list}
3945
if save_split_dir is not None:
4046
save_folder = osp.split(save_split_dir)[0]
4147
if not osp.exists(save_folder):
@@ -80,22 +86,44 @@ def preprocess(data_dict, patch_ft_dir, patch_coors_dir, num_sample=2000,
8086

8187

8288
def get_dataloader(data_dict, patch_ft_dir):
83-
pass
89+
all_ft_list = glob.glob(osp.join(patch_ft_dir, '*_fts.npy'))
8490

91+
ft_dict = {}
92+
for _dir in all_ft_list:
93+
ft_dict[get_id(_dir)] = _dir
8594

86-
class slide_patch(Dataset):
95+
SP_datasets = {phase: SlidePatch(data_dict[phase], ft_dict, data_dict['survival_time_max'])
96+
for phase in ['train', 'val']}
97+
SP_dataloaders = {phase: DataLoader(SP_datasets[phase], batch_size=1,
98+
shuffle=True, num_workers=4)
99+
for phase in ['train', 'val']}
100+
dataset_size = {phase: len(SP_datasets[phase]) for phase in ['train', 'val']}
101+
return SP_dataloaders, dataset_size
87102

88-
def __getitem__(self, index: int) -> T_co:
89-
return super().__getitem__(index)
103+
104+
class SlidePatch(Dataset):
105+
106+
def __init__(self, data_dict: dict, ft_dict, survival_time_max):
107+
super().__init__()
108+
self.st_max = float(survival_time_max)
109+
self.id_list = list(data_dict.keys())
110+
self.data_dict = data_dict
111+
self.ft_dict = ft_dict
112+
113+
def __getitem__(self, idx: int):
114+
id = self.id_list[idx]
115+
fts = torch.tensor(np.load(self.ft_dict[id])).float()
116+
st = torch.tensor(self.data_dict[id]['survival_time_max']).float()
117+
return fts, st / self.st_max
90118

91119
def __len__(self) -> int:
92-
return super().__len__()
120+
return len(self.id_list)
93121

94122

95123
def check_patch_ft(dir_list, patch_ft_dir):
96124
to_do_list = []
97125
done_list = glob.glob(osp.join(patch_ft_dir, '*_ft.npy'))
98-
done_list = [get_id(_dir).split('_ft.')[0] for _dir in done_list]
126+
done_list = [get_id(_dir) for _dir in done_list]
99127
for _dir in dir_list:
100128
id = get_id(_dir)
101129
if id not in done_list:
@@ -104,4 +132,4 @@ def check_patch_ft(dir_list, patch_ft_dir):
104132

105133

106134
def get_id(_dir):
107-
return osp.splitext(osp.split(_dir)[1])[0]
135+
return osp.splitext(osp.split(_dir)[1])[0].split('_')[0]

0 commit comments

Comments
 (0)