Skip to content

Commit 5d9061d

Browse files
committed
holoscan conversion check + explicit params added + setter cleanup
Signed-off-by: bluna301 <[email protected]>
1 parent 6f72b0b commit 5d9061d

File tree

2 files changed

+122
-85
lines changed

2 files changed

+122
-85
lines changed

monai/deploy/core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
# Need to import explicit ones to quiet mypy complaints
3232
from holoscan.core import *
33-
from holoscan.core import Application, Condition, ConditionType, Fragment, Operator, OperatorSpec
33+
from holoscan.core import Application, Condition, ConditionType, Fragment, Operator, OperatorSpec, Resource
3434

3535
from .app_context import AppContext, init_app_context
3636
from .arg_parser import parse_args

monai/deploy/operators/monai_seg_inference_operator.py

Lines changed: 121 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,15 @@
1414
import os
1515
from pathlib import Path
1616
from threading import Lock
17-
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
17+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
1818

1919
import numpy as np
20+
import torch
2021

2122
from monai.deploy.utils.importutil import optional_import
2223
from monai.utils import StrEnum # Will use the built-in StrEnum when SDK requires Python 3.11.
23-
from monai.utils import BlendMode, PytorchPadMode
2424

2525
MONAI_UTILS = "monai.utils"
26-
torch, _ = optional_import("torch", "1.5")
2726
np_str_obj_array_pattern, _ = optional_import("torch.utils.data._utils.collate", name="np_str_obj_array_pattern")
2827
Dataset, _ = optional_import("monai.data", name="Dataset")
2928
DataLoader, _ = optional_import("monai.data", name="DataLoader")
@@ -42,7 +41,7 @@
4241
# Dynamic class is not handled so make it Any for now: https://github.com/python/mypy/issues/2477
4342
Compose: Any = Compose_
4443

45-
from monai.deploy.core import AppContext, ConditionType, Fragment, Image, OperatorSpec
44+
from monai.deploy.core import AppContext, Condition, ConditionType, Fragment, Image, OperatorSpec, Resource
4645

4746
from .inference_operator import InferenceOperator
4847

@@ -56,11 +55,6 @@ class InfererType(StrEnum):
5655
SLIDING_WINDOW = "sliding_window"
5756

5857

