Skip to content

Commit 9a0bf41

Browse files
Fixes for Train+Infer (#1156)
* Remove deepedit nuclei and fix validation dataloader transforms Signed-off-by: Sachidanand Alle <[email protected]> * Fix infer decollate Signed-off-by: Sachidanand Alle <[email protected]> * remove deepedit from pathology Signed-off-by: Sachidanand Alle <[email protected]> Signed-off-by: Sachidanand Alle <[email protected]>
1 parent 7fca1dc commit 9a0bf41

File tree

10 files changed

+23
-368
lines changed

10 files changed

+23
-368
lines changed

monailabel/tasks/infer/basic_infer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,8 +477,11 @@ def run_inferer(self, data: Dict[str, Any], convert_to_batch=True, device="cuda"
477477
torch.cuda.empty_cache()
478478

479479
if convert_to_batch:
480-
outputs_d = decollate_batch(outputs)
481-
outputs = outputs_d[0]
480+
if data.get("decollate_batch", False): # TODO:: Use automatically depending on the input/output
481+
outputs_d = decollate_batch(outputs)
482+
outputs = outputs_d[0]
483+
else:
484+
outputs = outputs[0]
482485

483486
if isinstance(outputs, dict):
484487
data.update(outputs)

monailabel/tasks/train/basic_train.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,19 @@ def lr_scheduler_handler(self, context: Context):
198198
lr_scheduler = torch.optim.lr_scheduler.StepLR(context.optimizer, step_size=1000, gamma=0.1)
199199
return LrScheduleHandler(lr_scheduler, print_lr=True)
200200

201-
def _dataset(self, context, datalist, replace_rate=0.25):
201+
def _dataset(self, context, datalist, is_train, replace_rate=0.25):
202202
if context.multi_gpu:
203203
world_size = torch.distributed.get_world_size()
204204
if len(datalist) // world_size: # every gpu gets full data when datalist is smaller
205205
datalist = partition_dataset(data=datalist, num_partitions=world_size, even_divisible=True)[
206206
context.local_rank
207207
]
208208

209-
transforms = self._validate_transforms(self.train_pre_transforms(context), "Training", "pre")
209+
transforms = (
210+
self._validate_transforms(self.train_pre_transforms(context), "Training", "pre")
211+
if is_train
212+
else self._validate_transforms(self.val_pre_transforms(context), "Validation", "pre")
213+
)
210214
dataset = (
211215
CacheDataset(datalist, transforms)
212216
if context.dataset_type == "CacheDataset"
@@ -234,8 +238,8 @@ def _dataloader(self, context, dataset, batch_size, num_workers, shuffle=False):
234238
num_workers=num_workers,
235239
)
236240

237-
def train_data_loader(self, context, num_workers=0, shuffle=False):
238-
dataset, datalist = self._dataset(context, context.train_datalist)
241+
def train_data_loader(self, context, num_workers=0, shuffle=True):
242+
dataset, datalist = self._dataset(context, context.train_datalist, is_train=True)
239243
logger.info(f"{context.local_rank} - Records for Training: {len(datalist)}")
240244
logger.debug(f"{context.local_rank} - Training: {datalist}")
241245

@@ -288,7 +292,7 @@ def train_additional_metrics(self, context: Context):
288292
return None
289293

290294
def val_data_loader(self, context: Context, num_workers=0):
291-
dataset, datalist = self._dataset(context, context.val_datalist)
295+
dataset, datalist = self._dataset(context, context.val_datalist, is_train=False)
292296
logger.info(f"{context.local_rank} - Records for Validation: {len(datalist)}")
293297
logger.debug(f"{context.local_rank} - Validation: {datalist}")
294298

runtests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ function run_integration_tests() {
468468
# network training/inference/eval integration tests
469469
if [ $doNetTests = true ]; then
470470
run_integration_tests "radiology" "tests/data/dataset/local/spleen" "deepedit,segmentation_spleen,segmentation,deepgrow_2d,deepgrow_3d"
471-
run_integration_tests "pathology" "tests/data/pathology" "deepedit_nuclei,segmentation_nuclei,nuclick"
471+
run_integration_tests "pathology" "tests/data/pathology" "segmentation_nuclei,nuclick"
472472
run_integration_tests "monaibundle" "tests/data/dataset/local/spleen" "spleen_ct_segmentation_v0.1.0,spleen_deepedit_annotation_v0.1.0,swin_unetr_btcv_segmentation_v0.1.0"
473473
run_integration_tests "endoscopy" "tests/data/endoscopy" "tooltracking,inbody,deepedit"
474474
fi

sample-apps/pathology/README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,16 @@ weights/model (UNET).
2424
- Connective/Soft tissue cells
2525
- Dead Cells
2626
- Epithelial
27-
- **DeepEdit Nuclei** - It is a combination of
28-
both [Interaction + Auto Segmentation](https://github.com/Project-MONAI/MONAILabel/wiki/DeepEdit) model which is
29-
trained to segment Nuclei cells that combines all above labels as *Nuclei*.
30-
- **NuClick** - This is NuClick implementation (UNet model) as provided at: https://github.com/mostafajahanifar/nuclick_torch. Training task for monailabel is not yet supported.
27+
- **NuClick** - This is NuClick implementation (UNet model) as provided at: https://github.com/mostafajahanifar/nuclick_torch.
28+
- **Classification Nuclei** - It is a simple classification model which can be used along with NuClick model to classify Nuclei cells.
3129

3230
### Dataset
3331

34-
Above _Nuclei_ models are trained
35-
on [PanNuke Dataset for Nuclei Instance Segmentation and Classification](https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke)
32+
Above _Nuclei_ models are trained either on
33+
- [PanNuke Dataset for Nuclei Instance Segmentation and Classification](https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke)
34+
- [CoNSeP Dataset](https://warwick.ac.uk/fac/cross_fac/tia/data/hovernet)
35+
36+
Pass `--conf consep true` option while starting MONAILabel server to use models trained on CoNSeP Dataset
3637

3738
### Inputs
3839

@@ -86,7 +87,7 @@ The current version of plugin comes with **limited features** to support basic a
8687
- Make sure MONAILabel Server URL is correctly through `Preferences`.
8788
- Open Sample Whole Slide Image in QuPath (which is shared as studies for MONAILabel server)
8889
- Add/Select Rectangle ROI to run annotations using MONAI Label models.
89-
- For Interative model (e.g. DeepEdit) you can choose to provide `Positive` and `Negative` points through Annotation panel.
90+
- For Interactive model you can choose to provide `Positive` and `Negative` points through Annotation panel.
9091

9192
![image](../../docs/images/qupath.jpg)
9293

sample-apps/pathology/lib/configs/deepedit_nuclei.py

Lines changed: 0 additions & 93 deletions
This file was deleted.

sample-apps/pathology/lib/infers/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,5 @@
1010
# limitations under the License.
1111

1212
from .classification_nuclei import ClassificationNuclei
13-
from .deepedit_nuclei import DeepEditNuclei
1413
from .nuclick import NuClick
1514
from .segmentation_nuclei import SegmentationNuclei

sample-apps/pathology/lib/infers/deepedit_nuclei.py

Lines changed: 0 additions & 90 deletions
This file was deleted.

sample-apps/pathology/lib/trainers/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,5 @@
1010
# limitations under the License.
1111

1212
from .classification_nuclei import ClassificationNuclei
13-
from .deepedit_nuclei import DeepEditNuclei
1413
from .nuclick import NuClick
1514
from .segmentation_nuclei import SegmentationNuclei

0 commit comments

Comments
 (0)