Skip to content

Commit 2430b65

Browse files
committed
writing extracting patch feature functions.
1 parent 112003e commit 2430b65

File tree

16 files changed

+419
-24
lines changed

16 files changed

+419
-24
lines changed

.idea/.gitignore

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

HyperG/utils/data/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
from .mri.io import read_mri_series, save_mri_series
21
from .data_helper import split_id

HyperG/utils/data/data_helper.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from random import shuffle
2+
import pickle
3+
import os.path as osp
24

35

46
def split_id(id_list, ratio):
@@ -7,4 +9,7 @@ def split_id(id_list, ratio):
79

810
id_train = id_list[:train_len]
911
id_val = id_list[train_len:]
10-
return id_train, id_val
12+
return id_train, id_val
13+
14+
15+

HyperG/utils/data/mri/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .io import read_mri_series, save_mri_series
2+
3+
__all__ = ['read_mri_series', 'save_mri_series']
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .overview_patches import draw_patches_on_slide
2+
from .sample_patches import sample_patch_coors
3+
4+
__all__ = ['sample_patch_coors', 'draw_patches_on_slide']
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import numpy as np
2+
import openslide
3+
from PIL import Image
4+
5+
from .sample_patches import get_just_gt_level
6+
7+
SAMPLED = 2
8+
SAMPLED_COLOR = [0, 0, 255]
9+
10+
11+
def draw_patches_on_slide(slide_dir, patch_coors, mini_frac=32):
12+
slide = openslide.open_slide(slide_dir)
13+
mini_size = np.ceil(np.array(slide.level_dimensions[0]) / mini_frac).astype(np.int)
14+
mini_level = get_just_gt_level(slide, mini_size)
15+
16+
img = slide.read_region((0, 0), mini_level, slide.level_dimensions[mini_level])
17+
img = img.resize(mini_size)
18+
19+
sampled_mask = gather_sampled_patches(patch_coors, mini_size, mini_frac)
20+
sampled_patches_img = fuse_img_mask(np.asarray(img), sampled_mask)
21+
22+
img.close()
23+
return Image.fromarray(sampled_patches_img)
24+
25+
26+
def gather_sampled_patches(patch_coors, mini_size, mini_frac):
27+
# generate sampled area mask
28+
sampled_mask = np.zeros((mini_size[1], mini_size[0]), np.uint8)
29+
for _coor in patch_coors:
30+
_mini_coor = (int(_coor[0] / mini_frac), int(_coor[1] / mini_frac))
31+
_mini_patch_size = (int(_coor[2] / mini_frac), int(_coor[3] / mini_frac))
32+
sampled_mask[_mini_coor[1]:_mini_coor[1] + _mini_patch_size[1],
33+
_mini_coor[0]:_mini_coor[0] + _mini_patch_size[0]] = SAMPLED
34+
sampled_mask = np.asarray(Image.fromarray(sampled_mask).resize(mini_size))
35+
36+
return sampled_mask
37+
38+
39+
def fuse_img_mask(img: np.array, mask: np.array, alpha=0.7):
40+
assert img.shape == mask.shape
41+
img = img.copy()
42+
if (mask != 0).any():
43+
img[mask != 0] = alpha * img[mask != 0] + \
44+
(1 - alpha) * np.array(SAMPLED_COLOR)
45+
return Image.fromarray(img)

HyperG/utils/data/pathology/patch_funs.py

