Skip to content

Commit ebd6c13

Browse files
NuClick + Classification model using Consep dataset (#1120)
Signed-off-by: SACHIDANAND ALLE <[email protected]>
1 parent dff17b2 commit ebd6c13

27 files changed

+648
-471
lines changed

monailabel/interfaces/utils/wsi.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ctypes import cdll
1717
from math import ceil
1818

19+
import numpy as np
1920
from monai.utils import optional_import
2021

2122
logger = logging.getLogger(__name__)
@@ -59,6 +60,10 @@ def create_infer_wsi_tasks(request, image):
5960
infer_tasks = []
6061
count = 0
6162
pw, ph = tile_size[0], tile_size[1]
63+
64+
ignore_small_patches = request.get("ignore_small_patches", False)
65+
ignore_non_click_patches = request.get("ignore_non_click_patches", False)
66+
6267
for row in range(rows):
6368
for col in range(cols):
6469
tx = col * pw + x
@@ -67,6 +72,22 @@ def create_infer_wsi_tasks(request, image):
6772
tw = min(pw, x + w - tx)
6873
th = min(ph, y + h - ty)
6974

75+
if ignore_small_patches and (tw < pw or th < ph):
76+
continue
77+
78+
if ignore_non_click_patches and (request.get("foreground") or request.get("background")):
79+
80+
def filter_points(ptype):
81+
pos = request.get(ptype)
82+
pos = (np.array(pos) - (tx, ty)).astype(int).tolist() if pos else []
83+
pos = [p for p in pos if 0 < p[0] < tw and 0 < p[1] < th]
84+
return pos
85+
86+
fg = filter_points("foreground")
87+
bg = filter_points("background")
88+
if not fg and not bg:
89+
continue
90+
7091
task = copy.deepcopy(request)
7192
task.update(
7293
{

monailabel/tasks/train/basic_train.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,12 @@ def __init__(
173173
self._labels = [] if labels is None else [labels] if isinstance(labels, str) else labels
174174
self._disable_tracking = disable_tracking
175175

176+
def info(self):
177+
r = super().info()
178+
if self._labels:
179+
r["labels"] = self._labels
180+
return r
181+
176182
@abstractmethod
177183
def network(self, context: Context):
178184
pass
@@ -239,7 +245,12 @@ def train_inferer(self, context: Context):
239245
return SimpleInferer()
240246

241247
def train_key_metric(self, context: Context):
242-
return {self.TRAIN_KEY_METRIC: MeanDice(output_transform=from_engine(["pred", "label"]))}
248+
return {
249+
self.TRAIN_KEY_METRIC: MeanDice(
250+
output_transform=from_engine(["pred", "label"]),
251+
include_background=False,
252+
)
253+
}
243254

244255
def load_path(self, output_dir, pretrained=True):
245256
load_path = os.path.join(output_dir, self._key_metric_filename)
@@ -297,7 +308,12 @@ def val_handlers(self, context: Context):
297308
return val_handlers if context.local_rank == 0 else None
298309

299310
def val_key_metric(self, context):
300-
return {self.VAL_KEY_METRIC: MeanDice(output_transform=from_engine(["pred", "label"]))}
311+
return {
312+
self.VAL_KEY_METRIC: MeanDice(
313+
output_transform=from_engine(["pred", "label"]),
314+
include_background=False,
315+
)
316+
}
301317

302318
def train_iteration_update(self, context: Context):
303319
return None

monailabel/utils/others/generic.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,16 @@ def device_list():
200200
return devices
201201

202202

203-
def create_dataset_from_path(folder, images="images", labels="labels", img_ext=".jpg", lab_ext=".png"):
204-
images = [i for i in os.listdir(os.path.join(folder, images)) if i.endswith(img_ext)]
205-
images = sorted(os.path.join(folder, "images", i) for i in images)
203+
def create_dataset_from_path(folder, image_dir="images", label_dir="labels", img_ext=".jpg", lab_ext=".png"):
204+
def _list_files(d, ext):
205+
files = [i for i in os.listdir(d) if i.endswith(ext)]
206+
return sorted(os.path.join(d, i) for i in files)
206207

207-
labels = [i for i in os.listdir(os.path.join(folder, labels)) if i.endswith(lab_ext)]
208-
labels = sorted(os.path.join(folder, "labels", i) for i in labels)
208+
image_dir = os.path.join(folder, image_dir) if image_dir else folder
209+
images = _list_files(image_dir, img_ext)
210+
211+
label_dir = os.path.join(folder, label_dir) if label_dir else folder
212+
labels = _list_files(label_dir, img_ext)
209213

210214
for i, l in zip(images, labels):
211215
if get_basename_no_ext(i) != get_basename_no_ext(l):

plugins/qupath/src/main/java/qupath/lib/extension/monailabel/Extension.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ public void installExtension(QuPathGUI qupath) {
4545

4646
var activeLearning = ActionTools.createAction(new NextSample(qupath), "Next Sample/Patch...");
4747
activeLearning.setAccelerator(KeyCombination.keyCombination("ctrl+n"));
48-
activeLearning.disabledProperty().bind(qupath.imageDataProperty().isNull());
4948
MenuTools.addMenuItems(qupath.getMenu("MONAI Label", true), activeLearning);
5049

5150
MenuTools.addMenuItems(qupath.getMenu("MONAI Label", true), ActionUtils.ACTION_SEPARATOR);

plugins/qupath/src/main/java/qupath/lib/extension/monailabel/MonaiLabelClient.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,13 @@ public static class LabelInfo {
124124
public static class NextSampleInfo {
125125
public String id;
126126
public int[] bbox = { 0, 0, 0, 0 };
127+
public String path;
127128
}
128129

129130
public static class InferParams {
130131
public List<List<Integer>> foreground = new ArrayList<>();
131132
public List<List<Integer>> background = new ArrayList<>();
133+
public int max_workers = 1;
132134

133135
public void addClicks(ArrayList<Point2> clicks, boolean f) {
134136
List<List<Integer>> t = f ? foreground : background;

plugins/qupath/src/main/java/qupath/lib/extension/monailabel/Settings.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
package qupath.lib.extension.monailabel;
1515

1616
import javafx.application.Platform;
17-
import javafx.beans.property.BooleanProperty;
17+
import javafx.beans.property.IntegerProperty;
1818
import javafx.beans.property.StringProperty;
1919
import qupath.lib.gui.QuPathGUI;
2020
import qupath.lib.gui.prefs.PathPrefs;
@@ -27,10 +27,10 @@ public static StringProperty serverURLProperty() {
2727
return serverURL;
2828
}
2929

30-
private static BooleanProperty wsi = PathPrefs.createPersistentPreference("wsi", Boolean.FALSE);
30+
private static IntegerProperty maxWorkers = PathPrefs.createPersistentPreference("max_workers", 1);
3131

32-
public static BooleanProperty wsiProperty() {
33-
return wsi;
32+
public static IntegerProperty maxWorkersProperty() {
33+
return maxWorkers;
3434
}
3535

3636
void addProperties(QuPathGUI qupath) {
@@ -41,8 +41,8 @@ void addProperties(QuPathGUI qupath) {
4141

4242
qupath.getPreferencePane().addPropertyPreference(Settings.serverURLProperty(), String.class, "Server URL",
4343
"MONAI Label", "Set MONAI Label Server URL (default: http://127.0.0.1:8000)");
44-
qupath.getPreferencePane().addPropertyPreference(Settings.wsiProperty(), Boolean.class, "Enable WSI",
45-
"MONAI Label", "Allow WSI Inference when ROI is not selected");
44+
qupath.getPreferencePane().addPropertyPreference(Settings.maxWorkersProperty(), Integer.class, "Max Workers",
45+
"MONAI Label", "Max Workers (WSI Inference)");
4646

4747
}
4848
}

plugins/qupath/src/main/java/qupath/lib/extension/monailabel/commands/NextSample.java

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
package qupath.lib.extension.monailabel.commands;
1515

16+
import java.io.File;
1617
import java.util.ArrayList;
1718
import java.util.Arrays;
1819
import java.util.List;
@@ -48,7 +49,7 @@ public void run() {
4849
try {
4950
var viewer = qupath.getViewer();
5051
var imageData = viewer.getImageData();
51-
String image = Utils.getNameWithoutExtension(imageData.getServerPath());
52+
String image = imageData != null ? Utils.getNameWithoutExtension(imageData.getServerPath()) : "";
5253

5354
ResponseInfo info = MonaiLabelClient.info();
5455
List<String> names = new ArrayList<String>();
@@ -60,19 +61,27 @@ public void run() {
6061
selectedStrategy = names.isEmpty() ? "" : names.get(0);
6162
}
6263

64+
boolean imageLoaded = qupath.imageDataProperty() == null || qupath.imageDataProperty().isNull().get() ? false : true;
65+
boolean nextPatch = imageLoaded;
66+
int[] patchSize = {1024, 1024};
67+
int[] imageSize = {0, 0};
68+
6369
ParameterList list = new ParameterList();
6470
list.addChoiceParameter("Strategy", "Active Learning Strategy", selectedStrategy, names);
65-
list.addBooleanParameter("NextPatch", "Next Patch (from current Image)", true);
66-
list.addStringParameter("PatchSize", "PatchSize", Arrays.toString(selectedPatchSize));
71+
if (nextPatch) {
72+
list.addBooleanParameter("NextPatch", "Next Patch (from current Image)", nextPatch);
73+
list.addStringParameter("PatchSize", "PatchSize", Arrays.toString(selectedPatchSize));
74+
}
6775

6876
if (Dialogs.showParameterDialog("MONAILabel", list)) {
6977
String strategy = (String) list.getChoiceParameterValue("Strategy");
70-
boolean nextPatch = list.getBooleanParameterValue("NextPatch").booleanValue();
71-
int[] patchSize = Utils.parseStringArray(list.getStringParameterValue("PatchSize"));
78+
if (nextPatch) {
79+
nextPatch = list.getBooleanParameterValue("NextPatch").booleanValue();
80+
patchSize = Utils.parseStringArray(list.getStringParameterValue("PatchSize"));
81+
imageSize = new int[] { imageData.getServer().getWidth(), imageData.getServer().getHeight() };
7282

73-
var server = imageData.getServer();
74-
int[] imageSize = new int[] { server.getWidth(), server.getHeight() };
75-
logger.info(String.join(",", imageData.getProperties().keySet()));
83+
logger.info(String.join(",", imageData.getProperties().keySet()));
84+
}
7685

7786
selectedStrategy = strategy;
7887
selectedPatchSize = patchSize;
@@ -99,6 +108,17 @@ public void run() {
99108
var obj = PathObjects.createAnnotationObject(roi);
100109
imageData.getHierarchy().addPathObject(obj);
101110
imageData.getHierarchy().getSelectionModel().setSelectedObject(obj);
111+
} else {
112+
String message = "This will close the current image without saving.\nAre you sure to continue?";
113+
if (!imageLoaded || Dialogs.showConfirmDialog("MONAILabel", message)) {
114+
File f = new File(sample.path);
115+
if (f.isFile() && f.exists()) {
116+
qupath.openImage(f.getAbsolutePath(), false, false);
117+
} else {
118+
// TODO:: Download and Open image (wsi) from Remote
119+
// String f = "C:\\Dataset\\Pathology\\TCGA-02-0010-01Z-00-DX4.07de2e55-a8fe-40ee-9e98-bcb78050b9f7.svs";
120+
}
121+
}
102122
}
103123
}
104124

plugins/qupath/src/main/java/qupath/lib/extension/monailabel/commands/RunInference.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import qupath.lib.extension.monailabel.MonaiLabelClient;
3636
import qupath.lib.extension.monailabel.MonaiLabelClient.RequestInfer;
3737
import qupath.lib.extension.monailabel.MonaiLabelClient.ResponseInfo;
38+
import qupath.lib.extension.monailabel.Settings;
3839
import qupath.lib.extension.monailabel.Utils;
3940
import qupath.lib.geom.Point2;
4041
import qupath.lib.gui.QuPathGUI;
@@ -50,6 +51,7 @@
5051
import qupath.lib.regions.RegionRequest;
5152
import qupath.lib.roi.PointsROI;
5253
import qupath.lib.roi.ROIs;
54+
import qupath.lib.roi.RectangleROI;
5355
import qupath.lib.roi.interfaces.ROI;
5456
import qupath.lib.scripting.QP;
5557

@@ -72,6 +74,22 @@ public void run() {
7274
var imageData = viewer.getImageData();
7375
var selected = imageData.getHierarchy().getSelectionModel().getSelectedObject();
7476
var roi = selected != null ? selected.getROI() : null;
77+
78+
// Select first RectangleROI if not selected explicitly
79+
if (roi == null || !(roi instanceof RectangleROI)) {
80+
List<PathObject> objs = imageData.getHierarchy().getFlattenedObjectList(null);
81+
for (int i = 0; i < objs.size(); i++) {
82+
var obj = objs.get(i);
83+
ROI r = obj.getROI();
84+
if (r instanceof RectangleROI) {
85+
roi = r;
86+
Dialogs.showWarningNotification("MONALabel", "ROI is NOT explicitly selected; using first Rectangle ROI from Hierarchy");
87+
imageData.getHierarchy().getSelectionModel().setSelectedObject(obj);
88+
break;
89+
}
90+
}
91+
}
92+
7593
int[] bbox = Utils.getBBOX(roi);
7694
int tileSize = selectedTileSize;
7795
if (bbox[2] == 0 && bbox[3] == 0 && selectedBBox != null) {
@@ -204,6 +222,7 @@ private void runInference(String model, ResponseInfo info, int[] bbox, int tileS
204222
}
205223
req.params.addClicks(fg, true);
206224
req.params.addClicks(bg, false);
225+
req.params.max_workers = Settings.maxWorkersProperty().intValue();
207226

208227

209228
Document dom = MonaiLabelClient.infer(model, image, imageFile, sessionId, req);

sample-apps/pathology/lib/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,11 @@
88
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
11+
12+
import ctypes.util
13+
import platform
14+
from ctypes import cdll
15+
16+
# For windows (preload openslide dll using file_library) https://github.com/openslide/openslide-python/pull/151
17+
if platform.system() == "Windows":
18+
cdll.LoadLibrary(str(ctypes.util.find_library("libopenslide-0.dll")))

sample-apps/pathology/lib/configs/classification_nuclei.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **
3232

3333
# Labels
3434
self.labels = {
35-
"Neoplastic cells": 0,
36-
"Inflammatory": 1,
37-
"Connective/Soft tissue cells": 2,
38-
"Dead Cells": 3,
39-
"Epithelial": 4,
35+
"Neoplastic cells": 1,
36+
"Inflammatory": 2,
37+
"Connective/Soft tissue cells": 3,
38+
"Dead Cells": 4,
39+
"Epithelial": 5,
4040
}
4141
self.label_colors = {
4242
"Neoplastic cells": (255, 0, 0),
@@ -46,16 +46,31 @@ def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **
4646
"Epithelial": (0, 0, 255),
4747
}
4848

49+
consep = strtobool(self.conf.get("consep", "false"))
50+
if consep:
51+
self.labels = {
52+
"Other": 1,
53+
"Inflammatory": 2,
54+
"Epithelial": 3,
55+
"Spindle-Shaped": 4,
56+
}
57+
self.label_colors = {
58+
"Other": (255, 0, 0),
59+
"Inflammatory": (255, 255, 0),
60+
"Epithelial": (0, 0, 255),
61+
"Spindle-Shaped": (0, 255, 0),
62+
}
63+
4964
# Model Files
5065
self.path = [
51-
os.path.join(self.model_dir, f"pretrained_{name}.pt"), # pretrained
52-
os.path.join(self.model_dir, f"{name}.pt"), # published
66+
os.path.join(self.model_dir, f"pretrained_{name}{'_consep' if consep else ''}.pt"), # pretrained
67+
os.path.join(self.model_dir, f"{name}{'_consep' if consep else ''}.pt"), # published
5368
]
5469

5570
# Download PreTrained Model
5671
if strtobool(self.conf.get("use_pretrained_model", "true")):
5772
url = f"{self.conf.get('pretrained_path', self.PRE_TRAINED_PATH)}"
58-
url = f"{url}/pathology_classification_densenet121_nuclei.pt"
73+
url = f"{url}/pathology_classification_densenet121_nuclei{'_consep' if consep else ''}.pt"
5974
download_file(url, self.path[0])
6075

6176
# Network

0 commit comments

Comments
 (0)