Skip to content

Commit 4736d5f

Browse files
matt3opre-commit-ci[bot]diazandr3s
authored
Add SW_FastEdit code (#1689)
* Add SW_FastEdit code inference code Signed-off-by: Matthias Hadlich <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Matthias Hadlich <[email protected]> * Fix pre-commit issues Signed-off-by: Matthias Hadlich <[email protected]> * Fix typing issues Signed-off-by: Matthias Hadlich <[email protected]> * Update model name/path Signed-off-by: Andres Diaz-Pinto <[email protected]> --------- Signed-off-by: Matthias Hadlich <[email protected]> Signed-off-by: Andres Diaz-Pinto <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Andres Diaz-Pinto <[email protected]>
1 parent e31eaee commit 4736d5f

File tree

4 files changed

+512
-1
lines changed

4 files changed

+512
-1
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import logging
13+
import os
14+
from typing import Any, Dict, Optional, Union
15+
16+
import lib.infers
17+
import lib.trainers
18+
from monai.networks.nets.dynunet import DynUNet
19+
20+
from monailabel.interfaces.config import TaskConfig
21+
from monailabel.interfaces.tasks.infer_v2 import InferTask
22+
from monailabel.interfaces.tasks.train import TrainTask
23+
from monailabel.utils.others.generic import download_file, strtobool
24+
25+
logger = logging.getLogger(__name__)
26+
27+
28+
class SWFastEditConfig(TaskConfig):
29+
def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **kwargs):
30+
super().init(name, model_dir, conf, planner, **kwargs)
31+
32+
# Labels
33+
self.labels = [
34+
"tumor",
35+
"background",
36+
]
37+
38+
self.label_names = {label: self.labels.index(label) for label in self.labels}
39+
print(self.label_names)
40+
# Model Files
41+
self.path = [
42+
os.path.join(self.model_dir, f"pretrained_{name}.pt"), # pretrained
43+
os.path.join(self.model_dir, f"{name}.pt"), # published
44+
]
45+
46+
# Download PreTrained Model
47+
# Model is pretrained on PET scans from the AutoPET dataset
48+
if strtobool(self.conf.get("use_pretrained_model", "true")):
49+
url = f"{self.conf.get('pretrained_path', self.PRE_TRAINED_PATH)}"
50+
url = f"{url}/radiology_segmentation_sw_fastedit_pet.pt"
51+
print(f"Downloading from {self.path[0]}")
52+
download_file(url, self.path[0])
53+
54+
# Network
55+
self.network = DynUNet(
56+
spatial_dims=3,
57+
# 1 dim for the image, the other ones for the signal per label with is the size of image
58+
in_channels=1 + len(self.labels),
59+
out_channels=len(self.labels),
60+
kernel_size=[3, 3, 3, 3, 3, 3],
61+
strides=[1, 2, 2, 2, 2, [2, 2, 1]],
62+
upsample_kernel_size=[2, 2, 2, 2, [2, 2, 1]],
63+
norm_name="instance",
64+
deep_supervision=False,
65+
res_block=True,
66+
)
67+
68+
AUTOPET_SPACING = (2.03642011, 2.03642011, 3.0)
69+
self.target_spacing = AUTOPET_SPACING # AutoPET default
70+
71+
def infer(self) -> Union[InferTask, Dict[str, InferTask]]:
72+
inferer = lib.infers.SWFastEdit(
73+
path=self.path,
74+
network=self.network,
75+
labels=self.labels,
76+
label_names=self.label_names,
77+
preload=strtobool(self.conf.get("preload", "false")),
78+
config={"cache_transforms": True, "cache_transforms_in_memory": True, "cache_transforms_ttl": 1200},
79+
target_spacing=self.target_spacing,
80+
)
81+
# Reenable this for the Auto Segmentation support
82+
# seg_inferer = lib.infers.SWFastEdit(
83+
# path=self.path,
84+
# network=self.network,
85+
# labels=self.labels,
86+
# label_names=self.label_names,
87+
# preload=strtobool(self.conf.get("preload", "false")),
88+
# target_spacing=self.target_spacing,
89+
# type=InferType.SEGMENTATION,
90+
# )
91+
92+
return {
93+
self.name: inferer,
94+
# f"{self.name}_seg": seg_inferer,
95+
}
96+
# return task
97+
98+
def trainer(self) -> Optional[TrainTask]:
99+
return None

