Skip to content

Commit f6648a3

Browse files
authored
Improve NVTX Range Naming (#3484)
* Update to not include number for the name of the first range Signed-off-by: Behrooz <[email protected]> * Update CuCIM and TorchVision wrappers to include name Signed-off-by: Behrooz <[email protected]> * Update nvtx range to append undelying class for wrapper tranforms Signed-off-by: Behrooz <[email protected]> * Add new test cases to cover changes Signed-off-by: Behrooz <[email protected]> * Update cucim and torchvision check Signed-off-by: Behrooz <[email protected]>
1 parent 7a66f18 commit f6648a3

File tree

4 files changed

+49
-8
lines changed

4 files changed

+49
-8
lines changed

monai/transforms/utility/array.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,7 @@ def __init__(self, name: str, *args, **kwargs) -> None:
10081008
10091009
"""
10101010
super().__init__()
1011+
self.name = name
10111012
transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name)
10121013
self.trans = transform(*args, **kwargs)
10131014

@@ -1196,6 +1197,7 @@ class CuCIM(Transform):
11961197

11971198
def __init__(self, name: str, *args, **kwargs) -> None:
11981199
super().__init__()
1200+
self.name = name
11991201
self.transform, _ = optional_import("cucim.core.operations.expose.transform", name=name)
12001202
self.args = args
12011203
self.kwargs = kwargs

monai/transforms/utility/dictionary.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,6 +1325,7 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F
13251325
13261326
"""
13271327
super().__init__(keys, allow_missing_keys)
1328+
self.name = name
13281329
self.trans = TorchVision(name, *args, **kwargs)
13291330

13301331
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
@@ -1364,6 +1365,7 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F
13641365
13651366
"""
13661367
MapTransform.__init__(self, keys, allow_missing_keys)
1368+
self.name = name
13671369
self.trans = TorchVision(name, *args, **kwargs)
13681370

13691371
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
@@ -1525,6 +1527,7 @@ class CuCIMd(MapTransform):
15251527

15261528
def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:
15271529
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
1530+
self.name = name
15281531
self.trans = CuCIM(name, *args, **kwargs)
15291532

15301533
def __call__(self, data):

monai/utils/nvtx.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,15 @@ def __call__(self, obj: Any):
6262
# Define the name to be associated to the range if not provided
6363
if self.name is None:
6464
name = type(obj).__name__
65+
# If CuCIM or TorchVision transform wrappers are being used,
66+
# append the underlying transform to the name for more clarity
67+
if "CuCIM" in name or "TorchVision" in name:
68+
name = f"{name}_{obj.name}"
6569
self.name_counter[name] += 1
66-
self.name = f"{name}_{self.name_counter[name]}"
70+
if self.name_counter[name] > 1:
71+
self.name = f"{name}_{self.name_counter[name]}"
72+
else:
73+
self.name = name
6774

6875
# Define the methods to be wrapped if not provided
6976
if self.methods is None:

tests/test_nvtx_decorator.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,25 @@
1717

1818
from monai.transforms import (
1919
Compose,
20+
CuCIM,
2021
Flip,
2122
FlipD,
2223
RandAdjustContrast,
24+
RandCuCIM,
2325
RandFlip,
2426
Randomizable,
2527
Rotate90,
28+
ToCupy,
29+
TorchVision,
2630
ToTensor,
2731
ToTensorD,
2832
)
2933
from monai.utils import Range, optional_import
3034

3135
_, has_nvtx = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?")
36+
_, has_cp = optional_import("cupy")
37+
_, has_tvt = optional_import("torchvision.transforms")
38+
_, has_cut = optional_import("cucim.core.operations.expose.transform")
3239

3340

3441
TEST_CASE_ARRAY_0 = [np.random.randn(3, 3)]
@@ -40,10 +47,12 @@
4047
TEST_CASE_TORCH_0 = [torch.randn(3, 3)]
4148
TEST_CASE_TORCH_1 = [torch.randn(3, 10, 10)]
4249

50+
TEST_CASE_WRAPPER = [np.random.randn(3, 10, 10)]
4351

52+
53+
@unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!")
4454
class TestNVTXRangeDecorator(unittest.TestCase):
4555
@parameterized.expand([TEST_CASE_ARRAY_0, TEST_CASE_ARRAY_1])
46-
@unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!")
4756
def test_tranform_array(self, input):
4857
transforms = Compose([Range("random flip")(Flip()), Range()(ToTensor())])
4958
# Apply transforms
@@ -65,11 +74,10 @@ def test_tranform_array(self, input):
6574
self.assertIsInstance(output2, torch.Tensor)
6675
self.assertIsInstance(output3, torch.Tensor)
6776
np.testing.assert_equal(output.numpy(), output1.numpy())
68-
np.testing.assert_equal(output.numpy(), output1.numpy())
77+
np.testing.assert_equal(output.numpy(), output2.numpy())
6978
np.testing.assert_equal(output.numpy(), output3.numpy())
7079

7180
@parameterized.expand([TEST_CASE_DICT_0, TEST_CASE_DICT_1])
72-
@unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!")
7381
def test_tranform_dict(self, input):
7482
transforms = Compose([Range("random flip dict")(FlipD(keys="image")), Range()(ToTensorD("image"))])
7583
# Apply transforms
@@ -94,8 +102,32 @@ def test_tranform_dict(self, input):
94102
np.testing.assert_equal(output.numpy(), output2.numpy())
95103
np.testing.assert_equal(output.numpy(), output3.numpy())
96104

105+
@parameterized.expand([TEST_CASE_WRAPPER])
106+
@unittest.skipUnless(has_cp, "Requires CuPy.")
107+
@unittest.skipUnless(has_cut, "Requires cuCIM transforms.")
108+
@unittest.skipUnless(has_tvt, "Requires torchvision transforms.")
109+
def test_wrapper_tranforms(self, input):
110+
transform_list = [
111+
ToTensor(),
112+
TorchVision(name="RandomHorizontalFlip", p=1.0),
113+
ToCupy(),
114+
CuCIM(name="image_flip", spatial_axis=-1),
115+
RandCuCIM(name="rand_image_rotate_90", prob=1.0, max_k=1, spatial_axis=(-2, -1)),
116+
]
117+
118+
transforms = Compose(transform_list)
119+
transforms_range = Compose([Range()(t) for t in transform_list])
120+
121+
# Apply transforms
122+
output = transforms(input)
123+
124+
# Apply transforms with Range
125+
output_r = transforms_range(input)
126+
127+
# Check the outputs
128+
np.testing.assert_equal(output.get(), output_r.get())
129+
97130
@parameterized.expand([TEST_CASE_ARRAY_1])
98-
@unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!")
99131
def test_tranform_randomized(self, input):
100132
# Compose deterministic and randomized transforms
101133
transforms = Compose(
@@ -136,7 +168,6 @@ def test_tranform_randomized(self, input):
136168
break
137169

138170
@parameterized.expand([TEST_CASE_TORCH_0, TEST_CASE_TORCH_1])
139-
@unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!")
140171
def test_network(self, input):
141172
# Create a network
142173
model = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Sigmoid())
@@ -164,7 +195,6 @@ def test_network(self, input):
164195
np.testing.assert_equal(output.numpy(), output3.numpy())
165196

166197
@parameterized.expand([TEST_CASE_TORCH_0, TEST_CASE_TORCH_1])
167-
@unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!")
168198
def test_loss(self, input):
169199
# Create a network and loss
170200
model = torch.nn.Sigmoid()
@@ -194,7 +224,6 @@ def test_loss(self, input):
194224
np.testing.assert_equal(output.numpy(), output2.numpy())
195225
np.testing.assert_equal(output.numpy(), output3.numpy())
196226

197-
@unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!")
198227
def test_context_manager(self):
199228
model = torch.nn.Sigmoid()
200229
loss = torch.nn.BCELoss()

0 commit comments

Comments
 (0)