Lines changed: 0 additions & 9 deletions
This file was deleted.
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import os.path as osp
2+
from itertools import product
3+
from random import shuffle
4+
5+
import numpy as np
6+
import openslide
7+
import tqdm
8+
from scipy import ndimage
9+
from skimage.filters import threshold_otsu
10+
from skimage.morphology import dilation, star
11+
12+
BACKGROUND = 0
13+
FOREGROUND = 1
14+
15+
16+
def sample_patch_coors(slide_dir, num_sample=2000, patch_size=256):
17+
slide = openslide.open_slide(slide_dir)
18+
slide_name = osp.basename(slide_dir)
19+
slide_name = slide_name[:slide_name.rfind('.')]
20+
21+
mini_frac = 32
22+
mini_size = np.ceil(np.array(slide.level_dimensions[0]) / mini_frac).astype(np.int)
23+
mini_level = get_just_gt_level(slide, mini_size)
24+
mini_patch_size = patch_size // mini_frac
25+
26+
bg_mask = generate_background_mask(slide, mini_level, mini_size)
27+
assert bg_mask.shape == (mini_size[1], mini_size[0])
28+
29+
# extract patches from available area
30+
patch_coors = []
31+
num_row, num_col = bg_mask.shape
32+
num_row = num_row - mini_patch_size
33+
num_col = num_col - mini_patch_size
34+
35+
row_col = list(product(range(num_row), range(num_col)))
36+
shuffle(row_col)
37+
cnt = 0
38+
39+
# attention center
40+
H_min = int(np.ceil(mini_patch_size / 8))
41+
H_max = int(np.ceil(mini_patch_size / 8 * 7))
42+
W_min = int(np.ceil(mini_patch_size / 8))
43+
W_max = int(np.ceil(mini_patch_size / 8 * 7))
44+
# half of the center
45+
th_num = int(np.ceil((mini_patch_size * 3 / 4 * mini_patch_size * 3 / 4)))
46+
47+
for row, col in tqdm(row_col):
48+
if cnt >= num_sample:
49+
break
50+
mini_patch = bg_mask[row:row + mini_patch_size, col: col + mini_patch_size]
51+
origin = (int(col * mini_frac), int(row * mini_frac), patch_size, patch_size)
52+
if np.count_nonzero(mini_patch[H_min:H_max, W_min:W_max]) >= th_num:
53+
# # filter those white background
54+
# if is_bg(slide, origin, patch_size):
55+
# continue
56+
patch_coors.append(origin)
57+
cnt += 1
58+
59+
return patch_coors
60+
61+
62+
# get the just size that equal to mask_size
63+
def get_just_gt_level(slide: openslide, size):
64+
level = slide.level_count - 1
65+
while level >= 0 and slide.level_dimensions[level][0] < size[0] and \
66+
slide.level_dimensions[level][1] < size[1]:
67+
level -= 1
68+
return level
69+
70+
71+
def generate_background_mask(slide: openslide, mini_level, mini_size):
72+
img = slide.read_region((0, 0), mini_level, slide.level_dimensions[mini_level])
73+
img = img.resize(mini_size)
74+
bg_mask = threshold_segmentation(img)
75+
img.close()
76+
return bg_mask
77+
78+
79+
# background segmentation algorithm
80+
def threshold_segmentation(img):
81+
# calculate the overview level size and retrieve the image
82+
img_hsv = img.convert('HSV')
83+
img_hsv_np = np.array(img_hsv)
84+
85+
# dilate image and then threshold the image
86+
schannel = img_hsv_np[:, :, 1]
87+
mask = np.zeros(schannel.shape)
88+
89+
schannel = dilation(schannel, star(3))
90+
schannel = ndimage.gaussian_filter(schannel, sigma=(5, 5), order=0)
91+
threshold_global = threshold_otsu(schannel)
92+
93+
mask[schannel > threshold_global] = FOREGROUND
94+
mask[schannel <= threshold_global] = BACKGROUND
95+
96+
return mask
97+
98+
99+
def is_bg(slide, origin, patch_size):
100+
img = slide.read_region(origin, 0, (patch_size, patch_size))
101+
# bad case is background
102+
if np.array(img)[:, :, 1].mean() > 200: # is bg
103+
img.close()
104+
return True
105+
else:
106+
img.close()
107+
return False

examples/clssification/breast_pathology/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
batch_size = 1
2323

2424
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25-
save_dir = osp.join(result_root, 'model_save.pth')
25+
model_save_dir = osp.join(result_root, 'model_best.pth')
2626

2727
# check directions
2828
assert check_dir(data_root, False)
@@ -115,4 +115,4 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
115115
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
116116