59-
# define other StrEnum types
60-
BlendModeType = BlendMode
61-
PytorchPadModeType = PytorchPadMode
62-
63-
6458
# @md.env(pip_packages=["monai>=1.0.0", "torch>=1.10.2", "numpy>=1.21"])
6559
class MonaiSegInferenceOperator(InferenceOperator):
6660
"""This segmentation operator uses MONAI transforms and performs Simple or Sliding Window Inference.
@@ -90,9 +84,10 @@ class MonaiSegInferenceOperator(InferenceOperator):
9084
@staticmethod
9185
def filter_sw_kwargs(**kwargs) -> Dict[str, Any]:
9286
"""
93-
Returns a dictionary of named parameters of the sliding_window_inference function that are:
94-
- Not explicitly defined in the __init__ of this class
95-
- Not explicitly used when calling sliding_window_inference
87+
Returns a dictionary of named parameters of the sliding_window_inference function that:
88+
- Are not explicitly defined in the __init__ of this class
89+
- Are not explicitly used when calling sliding_window_inference
90+
- Can be successfully converted from Python --> Holoscan's C++ layer
9691
9792
Args:
9893
**kwargs: extra arguments passed into __init__ beyond the explicitly defined args.
@@ -101,19 +96,30 @@ def filter_sw_kwargs(**kwargs) -> Dict[str, Any]:
10196
filtered_params: A filtered dictionary of arguments to be passed to sliding_window_inference.
10297
"""
10398

99+
logger = logging.getLogger(f"{__name__}.{MonaiSegInferenceOperator.__name__}")
100+
104101
init_params = inspect.signature(MonaiSegInferenceOperator).parameters
105102

106103
# inputs + predictor explicitly used when calling sliding_window_inference
107-
explicit_used = ["inputs", "predictor"]
104+
explicit_used = {"inputs", "predictor"}
105+
106+
# Holoscan convertible types (not exhaustive)
107+
# This will be revisited when there is a better way to handle this.
108+
allowed_types = (str, int, float, bool, bytes, list, tuple, torch.Tensor, Condition, Resource)
108109

109110
filtered_params = {}
110111
for name, val in kwargs.items():
112+
# Drop explicitly defined kwargs
111113
if name in init_params or name in explicit_used:
112-
# match log formatting
113-
logger = logging.getLogger(f"{__name__}.{MonaiSegInferenceOperator.__name__}")
114-
logger.warning(f"{name!r} is already explicity defined or used; ignoring input arg")
115-
else:
116-
filtered_params[name] = val
114+
logger.warning(f"{name!r} is already explicitly defined or used; dropping kwarg.")
115+
continue
116+
# Drop kwargs that can't be converted by Holoscan
117+
if not isinstance(val, allowed_types):
118+
logger.warning(
119+
f"{name!r} type of {type(val).__name__!r} is a non-convertible kwarg for Holoscan; dropping kwarg."
120+
)
121+
continue
122+
filtered_params[name] = val
117123
return filtered_params
118124

119125
def __init__(
@@ -125,10 +131,11 @@ def __init__(
125131
post_transforms: Compose,
126132
app_context: AppContext,
127133
model_name: Optional[str] = "",
128-
overlap: float = 0.25,
129134
sw_batch_size: int = 4,
130-
mode: Union[BlendModeType, str] = BlendModeType.CONSTANT,
131-
padding_mode: Union[PytorchPadModeType, str] = PytorchPadModeType.CONSTANT,
135+
overlap: Union[Sequence[float], float] = 0.25,
136+
sw_device: Optional[Union[torch.device, str]] = None,
137+
device: Optional[Union[torch.device, str]] = None,
138+
process_fn: Optional[Callable] = None,
132139
inferer: Union[InfererType, str] = InfererType.SLIDING_WINDOW,
133140
model_path: Path = MODEL_LOCAL_PATH,
134141
**kwargs,
@@ -137,25 +144,25 @@ def __init__(
137144
138145
Args:
139146
fragment (Fragment): An instance of the Application class which is derived from Fragment.
140-
roi_size (Union[Sequence[int], int]): The window size to execute "SLIDING_WINDOW" evaluation.
141-
An optional input only to be passed for "SLIDING_WINDOW".
142-
If using a "SIMPLE" Inferer, this input is ignored.
147+
roi_size (Sequence[int], int, optional): The window size to execute "SLIDING_WINDOW" evaluation.
148+
Applicable for "SLIDING_WINDOW" only.
143149
pre_transforms (Compose): MONAI Compose object used for pre-transforms.
144150
post_transforms (Compose): MONAI Compose object used for post-transforms.
145151
app_context (AppContext): Object holding the I/O and model paths, and potentially loaded models.
146152
model_name (str, optional): Name of the model. Default to "" for single model app.
147-
overlap (float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
153+
sw_batch_size (int): The batch size to run window slices. Defaults to 4.
154+
Applicable for "SLIDING_WINDOW" only.
155+
overlap (Sequence[float], float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
148156
Applicable for "SLIDING_WINDOW" only.
149-
sw_batch_size(int): The batch size to run window slices. Defaults to 4.
157+
sw_device (torch.device, str, optional): Device for the window data. Defaults to None.
150158
Applicable for "SLIDING_WINDOW" only.
151-
mode (BlendModeType): How to blend output of overlapping windows, "CONSTANT" or "GAUSSIAN". Defaults to "CONSTANT".
159+
device: (torch.device, str, optional): Device for the stitched output prediction. Defaults to None.
152160
Applicable for "SLIDING_WINDOW" only.
153-
padding_mode (PytorchPadModeType): Padding mode for ``inputs``, when ``roi_size`` is larger than inputs,
154-
"CONSTANT", "REFLECT", "REPLICATE", or "CIRCULAR". Defaults to "CONSTANT".
161+
process_fn: (Callable, optional): process inference output and adjust the importance map per window. Defaults to None.
155162
Applicable for "SLIDING_WINDOW" only.
156-
inferer (InfererType): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
163+
inferer (InfererType, str): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
157164
model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir.
158-
**kwargs: any other sliding window parameters to forward (e.g. `sigma_scale`, `cval`, etc.).
165+
**kwargs: any other sliding window parameters to forward (e.g. `mode`, `cval`, etc.).
159166
"""
160167

