Skip to content

Commit 7720913

Browse files
committed
fix: avoid proxying sub-fields of a Tensor (including their .grad and .data) which has led to repeated trace and memory leak
1 parent eef5e7e commit 7720913

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

mldaikon/proxy_wrapper/proxy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,11 @@ def __getattr__(self, name):
469469
return Proxy.__torch_function__
470470
attr = getattr(self._obj, name)
471471

472+
if isinstance(self._obj, torch.Tensor) and isinstance(attr, torch.Tensor):
473+
# we should not proxy sub-tensor fields for a tensor, this can cause circular reference and memory leak
474+
# one caveat with this is that if the code wants to operate on the sub-tensor separately, we will lose track of their updates when they are not proxied
475+
return attr
476+
472477
if self.__dict__["var_name"] == "":
473478
var_name = name
474479
else:

0 commit comments

Comments
 (0)