Skip to content

Commit c6793fd

Browse files
Nic-Mamonai-bot
andauthored
2525 Add decollate logic to DeepGrow module (8/July) (#2530)
* [DLMED] add decollate logic to DeepGrow Signed-off-by: Nic Ma <[email protected]> * [MONAI] python code formatting Signed-off-by: monai-bot <[email protected]> * [DLMED] fix flake8 issue Signed-off-by: Nic Ma <[email protected]> * [DLMED] add AddRandomGuidanced to test Signed-off-by: Nic Ma <[email protected]> * [DLMED] update network Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix typo Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix issue in addsignal Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix guidance issue Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix flake8 Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix mypy and docs Signed-off-by: Nic Ma <[email protected]> Co-authored-by: monai-bot <[email protected]>
1 parent 17dde67 commit c6793fd

File tree

4 files changed

+78
-110
lines changed

4 files changed

+78
-110
lines changed

monai/apps/deepgrow/interaction.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import torch
1414

15+
from monai.data import decollate_batch, list_data_collate
1516
from monai.engines import SupervisedEvaluator, SupervisedTrainer
1617
from monai.engines.utils import IterationEvents
1718
from monai.transforms import Compose
@@ -74,6 +75,9 @@ def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchd
7475
batchdata[self.key_probability] = torch.as_tensor(
7576
([1.0 - ((1.0 / self.max_interactions) * j)] if self.train else [1.0]) * len(inputs)
7677
)
77-
batchdata = self.transforms(batchdata)
78+
# decollate batch data to execute click transforms
79+
batchdata_list = [self.transforms(i) for i in decollate_batch(batchdata, detach=True)]
80+
# collate list into a batch for next round interaction
81+
batchdata = list_data_collate(batchdata_list)
7882

7983
return engine._iteration(engine, batchdata)

monai/apps/deepgrow/transforms.py

Lines changed: 26 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
11+
import json
1112
from typing import Callable, Dict, Optional, Sequence, Union
1213

1314
import numpy as np
@@ -144,7 +145,7 @@ def _apply(self, label, sid):
144145
def __call__(self, data):
145146
d = dict(data)
146147
self.randomize(data)
147-
d[self.guidance] = self._apply(d[self.label], self.sid)
148+
d[self.guidance] = json.dumps(self._apply(d[self.label], self.sid).astype(int).tolist())
148149
return d
149150

150151

@@ -159,7 +160,7 @@ class AddGuidanceSignald(Transform):
159160
guidance: key to store guidance.
160161
sigma: standard deviation for Gaussian kernel.
161162
number_intensity_ch: channel index.
162-
batched: whether input is batched or not.
163+
163164
"""
164165

165166
def __init__(
@@ -168,17 +169,16 @@ def __init__(
168169
guidance: str = "guidance",
169170
sigma: int = 2,
170171
number_intensity_ch: int = 1,
171-
batched: bool = False,
172172
):
173173
self.image = image
174174
self.guidance = guidance
175175
self.sigma = sigma
176176
self.number_intensity_ch = number_intensity_ch
177-
self.batched = batched
178177

179178
def _get_signal(self, image, guidance):
180179
dimensions = 3 if len(image.shape) > 3 else 2
181180
guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance
181+
guidance = json.loads(guidance) if isinstance(guidance, str) else guidance
182182
if dimensions == 3:
183183
signal = np.zeros((len(guidance), image.shape[-3], image.shape[-2], image.shape[-1]), dtype=np.float32)
184184
else:
@@ -210,16 +210,9 @@ def _get_signal(self, image, guidance):
210210
return signal
211211

212212
def _apply(self, image, guidance):
213-
if not self.batched:
214-
signal = self._get_signal(image, guidance)
215-
return np.concatenate([image, signal], axis=0)
216-
217-
images = []
218-
for i, g in zip(image, guidance):
219-
i = i[0 : 0 + self.number_intensity_ch, ...]
220-
signal = self._get_signal(i, g)
221-
images.append(np.concatenate([i, signal], axis=0))
222-
return images
213+
signal = self._get_signal(image, guidance)
214+
image = image[0 : 0 + self.number_intensity_ch, ...]
215+
return np.concatenate([image, signal], axis=0)
223216

224217
def __call__(self, data):
225218
d = dict(data)
@@ -234,26 +227,17 @@ class FindDiscrepancyRegionsd(Transform):
234227
"""
235228
Find discrepancy between prediction and actual during click interactions during training.
236229
237-
If batched is true:
238-
239-
label is in shape (B, C, D, H, W) or (B, C, H, W)
240-
pred has same shape as label
241-
discrepancy will have shape (B, 2, C, D, H, W) or (B, 2, C, H, W)
242-
243230
Args:
244231
label: key to label source.
245232
pred: key to prediction source.
246233
discrepancy: key to store discrepancies found between label and prediction.
247-
batched: whether input is batched or not.
234+
248235
"""
249236

250-
def __init__(
251-
self, label: str = "label", pred: str = "pred", discrepancy: str = "discrepancy", batched: bool = True
252-
):
237+
def __init__(self, label: str = "label", pred: str = "pred", discrepancy: str = "discrepancy"):
253238
self.label = label
254239
self.pred = pred
255240
self.discrepancy = discrepancy
256-
self.batched = batched
257241

258242
@staticmethod
259243
def disparity(label, pred):
@@ -266,13 +250,7 @@ def disparity(label, pred):
266250
return [pos_disparity, neg_disparity]
267251

268252
def _apply(self, label, pred):
269-
if not self.batched:
270-
return self.disparity(label, pred)
271-
272-
disparity = []
273-
for la, pr in zip(label, pred):
274-
disparity.append(self.disparity(la, pr))
275-
return disparity
253+
return self.disparity(label, pred)
276254

277255
def __call__(self, data):
278256
d = dict(data)
@@ -286,53 +264,32 @@ def __call__(self, data):
286264
class AddRandomGuidanced(Randomizable, Transform):
287265
"""
288266
Add random guidance based on discrepancies that were found between label and prediction.
289-
290-
If batched is True, input shape is as below:
291-
292-
Guidance is of shape (B, 2, N, # of dim) where B is batch size, 2 means positive and negative,
293-
N means how many guidance points, # of dim is the total number of dimensions of the image
294-
(for example if the image is CDHW, then # of dim would be 4).
295-
296-
Discrepancy is of shape (B, 2, C, D, H, W) or (B, 2, C, H, W)
297-
298-
Probability is of shape (B, 1)
299-
300-
else:
301-
302-
Guidance is of shape (2, N, # of dim)
303-
304-
Discrepancy is of shape (2, C, D, H, W) or (2, C, H, W)
305-
306-
Probability is of shape (1)
267+
input shape is as below:
268+
Guidance is of shape (2, N, # of dim)
269+
Discrepancy is of shape (2, C, D, H, W) or (2, C, H, W)
270+
Probability is of shape (1)
307271
308272
Args:
309273
guidance: key to guidance source.
310274
discrepancy: key that represents discrepancies found between label and prediction.
311275
probability: key that represents click/interaction probability.
312-
batched: whether input is batched or not.
276+
313277
"""
314278

315279
def __init__(
316280
self,
317281
guidance: str = "guidance",
318282
discrepancy: str = "discrepancy",
319283
probability: str = "probability",
320-
batched: bool = True,
321284
):
322285
self.guidance = guidance
323286
self.discrepancy = discrepancy
324287
self.probability = probability
325-
self.batched = batched
326288
self._will_interact = None
327289

328290
def randomize(self, data=None):
329291
probability = data[self.probability]
330-
if not self.batched:
331-
self._will_interact = self.R.choice([True, False], p=[probability, 1.0 - probability])
332-
else:
333-
self._will_interact = []
334-
for p in probability:
335-
self._will_interact.append(self.R.choice([True, False], p=[p, 1.0 - p]))
292+
self._will_interact = self.R.choice([True, False], p=[probability, 1.0 - probability])
336293

337294
def find_guidance(self, discrepancy):
338295
distance = distance_transform_cdt(discrepancy).flatten()
@@ -368,24 +325,16 @@ def add_guidance(self, discrepancy, will_interact):
368325

369326
def _apply(self, guidance, discrepancy):
370327
guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance
371-
if not self.batched:
372-
pos, neg = self.add_guidance(discrepancy, self._will_interact)
373-
if pos:
374-
guidance[0].append(pos)
375-
guidance[1].append([-1] * len(pos))
376-
if neg:
377-
guidance[0].append([-1] * len(neg))
378-
guidance[1].append(neg)
379-
else:
380-
for g, d, w in zip(guidance, discrepancy, self._will_interact):
381-
pos, neg = self.add_guidance(d, w)
382-
if pos:
383-
g[0].append(pos)
384-
g[1].append([-1] * len(pos))
385-
if neg:
386-
g[0].append([-1] * len(neg))
387-
g[1].append(neg)
388-
return np.asarray(guidance)
328+
guidance = json.loads(guidance) if isinstance(guidance, str) else guidance
329+
pos, neg = self.add_guidance(discrepancy, self._will_interact)
330+
if pos:
331+
guidance[0].append(pos)
332+
guidance[1].append([-1] * len(pos))
333+
if neg:
334+
guidance[0].append([-1] * len(neg))
335+
guidance[1].append(neg)
336+
337+
return json.dumps(np.asarray(guidance).astype(int).tolist())
389338

390339
def __call__(self, data):
391340
d = dict(data)

tests/test_deepgrow_interaction.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,39 +11,60 @@
1111

1212
import unittest
1313

14+
import numpy as np
1415
import torch
1516

1617
from monai.apps.deepgrow.interaction import Interaction
18+
from monai.apps.deepgrow.transforms import (
19+
AddGuidanceSignald,
20+
AddInitialSeedPointd,
21+
AddRandomGuidanced,
22+
FindAllValidSlicesd,
23+
FindDiscrepancyRegionsd,
24+
)
1725
from monai.data import Dataset
1826
from monai.engines import SupervisedTrainer
1927
from monai.engines.utils import IterationEvents
20-
from monai.transforms import Activationsd, Compose, ToNumpyd
28+
from monai.transforms import Activationsd, Compose, ToNumpyd, ToTensord
2129

2230

2331
def add_one(engine):
24-
if engine.state.best_metric is -1:
32+
if engine.state.best_metric == -1:
2533
engine.state.best_metric = 0
2634
else:
2735
engine.state.best_metric = engine.state.best_metric + 1
2836

2937

3038
class TestInteractions(unittest.TestCase):
3139
def run_interaction(self, train, compose):
32-
data = []
33-
for i in range(5):
34-
data.append({"image": torch.tensor([float(i)]), "label": torch.tensor([float(i)])})
35-
network = torch.nn.Linear(1, 1)
40+
data = [{"image": np.ones((1, 2, 2, 2)).astype(np.float32), "label": np.ones((1, 2, 2, 2))} for _ in range(5)]
41+
network = torch.nn.Linear(2, 2)
3642
lr = 1e-3
3743
opt = torch.optim.SGD(network.parameters(), lr)
3844
loss = torch.nn.L1Loss()
39-
dataset = Dataset(data, transform=None)
45+
train_transforms = Compose(
46+
[
47+
FindAllValidSlicesd(label="label", sids="sids"),
48+
AddInitialSeedPointd(label="label", guidance="guidance", sids="sids"),
49+
AddGuidanceSignald(image="image", guidance="guidance"),
50+
ToTensord(keys=("image", "label")),
51+
]
52+
)
53+
dataset = Dataset(data, transform=train_transforms)
4054
data_loader = torch.utils.data.DataLoader(dataset, batch_size=5)
4155

42-
iteration_transforms = [Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys="pred")]
56+
iteration_transforms = [
57+
Activationsd(keys="pred", sigmoid=True),
58+
ToNumpyd(keys=["image", "label", "pred"]),
59+
FindDiscrepancyRegionsd(label="label", pred="pred", discrepancy="discrepancy"),
60+
AddRandomGuidanced(guidance="guidance", discrepancy="discrepancy", probability="probability"),
61+
AddGuidanceSignald(image="image", guidance="guidance"),
62+
ToTensord(keys=("image", "label")),
63+
]
4364
iteration_transforms = Compose(iteration_transforms) if compose else iteration_transforms
4465

4566
i = Interaction(transforms=iteration_transforms, train=train, max_interactions=5)
46-
self.assertEqual(len(i.transforms.transforms), 2, "Mismatch in expected transforms")
67+
self.assertEqual(len(i.transforms.transforms), 6, "Mismatch in expected transforms")
4768

4869
# set up engine
4970
engine = SupervisedTrainer(

tests/test_deepgrow_transforms.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030

3131
IMAGE = np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]])
3232
LABEL = np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]])
33-
BATCH_IMAGE = np.array([[[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]])
34-
BATCH_LABEL = np.array([[[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]])
3533

3634
DATA_1 = {
3735
"image": IMAGE,
@@ -61,24 +59,22 @@
6159
}
6260

6361
DATA_3 = {
64-
"image": BATCH_IMAGE,
65-
"label": BATCH_LABEL,
66-
"pred": np.array([[[[[0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 1, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]]),
62+
"image": IMAGE,
63+
"label": LABEL,
64+
"pred": np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 1, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]),
6765
}
6866

6967
DATA_4 = {
70-
"image": BATCH_IMAGE,
71-
"label": BATCH_LABEL,
72-
"guidance": np.array([[[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]]),
68+
"image": IMAGE,
69+
"label": LABEL,
70+
"guidance": np.array([[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]),
7371
"discrepancy": np.array(
7472
[
75-
[
76-
[[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
77-
[[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
78-
]
73+
[[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
74+
[[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
7975
]
8076
),
81-
"probability": [1.0],
77+
"probability": 1.0,
8278
}
8379

8480
DATA_5 = {
@@ -192,11 +188,11 @@
192188
ADD_INITIAL_POINT_TEST_CASE_1 = [
193189
{"label": "label", "guidance": "guidance", "sids": "sids"},
194190
DATA_1,
195-
np.array([[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]),
191+
"[[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]",
196192
]
197193

198194
ADD_GUIDANCE_TEST_CASE_1 = [
199-
{"image": "image", "guidance": "guidance", "batched": False},
195+
{"image": "image", "guidance": "guidance"},
200196
DATA_2,
201197
np.array(
202198
[
@@ -233,18 +229,16 @@
233229
DATA_3,
234230
np.array(
235231
[
236-
[
237-
[[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
238-
[[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
239-
]
232+
[[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
233+
[[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
240234
]
241235
),
242236
]
243237

244238
ADD_RANDOM_GUIDANCE_TEST_CASE_1 = [
245-
{"guidance": "guidance", "discrepancy": "discrepancy", "probability": "probability", "batched": True},
239+
{"guidance": "guidance", "discrepancy": "discrepancy", "probability": "probability"},
246240
DATA_4,
247-
np.array([[[[1, 0, 2, 2], [1, 0, 1, 3]], [[-1, -1, -1, -1], [-1, -1, -1, -1]]]]),
241+
"[[[1, 0, 2, 2], [1, 0, 1, 3]], [[-1, -1, -1, -1], [-1, -1, -1, -1]]]",
248242
]
249243

250244
ADD_GUIDANCE_FROM_POINTS_TEST_CASE_1 = [
@@ -398,7 +392,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
398392
add_fn = AddInitialSeedPointd(**arguments)
399393
add_fn.set_random_state(seed)
400394
result = add_fn(input_data)
401-
np.testing.assert_allclose(result[arguments["guidance"]], expected_result)
395+
self.assertEqual(result[arguments["guidance"]], expected_result)
402396

403397

404398
class TestAddGuidanceSignald(unittest.TestCase):
@@ -422,7 +416,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
422416
add_fn = AddRandomGuidanced(**arguments)
423417
add_fn.set_random_state(seed)
424418
result = add_fn(input_data)
425-
np.testing.assert_allclose(result[arguments["guidance"]], expected_result, rtol=1e-5)
419+
self.assertEqual(result[arguments["guidance"]], expected_result)
426420

427421

428422
class TestAddGuidanceFromPointsd(unittest.TestCase):

0 commit comments

Comments
 (0)