117117
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=20)
118-
torch.save(model_ft.cpu().state_dict(), save_dir)
118+
torch.save(model_ft.cpu().state_dict(), model_save_dir)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import glob
2+
import json
3+
import os
4+
import os.path as osp
5+
import pickle
6+
7+
import numpy as np
8+
from extract_patch_feature import extract_ft
9+
from torch.utils.data import Dataset
10+
from torch.utils.data.dataset import T_co
11+
12+
from HyperG.utils.data import split_id
13+
from HyperG.utils.data.pathology import sample_patch_coors, draw_patches_on_slide
14+
15+
16+
def split_train_val(data_root, ratio=0.8, save_split_dir=None, resplit=False):
17+
if not resplit and save_split_dir is not None and osp.exists(save_split_dir):
18+
with open(save_split_dir, 'rb') as f:
19+
result = pickle.load(f)
20+
return result
21+
22+
all_list = glob.glob(osp.join(data_root, '*.svs'))
23+
with open(osp.join(data_root, 'opti_survival.json'), 'r') as fp:
24+
lbls = json.load(fp)
25+
26+
all_dict = {}
27+
for full_dir in all_list:
28+
id = get_id(full_dir)
29+
all_dict[id]['img_dir'] = full_dir
30+
all_dict[id]['survival_time'] = lbls[id]
31+
32+
id_list = list(all_dict.keys())
33+
train_list, val_list = split_id(id_list, ratio)
34+
35+
train_list = [all_dict[_id] for _id in train_list]
36+
val_list = [all_dict[_id] for _id in val_list]
37+
38+
result = {'train': train_list, 'val': val_list}
39+
if save_split_dir is not None:
40+
save_folder = osp.split(save_split_dir)[0]
41+
if not osp.exists(save_folder):
42+
os.makedirs(save_folder)
43+
with open(save_split_dir, 'wb') as f:
44+
pickle.dump(result, f)
45+
46+
return result
47+
48+
49+
def preprocess(data_dict, patch_ft_dir, patch_coors_dir, num_sample=2000,
50+
patch_size=256, sampled_vis=None, mini_frac=32):
51+
# check if each slide patch feature exists
52+
all_dir_list = []
53+
for phase in ['train', 'val']:
54+
for _dir in data_dict[phase]:
55+
all_dir_list.append(_dir['img_dir'])
56+
to_do_list = check_patch_ft(all_dir_list, patch_ft_dir)
57+
58+
if to_do_list is not None:
59+
for _idx, _dir in enumerate(to_do_list):
60+
print(f'processing {_idx + 1}/{len(to_do_list)}...')
61+
_id = get_id(_dir)
62+
_patch_coors = sample_patch_coors(_dir, num_sample=2000, patch_size=256)
63+
64+
# save sampled patch coordinates
65+
with open(osp.join(patch_coors_dir, f'{_id}_coors.pkl')) as fp:
66+
pickle.dump(_patch_coors, fp)
67+
68+
# visualize sampled patches on slide
69+
if sampled_vis is not None:
70+
_vis_img = draw_patches_on_slide(_dir, _patch_coors, mini_frac=32)
71+
with open(osp.join(sampled_vis, f'{_id}_sampled_patches.jpg')) as fp:
72+
_vis_img.save(fp)
73+
74+
# extract patch feature for each slide
75+
for _dir in all_dir_list:
76+
_id = get_id(_dir)
77+
_patch_coors = None
78+
fts = extract_ft(_dir, _patch_coors)
79+
np.save(osp.join(patch_ft_dir, f'{_id}_fts.npy'), fts.cpu().numpy())
80+
81+
82+
def get_dataloader(data_dict, patch_ft_dir):
83+
pass
84+
85+
86+
class slide_patch(Dataset):
87+
88+
def __getitem__(self, index: int) -> T_co:
89+
return super().__getitem__(index)
90+
91+
def __len__(self) -> int:
92+
return super().__len__()
93+
94+
95+
def check_patch_ft(dir_list, patch_ft_dir):
96+
to_do_list = []
97+
done_list = glob.glob(osp.join(patch_ft_dir, '*_ft.npy'))
98+
done_list = [get_id(_dir).split('_ft.')[0] for _dir in done_list]
99+
for _dir in dir_list:
100+
id = get_id(_dir)
101+
if id not in done_list:
102+
to_do_list.append(_dir)
103+
return to_do_list
104+
105+
106+
def get_id(_dir):
107+
return osp.splitext(osp.split(_dir)[1])[0]

0 commit comments

Comments
 (0)