sample-apps/radiology/lib/infers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
from .segmentation import Segmentation
1717
from .segmentation_spleen import SegmentationSpleen
1818
from .segmentation_vertebra import SegmentationVertebra
19+
from .sw_fastedit import SWFastEdit
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import json
13+
import logging
14+
import os
15+
import pathlib
16+
import shutil
17+
from pathlib import Path
18+
from typing import Callable, Sequence, Union
19+
20+
import nibabel as nib
21+
import numpy as np
22+
import torch
23+
from lib.transforms.transforms import AddEmptySignalChannels, AddGuidanceSignal
24+
from monai.inferers import Inferer, SlidingWindowInferer
25+
from monai.transforms import (
26+
Activationsd,
27+
AsDiscreted,
28+
CenterSpatialCropd,
29+
EnsureChannelFirstd,
30+
EnsureTyped,
31+
Identityd,
32+
LoadImaged,
33+
Orientationd,
34+
ScaleIntensityRangePercentilesd,
35+
SignalFillEmptyd,
36+
Spacingd,
37+
SqueezeDimd,
38+
)
39+
from monai.utils import set_determinism
40+
41+
from monailabel.interfaces.tasks.infer_v2 import InferType
42+
from monailabel.tasks.infer.basic_infer import BasicInferTask, CallBackTypes
43+
44+
# monai_version = pkg_resources.get_distribution("monai").version
45+
# if not pkg_resources.parse_version(monai_version) >= pkg_resources.parse_version("1.3.0"):
46+
# raise UserWarning("This code needs at least MONAI 1.3.0")
47+
48+
49+
logger = logging.getLogger(__name__)
50+
51+
52+
class SWFastEdit(BasicInferTask):
53+
54+
def __init__(
55+
self,
56+
path,
57+
network=None,
58+
type=InferType.DEEPEDIT,
59+
labels=None,
60+
label_names=None,
61+
dimension=3,
62+
target_spacing=(2.03642011, 2.03642011, 3.0),
63+
description="",
64+
**kwargs,
65+
):
66+
super().__init__(
67+
path=path,
68+
network=network,
69+
type=type,
70+
labels=labels,
71+
dimension=dimension,
72+
description=description,
73+
**kwargs,
74+
)
75+
self.label_names = label_names
76+
self.target_spacing = target_spacing
77+
78+
set_determinism(42)
79+
self.model_state_dict = "net"
80+
self.load_strict = True
81+
self._amp = True
82+
# Either no crop with None or crop like (128,128,128), sliding window does not need this parameter unless
83+
# too much memory is used for the stitching of the output windows
84+
self.val_crop_size = None
85+
86+
# Inferer parameters
87+
# Increase the overlap for up to 1% more Dice, however the time and memory consumption increase a lot!
88+
self.sw_overlap = 0.25
89+
# Should be the same ROI size as it was trained on
90+
self.sw_roi_size = (128, 128, 128)
91+
92+
# Reduce this if you run into OOMs
93+
self.train_sw_batch_size = 8
94+
# Reduce this if you run into OOMs
95+
self.val_sw_batch_size = 16
96+
97+
def __call__(self, request, callbacks=None):
98+
if callbacks is None:
99+
callbacks = {}
100+
callbacks[CallBackTypes.POST_TRANSFORMS] = post_callback
101+
102+
return super().__call__(request, callbacks)
103+
104+
def pre_transforms(self, data=None) -> Sequence[Callable]:
105+
# print("#########################################")
106+
# data['label_dict'] = self.label_names
107+
data["label_names"] = self.label_names
108+
109+
# Make sure the click keys already exist
110+
for label in self.label_names:
111+
if label not in data:
112+
data[label] = []
113+
# data['click_path'] = self.click_path
114+
115+
cpu_device = torch.device("cpu")
116+
device = data.get("device") if data else None
117+
loglevel = logging.DEBUG
118+
input_keys = "image"
119+
120+
t = []
121+
t_val_1 = [
122+
LoadImaged(keys=input_keys, reader="ITKReader", image_only=False),
123+
EnsureChannelFirstd(keys=input_keys),
124+
ScaleIntensityRangePercentilesd(
125+
keys="image", lower=0.05, upper=99.95, b_min=0.0, b_max=1.0, clip=True, relative=False
126+
),
127+
# ScaleIntensityRanged(keys="image", a_min=0, a_max=43, b_min=0.0, b_max=1.0, clip=True),
128+
SignalFillEmptyd(keys=input_keys),
129+
]
130+
t.extend(t_val_1)
131+
# self.add_cache_transform(t, data)
132+
t_val_2 = [
133+
AddEmptySignalChannels(keys=input_keys, device=device),
134+
AddGuidanceSignal(
135+
keys=input_keys,
136+
sigma=1,
137+
disks=True,
138+
device=device,
139+
),
140+
Orientationd(keys=input_keys, axcodes="RAS"),
141+
Spacingd(keys=input_keys, pixdim=self.target_spacing),
142+
(
143+
CenterSpatialCropd(keys=input_keys, roi_size=self.val_crop_size)
144+
if self.val_crop_size is not None
145+
else Identityd(keys=input_keys, allow_missing_keys=True)
146+
),
147+
EnsureTyped(keys=input_keys, device=device),
148+
]
149+
t.extend(t_val_2)
150+
return t
151+
152+
def inferer(self, data=None) -> Inferer:
153+
sw_params = {
154+
"roi_size": self.sw_roi_size,
155+
"mode": "gaussian",
156+
"cache_roi_weight_map": False,
157+
"overlap": self.sw_overlap,
158+
}
159+
eval_inferer = SlidingWindowInferer(sw_batch_size=self.val_sw_batch_size, **sw_params)
160+
return eval_inferer
161+
162+
def inverse_transforms(self, data=None) -> Union[None, Sequence[Callable]]:
163+
return [] # Self-determine from the list of pre-transforms provided
164+
165+
def post_transforms(self, data=None) -> Sequence[Callable]:
166+
device = data.get("device") if data else None
167+
return [
168+
EnsureTyped(keys="pred", device=device),
169+
Activationsd(keys="pred", softmax=True),
170+
AsDiscreted(keys="pred", argmax=True),
171+
SqueezeDimd(keys="pred", dim=0),
172+
EnsureTyped(keys="pred", device="cpu" if data else None, dtype=torch.uint8),
173+
]
174+
175+
176+
def post_callback(data):
177+
"""
178+
Saves clicks in the same folder where the created labels are stored.
179+
Can also help debugging by providing a way of saving nifti files.
180+
"""
181+
image_name = Path(os.path.basename(data["image_path"]))
182+
true_image_name = image_name.name.removesuffix("".join(image_name.suffixes))
183+
image_folder = Path(data["image_path"]).parent
184+
185+
labels_folder = os.path.join(image_folder, "labels", "final")
186+
if not os.path.exists(labels_folder):
187+
print(f"##### Creating {labels_folder}")
188+
pathlib.Path(labels_folder).mkdir(parents=True)
189+
190+
# Save the clicks
191+
clicks_per_label = {}
192+
for key in data["label_names"].keys():
193+
clicks_per_label[key] = data[key]
194+
assert isinstance(data[key], list)
195+
196+
click_file_path = os.path.join(labels_folder, f"{true_image_name}_clicks.json")
197+
logger.info(f"Now dumping dict: {clicks_per_label} to file {click_file_path} ...")
198+
with open(click_file_path, "w") as clicks_file:
199+
json.dump(clicks_per_label, clicks_file)
200+
201+
# Save debug NIFTI, not fully working since the inverse transform of the image is not avaible
202+
if False:
203+
logger.info("SAVING NIFTI")
204+
inputs = data["image"]
205+
pred = data["pred"]
206+
logger.info(f"inputs.shape is {inputs.shape}")
207+
logger.info(f"sum of fgg is {torch.sum(inputs[1])}")
208+
logger.info(f"sum of bgg is {torch.sum(inputs[2])}")
209+
logger.info(f"Image path is {data['image_path']}, copying file")
210+
shutil.copyfile(data["image_path"], f"{path}/im.nii.gz")
211+
# save_nifti(f"{path}/im", inputs[0].cpu().detach().numpy())
212+
save_nifti(f"{path}/guidance_fgg", inputs[1].cpu().detach().numpy())
213+
save_nifti(f"{path}/guidance_bgg", inputs[2].cpu().detach().numpy())
214+
logger.info(f"pred.shape is {pred.shape}")
215+
save_nifti(f"{path}/pred", pred.cpu().detach().numpy())
216+
return data
217+
218+
219+
def save_nifti(name, im):
220+
"""ONLY FOR DEBUGGING"""
221+
affine = np.eye(4)
222+
affine[0][0] = -1
223+
ni_img = nib.Nifti1Image(im, affine=affine)
224+
ni_img.header.get_xyzt_units()
225+
ni_img.to_filename(f"{name}.nii.gz")

0 commit comments

Comments
 (0)