Skip to content

Commit 9593cfe

Browse files
🐛 Fix Multi-GPU Support with torch.compile (#923)
Change multi-GPU mode from `DataParallel` to `DataDistributedParallel` to work with `torch.compile`. However, this essentially limits the task to using one GPU alone when using torch.compile. It is not a trivial solution to change this to use multiple GPUs that also work with torch.compile. - We will release a future fix to fully correct this with the new engine. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b564590 commit 9593cfe

File tree

3 files changed

+80
-3
lines changed

3 files changed

+80
-3
lines changed

tests/models/test_feature_extractor.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,48 @@ def test_full_inference(
115115
# ! else the output values will not exactly be the same (still < 1.0e-4
116116
# ! of epsilon though)
117117
assert np.mean(np.abs(features[:4] - _features)) < 1.0e-1
118+
119+
120+
@pytest.mark.skipif(
121+
toolbox_env.running_on_ci() or not ON_GPU,
122+
reason="Local test on machine with GPU.",
123+
)
124+
def test_multi_gpu_feature_extraction(remote_sample: Callable, tmp_path: Path) -> None:
125+
"""Local functionality test for feature extraction using multiple GPUs."""
126+
save_dir = tmp_path / "output"
127+
mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))
128+
shutil.rmtree(save_dir, ignore_errors=True)
129+
130+
# Use multiple GPUs
131+
device = select_device(on_gpu=ON_GPU)
132+
133+
wsi_ioconfig = IOSegmentorConfig(
134+
input_resolutions=[{"units": "mpp", "resolution": 0.5}],
135+
patch_input_shape=[224, 224],
136+
output_resolutions=[{"units": "mpp", "resolution": 0.5}],
137+
patch_output_shape=[224, 224],
138+
stride_shape=[224, 224],
139+
)
140+
141+
model = TimmBackbone(backbone="UNI", pretrained=True)
142+
extractor = DeepFeatureExtractor(
143+
model=model,
144+
auto_generate_mask=True,
145+
batch_size=32,
146+
num_loader_workers=4,
147+
num_postproc_workers=4,
148+
)
149+
150+
output_list = extractor.predict(
151+
[mini_wsi_svs],
152+
mode="wsi",
153+
device=device,
154+
ioconfig=wsi_ioconfig,
155+
crash_on_exception=True,
156+
save_dir=save_dir,
157+
)
158+
wsi_0_root_path = output_list[0][1]
159+
positions = np.load(f"{wsi_0_root_path}.position.npy")
160+
features = np.load(f"{wsi_0_root_path}.features.0.npy")
161+
assert len(positions.shape) == 2
162+
assert len(features.shape) == 2

tiatoolbox/models/engine/semantic_segmentor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import joblib
1414
import numpy as np
1515
import torch
16+
import torch.distributed as dist
1617
import torch.multiprocessing as torch_mp
1718
import torch.utils.data as torch_data
1819
import tqdm
@@ -1421,6 +1422,14 @@ def predict( # noqa: PLR0913
14211422
logger.warning("Unable to remove %s", self._cache_dir)
14221423

14231424
self._memory_cleanup()
1425+
from tiatoolbox.models.architecture.utils import is_torch_compile_compatible
1426+
1427+
if (
1428+
device == "cuda"
1429+
and torch.cuda.device_count() > 1
1430+
and is_torch_compile_compatible()
1431+
): # pragma: no cover
1432+
dist.destroy_process_group()
14241433

14251434
return self._outputs
14261435

tiatoolbox/models/models_abc.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,16 @@
22

33
from __future__ import annotations
44

5+
import os
56
from abc import ABC, abstractmethod
67
from typing import TYPE_CHECKING, Any, Callable
78

89
import torch
910
import torch._dynamo
11+
import torch.distributed as dist
12+
from torch.nn.parallel import DistributedDataParallel
13+
14+
from tiatoolbox.models.architecture.utils import is_torch_compile_compatible
1015

1116
torch._dynamo.config.suppress_errors = True # skipcq: PYL-W0212 # noqa: SLF001
1217

@@ -51,12 +56,30 @@ def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module:
5156
The model after being moved to specified device.
5257
5358
"""
54-
if device != "cpu":
59+
torch_device = torch.device(device)
60+
61+
# Use DDP if multiple GPUs and not on CPU
62+
if (
63+
device == "cuda"
64+
and torch.cuda.device_count() > 1
65+
and is_torch_compile_compatible()
66+
): # pragma: no cover
67+
# This assumes a single-process DDP setup for inference
68+
model = model.to(torch_device)
69+
os.environ["MASTER_ADDR"] = "localhost"
70+
os.environ["MASTER_PORT"] = "12355"
71+
dist.init_process_group(backend="nccl", rank=0, world_size=1)
72+
model = DistributedDataParallel(model, device_ids=[torch_device.index])
73+
74+
elif device != "cpu":
5575
# DataParallel work only for cuda
5676
model = torch.nn.DataParallel(model)
77+
model = model.to(torch_device)
5778

58-
torch_device = torch.device(device)
59-
return model.to(torch_device)
79+
else:
80+
model = model.to(torch_device)
81+
82+
return model
6083

6184

6285
class ModelABC(ABC, torch.nn.Module):

0 commit comments

Comments
 (0)