Skip to content

Commit c242bdf

Browse files
committed
license update + code optimizations
Signed-off-by: bluna301 <[email protected]>
1 parent b1f28c5 commit c242bdf

17 files changed

+195
-161
lines changed

examples/apps/cchmc_ped_abd_ct_seg_app/README.md

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ For questions, please feel free to contact Elan Somasundaram (Elanchezhian.Somas
99
## Unique Features
1010

1111
Some unique features of this MAP pipeline include:
12-
- **Custom Inference Operator:** custom `AbdomenSegOperator` enables either PyTorch or TorchScript model loading as desired
12+
- **Custom Inference Operator:** custom `AbdomenSegOperator` enables either PyTorch or TorchScript model loading
1313
- **DICOM Secondary Capture Output:** custom `DICOMSecondaryCaptureWriterOperator` writes a DICOM SC with organ contours
1414
- **Output Filtering:** model produces Liver-Spleen-Pancreas segmentations, but seg visibility in the outputs (DICOM SEG, SC, SR) can be controlled in `app.py`
1515
- **MONAI Deploy Express MongoDB Write:** custom operators (`MongoDBEntryCreatorOperator` and `MongoDBWriterOperator`) allow for writing to the MongoDB database associated with MONAI Deploy Express
@@ -20,9 +20,18 @@ Several scripts have been compiled that quickly execute useful actions (such as
2020
## Notes
2121
The DICOM Series selection criteria has been customized based on the model's training and CCHMC use cases. By default, Axial CT series with Slice Thickness between 3.0 - 5.0 mm (inclusive) will be selected for.
2222

23-
If PyTorch model loading is desired, please uncomment the "PyTorch Model Loading" section in the `AbdomenSegOperator`.
23+
If MongoDB writing is not desired, please comment out the relevant sections in `app.py` and the `AbdomenSegOperator`.
2424

25-
If MongoDB writing is desired, please uncomment the relevant sections in `app.py` and the `AbdomenSegOperator`. Please note that MongoDB connection values (username, password, and port) are the default values pulled from the v0.6.0 MONAI Deploy Express [.env](https://github.com/Project-MONAI/monai-deploy/blob/main/deploy/monai-deploy-express/.env) and [docker-compose.yaml](https://github.com/Project-MONAI/monai-deploy/blob/main/deploy/monai-deploy-express/docker-compose.yml) files; these default values are harcoded into the `MongoDBWriterOperator`. If your instance of MONAI Deploy Express has modified values for these fields, the `MongoDBWriterOperator` will need to be udpated accordingly.
25+
To execute the pipeline with MongoDB writing enabled, it is best to create a `.env` file that the `MongoDBWriterOperator` can load in. Below is an example `.env` file that follows the format outlined in this operator; note that these values are the default variable values as defined in the [.env](https://github.com/Project-MONAI/monai-deploy/blob/main/deploy/monai-deploy-express/.env) and [docker-compose.yaml](https://github.com/Project-MONAI/monai-deploy/blob/main/deploy/monai-deploy-express/docker-compose.yml) files of v0.6.0 of MONAI Deploy Express:
26+
27+
```dotenv
28+
MONGODB_USERNAME=root
29+
MONGODB_PASSWORD=rootpassword
30+
MONGODB_PORT=27017
31+
MONGODB_IP_DOCKER=172.17.0.1 # default Docker bridge network IP
32+
```
33+
34+
Prior to packaging into a MAP, the MongoDB credentials should be harcoded into the `MongoDBWriterOperator`.
2635

2736
The MONAI Deploy Express MongoDB Docker container (`mdl-mongodb`) needs to be connected to the Docker bridge network in order for the MongoDB write to be successful. Executing the following command in a MONAI Deploy Express terminal will establish this connection:
2837

examples/apps/cchmc_ped_abd_ct_seg_app/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2021-2024 MONAI Consortium
1+
# Copyright 2021-2025 MONAI Consortium
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at

examples/apps/cchmc_ped_abd_ct_seg_app/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2021-2024 MONAI Consortium
1+
# Copyright 2021-2025 MONAI Consortium
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at

examples/apps/cchmc_ped_abd_ct_seg_app/abdomen_seg_operator.py

Lines changed: 58 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2021-2024 MONAI Consortium
1+
# Copyright 2021-2025 MONAI Consortium
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -13,12 +13,14 @@
1313
from pathlib import Path
1414
from typing import List
1515

16+
import torch
1617
from numpy import float32, int16
1718

1819
# import custom transforms from post_transforms.py
1920
from post_transforms import CalculateVolumeFromMaskd, ExtractVolumeToTextd, LabelToContourd, OverlayImageLabeld
2021

21-
from monai.deploy.core import AppContext, Fragment, Operator, OperatorSpec
22+
import monai
23+
from monai.deploy.core import AppContext, Fragment, Model, Operator, OperatorSpec
2224
from monai.deploy.operators.monai_seg_inference_operator import InfererType, InMemImageReader, MonaiSegInferenceOperator
2325
from monai.transforms import (
2426
Activationsd,
@@ -36,11 +38,6 @@
3638
Spacingd,
3739
)
3840

39-
# # PyTorch model pipeline dependencies
40-
# import torch
41-
# import monai
42-
# from monai.deploy.core import Model
43-
4441

4542
# this operator performs inference with the new version of the bundle
4643
class AbdomenSegOperator(Operator):
@@ -95,6 +92,48 @@ def _find_model_file_path(self, model_path: Path):
9592

9693
raise ValueError(f"Model file not found in the provided path: {model_path}")
9794

95+
# load a PyTorch model and register it in app_context
96+
def _load_pytorch_model(self):
97+
98+
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
99+
_kernel_size: tuple = (3, 3, 3, 3, 3, 3)
100+
_strides: tuple = (1, 2, 2, 2, 2, (2, 2, 1))
101+
_upsample_kernel_size: tuple = (2, 2, 2, 2, (2, 2, 1))
102+
103+
# create DynUNet model with the specified architecture parameters + move to computational device (GPU or CPU)
104+
# parameters pulled from inference.yaml file of the MONAI bundle
105+
model = monai.networks.nets.dynunet.DynUNet(
106+
spatial_dims=3,
107+
in_channels=1,
108+
out_channels=4,
109+
kernel_size=_kernel_size,
110+
strides=_strides,
111+
upsample_kernel_size=_upsample_kernel_size,
112+
norm_name="INSTANCE",
113+
deep_supervision=False,
114+
res_block=True,
115+
).to(_device)
116+
117+
# load model state dictionary (i.e. mapping param names to tensors) via torch.load
118+
# weights_only=True to avoid arbitrary code execution during unpickling
119+
state_dict = torch.load(self.model_path, weights_only=True)
120+
121+
# assign loaded weights to model architecture via load_state_dict
122+
model.load_state_dict(state_dict)
123+
124+
# set model in evaluation (inference) mode
125+
model.eval()
126+
127+
# create a MONAI Model object to encapsulate the PyTorch model and metadata
128+
loaded_model = Model(self.model_path, name="ped_abd_ct_seg")
129+
130+
# assign loaded PyTorch model as the predictor for the Model object
131+
loaded_model.predictor = model
132+
133+
# register the loaded Model object in the application context so other operators can access it
134+
# MonaiSegInferenceOperator uses _get_model method to load models; looks at app_context.models first
135+
self.app_context.models = loaded_model
136+
98137
def setup(self, spec: OperatorSpec):
99138
spec.input(self.input_name_image)
100139

@@ -104,8 +143,8 @@ def setup(self, spec: OperatorSpec):
104143
# DICOM SR
105144
spec.output(self.output_name_text_dicom_sr)
106145

107-
# # MongoDB
108-
# spec.output(self.output_name_text_mongodb)
146+
# MongoDB
147+
spec.output(self.output_name_text_mongodb)
109148

110149
# DICOM SC
111150
spec.output(self.output_name_sc_path)
@@ -122,50 +161,14 @@ def compute(self, op_input, op_output, context):
122161
pre_transforms = self.pre_process(_reader)
123162
post_transforms = self.post_process(pre_transforms)
124163

125-
##########
126-
127-
# # PyTorch Model Loading:
128-
129-
# _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
130-
# _kernel_size: tuple = (3, 3, 3, 3, 3, 3)
131-
# _strides: tuple = (1, 2, 2, 2, 2, (2, 2, 1))
132-
# _upsample_kernel_size: tuple = (2, 2, 2, 2, (2, 2, 1))
133-
134-
# # create DynUNet model with the specified architecture parameters + move to computational device (GPU or CPU)
135-
# # parameters pulled from inference.yaml file of the MONAI bundle
136-
# model = monai.networks.nets.dynunet.DynUNet(
137-
# spatial_dims=3,
138-
# in_channels=1,
139-
# out_channels=4,
140-
# kernel_size=_kernel_size,
141-
# strides=_strides,
142-
# upsample_kernel_size=_upsample_kernel_size,
143-
# norm_name="INSTANCE",
144-
# deep_supervision=False,
145-
# res_block=True
146-
# ).to(_device)
147-
148-
# # load model state dictionary (i.e. mapping param names to tensors) via torch.load
149-
# # weights_only=True to avoid arbitrary code execution during unpickling
150-
# state_dict = torch.load(self.model_path, weights_only=True)
151-
152-
# # assign loaded weights to model architecture via load_state_dict
153-
# model.load_state_dict(state_dict)
154-
155-
# # set model in evaluation (inference) mode
156-
# model.eval()
157-
158-
# # create a MONAI Model object to encapsulate the PyTorch model and metadata
159-
# loaded_model = Model(self.model_path, name="ped_abd_ct_seg")
160-
161-
# # assign loaded PyTorch model as the predictor for the Model object
162-
# loaded_model.predictor = model
163-
164-
# # register the loaded Model object in the application context so other operators can access it
165-
# # MonaiSegInferenceOperator uses _get_model method to load models; looks at app_context.models first
166-
# self.app_context.models = loaded_model
167-
168-
##########
164+
# if PyTorch model
165+
if self.model_path.suffix.lower() == ".pt":
166+
# load the PyTorch model
167+
self._logger.info("PyTorch model detected")
168+
self._load_pytorch_model()
169+
# else, we have TorchScript model
170+
else:
171+
self._logger.info("TorchScript model detected")
169172

170173
# delegates inference and saving output to the built-in operator.
171174
infer_operator = MonaiSegInferenceOperator(
@@ -202,8 +205,8 @@ def compute(self, op_input, op_output, context):
202205
# DICOM SR
203206
op_output.emit(result_text_dicom_sr, self.output_name_text_dicom_sr)
204207

205-
# # MongoDB
206-
# op_output.emit(result_text_mongodb, self.output_name_text_mongodb)
208+
# MongoDB
209+
op_output.emit(result_text_mongodb, self.output_name_text_mongodb)
207210

208211
# DICOM SC
209212
# temporary DICOM SC (w/o source DICOM metadata) saved in output_folder / temp directory

examples/apps/cchmc_ped_abd_ct_seg_app/app.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2021-2024 MONAI Consortium
1+
# Copyright 2021-2025 MONAI Consortium
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -18,28 +18,22 @@
1818
# custom DICOM Secondary Capture (SC) writer operator
1919
from dicom_sc_writer_operator import DICOMSCWriterOperator
2020

21+
# custom MongoDB operators
22+
from mongodb_entry_creator_operator import MongoDBEntryCreatorOperator
23+
from mongodb_writer_operator import MongoDBWriterOperator
24+
2125
# required for setting SegmentDescription attributes
2226
# direct import as this is not part of App SDK package
2327
from pydicom.sr.codedict import codes
2428

2529
from monai.deploy.conditions import CountCondition
2630
from monai.deploy.core import Application
27-
28-
# DICOM operators
2931
from monai.deploy.operators.dicom_data_loader_operator import DICOMDataLoaderOperator
30-
31-
# DICOM writer operators
3232
from monai.deploy.operators.dicom_seg_writer_operator import DICOMSegmentationWriterOperator, SegmentDescription
33-
34-
# custom DICOMSeriesSelectorOperator
3533
from monai.deploy.operators.dicom_series_selector_operator import DICOMSeriesSelectorOperator
3634
from monai.deploy.operators.dicom_series_to_volume_operator import DICOMSeriesToVolumeOperator
3735
from monai.deploy.operators.dicom_text_sr_writer_operator import DICOMTextSRWriterOperator, EquipmentInfo, ModelInfo
3836

39-
# # MongoDB operators
40-
# from mongodb_entry_creator_operator import MongoDBEntryCreatorOperator
41-
# from mongodb_writer_operator import MongoDBWriterOperator
42-
4337

4438
# inherit new Application class instance, AIAbdomenSegApp, from MONAI Application base class
4539
# base class provides support for chaining up operators and executing application
@@ -195,23 +189,16 @@ def compose(self):
195189
output_folder=app_output_path / "SC",
196190
)
197191

198-
# # MongoDB database, collection, and MAP version info
199-
# database_name = "CTLiverSpleenSegPredictions"
200-
# collection_name = "OrganVolumes"
201-
# map_version = "1.0.0"
192+
# MongoDB database, collection, and MAP version info
193+
database_name = "CTLiverSpleenSegPredictions"
194+
collection_name = "OrganVolumes"
195+
map_version = "1.0.0"
202196

203-
# # custom MongoDB Entry Creator op
204-
# mongodb_entry_creator = MongoDBEntryCreatorOperator(
205-
# self,
206-
# map_version=map_version
207-
# )
197+
# custom MongoDB Entry Creator op
198+
mongodb_entry_creator = MongoDBEntryCreatorOperator(self, map_version=map_version)
208199

209-
# # custom MongoDB Writer op
210-
# mongodb_writer = MongoDBWriterOperator(
211-
# self,
212-
# database_name=database_name,
213-
# collection_name=collection_name
214-
# )
200+
# custom MongoDB Writer op
201+
mongodb_writer = MongoDBWriterOperator(self, database_name=database_name, collection_name=collection_name)
215202

216203
# create the processing pipeline, by specifying the source and destination operators, and
217204
# ensuring the output from the former matches the input of the latter, in both name and type
@@ -243,10 +230,12 @@ def compose(self):
243230
)
244231
self.add_flow(abd_seg_op, dicom_sc_writer, {("dicom_sc_dir", "dicom_sc_dir")})
245232

246-
# # MongoDB
247-
# self.add_flow(series_selector_op, mongodb_entry_creator, {("study_selected_series_list", "study_selected_series_list")})
248-
# self.add_flow(abd_seg_op, mongodb_entry_creator, {("result_text_mongodb", "text")})
249-
# self.add_flow(mongodb_entry_creator, mongodb_writer, {("mongodb_database_entry", "mongodb_database_entry")})
233+
# MongoDB
234+
self.add_flow(
235+
series_selector_op, mongodb_entry_creator, {("study_selected_series_list", "study_selected_series_list")}
236+
)
237+
self.add_flow(abd_seg_op, mongodb_entry_creator, {("result_text_mongodb", "text")})
238+
self.add_flow(mongodb_entry_creator, mongodb_writer, {("mongodb_database_entry", "mongodb_database_entry")})
250239

251240
logging.info(f"End {self.compose.__name__}")
252241

examples/apps/cchmc_ped_abd_ct_seg_app/app.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2021-2024 MONAI Consortium
1+
# Copyright 2021-2025 MONAI Consortium
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at

examples/apps/cchmc_ped_abd_ct_seg_app/dicom_sc_writer_operator.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2021-2024 MONAI Consortium
1+
# Copyright 2021-2025 MONAI Consortium
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -197,13 +197,8 @@ def write(self, dicom_sc_dir, dicom_series: DICOMSeries, output_dir: Path):
197197

198198
output_dir.mkdir(parents=True, exist_ok=True) # just in case
199199

200-
# find the temporary DICOM SC file in the directory (assuming there's only one .dcm file)
200+
# find the temporary DICOM SC file in the directory; there should only be one .dcm file present
201201
dicom_files = list(dicom_sc_dir.glob("*.dcm"))
202-
if len(dicom_files) == 0:
203-
raise FileNotFoundError(f"No DICOM files found in the directory: {dicom_sc_dir}")
204-
if len(dicom_files) > 1:
205-
self.logger.warning(f"Multiple DICOM files found, using the first one: {dicom_files[0]}")
206-
207202
dicom_sc_file = dicom_files[0]
208203

209204
# load the temporary DICOM SC file using pydicom

0 commit comments

Comments
 (0)