Skip to content

Commit 6007e0c

Browse files
Fix and Improve performance for pathology models (#1158)
* Fix and Improve performance for pathology models Signed-off-by: Sachidanand Alle <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix dependency Signed-off-by: Sachidanand Alle <[email protected]> * sync up nuclick changes Signed-off-by: Sachidanand Alle <[email protected]> Signed-off-by: Sachidanand Alle <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f8ebb38 commit 6007e0c

File tree

12 files changed

+883
-417
lines changed

12 files changed

+883
-417
lines changed

monailabel/transform/post.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from monai.config import KeysCollection, NdarrayOrTensor
2020
from monai.data import MetaTensor
2121
from monai.transforms import MapTransform, Resize, Transform, generate_spatial_bounding_box, get_extreme_points
22-
from monai.utils import InterpolateMode, ensure_tuple_rep
22+
from monai.utils import InterpolateMode, convert_to_numpy, ensure_tuple_rep
2323
from shapely.geometry import Point, Polygon
2424
from torchvision.utils import make_grid, save_image
2525

@@ -176,7 +176,7 @@ def __call__(self, data):
176176
color_map = d.get(self.key_label_colors) if self.colormap is None else self.colormap
177177

178178
foreground_points = d.get(self.key_foreground_points, []) if self.key_foreground_points else []
179-
foreground_points = [Point(pt[1], pt[0]) for pt in foreground_points] # polygons in (y, x) format
179+
foreground_points = [Point(pt[0], pt[1]) for pt in foreground_points] # polygons in (x, y) format
180180

181181
elements = []
182182
label_names = set()
@@ -188,8 +188,9 @@ def __call__(self, data):
188188
labels = [label for label in np.unique(p).tolist() if label > 0]
189189
logger.debug(f"Total Unique Masks (excluding background): {labels}")
190190
for label_idx in labels:
191-
p = d[key].array if isinstance(d[key], MetaTensor) else d[key]
191+
p = convert_to_numpy(d[key]) if isinstance(d[key], torch.Tensor) else d[key]
192192
p = np.where(p == label_idx, 1, 0).astype(np.uint8)
193+
p = np.moveaxis(p, 0, 1) # for cv2
193194

194195
label_name = self.labels.get(label_idx, label_idx)
195196
label_names.add(label_name)
@@ -237,29 +238,36 @@ def __call__(self, data):
237238

238239

239240
class DumpImagePrediction2Dd(Transform):
240-
def __init__(self, image_path, pred_path):
241+
def __init__(self, image_path, pred_path, pred_only=True):
241242
self.image_path = image_path
242243
self.pred_path = pred_path
244+
self.pred_only = pred_only
243245

244246
def __call__(self, data):
245247
d = dict(data)
246-
image = d["image"].array
247-
pred = d["pred"].array
248-
249-
img_tensor = make_grid(torch.from_numpy(image[:3] * 128 + 128), normalize=True)
250-
save_image(img_tensor, self.image_path)
251-
252-
image_pred = [pred, image[3][None], image[4][None]] if image.shape[0] == 5 else [pred]
253-
image_pred_np = np.array(image_pred)
254-
image_pred_t = torch.from_numpy(image_pred_np)
255-
256-
tensor = make_grid(
257-
tensor=image_pred_t,
258-
nrow=len(image_pred),
259-
normalize=True,
260-
pad_value=10,
261-
)
262-
save_image(tensor, self.pred_path)
248+
for bidx in range(d["image"].shape[0]):
249+
image = np.moveaxis(d["image"][bidx], 1, 2)
250+
pred = np.moveaxis(d["pred"][bidx], 0, 1)
251+
252+
img_tensor = make_grid(torch.from_numpy(image[:3] * 128 + 128), normalize=True)
253+
save_image(img_tensor, self.image_path)
254+
255+
if self.pred_only:
256+
pred_tensor = make_grid(torch.from_numpy(pred), normalize=True)
257+
save_image(pred_tensor[0], self.pred_path)
258+
return d
259+
260+
image_pred = [pred[None], image[3][None], image[4][None]] if image.shape[0] == 5 else [pred[None]]
261+
image_pred_np = np.array(image_pred)
262+
image_pred_t = torch.from_numpy(image_pred_np)
263+
264+
tensor = make_grid(
265+
tensor=image_pred_t,
266+
nrow=len(image_pred),
267+
normalize=True,
268+
pad_value=10,
269+
)
270+
save_image(tensor, self.pred_path)
263271
return d
264272

265273

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,6 @@ bcrypt==3.2.2
4343
shapely==1.8.2
4444
requests==2.28.1
4545
scikit-learn
46+
scipy
4647

4748
#sudo apt-get install openslide-tools -y

sample-apps/pathology/lib/handlers.py

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(
6262
self,
6363
summary_writer: Optional[SummaryWriter] = None,
6464
log_dir: str = "./runs",
65-
tag_name="val_acc",
65+
tag_name="val",
6666
interval: int = 1,
6767
batch_transform: Callable = lambda x: x,
6868
output_transform: Callable = lambda x: x,
@@ -88,7 +88,8 @@ def __init__(
8888
self.class_y_pred: List[Any] = []
8989

9090
def attach(self, engine: Engine) -> None:
91-
engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.interval), self, "iteration")
91+
if self.interval == 1:
92+
engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.interval), self, "iteration")
9293
engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.interval), self, "epoch")
9394

