Skip to content

Commit 6f24c51

Browse files
committed
Refactor ImageOverlayWriter documentation and enhance model_id validation
- Moved the documentation for ImageOverlayWriter into the class docstring for better organization and clarity. - Improved the model_id validation logic in AppGenerator to prevent code injection and path traversal, ensuring stricter input checks. - Updated the generated application template to reflect changes in the channel_first logic. - Added unit tests to verify the correctness of the refactored channel_first logic. Signed-off-by: Victor Chang <[email protected]>
1 parent 802691f commit 6f24c51

File tree

4 files changed

+90
-20
lines changed

4 files changed

+90
-20
lines changed

monai/deploy/operators/image_overlay_writer_operator.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,6 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
"""
13-
Image Overlay Writer
14-
15-
Blends a segmentation mask onto an RGB image and saves the result as a PNG.
16-
17-
Named inputs:
18-
- image: original RGB frame as Image or ndarray (HWC, uint8/float)
19-
- pred: predicted mask as Image or ndarray (H x W or 1 x H x W). If multi-channel
20-
probability tensor is provided, you may pre-argmax before this operator.
21-
- filename: base name (stem) for output file
22-
"""
23-
2412
import logging
2513
from pathlib import Path
2614
from typing import Optional, Tuple
@@ -34,6 +22,17 @@
3422

3523

