Skip to content

Commit 2cd6d0b

Browse files
Update infer bundle class to create deepedit infer class (#1527)
* Update infer bundle class to create deepedit infer class Signed-off-by: Andres <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update for pre_commit to pass Signed-off-by: Andres <[email protected]> * Add background as label Signed-off-by: Andres <[email protected]> * Update for pre_commit to pass Signed-off-by: Andres <[email protected]> * Update bundle infer class for deepedit Signed-off-by: Andres <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update key name for preprocessing Signed-off-by: Andres <[email protected]> * Update key name for preprocessing and postprocessing Signed-off-by: Andres <[email protected]> --------- Signed-off-by: Andres <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c37627f commit 2cd6d0b

File tree

3 files changed

+32
-9
lines changed

3 files changed

+32
-9
lines changed

monailabel/tasks/infer/bundle.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def __init__(
109109
self.bundle_path = path
110110
self.bundle_config_path = os.path.join(path, "configs", config_paths[0])
111111
self.bundle_config = self._load_bundle_config(self.bundle_path, self.bundle_config_path)
112+
# For deepedit inferer - allow the use of clicks
113+
self.bundle_config.config["use_click"] = True if type.lower() == "deepedit" else False
112114

113115
if self.dropout > 0:
114116
self.bundle_config["network_def"]["dropout"] = self.dropout
@@ -133,7 +135,13 @@ def __init__(
133135
self.key_image, image = next(iter(metadata["network_data_format"]["inputs"].items()))
134136
self.key_pred, pred = next(iter(metadata["network_data_format"]["outputs"].items()))
135137

136-
labels = {v.lower(): int(k) for k, v in pred.get("channel_def", {}).items() if v.lower() != "background"}
138+
# labels = ({v.lower(): int(k) for k, v in pred.get("channel_def", {}).items() if v.lower() != "background"})
139+
labels = {}
140+
for k, v in pred.get("channel_def", {}).items():
141+
if (not type.lower() == "deepedit") and (v.lower() != "background"):
142+
labels[v.lower()] = int(k)
143+
else:
144+
labels[v.lower()] = int(k)
137145
description = metadata.get("description")
138146
spatial_shape = image.get("spatial_shape")
139147
dimension = len(spatial_shape) if spatial_shape else 3
@@ -192,6 +200,7 @@ def pre_transforms(self, data=None) -> Sequence[Callable]:
192200
if self.bundle_config.get(k):
193201
c = self.bundle_config.get_parsed_content(k, instantiate=True)
194202
pre = list(c.transforms) if isinstance(c, Compose) else c
203+
195204
pre = self._filter_transforms(pre, self.pre_filter)
196205

197206
for t in pre:
@@ -254,6 +263,7 @@ def post_transforms(self, data=None) -> Sequence[Callable]:
254263
if self.bundle_config.get(k):
255264
c = self.bundle_config.get_parsed_content(k, instantiate=True)
256265
post = list(c.transforms) if isinstance(c, Compose) else c
266+
257267
post = self._filter_transforms(post, self.post_filter)
258268

259269
if self.add_post_restore:

monailabel/utils/others/generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def get_zoo_bundle(model_dir, conf, models, conf_key):
263263
print("")
264264
print("---------------------------------------------------------------------------------------")
265265
print(
266-
"Github access rate limit reached, pleaes provide personal auth token by setting env MONAI_ZOO_AUTH_TOKEN"
266+
"Github access rate limit reached, please provide personal auth token by setting env MONAI_ZOO_AUTH_TOKEN"
267267
)
268268
print("or --conf auth_token <personal auth token>")
269269
exit(-1)

sample-apps/monaibundle/main.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,22 @@ def init_infers(self) -> Dict[str, InferTask]:
5858
#################################################
5959
# Models
6060
#################################################
61+
6162
for n, b in self.models.items():
62-
i = BundleInferTask(b, self.conf)
63-
logger.info(f"+++ Adding Inferer:: {n} => {i}")
64-
infers[n] = i
63+
if "deepedit" in n:
64+
# Adding automatic inferer
65+
i = BundleInferTask(b, self.conf, type="segmentation")
66+
logger.info(f"+++ Adding Inferer:: {n}_seg => {i}")
67+
infers[n + "_seg"] = i
68+
# Adding inferer for managing clicks
69+
i = BundleInferTask(b, self.conf, type="deepedit")
70+
logger.info("+++ Adding DeepEdit Inferer")
71+
infers[n] = i
72+
else:
73+
i = BundleInferTask(b, self.conf)
74+
logger.info(f"+++ Adding Inferer:: {n} => {i}")
75+
infers[n] = i
76+
6577
return infers
6678

6779
def init_trainers(self) -> Dict[str, TrainTask]:
@@ -147,8 +159,9 @@ def main():
147159
app_dir = os.path.dirname(__file__)
148160
studies = args.studies
149161

150-
app = MyApp(app_dir, studies, {"preload": "true", "models": "spleen_ct_segmentation"})
151-
train(app)
162+
app = MyApp(app_dir, studies, {"preload": "false", "models": "spleen_deepedit_annotation"})
163+
# train(app)
164+
infer(app)
152165

153166

154167
def infer(app):
@@ -157,7 +170,7 @@ def infer(app):
157170

158171
res = app.infer(
159172
request={
160-
"model": "spleen_ct_segmentation",
173+
"model": "spleen_deepedit_annotation",
161174
"image": "image",
162175
}
163176
)
@@ -170,7 +183,7 @@ def infer(app):
170183
def train(app):
171184
app.train(
172185
request={
173-
"model": "spleen_ct_segmentation",
186+
"model": "spleen_deepedit_annotation",
174187
"max_epochs": 2,
175188
"multi_gpu": False,
176189
"val_split": 0.1,

0 commit comments

Comments
 (0)