9495
def __call__(self, engine: Engine, action) -> None:
@@ -130,39 +131,37 @@ def write_images(self, batch_data, output_data, epoch):
130131
image = batch_data[bidx]["image"].detach().cpu().numpy()
131132
y = output_data[bidx]["label"].detach().cpu().numpy()
132133

133-
tag_prefix = f"b{bidx} - " if self.batch_limit != 1 else ""
134-
img_tensor = make_grid(torch.from_numpy(image[:3] * 128 + 128), normalize=True)
135-
self.writer.add_image(tag=f"{tag_prefix}Image", img_tensor=img_tensor, global_step=epoch)
136-
137134
if self.class_names:
138135
sig_np = image[:3] * 128 + 128
139136
sig_np[0, :, :] = np.where(image[3] > 0, 1, sig_np[0, :, :])
140-
sig_tensor = make_grid(torch.from_numpy(sig_np), normalize=True)
141-
self.writer.add_image(tag=f"{tag_prefix}Signal", img_tensor=sig_tensor, global_step=epoch)
142137
if np.count_nonzero(image[3]) == 0:
143-
self.logger.info("+++++++++ BUG (Signal is ZERO)")
138+
self.logger.info(f"{self.tag_name} => +++++++++ BUG (Signal is ZERO)")
144139

145140
y_pred = output_data[bidx]["pred"].detach().cpu().numpy()
146141

147142
y_c = np.argmax(y)
148143
y_pred_c = np.argmax(y_pred)
149144

150-
tag_prefix = f"b{bidx} - " if self.batch_limit != 1 else ""
151-
label_pred_tag = f"{tag_prefix}Label vs Pred:"
145+
tag_prefix = f"{self.tag_name} - b{bidx} - " if self.batch_limit != 1 else f"{self.tag_name} - "
146+
label_pred_tag = f"{tag_prefix}Image/Signal/Label/Pred:"
152147

153-
y_img = Image.new("RGB", (200, 100))
148+
y_img = Image.new("RGB", image.shape[-2:])
154149
draw = ImageDraw.Draw(y_img)
155150
draw.text((10, 50), self.class_names.get(f"{y_c}", f"{y_c}"))
156151

157-
y_pred_img = Image.new("RGB", (200, 100), "green" if y_c == y_pred_c else "red")
152+
y_pred_img = Image.new("RGB", image.shape[-2:], "green" if y_c == y_pred_c else "red")
158153
draw = ImageDraw.Draw(y_pred_img)
159154
draw.text((10, 50), self.class_names.get(f"{y_pred_c}", f"{y_pred_c}"))
160155