161168
self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
@@ -165,29 +172,33 @@ def __init__(
165172
self._pred_dataset_key = "pred"
166173
self._input_image = None # Image will come in when compute is called.
167174
self._reader: Any = None
168-
self._roi_size = ensure_tuple(roi_size)
169-
self._pre_transform = pre_transforms
170-
self._post_transforms = post_transforms
175+
self.roi_size = ensure_tuple(roi_size)
176+
self.pre_transforms = pre_transforms
177+
self.post_transforms = post_transforms
171178
self._model_name = model_name.strip() if isinstance(model_name, str) else ""
172-
self._overlap = overlap
173-
self._sw_batch_size = sw_batch_size
174-
self._mode = mode
175-
self._padding_mode = padding_mode
176-
self._inferer = inferer
179+
self.overlap = overlap
180+
self.sw_batch_size = sw_batch_size
181+
self.inferer = inferer
177182
self._implicit_params = self.filter_sw_kwargs(**kwargs) # Filter keyword args
178183

184+
# Sliding window inference args whose type Holoscan can't convert - define explicitly
185+
self.sw_device = sw_device
186+
self.device = device
187+
self.process_fn = process_fn
188+
179189
# Add this so that the local model path can be set from the calling app
180190
self.model_path = model_path
181-
self.input_name_image = "image"
182-
self.output_name_seg = "seg_image"
191+
self._input_name_image = "image"
192+
self._output_name_seg = "seg_image"
183193

184194
# The execution context passed in on compute does not have the required model info, so need to
185195
# get and keep the model via the AppContext obj on construction.
186196
self.app_context = app_context
187197

188-
self.model = self._get_model(self.app_context, self.model_path, self._model_name)
198+
self._model = self._get_model(self.app_context, self.model_path, self._model_name)
189199

190-
super().__init__(fragment, *args, **kwargs)
200+
# Pass filtered kwargs
201+
super().__init__(fragment, *args, **self._implicit_params)
191202

192203
def _get_model(self, app_context: AppContext, model_path: Path, model_name: str):
193204
"""Load the model with the given name from context or model path
@@ -212,8 +223,8 @@ def _get_model(self, app_context: AppContext, model_path: Path, model_name: str)
212223
return model
213224

214225
def setup(self, spec: OperatorSpec):
215-
spec.input(self.input_name_image)
216-
spec.output(self.output_name_seg).condition(ConditionType.NONE) # Downstream receiver optional.
226+
spec.input(self._input_name_image)
227+
spec.output(self._output_name_seg).condition(ConditionType.NONE) # Downstream receiver optional.
217228

218229
@property
219230
def roi_size(self):
@@ -252,9 +263,13 @@ def overlap(self):
252263
return self._overlap
253264

254265
@overlap.setter
255-
def overlap(self, val: float):
256-
if val < 0 or val > 1:
257-
raise ValueError("Overlap must be between 0 and 1.")
266+
def overlap(self, val: Union[Sequence[float], float]):
267+
if not isinstance(val, (Sequence, int, float)) or isinstance(val, str):
268+
raise TypeError(f"Overlap must be type Sequence[float] | float, got {type(val).__name__}.")
269+
elif isinstance(val, Sequence) and not all(isinstance(x, (int, float)) and 0 <= x < 1 for x in val):
270+
raise ValueError("Each overlap value must be >= 0 and < 1.")
271+
elif isinstance(val, (int, float)) and not (0 <= float(val) < 1):
272+
raise ValueError(f"Overlap must be >= 0 and < 1, got {val}.")
258273
self._overlap = val
259274

260275
@property
@@ -269,37 +284,58 @@ def sw_batch_size(self, val: int):
269284
self._sw_batch_size = val
270285

271286
@property
272-
def mode(self) -> Union[BlendModeType, str]:
273-
"""The blend mode used during sliding window inference"""
274-
return self._mode
287+
def sw_device(self):
288+
"""Device for the window data."""
289+
return self._sw_device
290+
291+
@sw_device.setter
292+
def sw_device(self, val: Optional[Union[torch.device, str]]):
293+
if val is not None and not isinstance(val, (torch.device, str)):
294+
raise TypeError(f"sw_device must be type torch.device | str | None, got {type(val).__name__}.")
295+
self._sw_device = val
296+
297+
@property
298+
def device(self):
299+
"""Device for the stitched output prediction."""
300+
return self._device
275301

276-
@mode.setter
277-
def mode(self, val: BlendModeType):
278-
if not isinstance(val, BlendModeType):
279-
raise ValueError(f"Value must be of the correct type {BlendModeType}.")
280-
self._mode = val
302+
@device.setter
303+
def device(self, val: Optional[Union[torch.device, str]]):
304+
if val is not None and not isinstance(val, (torch.device, str)):
305+
raise TypeError(f"device must be type torch.device | str | None, got {type(val).__name__}.")
306+
self._device = val
281307

282308
@property
283-
def padding_mode(self) -> Union[PytorchPadModeType, str]:
284-
"""The padding mode to use when padding input images for inference"""
285-
return self._padding_mode
309+
def process_fn(self):
310+
"""Process inference output and adjust the importance map per window."""
311+
return self._process_fn
286312

287-
@padding_mode.setter
288-
def padding_mode(self, val: PytorchPadModeType):
289-
if not isinstance(val, PytorchPadModeType):
290-
raise ValueError(f"Value must be of the correct type {PytorchPadModeType}.")
291-
self._padding_mode = val
313+
@process_fn.setter
314+
def process_fn(self, val: Optional[Callable]):
315+
if val is not None and not callable(val):
316+
raise TypeError(f"process_fn must be type Callable | None, got {type(val).__name__}.")
317+
self._process_fn = val
292318

293319
@property
294320
def inferer(self) -> Union[InfererType, str]:
295321
"""The type of inferer to use"""
296322
return self._inferer
297323

298324
@inferer.setter
299-
def inferer(self, val: InfererType):
300-
if not isinstance(val, InfererType):
301-
raise ValueError(f"Value must be of the correct type {InfererType}.")
302-
self._inferer = val
325+
def inferer(self, val: Union[InfererType, str]):
326+
if isinstance(val, InfererType):
327+
self._inferer = val
328+
return
329+
330+
if isinstance(val, str):
331+
s = val.strip().lower()
332+
valid = (InfererType.SIMPLE.value, InfererType.SLIDING_WINDOW.value)
333+
if s in valid:
334+
self._inferer = InfererType(s)
335+
return
336+
raise ValueError(f"inferer must be one of {valid}, got {val!r}.")
337+
338+
raise TypeError(f"inferer must be InfererType or str, got {type(val).__name__}.")
303339

304340
def _convert_dicom_metadata_datatype(self, metadata: Dict):
305341
"""Converts metadata in pydicom types to the corresponding native types.
@@ -354,10 +390,10 @@ def compute(self, op_input, op_output, context):
354390
else:
355391
self._executing = True
356392
try:
357-
input_image = op_input.receive(self.input_name_image)
393+
input_image = op_input.receive(self._input_name_image)
358394
if not input_image:
359395
raise ValueError("Input is None.")
360-
op_output.emit(self.compute_impl(input_image, context), self.output_name_seg)
396+
op_output.emit(self.compute_impl(input_image, context), self._output_name_seg)
361397
finally:
362398
# Reset state on completing this method execution.
363399
with self._lock:
@@ -372,12 +408,12 @@ def compute_impl(self, input_image, context):
372408
# Need to give a name to the image as in-mem Image obj has no name.
373409
img_name = str(input_img_metadata.get("SeriesInstanceUID", "Img_in_context"))
374410

375-
pre_transforms: Compose = self._pre_transform
376-
post_transforms: Compose = self._post_transforms
411+
pre_transforms: Compose = self.pre_transforms
412+
post_transforms: Compose = self.post_transforms
377413
self._reader = InMemImageReader(input_image)
378414

379-
pre_transforms = self._pre_transform if self._pre_transform else self.pre_process(self._reader)
380-
post_transforms = self._post_transforms if self._post_transforms else self.post_process(pre_transforms)
415+
pre_transforms = self.pre_transforms if self.pre_transforms else self.pre_process(self._reader)
416+
post_transforms = self.post_transforms if self.post_transforms else self.post_process(pre_transforms)
381417

382418
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
383419
dataset = Dataset(data=[{self._input_dataset_key: img_name}], transform=pre_transforms)
@@ -389,23 +425,24 @@ def compute_impl(self, input_image, context):
389425
for d in dataloader:
390426
images = d[self._input_dataset_key].to(device)
391427
self._logger.info(f"Input of {type(images)} shape: {images.shape}")
392-
if self._inferer == InfererType.SLIDING_WINDOW:
428+
if self.inferer == InfererType.SLIDING_WINDOW:
393429
d[self._pred_dataset_key] = sliding_window_inference(
394430
inputs=images,
395-
roi_size=self._roi_size,
396-
sw_batch_size=self.sw_batch_size,
431+
roi_size=self.roi_size,
397432
overlap=self.overlap,
398-
mode=self._mode,
399-
padding_mode=self._padding_mode,
400-
predictor=self.model,
401-
**self._implicit_params, # additional sliding window arguments
433+
sw_batch_size=self.sw_batch_size,
434+
sw_device=self.sw_device,
435+
device=self.device,
436+
process_fn=self.process_fn,
437+
predictor=self._model,
438+
**self._implicit_params, # Additional sliding window arguments
402439
)
403-
elif self._inferer == InfererType.SIMPLE:
440+
elif self.inferer == InfererType.SIMPLE:
404441
# Instantiates the SimpleInferer and directly uses its __call__ function
405-
d[self._pred_dataset_key] = simple_inference()(inputs=images, network=self.model)
442+
d[self._pred_dataset_key] = simple_inference()(inputs=images, network=self._model)
406443
else:
407444
raise ValueError(
408-
f"Unknown inferer: {self._inferer!r}. Available options are "
445+
f"Unknown inferer: {self.inferer!r}. Available options are "
409446
f"{InfererType.SLIDING_WINDOW!r} and {InfererType.SIMPLE!r}."
410447
)
411448

0 commit comments

Comments
 (0)