3624
class ImageOverlayWriter(Operator):
25+
"""
26+
Image Overlay Writer
27+
28+
Blends a segmentation mask onto an RGB image and saves the result as a PNG.
29+
30+
Named inputs:
31+
- image: original RGB frame as Image or ndarray (HWC, uint8/float)
32+
- pred: predicted mask as Image or ndarray (H x W or 1 x H x W). If multi-channel
33+
probability tensor is provided, you may pre-argmax before this operator.
34+
- filename: base name (stem) for output file
35+
"""
3736
def __init__(
3837
self,
3938
fragment: Fragment,

tools/pipeline-generator/pipeline_generator/generator/app_generator.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from ..config.settings import Settings, load_config
2121
from .bundle_downloader import BundleDownloader
2222

23+
import re
24+
2325
logger = logging.getLogger(__name__)
2426

2527

@@ -92,9 +94,14 @@ def generate_app(
9294
Returns:
9395
Path to the generated application directory
9496
"""
95-
# Validate model_id to prevent code injection
96-
if not model_id or not all(c.isalnum() or c in "/-_" for c in model_id):
97-
raise ValueError(f"Invalid model_id: {model_id}. Only alphanumeric characters, /, -, and _ are allowed.")
97+
# Validate model_id to prevent code injection and path traversal
98+
# Only allow model IDs like "owner/model-name" or "model_name", no leading/trailing slash, no "..", no empty segments
99+
model_id_pattern = r"^(?!.*\.\.)(?!/)(?!.*//)(?!.*\/$)[A-Za-z0-9_-]+(\/[A-Za-z0-9_-]+)*$"
100+
101+
if not model_id or not re.match(model_id_pattern, model_id):
102+
raise ValueError(
103+
f"Invalid model_id: {model_id}. Only alphanumeric characters, hyphens, underscores, and single slashes between segments are allowed. No leading/trailing slashes, consecutive slashes, or '..' allowed."
104+
)
98105

99106
# Create output directory
100107
output_dir.mkdir(parents=True, exist_ok=True)
@@ -250,6 +257,18 @@ def _prepare_context(
250257
elif isinstance(cfgs, dict):
251258
resolved_channel_first = cfgs.get("channel_first", None)
252259

260+
# Determine final channel_first value
261+
if resolved_channel_first is not None:
262+
# Use explicit override from configuration
263+
channel_first = resolved_channel_first
264+
else:
265+
# Apply default logic: False for image input classification, True otherwise
266+
input_type_resolved = input_type or ("dicom" if use_dicom else ("image" if use_image else "nifti"))
267+
if input_type_resolved == "image" and "classification" not in task.lower():
268+
channel_first = False
269+
else:
270+
channel_first = True
271+
253272
# Collect dependency hints from metadata.json
254273
required_packages_version = metadata.get("required_packages_version", {}) if metadata else {}
255274
extra_dependencies = getattr(model_config, "dependencies", []) if model_config else []
@@ -280,7 +299,7 @@ def _prepare_context(
280299
"authors": metadata.get("authors", "MONAI"),
281300
"output_postfix": output_postfix,
282301
"model_type": model_type,
283-
"channel_first_override": resolved_channel_first,
302+
"channel_first": channel_first,
284303
"required_packages_version": required_packages_version,
285304
"extra_dependencies": extra_dependencies,
286305
}

tools/pipeline-generator/pipeline_generator/templates/app.py.j2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ class {{ app_name }}(Application):
146146
loader_op = ImageDirectoryLoader(
147147
self,
148148
input_folder=app_input_path,
149-
channel_first={% if channel_first_override is not none %}{{ 'True' if channel_first_override else 'False' }}{% else %}{{ 'False' if input_type == 'image' and 'classification' not in task.lower() else 'True' }}{% endif %},
149+
channel_first={{ channel_first }},
150150
name="image_loader"
151151
)
152152
{% elif input_type == "custom" %}

tools/pipeline-generator/tests/test_generator.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,9 @@ def test_model_config_with_channel_first_override(self):
306306
data_format="auto",
307307
)
308308

309-
# This covers lines 201-210
309+
# Verify channel_first logic is computed correctly
310310
call_args = mock_app_py.call_args[0][1]
311-
assert call_args["channel_first_override"] is False
311+
assert call_args["channel_first"] is False
312312

313313
def test_metadata_with_numpy_pytorch_versions(self):
314314
"""Test metadata with numpy_version and pytorch_version."""
@@ -450,7 +450,59 @@ def test_model_config_with_dict_configs(self):
450450
)
451451

452452
call_args = mock_app_py.call_args[0][1]
453-
assert call_args["channel_first_override"] is True
453+
assert call_args["channel_first"] is True
454+
455+
def test_channel_first_logic_refactoring(self):
456+
"""Test the refactored channel_first logic works correctly."""
457+
generator = AppGenerator()
458+
459+
# Test case 1: image input, non-classification task -> should be False
460+
context1 = generator._prepare_context(
461+
model_id="test/model",
462+
metadata={"task": "segmentation", "name": "Test Model"},
463+
inference_config={},
464+
model_file=None,
465+
app_name="TestApp",
466+
input_type="image",
467+
output_type="nifti"
468+
)
469+
assert context1["channel_first"] is False
470+
471+
# Test case 2: image input, classification task -> should be True
472+
context2 = generator._prepare_context(
473+
model_id="test/model",
474+
metadata={"task": "classification", "name": "Test Model"},
475+
inference_config={},
476+
model_file=None,
477+
app_name="TestApp",
478+
input_type="image",
479+
output_type="json"
480+
)
481+
assert context2["channel_first"] is True
482+
483+
# Test case 3: dicom input -> should be True
484+
context3 = generator._prepare_context(
485+
model_id="test/model",
486+
metadata={"task": "segmentation", "name": "Test Model"},
487+
inference_config={},
488+
model_file=None,
489+
app_name="TestApp",
490+
input_type="dicom",
491+
output_type="nifti"
492+
)
493+
assert context3["channel_first"] is True
494+
495+
# Test case 4: nifti input -> should be True
496+
context4 = generator._prepare_context(
497+
model_id="test/model",
498+
metadata={"task": "segmentation", "name": "Test Model"},
499+
inference_config={},
500+
model_file=None,
501+
app_name="TestApp",
502+
input_type="nifti",
503+
output_type="nifti"
504+
)
505+
assert context4["channel_first"] is True
454506

455507
def test_get_default_metadata(self):
456508
"""Test _get_default_metadata method directly."""

0 commit comments

Comments
 (0)