161-
label_pred = [np.moveaxis(np.array(y_img), -1, 0), np.moveaxis(np.array(y_pred_img), -1, 0)]
162156
img_tensor = make_grid(
163-
tensor=torch.from_numpy(np.array(label_pred)),
164-
nrow=3,
165-
normalize=False,
157+
tensor=[
158+
torch.from_numpy(sig_np),
159+
torch.from_numpy(np.stack((np.where(image[3] > 0, 255, 0),) * 3)),
160+
torch.from_numpy(np.moveaxis(np.array(y_img), -1, 0)),
161+
torch.from_numpy(np.moveaxis(np.array(y_pred_img), -1, 0)),
162+
],
163+
nrow=4,
164+
normalize=True,
166165
pad_value=10,
167166
)
168167
self.writer.add_image(tag=label_pred_tag, img_tensor=img_tensor, global_step=epoch)
@@ -171,35 +170,60 @@ def write_images(self, batch_data, output_data, epoch):
171170
if self.batch_limit == 1 and bidx < (len(batch_data) - 1) and np.sum(y) == 0:
172171
continue
173172

173+
tag_prefix = f"{self.tag_name} - b{bidx} - " if self.batch_limit != 1 else ""
174+
img_np = image[:3] * 128 + 128
175+
if image.shape[0] > 3:
176+
img_np[0, :, :] = np.where(image[3] > 0, 1, img_np[0, :, :])
177+
img_tensor = make_grid(torch.from_numpy(img_np), normalize=True)
178+
self.writer.add_image(tag=f"{tag_prefix}Image", img_tensor=img_tensor, global_step=epoch)
179+
174180
y_pred = output_data[bidx]["pred"].detach().cpu().numpy()
175181

176182
for region in range(y_pred.shape[0]):
177183
if region == 0 and y_pred.shape[0] > 1: # one-hot; background
178184
continue
179185

