Skip to content

Commit d73ce66

Browse files
committed
fix: update pannuke dataset
1 parent e2840c6 commit d73ce66

File tree

2 files changed

+392
-0
lines changed

2 files changed

+392
-0
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .pannuke import Pannuke
2+
3+
__all__ = [
4+
"Pannuke",
5+
]
Lines changed: 387 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,387 @@
1+
import shutil
2+
from pathlib import Path
3+
from typing import Dict
4+
5+
import numpy as np
6+
from tqdm import tqdm
7+
8+
from cellseg_models_pytorch.utils import (
9+
Downloader,
10+
FileHandler,
11+
H5Handler,
12+
fix_duplicates,
13+
)
14+
15+
try:
16+
import tables as tb
17+
18+
_has_tb = True
19+
except ModuleNotFoundError:
20+
_has_tb = False
21+
22+
__all__ = ["Pannuke"]
23+
24+
25+
class Pannuke:
26+
def __init__(
27+
self, save_dir: str, fold_split: Dict[str, str], verbose: bool = False
28+
) -> None:
29+
"""Pannuke dataset class."""
30+
self.save_dir = Path(save_dir)
31+
self.fold_split = fold_split
32+
self.verbose = verbose
33+
34+
allowed_splits = ("train", "valid", "test")
35+
if not all([k in allowed_splits for k in fold_split.values()]):
36+
raise ValueError(
37+
f"`fold_split` values need to be in {allowed_splits}. "
38+
f"Got: {list(fold_split.values())}"
39+
)
40+
41+
self.has_downloaded = self._check_if_downloaded()
42+
self.has_prepared_folders = self._check_if_folders_prepared()
43+
self.has_prepared_h5 = self._check_if_h5_prepared()
44+
45+
@property
46+
def train_image_dir(self) -> Path:
47+
"""Path to the train directory."""
48+
return self._get_fold_dir("train", is_mask=False)
49+
50+
@property
51+
def valid_image_dir(self) -> Path:
52+
"""Path to the train directory."""
53+
return self._get_fold_dir("valid", is_mask=False)
54+
55+
@property
56+
def test_image_dir(self) -> Path:
57+
"""Path to the train directory."""
58+
return self._get_fold_dir("test", is_mask=False)
59+
60+
@property
61+
def train_label_dir(self) -> Path:
62+
"""Path to the train directory."""
63+
return self._get_fold_dir("train", is_mask=True)
64+
65+
@property
66+
def valid_label_dir(self) -> Path:
67+
"""Path to the train directory."""
68+
return self._get_fold_dir("valid", is_mask=True)
69+
70+
@property
71+
def test_label_dir(self) -> Path:
72+
"""Path to the train directory."""
73+
return self._get_fold_dir("test", is_mask=True)
74+
75+
@property
76+
def train_h5_file(self) -> Path:
77+
"""Path to the train directory."""
78+
return self._get_fold_h5("train")
79+
80+
@property
81+
def valid_h5_file(self) -> Path:
82+
"""Path to the train directory."""
83+
return self._get_fold_h5("valid")
84+
85+
@property
86+
def test_h5_file(self) -> Path:
87+
"""Path to the train directory."""
88+
return self._get_fold_h5("test")
89+
90+
@property
91+
def type_classes(self) -> Dict[int, str]:
92+
"""Pannuke cell type classes."""
93+
return {
94+
0: "bg",
95+
1: "neoplastic",
96+
2: "inflammatory",
97+
3: "connective",
98+
4: "dead",
99+
5: "epithelial",
100+
}
101+
102+
def download(self, root: str) -> None:
103+
"""Download the pannuke dataset from online."""
104+
# create save_dir
105+
self.save_dir.mkdir(parents=True, exist_ok=True)
106+
107+
# init downloader
108+
downloader = Downloader(self.save_dir)
109+
for ix in [1, 2, 3]:
110+
url = f"https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke/fold_{ix}.zip"
111+
downloader.download(url)
112+
FileHandler.extract_zips_in_folder(root, rm=True)
113+
114+
def prepare_data(self, rm_orig: bool = False, to_h5: bool = False) -> None:
115+
"""Prepare the pannuke datasets.
116+
117+
1. Download pannuke folds from:
118+
"https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke/"
119+
2. Pre-process and split the images and masks into train, valid and test sets.
120+
121+
Parameters:
122+
rm_orig (bool, default=False):
123+
After processing all the files, If True, removes the original
124+
un-processed files.
125+
to_h5 (bool, default=False):
126+
If True, saves the processed images and masks in one HDF5 file.
127+
"""
128+
if not self.has_downloaded:
129+
if self.verbose:
130+
print(f"Downloading three Pannuke folds to {self.save_dir}")
131+
self.download(self.save_dir)
132+
133+
if not self.has_prepared_folders:
134+
fold_paths = self._get_fold_paths(self.save_dir)
135+
for fold, phase in self.fold_split.items():
136+
save_im_dir = self.save_dir / phase / "images"
137+
save_mask_dir = self.save_dir / phase / "labels"
138+
139+
self._prepare_data(
140+
fold_paths, fold, phase, save_im_dir, save_mask_dir, h5path=None
141+
)
142+
if not self.has_prepared_h5 and to_h5:
143+
fold_paths = self._get_fold_paths(self.save_dir)
144+
for fold, phase in self.fold_split.items():
145+
h5path = self.save_dir / f"{phase}.h5"
146+
self._prepare_data(
147+
fold_paths,
148+
fold,
149+
phase,
150+
save_im_dir=None,
151+
save_mask_dir=None,
152+
h5path=h5path,
153+
)
154+
else:
155+
print(
156+
"Found pre-processed Pannuke data. If in need of a re-download, please empty the `save_dir` folder."
157+
)
158+
159+
if rm_orig:
160+
for d in self.save_dir.iterdir():
161+
if "fold" in d.name.lower():
162+
shutil.rmtree(d)
163+
164+
def _prepare_data(
165+
self,
166+
fold_paths,
167+
fold,
168+
phase,
169+
save_im_dir: Path,
170+
save_mask_dir: Path,
171+
h5path: Path,
172+
):
173+
# determine fold number
174+
if isinstance(fold, int):
175+
fold_ix = fold
176+
else:
177+
fold_ix = int(fold[-1])
178+
179+
self._process_pannuke_fold(
180+
fold_paths, fold_ix, phase, save_im_dir, save_mask_dir, h5path
181+
)
182+
183+
def _check_if_downloaded(self) -> bool:
184+
# check if the pannuke data has been downloaded
185+
if self.save_dir.exists() and self.save_dir.is_dir():
186+
folds_found = [
187+
d.name
188+
for d in self.save_dir.iterdir()
189+
if "fold" in d.name.lower() and d.is_dir()
190+
]
191+
if len(folds_found) == 3:
192+
if self.verbose:
193+
print(
194+
f"Found all Pannuke folds {folds_found} inside {self.save_dir}."
195+
)
196+
return True
197+
return False
198+
199+
def _check_if_folders_prepared(self) -> True:
200+
# check if the pannuke data has been processed
201+
if self.save_dir.exists() and self.save_dir.is_dir():
202+
phases_found = [
203+
d.name
204+
for d in self.save_dir.iterdir()
205+
if d.name in ("train", "test", "valid") and d.is_dir()
206+
]
207+
if phases_found:
208+
if self.verbose:
209+
print(
210+
f"Found processed Pannuke data saved in {phases_found} folders "
211+
f"inside {self.save_dir}."
212+
)
213+
return True
214+
return False
215+
216+
def _check_if_h5_prepared(self) -> bool:
217+
# check if the pannuke data has been processed
218+
if self.save_dir.exists() and self.save_dir.is_dir():
219+
h5_found = [d.name for d in self.save_dir.iterdir() if d.suffix == ".h5"]
220+
if h5_found:
221+
if self.verbose:
222+
print(
223+
f"Found processed Pannuke data saved in {h5_found} hdf5 files "
224+
f"inside {self.save_dir}."
225+
)
226+
return True
227+
return False
228+
229+
def _get_fold_paths(self, path: Path) -> Dict[str, Path]:
230+
"""Get the paths to the .npy files in all of the fold folders."""
231+
return {
232+
f"{file.parts[-2]}_{file.name[:-4]}": file
233+
for dir1 in path.iterdir()
234+
if dir1.is_dir()
235+
for dir2 in dir1.iterdir()
236+
if dir2.is_dir()
237+
for dir3 in dir2.iterdir()
238+
if dir3.is_dir()
239+
for file in dir3.iterdir()
240+
if file.is_file() and file.suffix == ".npy"
241+
}
242+
243+
def _process_pannuke_fold(
244+
self,
245+
fold_paths: Dict[str, Path],
246+
fold: int,
247+
phase: str,
248+
save_im_dir: Path = None,
249+
save_mask_dir: Path = None,
250+
h5path: Path = None,
251+
) -> None:
252+
"""Save the pannuke patches .mat files in 'train', 'valid' & 'test' folders."""
253+
if h5path is not None:
254+
if not _has_tb:
255+
raise ModuleNotFoundError(
256+
"Please install `tables` to save the data in HDF5 format."
257+
)
258+
h5handler = H5Handler()
259+
patch_size = (256, 256)
260+
261+
h5 = tb.open_file(h5path, "w")
262+
try:
263+
h5handler.init_img(h5, patch_size)
264+
h5handler.init_mask(h5, "inst", patch_size)
265+
h5handler.init_mask(h5, "type", patch_size)
266+
h5handler.init_meta_data(h5)
267+
except Exception as e:
268+
h5.close()
269+
raise e
270+
else:
271+
# Create directories for the files.
272+
Path(save_im_dir).mkdir(parents=True, exist_ok=True)
273+
Path(save_mask_dir).mkdir(parents=True, exist_ok=True)
274+
275+
masks = np.load(fold_paths[f"fold{fold}_masks"]).astype("int32")
276+
imgs = np.load(fold_paths[f"fold{fold}_images"]).astype("uint8")
277+
types = np.load(fold_paths[f"fold{fold}_types"])
278+
279+
with tqdm(total=len(types)) as pbar:
280+
pbar.set_description(f"fold{fold}/{phase}")
281+
for tissue_type in np.unique(types):
282+
imgs_by_type = imgs[types == tissue_type]
283+
masks_by_type = masks[types == tissue_type]
284+
for j in range(imgs_by_type.shape[0]):
285+
im = imgs_by_type[j, ...]
286+
temp_mask = masks_by_type[j, ...]
287+
type_map = self._get_type_map(temp_mask)
288+
inst_map = self._get_inst_map(temp_mask[..., 0:5])
289+
name = f"{tissue_type}_fold{fold}_{j}"
290+
291+
if h5path is not None:
292+
try:
293+
h5handler.append_array(h5, im[None, ...], "image")
294+
h5handler.append_array(h5, inst_map[None, ...], "inst")
295+
h5handler.append_array(h5, type_map[None, ...], "type")
296+
h5handler.append_meta_data(
297+
h5, name, coords=(0, 0, 256, 256)
298+
)
299+
except Exception as e:
300+
h5.close()
301+
raise e
302+
else:
303+
fn_im = Path(save_im_dir / name).with_suffix(".png")
304+
FileHandler.write_img(fn_im, im)
305+
306+
fn_mask = Path(save_mask_dir / name).with_suffix(".mat")
307+
FileHandler.to_mat(
308+
masks={
309+
"inst": inst_map,
310+
"type": type_map,
311+
},
312+
path=fn_mask,
313+
)
314+
pbar.update(1)
315+
316+
if h5path is not None:
317+
h5.close()
318+
319+
def _get_type_map(self, pannuke_mask: np.ndarray) -> np.ndarray:
320+
"""Convert the pannuke mask to type map of shape (H, W)."""
321+
mask = pannuke_mask.copy()
322+
mask[mask > 0] = 1
323+
324+
# init type_map and set the background channel
325+
# of the pannuke mask as the first channel
326+
type_map = np.zeros_like(mask)
327+
type_map[..., 0] = mask[..., -1]
328+
for i, j in enumerate(range(1, mask.shape[-1])):
329+
type_map[..., j] = mask[..., i]
330+
331+
return np.argmax(type_map, axis=-1)
332+
333+
def _get_inst_map(self, pannuke_mask: np.ndarray) -> np.ndarray:
334+
"""Convert pannuke mask to inst_map of shape (H, W)."""
335+
mask = pannuke_mask.copy()
336+
337+
inst_map = np.zeros(mask.shape[:2], dtype="int32")
338+
for i in range(mask.shape[-1]):
339+
insts = mask[..., i]
340+
inst_ids = np.unique(insts)[1:]
341+
for inst_id in inst_ids:
342+
inst = np.array(insts == inst_id, np.uint8)
343+
inst_map[inst > 0] += inst_id
344+
345+
# fix duplicated instances
346+
inst_map = fix_duplicates(inst_map)
347+
return inst_map
348+
349+
def _check_fold_exists(self, fold: str) -> bool:
350+
"""Check if the fold exists in the save_dir."""
351+
for d in self.save_dir.iterdir():
352+
if d.is_dir():
353+
if fold in d.name:
354+
return True
355+
return False
356+
357+
def _check_file_exists(self, fold: str) -> bool:
358+
"""Check if the fold exists in the save_dir."""
359+
for d in self.save_dir.iterdir():
360+
if d.is_file():
361+
if fold in d.name:
362+
return True
363+
return False
364+
365+
def _get_fold_dir(self, fold: str, is_mask: bool = False) -> Path:
366+
dir_name = "images"
367+
if is_mask:
368+
dir_name = "labels"
369+
370+
if self._check_fold_exists(fold):
371+
return self.save_dir / f"{fold}/{dir_name}"
372+
else:
373+
raise ValueError(
374+
f"No '{fold}' directory found in {self.save_dir}. Expected a directory "
375+
f"named '{fold}/{dir_name}'. Run `prepare_data()` to create the data. "
376+
"and make sure that the fold exists in the `fold_split`."
377+
)
378+
379+
def _get_fold_h5(self, fold: str) -> Path:
380+
if self._check_file_exists(fold):
381+
return self.save_dir / f"{fold}.h5"
382+
else:
383+
raise ValueError(
384+
f"No '{fold}' HDF5 file found in {self.save_dir}. Expected a file "
385+
f"named '{fold}.h5'. Run `prepare_data(to_h5=True)` to create the data. "
386+
"and make sure that the fold exists in the `fold_split`."
387+
)

0 commit comments

Comments
 (0)