Skip to content

Commit 3f4fef4

Browse files
Optimize pathology inference - support multi-gpu (#712)
* Optimize pathology inference - support multi-gpu Signed-off-by: Sachidanand Alle <[email protected]> * Optimize pathology inference - support multi-gpu Signed-off-by: Sachidanand Alle <[email protected]>
1 parent f9e37fe commit 3f4fef4

File tree

19 files changed

+200
-302
lines changed

19 files changed

+200
-302
lines changed

monailabel/endpoints/wsi_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def run_wsi_inference(
9191
request["image"] = session.image
9292
request["session"] = session.to_json()
9393

94-
logger.info(f"WSI Infer Request: {request}")
94+
logger.debug(f"WSI Infer Request: {request}")
9595

9696
result = instance.infer_wsi(request)
9797
if result is None:

monailabel/interfaces/app.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import logging
1414
import os
1515
import platform
16+
import random
1617
import shutil
1718
import tempfile
1819
import time
@@ -592,7 +593,8 @@ def infer_wsi(self, request, datastore=None):
592593
f"WSI/Inference Task is not Initialized. There is no model '{model}' available",
593594
)
594595

595-
image = request["image"]
596+
img_id = request["image"]
597+
image = img_id
596598
request_c = copy.deepcopy(task.config())
597599
request_c.update(request)
598600
request = request_c
@@ -609,38 +611,55 @@ def infer_wsi(self, request, datastore=None):
609611
image = datastore.get_image_uri(request["image"])
610612

611613
start = time.time()
612-
logger.info(f"WSI Infer Request (final): {request}")
613614
infer_tasks = create_infer_wsi_tasks(request, image)
615+
if len(infer_tasks) > 1:
616+
logger.info(f"WSI Infer Request (final): {request}")
617+
614618
logger.debug(f"Total WSI Tasks: {len(infer_tasks)}")
615619
request["logging"] = request.get("logging", "WARNING" if len(infer_tasks) > 1 else "INFO")
616620

617-
multi_gpu = request.get("multi_gpu", False)
621+
multi_gpu = request.get("multi_gpu", True)
618622
multi_gpus = request.get("gpus", "all")
619623
gpus = (
620624
list(range(torch.cuda.device_count())) if not multi_gpus or multi_gpus == "all" else multi_gpus.split(",")
621625
)
622626
device_ids = [f"cuda:{id}" for id in gpus] if multi_gpu else [request.get("device", "cuda")]
623-
logger.info(f"MultiGpu: {multi_gpu}; Using Device(s): {device_ids}")
624627

625628
res_json = {"annotations": [None] * len(infer_tasks)}
626629
for idx, t in enumerate(infer_tasks):
627630
t["logging"] = request["logging"]
628-
t["device"] = device_ids[idx % len(device_ids)]
631+
t["device"] = (
632+
device_ids[idx % len(device_ids)]
633+
if len(infer_tasks) > 1
634+
else device_ids[random.randint(0, len(device_ids) - 1)]
635+
)
629636

630-
if len(infer_tasks) > 1 and len(device_ids) > 1:
631-
with ThreadPoolExecutor(max_workers=len(device_ids), thread_name_prefix="WSI Infer") as executor:
637+
total = len(infer_tasks)
638+
max_workers = request.get("max_workers", len(device_ids))
639+
640+
if len(infer_tasks) > 1 and (max_workers == 0 or max_workers > 1):
641+
logger.info(f"MultiGpu: {multi_gpu}; Using Device(s): {device_ids}; Max Workers: {max_workers}")
642+
futures = {}
643+
with ThreadPoolExecutor(max_workers if max_workers else None, "WSI Infer") as executor:
632644
for t in infer_tasks:
633-
tid = t["id"]
634-
future = executor.submit(self._run_infer_wsi_task, t)
645+
futures[t["id"]] = t, executor.submit(self._run_infer_wsi_task, t)
646+
647+
for tid, (t, future) in futures.items():
635648
res = future.result()
636649
res_json["annotations"][tid] = res
637-
logger.info(f"{tid} => {len(res_json)} / {len(infer_tasks)}; Latencies: {res.get('latencies')}")
650+
finished = len([a for a in res_json["annotations"] if a])
651+
logger.info(
652+
f"{img_id} => {tid} => {t['device']} => {finished} / {total}; Latencies: {res.get('latencies')}"
653+
)
638654
else:
639655
for t in infer_tasks:
640656
tid = t["id"]
641657
res = self._run_infer_wsi_task(t)
642658
res_json["annotations"][tid] = res
643-
logger.info(f"{tid} => {len(res_json)} / {len(infer_tasks)}; Latencies: {res.get('latencies')}")
659+
finished = len([a for a in res_json["annotations"] if a])
660+
logger.info(
661+
f"{img_id} => {tid} => {t['device']} => {finished} / {total}; Latencies: {res.get('latencies')}"
662+
)
644663

645664
latency_total = time.time() - start
646665
logger.debug("WSI Infer Time Taken: {:.4f}".format(latency_total))

monailabel/interfaces/tasks/infer.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from monailabel.interfaces.exception import MONAILabelError, MONAILabelException
2323
from monailabel.interfaces.utils.transform import run_transforms
2424
from monailabel.transform.writer import Writer
25+
from monailabel.utils.others.generic import device_list
2526

2627
logger = logging.getLogger(__name__)
2728

@@ -57,7 +58,7 @@ class InferTask:
5758

5859
def __init__(
5960
self,
60-
path: Union[str, Sequence[str]],
61+
path: Union[None, str, Sequence[str]],
6162
network: Union[None, Any],
6263
type: Union[str, InferType],
6364
labels: Union[str, None, Sequence[str], Dict[Any, Any]],
@@ -70,6 +71,7 @@ def __init__(
7071
config: Union[None, Dict[str, Any]] = None,
7172
load_strict: bool = False,
7273
roi_size=None,
74+
preload=False,
7375
):
7476
"""
7577
:param path: Model File Path. Supports multiple paths to support versions (Last item will be picked as latest)
@@ -84,8 +86,9 @@ def __init__(
8486
:param config: K,V pairs to be part of user config
8587
:param load_strict: Load model in strict mode
8688
:param roi_size: ROI size for scanning window inference
89+
:param preload: Preload model/network on all available GPU devices
8790
"""
88-
self.path = path
91+
self.path = [] if not path else [path] if isinstance(path, str) else path
8992
self.network = network
9093
self.type = type
9194
self.labels = [] if labels is None else [labels] if isinstance(labels, str) else labels
@@ -99,8 +102,9 @@ def __init__(
99102
self.roi_size = roi_size
100103

101104
self._networks: Dict = {}
105+
102106
self._config: Dict[str, Any] = {
103-
# "device": "cuda",
107+
# "device": device_list(),
104108
# "result_extension": None,
105109
# "result_dtype": None,
106110
# "result_compress": False
@@ -111,6 +115,11 @@ def __init__(
111115
if config:
112116
self._config.update(config)
113117

118+
if preload:
119+
for device in device_list():
120+
logger.info(f"Preload Network for device: {device}")
121+
self._get_network(device)
122+
114123
def info(self) -> Dict[str, Any]:
115124
return {
116125
"type": self.type,
@@ -127,19 +136,19 @@ def is_valid(self) -> bool:
127136
if self.network or self.type == InferType.SCRIBBLES:
128137
return True
129138

130-
paths = [self.path] if isinstance(self.path, str) else self.path
139+
paths = self.path
131140
for path in reversed(paths):
132-
if os.path.exists(path):
141+
if path and os.path.exists(path):
133142
return True
134143
return False
135144

136145
def get_path(self):
137146
if not self.path:
138147
return None
139148

140-
paths = [self.path] if isinstance(self.path, str) else self.path
149+
paths = self.path
141150
for path in reversed(paths):
142-
if os.path.exists(path):
151+
if path and os.path.exists(path):
143152
return path
144153
return None
145154

@@ -247,7 +256,10 @@ def __call__(self, request) -> Tuple[str, Dict[str, Any]]:
247256

248257
# device
249258
device = req.get("device", "cuda")
259+
if device.startswith("cuda") and not torch.cuda.is_available():
260+
device = "cpu"
250261
req["device"] = device
262+
251263
logger.setLevel(req.get("logging", "INFO").upper())
252264
logger.info(f"Infer Request (final): {req}")
253265

@@ -346,9 +358,6 @@ def _get_network(self, device):
346358
f"Model Path ({self.path}) does not exist/valid",
347359
)
348360

349-
if device.startswith("cuda") and not torch.cuda.is_available():
350-
device = "cpu"
351-
352361
cached = self._networks.get(device)
353362
statbuf = os.stat(path) if path else None
354363
network = None
@@ -360,16 +369,15 @@ def _get_network(self, device):
360369

361370
if network is None:
362371
if self.network:
363-
network = self.network
372+
network = copy.deepcopy(self.network)
373+
network.to(torch.device(device))
374+
364375
if path:
365376
checkpoint = torch.load(path, map_location=torch.device(device))
366377
model_state_dict = checkpoint.get(self.model_state_dict, checkpoint)
367378
network.load_state_dict(model_state_dict, strict=self.load_strict)
368379
else:
369-
network = torch.jit.load(path, map_location=torch.device(device))
370-
371-
if device.startswith("cuda"):
372-
network = network.cuda(device)
380+
network = torch.jit.load(path, map_location=torch.device(device)).to(torch.device)
373381

374382
network.eval()
375383
self._networks[device] = (network, statbuf.st_mtime if statbuf else 0)
@@ -388,22 +396,18 @@ def run_inferer(self, data, convert_to_batch=True, device="cuda"):
388396
"""
389397

390398
inferer = self.inferer(data)
391-
logger.info("Inferer:: {} => {}".format(inferer.__class__.__name__, inferer.__dict__))
392-
393-
device = device if device else "cuda"
394-
if device.startswith("cuda") and not torch.cuda.is_available():
395-
device = "cpu"
399+
logger.info("Inferer:: {} => {} => {}".format(device, inferer.__class__.__name__, inferer.__dict__))
396400

397401
network = self._get_network(device)
398402
if network:
399403
inputs = data[self.input_key]
400404
inputs = inputs if torch.is_tensor(inputs) else torch.from_numpy(inputs)
401405
inputs = inputs[None] if convert_to_batch else inputs
402-
if device.startswith("cuda"):
403-
inputs = inputs.cuda(torch.device(device))
406+
inputs = inputs.to(torch.device(device))
404407

405408
with torch.no_grad():
406409
outputs = inferer(inputs, network)
410+
407411
if device.startswith("cuda"):
408412
torch.cuda.empty_cache()
409413

monailabel/interfaces/utils/wsi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def create_infer_wsi_tasks(request, image):
5050
rows = ceil(h / tile_size[1]) # ROW
5151

5252
if rows * cols > 1:
53-
logger.info(f"Total Tiles to infer {rows} x {cols}: {rows * cols}")
53+
logger.info(f"Total Tiles to infer {rows} x {cols}: {rows * cols}; Dimensions: {w} x {h}")
5454

5555
infer_tasks = []
5656
count = 0

monailabel/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def args_start_server(self, parser):
5858
parser.add_argument("--ssl_certfile", default=None, type=str, help="SSL certificate file")
5959
parser.add_argument("--ssl_keyfile_password", default=None, type=str, help="SSL key file password")
6060
parser.add_argument("--ssl_ca_certs", default=None, type=str, help="CA certificates file")
61-
parser.add_argument("--workers", default=1, type=int, help="Number of worker processes")
61+
parser.add_argument("--workers", default=None, type=int, help="Number of worker processes")
6262
parser.add_argument("--limit_concurrency", default=None, type=int, help="Max concurrent connections")
6363
parser.add_argument("--access_log", action="store_true", help="Enable access log")
6464

monailabel/tasks/scoring/epistemic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
11-
11+
import copy
1212
import logging
1313
import os
1414
import time
@@ -115,7 +115,7 @@ def _load_model(self, path, network):
115115
logger.info(f"Using {model_file} for running Epistemic")
116116
model_ts = int(os.stat(model_file).st_mtime) if model_file and os.path.exists(model_file) else 1
117117
if network:
118-
model = network
118+
model = copy.deepcopy(network)
119119
if model_file:
120120
if torch.cuda.is_available():
121121
checkpoint = torch.load(model_file)

monailabel/tasks/scoring/tta.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
11-
11+
import copy
1212
import logging
1313
import os
1414
import time
@@ -126,7 +126,7 @@ def _load_model(self, path, network):
126126
logger.info(f"Using {model_file} for running TTA")
127127
model_ts = int(os.stat(model_file).st_mtime) if model_file and os.path.exists(model_file) else 1
128128
if network:
129-
model = network
129+
model = copy.deepcopy(network)
130130
if model_file:
131131
checkpoint = torch.load(model_file)
132132
model_state_dict = checkpoint.get("model", checkpoint)

monailabel/tasks/train/basic_train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from monailabel.interfaces.datastore import Datastore
5252
from monailabel.interfaces.tasks.train import TrainTask
5353
from monailabel.tasks.train.handler import PublishStatsAndModel, prepare_stats
54-
from monailabel.utils.others.generic import remove_file
54+
from monailabel.utils.others.generic import device_list, remove_file
5555

5656
logger = logging.getLogger(__name__)
5757

@@ -136,7 +136,7 @@ def __init__(
136136
self._config = {
137137
"name": "train_01",
138138
"pretrained": True,
139-
"device": "cuda",
139+
"device": device_list(),
140140
"max_epochs": 50,
141141
"early_stop_patience": -1,
142142
"val_split": 0.2,

monailabel/utils/others/generic.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import subprocess
2020
import time
2121

22-
import torch.cuda
22+
import torch
2323
from monai.apps import download_url
2424

2525
logger = logging.getLogger(__name__)
@@ -173,3 +173,11 @@ def download_file(url, path, delay=1, skip_on_exists=True):
173173
download_url(url, path)
174174
if delay > 0:
175175
time.sleep(delay)
176+
177+
178+
def device_list():
179+
devices = [] if torch.cuda.is_available() else ["cpu"]
180+
for i in range(torch.cuda.device_count()):
181+
devices.append(f"cuda:{i}")
182+
183+
return devices

0 commit comments

Comments
 (0)