186+
cl = np.count_nonzero(y[region])
187+
cp = np.count_nonzero(y_pred[region])
180188
self.logger.info(
181-
"{} - {} - Image: {};"
189+
"{} => {} - {} - Image: {};"
182190
" Label: {} (nz: {});"
183191
" Pred: {} (nz: {});"
184-
" Sig: (pos-nz: {}, neg-nz: {})".format(
192+
" Diff: {:.2f}%; "
193+
"{}".format(
194+
self.tag_name,
185195
bidx,
186196
region,
187197
image.shape,
188198
y.shape,
189-
np.count_nonzero(y[region]),
199+
cl,
190200
y_pred.shape,
191-
np.count_nonzero(y_pred[region]),
192-
np.count_nonzero(image[3]) if image.shape[0] == 5 else 0,
193-
np.count_nonzero(image[4]) if image.shape[0] == 5 else 0,
201+
cp,
202+
100 * (cp - cl) / (cl + 1),
203+
" Sig: (pos-nz: {}, neg-nz: {})".format(
204+
np.count_nonzero(image[3]) if image.shape[0] == 5 else 0,
205+
np.count_nonzero(image[4]) if image.shape[0] == 5 else 0,
206+
)
207+
if image.shape[0] == 5
208+
else "",
194209
)
195210
)
196211

197-
tag_prefix = f"b{bidx}:l{region} - " if self.batch_limit != 1 else f"l{region} - "
212+
tag_prefix = (
213+
f"{self.tag_name} - b{bidx}:l{region} - "
214+
if self.batch_limit != 1
215+
else f"{self.tag_name} - l{region} - "
216+
)
198217

199218
label_pred = [y[region][None], y_pred[region][None]]
200219
label_pred_tag = f"{tag_prefix}Label vs Pred:"
201220
if image.shape[0] == 5:
202-
label_pred = [y[region][None], y_pred[region][None], image[3][None], image[4][None]]
221+
label_pred = [
222+
y[region][None],
223+
y_pred[region][None],
224+
image[3][None] > 0,
225+
image[4][None] > 0,
226+
]
203227
label_pred_tag = f"{tag_prefix}Label vs Pred vs Pos vs Neg"
204228

205229
img_tensor = make_grid(
@@ -222,12 +246,12 @@ def write_region_metrics(self, epoch):
222246
for n, m in v.items():
223247
ltext.append(f"{n} => {m:.4f}")
224248
cname = self.class_names.get(k, k)
225-
self.writer.add_scalar(f"cr_{k}_{n}", m, epoch)
249+
self.writer.add_scalar(f"{self.tag_name}_cr_{k}_{n}", m, epoch)
226250

227-
self.logger.info(f"Epoch[{epoch}] Metrics -- Class: {cname}; {'; '.join(ltext)}")
251+
self.logger.info(f"{self.tag_name} => Epoch[{epoch}] Metrics -- Class: {cname}; {'; '.join(ltext)}")
228252
else:
229-
self.logger.info(f"Epoch[{epoch}] Metrics -- {k} => {v:.4f}")
230-
self.writer.add_scalar(f"cr_{k}", v, epoch)
253+
self.logger.info(f"{self.tag_name} => Epoch[{epoch}] Metrics -- {k} => {v:.4f}")
254+
self.writer.add_scalar(f"{self.tag_name}_cr_{k}", v, epoch)
231255

232256
self.class_y = []
233257
self.class_y_pred = []
@@ -237,13 +261,15 @@ def write_region_metrics(self, epoch):
237261
metric_sum = 0
238262
for region in self.metric_data:
239263
metric = self.metric_data[region].mean()
240-
self.logger.info(f"Epoch[{epoch}] Metrics -- Region: {region:0>2d}, {self.tag_name}: {metric:.4f}")
264+
self.logger.info(
265+
f"{self.tag_name} => Epoch[{epoch}] Metrics (Dice) -- Region: {region:0>2d}: {metric:.4f}"
266+
)
241267

242-
self.writer.add_scalar(f"dice_{region:0>2d}", metric, epoch)
268+
self.writer.add_scalar(f"{self.tag_name}_dice_{region:0>2d}", metric, epoch)
243269
metric_sum += metric
244270

245271
metric_avg = metric_sum / len(self.metric_data)
246-
self.writer.add_scalar("dice_regions_avg", metric_avg, epoch)
272+
self.writer.add_scalar(f"{self.tag_name}_dice_regions_avg", metric_avg, epoch)
247273

248274
self.writer.flush()
249275
self.metric_data = {}

sample-apps/pathology/lib/infers/classification_nuclei.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
from typing import Any, Callable, Dict, Sequence
1414

1515
import numpy as np
16-
from lib.transforms import FixNuclickClassd, LoadImagePatchd
16+
from lib.nuclick import AddLabelAsGuidanced
17+
from lib.transforms import LoadImagePatchd
1718
from monai.inferers import Inferer, SimpleInferer
18-
from monai.transforms import Activationsd, AsChannelFirstd, EnsureTyped, ScaleIntensityRangeD
19+
from monai.transforms import Activationsd, EnsureChannelFirstd, ScaleIntensityRangeD
1920

2021
from monailabel.interfaces.tasks.infer_v2 import InferType
2122
from monailabel.tasks.infer.basic_infer import BasicInferTask
@@ -62,10 +63,9 @@ def pre_transforms(self, data=None) -> Sequence[Callable]:
6263
return [
6364
LoadImagePatchd(keys="image", dtype=np.uint8),
6465
LoadImagePatchd(keys="label", dtype=np.uint8, mode="L"),
65-
EnsureTyped(keys=("image", "label")),
66-
AsChannelFirstd(keys="image"),
66+
EnsureChannelFirstd(keys=("image", "label")),
6767
ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0),
68-
FixNuclickClassd(image="image", label="label", offset=-1),
68+
AddLabelAsGuidanced(keys="image", source="label"),
6969
]
7070

7171
def inferer(self, data=None) -> Inferer:

sample-apps/pathology/lib/infers/nuclick.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@
1414

1515
import numpy as np
1616
import torch
17-
from lib.transforms import FixNuclickClassd, LoadImagePatchd, NuClickPostFilterLabelExd
18-
from monai.apps.nuclick.transforms import AddClickSignalsd, NuclickKeys
17+
from lib.nuclick import AddClickSignalsd, AddLabelAsGuidanced, NuclickKeys, PostFilterLabeld
18+
from lib.transforms import LoadImagePatchd
1919
from monai.config import KeysCollection
2020
from monai.transforms import (
2121
Activationsd,
2222
AsChannelFirstd,
2323
AsDiscreted,
2424
EnsureTyped,
2525
MapTransform,
26+
ScaleIntensityRangeD,
2627
SqueezeDimd,
2728
ToNumpyd,
2829
)
@@ -113,36 +114,38 @@ def info(self) -> Dict[str, Any]:
113114
def pre_transforms(self, data=None):
114115
return [
115116
LoadImagePatchd(keys="image", mode="RGB", dtype=np.uint8, padding=False),
117+
EnsureTyped(keys="image", device=data.get("device") if data else None),
116118
AsChannelFirstd(keys="image"),
117119
ConvertInteractiveClickSignals(
118120
source_annotation_keys="nuclick points",
119121
target_data_keys=NuclickKeys.FOREGROUND,
120122
allow_missing_keys=True,
121123
),
122-
AddClickSignalsd(image="image", foreground=NuclickKeys.FOREGROUND),
123-
EnsureTyped(keys="image", device=data.get("device") if data else None),
124+
ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0),
125+
AddClickSignalsd(image="image", foreground=NuclickKeys.FOREGROUND, gaussian=False),
124126
]
125127

