Skip to content

Commit d0c1e33

Browse files
committed
Enhance bundle organization and improve model handling in pipeline generator
- Added functionality to ensure the bundle root is included in sys.path for script imports. - Introduced a new model configuration for pediatric abdominal CT segmentation in the config file. - Improved the organization of model files by preferring PyTorch models over TensorRT models and handling subdirectory structures. - Enhanced unit tests to verify the new model organization logic and ensure correct behavior under various scenarios. Signed-off-by: Victor Chang <[email protected]>
1 parent 502e352 commit d0c1e33

File tree

6 files changed

+272
-5
lines changed

6 files changed

+272
-5
lines changed

monai/deploy/operators/monai_bundle_inference_operator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,11 @@ def _init_config(self, config_names):
468468
config_names ([str]): Names of the config (files) in the bundle
469469
"""
470470

471+
# Ensure bundle root is on sys.path so 'scripts.*' can be imported
472+
bundle_root = str(self._bundle_path)
473+
if bundle_root not in sys.path:
474+
sys.path.insert(0, bundle_root)
475+
471476
parser = get_bundle_config(str(self._bundle_path), config_names)
472477
self._parser = parser
473478

tools/pipeline-generator/pipeline_generator/config/config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ endpoints:
4343
- model_id: "MONAI/pancreas_ct_dints_segmentation"
4444
input_type: "nifti"
4545
output_type: "nifti"
46+
- model_id: "MONAI/pediatric_abdominal_ct_segmentation"
47+
input_type: "nifti"
48+
output_type: "nifti"
49+
dependencies:
50+
- nibabel>=3.2.0 # Required for NIfTI file I/O support
51+
- itk>=5.3.0 # Required for ITK-based image readers/writers
4652
- model_id: "MONAI/Llama3-VILA-M3-3B"
4753
input_type: "custom"
4854
output_type: "custom"

tools/pipeline-generator/pipeline_generator/generator/app_generator.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -348,11 +348,21 @@ def _detect_data_format(self, inference_config: Dict[str, Any], modality: str) -
348348
# Check preprocessing transforms for hints
349349
if "preprocessing" in inference_config:
350350
transforms = inference_config["preprocessing"].get("transforms", [])
351-
for transform in transforms:
352-
target = transform.get("_target_", "")
353-
if "LoadImaged" in target or "LoadImage" in target:
354-
# This suggests NIfTI format
351+
# Handle case where transforms might be a string expression (e.g., "$@preprocessing_transforms + @deepedit_transforms")
352+
if isinstance(transforms, str):
353+
# If transforms is a string expression, we can't analyze it directly
354+
# Look for LoadImaged in the inference config keys instead
355+
config_str = str(inference_config)
356+
if "LoadImaged" in config_str or "LoadImage" in config_str:
355357
return False
358+
elif isinstance(transforms, list):
359+
for transform in transforms:
360+
# Ensure transform is a dictionary before calling .get()
361+
if isinstance(transform, dict):
362+
target = transform.get("_target_", "")
363+
if "LoadImaged" in target or "LoadImage" in target:
364+
# This suggests NIfTI format
365+
return False
356366

357367
# Default based on modality
358368
return modality in ["CT", "MR", "MRI"]

tools/pipeline-generator/pipeline_generator/generator/bundle_downloader.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,84 @@ def organize_bundle_structure(self, bundle_path: Path) -> None:
192192
logger.debug(f"Moved {config_file} to configs/")
193193

194194
# Move model files to models/
195-
model_extensions = [".pt", ".ts", ".onnx"]
195+
# Prefer PyTorch (.pt) > ONNX (.onnx) > TorchScript (.ts) for better compatibility
196+
model_extensions = [".pt", ".onnx", ".ts"]
197+
198+
# First move model files from root directory
196199
for ext in model_extensions:
197200
for model_file in bundle_path.glob(f"*{ext}"):
198201
if model_file.is_file() and not (models_dir / model_file.name).exists():
199202
model_file.rename(models_dir / model_file.name)
200203
logger.debug(f"Moved {model_file.name} to models/")
204+
205+
# Check if we already have a suitable model in the main directory
206+
# Prefer .pt files, then .onnx, then .ts
207+
has_suitable_model = False
208+
for ext in model_extensions:
209+
if any(models_dir.glob(f"*{ext}")):
210+
has_suitable_model = True
211+
break
212+
213+
# If no suitable model in main directory, move from subdirectories
214+
if not has_suitable_model:
215+
# Also move model files from subdirectories to the main models/ directory
216+
# This handles cases where models are in subdirectories like models/A100/
217+
# Prefer PyTorch models over TensorRT models for better compatibility
218+
for ext in model_extensions:
219+
model_files = list(models_dir.glob(f"**/*{ext}"))
220+
if not model_files:
221+
continue
222+
223+
# Filter files that are not in the main models directory
224+
subdirectory_files = [f for f in model_files if f.parent != models_dir]
225+
if not subdirectory_files:
226+
continue
227+
228+
target_name = f"model{ext}"
229+
target_path = models_dir / target_name
230+
if target_path.exists():
231+
continue # Target already exists
232+
233+
# Prefer non-TensorRT models for better compatibility
234+
# TensorRT models often have "_trt" in their name
235+
preferred_file = None
236+
for model_file in subdirectory_files:
237+
if "_trt" not in model_file.name.lower():
238+
preferred_file = model_file
239+
break
240+
241+
# If no non-TensorRT model found, use the first available
242+
if preferred_file is None:
243+
preferred_file = subdirectory_files[0]
244+
245+
# Move the preferred model file
246+
preferred_file.rename(target_path)
247+
logger.debug(f"Moved {preferred_file.name} from {preferred_file.parent.name}/ to models/{target_name}")
248+
249+
# Clean up empty subdirectory if it exists
250+
try:
251+
if preferred_file.parent.exists() and not any(preferred_file.parent.iterdir()):
252+
preferred_file.parent.rmdir()
253+
logger.debug(f"Removed empty directory {preferred_file.parent}")
254+
except OSError:
255+
pass # Directory not empty or other issue
256+
break # Only move one model file total
257+
258+
# Ensure we have model.pt or model.ts in the main directory for MONAI Deploy
259+
# Create symlinks with standard names if needed
260+
standard_model_path = models_dir / "model.pt"
261+
if not standard_model_path.exists():
262+
# Look for any .pt file to link to model.pt
263+
pt_files = list(models_dir.glob("*.pt"))
264+
if pt_files:
265+
# Create a copy with the standard name
266+
pt_files[0].rename(standard_model_path)
267+
logger.debug(f"Renamed {pt_files[0].name} to model.pt")
268+
else:
269+
# No .pt file found, look for .ts file and create model.ts instead
270+
standard_ts_path = models_dir / "model.ts"
271+
if not standard_ts_path.exists():
272+
ts_files = list(models_dir.glob("*.ts"))
273+
if ts_files:
274+
ts_files[0].rename(standard_ts_path)
275+
logger.debug(f"Renamed {ts_files[0].name} to model.ts")

tools/pipeline-generator/tests/test_bundle_downloader.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,3 +339,134 @@ def test_get_inference_config_logs_error(self, mock_logger, tmp_path):
339339

340340
assert result is None
341341
mock_logger.error.assert_called()
342+
343+
def test_organize_bundle_structure_subdirectory_models(self, tmp_path):
344+
"""Test organizing models from subdirectories to main models/ directory."""
345+
bundle_path = tmp_path / "bundle"
346+
models_dir = bundle_path / "models"
347+
subdir = models_dir / "A100"
348+
subdir.mkdir(parents=True)
349+
350+
# Create model file in subdirectory
351+
subdir_model = subdir / "dynunet_FT_trt_16.ts"
352+
subdir_model.write_text("tensorrt model")
353+
354+
# Organize structure
355+
self.downloader.organize_bundle_structure(bundle_path)
356+
357+
# Model should be moved to main models/ directory with standard name
358+
assert (models_dir / "model.ts").exists()
359+
assert not subdir_model.exists()
360+
assert not subdir.exists() # Empty subdirectory should be removed
361+
362+
def test_organize_bundle_structure_prefers_pytorch_over_tensorrt(self, tmp_path):
363+
"""Test that PyTorch models are preferred over TensorRT models."""
364+
bundle_path = tmp_path / "bundle"
365+
models_dir = bundle_path / "models"
366+
subdir = models_dir / "A100"
367+
subdir.mkdir(parents=True)
368+
369+
# Create both PyTorch and TensorRT models in subdirectory
370+
pytorch_model = subdir / "dynunet_FT.pt"
371+
tensorrt_model = subdir / "dynunet_FT_trt_16.ts"
372+
pytorch_model.write_bytes(b"pytorch model")
373+
tensorrt_model.write_text("tensorrt model")
374+
375+
# Organize structure
376+
self.downloader.organize_bundle_structure(bundle_path)
377+
378+
# PyTorch model should be preferred and moved
379+
assert (models_dir / "model.pt").exists()
380+
assert not (models_dir / "model.ts").exists()
381+
assert not pytorch_model.exists()
382+
# TensorRT model should remain in subdirectory
383+
assert tensorrt_model.exists()
384+
385+
def test_organize_bundle_structure_standard_naming_pytorch(self, tmp_path):
386+
"""Test renaming PyTorch models to standard names."""
387+
bundle_path = tmp_path / "bundle"
388+
models_dir = bundle_path / "models"
389+
models_dir.mkdir(parents=True)
390+
391+
# Create PyTorch model with custom name
392+
custom_model = models_dir / "dynunet_FT.pt"
393+
custom_model.write_bytes(b"pytorch model")
394+
395+
# Organize structure
396+
self.downloader.organize_bundle_structure(bundle_path)
397+
398+
# Model should be renamed to standard name
399+
assert (models_dir / "model.pt").exists()
400+
assert not custom_model.exists()
401+
402+
def test_organize_bundle_structure_standard_naming_torchscript(self, tmp_path):
403+
"""Test renaming TorchScript models to standard names when no PyTorch model exists."""
404+
bundle_path = tmp_path / "bundle"
405+
models_dir = bundle_path / "models"
406+
models_dir.mkdir(parents=True)
407+
408+
# Create only TorchScript model with custom name
409+
custom_model = models_dir / "custom_model.ts"
410+
custom_model.write_text("torchscript model")
411+
412+
# Organize structure
413+
self.downloader.organize_bundle_structure(bundle_path)
414+
415+
# Model should be renamed to standard name
416+
assert (models_dir / "model.ts").exists()
417+
assert not custom_model.exists()
418+
419+
def test_organize_bundle_structure_skips_when_suitable_model_exists(self, tmp_path):
420+
"""Test that subdirectory organization is skipped when suitable model already exists."""
421+
bundle_path = tmp_path / "bundle"
422+
models_dir = bundle_path / "models"
423+
subdir = models_dir / "A100"
424+
subdir.mkdir(parents=True)
425+
426+
# Create model in main directory
427+
main_model = models_dir / "existing_model.pt"
428+
main_model.write_bytes(b"existing pytorch model")
429+
430+
# Create model in subdirectory
431+
subdir_model = subdir / "dynunet_FT_trt_16.ts"
432+
subdir_model.write_text("tensorrt model")
433+
434+
# Organize structure
435+
self.downloader.organize_bundle_structure(bundle_path)
436+
437+
# Main model should be renamed to standard name
438+
assert (models_dir / "model.pt").exists()
439+
assert not main_model.exists()
440+
441+
# Subdirectory model should remain untouched
442+
assert subdir_model.exists()
443+
assert subdir.exists()
444+
445+
def test_organize_bundle_structure_multiple_extensions_preference(self, tmp_path):
446+
"""Test extension preference order: .pt > .onnx > .ts."""
447+
bundle_path = tmp_path / "bundle"
448+
models_dir = bundle_path / "models"
449+
subdir = models_dir / "A100"
450+
subdir.mkdir(parents=True)
451+
452+
# Create models with different extensions in subdirectory
453+
pt_model = subdir / "model.pt"
454+
onnx_model = subdir / "model.onnx"
455+
ts_model = subdir / "model.ts"
456+
457+
pt_model.write_bytes(b"pytorch model")
458+
onnx_model.write_bytes(b"onnx model")
459+
ts_model.write_text("torchscript model")
460+
461+
# Organize structure
462+
self.downloader.organize_bundle_structure(bundle_path)
463+
464+
# Should prefer .pt model
465+
assert (models_dir / "model.pt").exists()
466+
assert not (models_dir / "model.onnx").exists()
467+
assert not (models_dir / "model.ts").exists()
468+
assert not pt_model.exists()
469+
470+
# Other models should remain in subdirectory
471+
assert onnx_model.exists()
472+
assert ts_model.exists()

tools/pipeline-generator/tests/test_generator.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,46 @@ def test_inference_config_with_loadimage_transform(self):
580580
result = generator._detect_data_format(inference_config, "CT")
581581
assert result is False
582582

583+
def test_inference_config_with_string_transforms(self):
584+
"""Test _detect_data_format with string transforms expression."""
585+
generator = AppGenerator()
586+
587+
# Create inference config with string transforms (like spleen_deepedit_annotation)
588+
inference_config = {
589+
"preprocessing": {
590+
"_target_": "Compose",
591+
"transforms": "$@preprocessing_transforms + @deepedit_transforms + @extra_transforms"
592+
},
593+
"preprocessing_transforms": [
594+
{"_target_": "LoadImaged", "keys": "image"},
595+
{"_target_": "EnsureChannelFirstd", "keys": "image"}
596+
]
597+
}
598+
599+
# This should return False (NIfTI format) because LoadImaged is found in config string
600+
result = generator._detect_data_format(inference_config, "CT")
601+
assert result is False
602+
603+
def test_inference_config_with_string_transforms_no_loadimage(self):
604+
"""Test _detect_data_format with string transforms expression without LoadImaged."""
605+
generator = AppGenerator()
606+
607+
# Create inference config with string transforms but no LoadImaged
608+
inference_config = {
609+
"preprocessing": {
610+
"_target_": "Compose",
611+
"transforms": "$@preprocessing_transforms + @other_transforms"
612+
},
613+
"preprocessing_transforms": [
614+
{"_target_": "SomeOtherTransform", "keys": "image"},
615+
{"_target_": "EnsureChannelFirstd", "keys": "image"}
616+
]
617+
}
618+
619+
# This should return True (DICOM format) for CT modality when no LoadImaged found
620+
result = generator._detect_data_format(inference_config, "CT")
621+
assert result is True
622+
583623
def test_detect_model_type_pathology(self):
584624
"""Test _detect_model_type for pathology models."""
585625
generator = AppGenerator()

0 commit comments

Comments
 (0)