Skip to content

Commit 13b621d

Browse files
weifengpypytorchmergebot
authored andcommitted
[DTensor] add __repr__ for CommDebugMode(get_total_count()=) (pytorch#165006)
I just want to print CommDebugMode and know if there is communication. implementing `__repr__` for `print(comm_mode)` ``` comm_mode = CommDebugMode() with comm_mode: out = torch.mm(inps, weight) print(comm_mode) # CommDebugMode(get_total_counts()=0) ``` Tags: Pull Request resolved: pytorch#165006 Approved by: https://github.com/anshul-si ghstack dependencies: pytorch#165024
1 parent 01738a3 commit 13b621d

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

torch/distributed/tensor/debug/_comm_mode.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,3 +734,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
734734
].append(operation_dict)
735735

736736
return out
737+
738+
def __repr__(self):
739+
return f"CommDebugMode(get_total_counts()={self.get_total_counts()})"

0 commit comments

Comments
 (0)