126128
def run_inferer(self, data, convert_to_batch=True, device="cuda"):
127129
output = super().run_inferer(data, False, device)
128130
if self.task_classification:
129-
data2 = copy.deepcopy(self.task_classification.config())
130131
pred1 = output["pred"]
131132
pred1 = torch.sigmoid(pred1)
132133
pred1 = pred1 >= 0.5
133134

135+
data2 = copy.deepcopy(self.task_classification.config())
134136
data2.update({"image": output["image"][:, :3], "label": pred1, "device": device})
135-
136-
data2 = self.task_classification.run_pre_transforms(data2, [FixNuclickClassd(image="image", label="label")])
137+
data2 = self.task_classification.run_pre_transforms(
138+
data2, [AddLabelAsGuidanced(keys="image", source="label")]
139+
)
137140

138141
output2 = self.task_classification.run_inferer(data2, False, device)
139142
pred2 = output2["pred"]
140143
pred2 = torch.softmax(pred2, dim=1)
141144
pred2 = torch.argmax(pred2, dim=1)
142145
pred2 = [int(p) for p in pred2]
143146

144-
output["pred_classes"] = [v + 1 for v in pred2]
145-
logger.info(f"Predicted Classes: {output['pred_classes']}")
147+
output[NuclickKeys.PRED_CLASSES] = [v + 1 for v in pred2]
148+
logger.info(f"Predicted Classes: {output[NuclickKeys.PRED_CLASSES]}")
146149
return output
147150

148151
def post_transforms(self, data=None) -> Sequence[Callable]:
@@ -152,7 +155,7 @@ def post_transforms(self, data=None) -> Sequence[Callable]:
152155
AsDiscreted(keys="pred", threshold=0.5),
153156
SqueezeDimd(keys="pred", dim=1),
154157
ToNumpyd(keys=("image", "pred")),
155-
NuClickPostFilterLabelExd(keys="pred"),
158+
PostFilterLabeld(keys="pred"),
156159
FindContoursd(keys="pred", labels=self.labels),
157160
]
158161

0 commit comments

Comments
 (0)