55import pickle
66
77import numpy as np
8+ import torch
89from extract_patch_feature import extract_ft
9- from torch .utils .data import Dataset
10+ from torch .utils .data import Dataset , DataLoader
1011from torch .utils .data .dataset import T_co
1112
1213from HyperG .utils .data import split_id
@@ -24,18 +25,23 @@ def split_train_val(data_root, ratio=0.8, save_split_dir=None, resplit=False):
2425 lbls = json .load (fp )
2526
2627 all_dict = {}
28+ survival_time_max = 0
2729 for full_dir in all_list :
28- id = get_id (full_dir )
29- all_dict [id ]['img_dir' ] = full_dir
30- all_dict [id ]['survival_time' ] = lbls [id ]
30+ _id = get_id (full_dir )
31+ all_dict [_id ]['img_dir' ] = full_dir
32+ all_dict [_id ]['survival_time' ] = lbls [_id ]
33+ survival_time_max = survival_time_max \
34+ if survival_time_max > lbls [_id ] else lbls [_id ]
3135
3236 id_list = list (all_dict .keys ())
3337 train_list , val_list = split_id (id_list , ratio )
3438
35- train_list = [all_dict [_id ] for _id in train_list ]
36- val_list = [all_dict [_id ] for _id in val_list ]
39+ result = {'survival_time_max' : survival_time_max }
40+ for _id in train_list :
41+ result ['train' ][_id ] = all_dict [_id ]
42+ for _id in val_list :
43+ result ['val' ][_id ] = all_dict [_id ]
3744
38- result = {'train' : train_list , 'val' : val_list }
3945 if save_split_dir is not None :
4046 save_folder = osp .split (save_split_dir )[0 ]
4147 if not osp .exists (save_folder ):
@@ -80,22 +86,44 @@ def preprocess(data_dict, patch_ft_dir, patch_coors_dir, num_sample=2000,
8086
8187
8288def get_dataloader (data_dict , patch_ft_dir ):
83- pass
89+ all_ft_list = glob . glob ( osp . join ( patch_ft_dir , '*_fts.npy' ))
8490
91+ ft_dict = {}
92+ for _dir in all_ft_list :
93+ ft_dict [get_id (_dir )] = _dir
8594
86- class slide_patch (Dataset ):
95+ SP_datasets = {phase : SlidePatch (data_dict [phase ], ft_dict , data_dict ['survival_time_max' ])
96+ for phase in ['train' , 'val' ]}
97+ SP_dataloaders = {phase : DataLoader (SP_datasets [phase ], batch_size = 1 ,
98+ shuffle = True , num_workers = 4 )
99+ for phase in ['train' , 'val' ]}
100+ dataset_size = {phase : len (SP_datasets [phase ]) for phase in ['train' , 'val' ]}
101+ return SP_dataloaders , dataset_size
87102
88- def __getitem__ (self , index : int ) -> T_co :
89- return super ().__getitem__ (index )
103+
104+ class SlidePatch (Dataset ):
105+
106+ def __init__ (self , data_dict : dict , ft_dict , survival_time_max ):
107+ super ().__init__ ()
108+ self .st_max = float (survival_time_max )
109+ self .id_list = list (data_dict .keys ())
110+ self .data_dict = data_dict
111+ self .ft_dict = ft_dict
112+
113+ def __getitem__ (self , idx : int ):
114+ id = self .id_list [idx ]
115+ fts = torch .tensor (np .load (self .ft_dict [id ])).float ()
116+ st = torch .tensor (self .data_dict [id ]['survival_time_max' ]).float ()
117+ return fts , st / self .st_max
90118
91119 def __len__ (self ) -> int :
92- return super (). __len__ ( )
120+ return len ( self . id_list )
93121
94122
95123def check_patch_ft (dir_list , patch_ft_dir ):
96124 to_do_list = []
97125 done_list = glob .glob (osp .join (patch_ft_dir , '*_ft.npy' ))
98- done_list = [get_id (_dir ). split ( '_ft.' )[ 0 ] for _dir in done_list ]
126+ done_list = [get_id (_dir ) for _dir in done_list ]
99127 for _dir in dir_list :
100128 id = get_id (_dir )
101129 if id not in done_list :
@@ -104,4 +132,4 @@ def check_patch_ft(dir_list, patch_ft_dir):
104132
105133
106134def get_id (_dir ):
107- return osp .splitext (osp .split (_dir )[1 ])[0 ]
135+ return osp .splitext (osp .split (_dir )[1 ])[0 ]. split ( '_' )[ 0 ]
0 commit comments