Skip to content

Commit b770793

Browse files
authored
Update NVTX Range (#2890)
Signed-off-by: Behrooz <[email protected]> * Update special method replacement
1 parent 1ecf5b6 commit b770793

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

monai/utils/nvtx.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _decorate_method(self, obj, method, append_method_name):
9292
name = self.name
9393

9494
# Get the class for special functions
95-
if method.startswith("_"):
95+
if method.startswith("__"):
9696
owner = type(obj)
9797
else:
9898
owner = obj
@@ -109,7 +109,16 @@ def range_wrapper(*args, **kwargs):
109109
return output
110110

111111
# Replace the method with the wrapped version
112-
setattr(owner, method, range_wrapper)
112+
if method.startswith("__"):
113+
# If it is a special method, it requires special attention
114+
class NVTXRangeDecoratedClass(owner):
115+
...
116+
117+
setattr(NVTXRangeDecoratedClass, method, range_wrapper)
118+
obj.__class__ = NVTXRangeDecoratedClass
119+
120+
else:
121+
setattr(owner, method, range_wrapper)
113122

114123
def _get_method(self, obj: Any) -> tuple:
115124
if isinstance(obj, Module):

0 commit comments

Comments
 (0)