Skip to content

Commit 3921f6d

Browse files
Load Config Options from Bundle (#1302)
* add bundle configs option support Signed-off-by: tangy5 <[email protected]> * add bundle configs option support Signed-off-by: tangy5 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update with merged bundle Signed-off-by: tangy5 <[email protected]> * small fix Signed-off-by: tangy5 <[email protected]> * small fix Signed-off-by: tangy5 <[email protected]> * small fix Signed-off-by: tangy5 <[email protected]> * sanity checks Signed-off-by: tangy5 <[email protected]> * update Signed-off-by: tangy5 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: tangy5 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d34b9c1 commit 3921f6d

File tree

5 files changed

+55
-4
lines changed

5 files changed

+55
-4
lines changed

monailabel/tasks/infer/basic_infer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,11 @@ def __call__(
267267
req = copy.deepcopy(self._config)
268268
req.update(request)
269269

270+
# model options
271+
self.path.append(
272+
os.path.join(os.path.dirname(self.path[0]), req.get("model_filename", "model.pt"))
273+
) if self.path and isinstance(self.path, list) else self.path
274+
270275
# device
271276
device = req.get("device", "cuda")
272277
device = device if isinstance(device, str) else device[0]

monailabel/tasks/infer/bundle.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
import glob
1213
import json
1314
import logging
1415
import os
@@ -66,6 +67,9 @@ def key_detector(self) -> Sequence[str]:
6667
def key_detector_ops(self) -> Sequence[str]:
6768
return ["detector_ops"]
6869

70+
def key_displayable_configs(self) -> Sequence[str]:
71+
return ["displayable_configs"]
72+
6973

7074
class BundleInferTask(BasicInferTask):
7175
"""
@@ -147,6 +151,17 @@ def __init__(
147151
preload=strtobool(conf.get("preload", "false")),
148152
**kwargs,
149153
)
154+
155+
# Add models options if more than one model is provided by bundle.
156+
pytorch_models = [os.path.basename(p) for p in glob.glob(os.path.join(path, "models", "*.pt"))]
157+
pytorch_models.sort(key=len)
158+
self._config.update({"model_filename": pytorch_models})
159+
# Add bundle's loadable params to MONAI Label config, load exposed keys and params to options panel
160+
for k in self.const.key_displayable_configs():
161+
if self.bundle_config.get(k):
162+
self.displayable_configs = self.bundle_config.get_parsed_content(k, instantiate=True) # type: ignore
163+
self._config.update(self.displayable_configs)
164+
150165
self.valid = True
151166
self.version = metadata.get("version")
152167
sys.path.remove(self.bundle_path)
@@ -160,6 +175,12 @@ def info(self) -> Dict[str, Any]:
160175
return i
161176

162177
def pre_transforms(self, data=None) -> Sequence[Callable]:
178+
# Update bundle parameters based on user's option
179+
for k in self.const.key_displayable_configs():
180+
if self.bundle_config.get(k):
181+
self.bundle_config[k].update({c: data[c] for c in self.displayable_configs.keys()})
182+
self.bundle_config.parse()
183+
163184
sys.path.insert(0, self.bundle_path)
164185
unload_module("scripts")
165186
self._update_device(data)

monailabel/tasks/train/bundle.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
import glob
1213
import json
1314
import logging
1415
import os
@@ -79,6 +80,9 @@ def key_experiment_name(self) -> str:
7980
def key_run_name(self) -> str:
8081
return "run_name"
8182

83+
def key_displayable_configs(self) -> Sequence[str]:
84+
return ["displayable_configs"]
85+
8286

8387
class BundleTrainTask(TrainTask):
8488
def __init__(
@@ -127,7 +131,11 @@ def info(self):
127131
return i
128132

129133
def config(self):
130-
return {
134+
# Add models and param optiom to train option panel
135+
pytorch_models = [os.path.basename(p) for p in glob.glob(os.path.join(self.bundle_path, "models", "*.pt"))]
136+
pytorch_models.sort(key=len)
137+
138+
config_options = {
131139
"device": "cuda", # DEVICE
132140
"pretrained": True, # USE EXISTING CHECKPOINT/PRETRAINED MODEL
133141
"max_epochs": 50, # TOTAL EPOCHS TO RUN
@@ -139,8 +147,15 @@ def config(self):
139147
else ["None", "mlflow"],
140148
"tracking_uri": settings.MONAI_LABEL_TRACKING_URI,
141149
"tracking_experiment_name": "",
150+
"model_filename": pytorch_models,
142151
}
143152

153+
for k in self.const.key_displayable_configs():
154+
if self.bundle_config.get(k):
155+
config_options.update(self.bundle_config.get_parsed_content(k, instantiate=True)) # type: ignore
156+
157+
return config_options
158+
144159
def _fetch_datalist(self, request, datastore: Datastore):
145160
datalist = datastore.datalist()
146161

@@ -177,8 +192,8 @@ def _partition_datalist(self, datalist, request, shuffle=False):
177192
logger.info(f"Total Records for Validation: {len(val_datalist) if val_datalist else ''}")
178193
return train_datalist, val_datalist
179194

180-
def _load_checkpoint(self, output_dir, pretrained, train_handlers):
181-
load_path = os.path.join(output_dir, self.const.model_pytorch()) if pretrained else None
195+
def _load_checkpoint(self, model_pytorch, pretrained, train_handlers):
196+
load_path = model_pytorch if pretrained else None
182197
if os.path.exists(load_path):
183198
logger.info(f"Add Checkpoint Loader for Path: {load_path}")
184199

@@ -226,7 +241,9 @@ def __call__(self, request, datastore: Datastore):
226241
logger.info(f"(Experiment Management) Run Name: {tracking_run_name}")
227242

228243
train_handlers = self.bundle_config.get(self.const.key_train_handlers(), [])
229-
self._load_checkpoint(os.path.join(self.bundle_path, "models"), pretrained, train_handlers)
244+
245+
model_pytorch = os.path.join(self.bundle_path, "models", request.get("model_filename", "model.pt"))
246+
self._load_checkpoint(model_pytorch, pretrained, train_handlers)
230247

231248
overrides = {
232249
self.const.key_bundle_root(): self.bundle_path,
@@ -236,6 +253,12 @@ def __call__(self, request, datastore: Datastore):
236253
self.const.key_train_handlers(): train_handlers,
237254
}
238255

256+
# update config options from user
257+
for k in self.const.key_displayable_configs():
258+
if self.bundle_config.get(k):
259+
displayable_configs = self.bundle_config.get_parsed_content(k, instantiate=True)
260+
overrides[k] = {c: request[c] for c in displayable_configs.keys()}
261+
239262
if tracking and tracking.lower() != "none":
240263
overrides[self.const.key_tracking()] = tracking
241264
if tracking_uri:

monailabel/utils/others/modelzoo_list.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,5 @@
3131
"swin_unetr_btcv_segmentation", # 3D transformer model for multi-organ segmentation. Added Oct 2022
3232
"wholeBrainSeg_Large_UNEST_segmentation", # whole brain segmentation for T1 MRI brain images. Added Oct 2022
3333
"lung_nodule_ct_detection", # The first lung nodule detection task can be used for MONAI Label. Added Dec 2022
34+
"wholeBody_ct_segmentation", # The SegResNet trained TotalSegmentator dataset with 104 tissues. Added Feb 2023
3435
]

sample-apps/monaibundle/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ The Bundle App supports most labeling models in the Model Zoo, please see the ta
4646
| [wholeBrainSeg_UNEST_segmentation](https://github.com/Project-MONAI/model-zoo/tree/dev/models/wholeBrainSeg_Large_UNEST_segmentation) | UNesT | Whole Brain | MRI T1 | A pre-trained for inference (3D) 133 whole brain structures segmentation |
4747
| [spleen_deepedit_annotation](https://github.com/Project-MONAI/model-zoo/tree/dev/models/spleen_deepedit_annotation) | DeepEdit | Spleen| CT | An interactive method for 3D spleen Segmentation |
4848
| [lung_nodule_ct_detection](https://github.com/Project-MONAI/model-zoo/tree/dev/models/lung_nodule_ct_detection) | RetinaNet | Lung Nodule| CT | The detection model for 3D CT images |
49+
| [wholeBody_ct_segmentation](https://github.com/Project-MONAI/model-zoo/tree/dev/models/wholeBody_ct_segmentation) | SegResNet | 104 body structures| CT | The segmentation model for 104 tissue from 3D CT images (TotalSegmentator Dataset) |
4950

5051
Supported tasks update based on [Model-Zoo](https://github.com/Project-MONAI/model-zoo/tree/dev/models) release.
5152

0 commit comments

Comments
 (0)