Skip to content

Commit eef5e7e

Browse files
committed
fix: save only needed information instead of tensors in Proxy.var_dict to avoid memory leak
1 parent 806b9b8 commit eef5e7e

File tree

2 files changed

+55
-97
lines changed

2 files changed

+55
-97
lines changed

mldaikon/proxy_wrapper/dumper.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def dump_json(
6363
"var_type": var_type,
6464
"mode": change_type, # "new", "update"
6565
"dump_loc": dump_loc,
66-
# "stack_trace": stack_trace,
6766
"process_id": process_id,
6867
"thread_id": thread_id,
6968
"time": time,

mldaikon/proxy_wrapper/proxy.py

Lines changed: 55 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
import copy
2-
import inspect
3-
import json
4-
import json.encoder
5-
import linecache
62
import logging
73
import os
84
import threading
95
import time
106
import types
11-
import typing
127
from typing import Dict
138

149
import torch
@@ -31,8 +26,22 @@
3126
from .utils import print_debug
3227

3328

34-
def get_line(filename, lineno):
35-
return linecache.getline(filename, lineno).strip()
29+
class ProxyObjInfo:
30+
def __init__(self, var_name: str, last_update_timestamp: int, version: int | None):
31+
self.var_name = var_name
32+
self.last_update_timestamp = last_update_timestamp
33+
self.version = version
34+
35+
@staticmethod
36+
def construct_from_proxy_obj(proxy_obj) -> "ProxyObjInfo":
37+
return ProxyObjInfo(
38+
proxy_obj.__dict__["var_name"],
39+
proxy_obj.__dict__["last_update_timestamp"],
40+
proxy_obj._obj._version if hasattr(proxy_obj._obj, "_version") else None,
41+
)
42+
43+
def __repr__(self):
44+
return f"ProxyObjInfo(var_name={self.var_name}, last_update_timestamp={self.last_update_timestamp}, version={self.version})"
3645

3746

3847
def proxy_handler(
@@ -94,7 +103,7 @@ def generator_proxy_handler():
94103

95104

96105
class Proxy:
97-
var_dict: Dict[str, typing.Any] = {}
106+
var_dict: Dict[str, ProxyObjInfo] = {}
98107
loglevel = logging.INFO
99108
jsondumper = dumper(
100109
os.path.join(os.getenv("ML_DAIKON_OUTPUT_DIR", "."), "proxy_log.json") # type: ignore
@@ -122,25 +131,6 @@ def proxy_parameters(module: torch.nn.Module, parent_name="", from_iter=False):
122131
+ f"Proxied {num_params} parameters of '{parent_name + module.__class__.__name__}', duration: {time_end - start_time} seconds"
123132
)
124133

125-
@staticmethod
126-
def get_frame_array(frame):
127-
frame_array = []
128-
while frame:
129-
if "mldaikon" in frame.f_code.co_filename:
130-
frame = frame.f_back
131-
continue
132-
133-
# fetch the frame info
134-
frame_array.append(
135-
(
136-
frame.f_code.co_filename,
137-
frame.f_lineno,
138-
get_line(frame.f_code.co_filename, frame.f_lineno),
139-
)
140-
)
141-
frame = frame.f_back
142-
return frame_array
143-
144134
def register_object(self):
145135
# get_global_registry().add_var(self, self.__dict__["var_name"])
146136
pass
@@ -149,22 +139,20 @@ def dump_trace(
149139
self,
150140
status,
151141
only_record=False,
152-
prev_obj=None,
142+
prev_obj=None, # DEPRECATED: this is necessary for delta dump, but we should do the compare using hashes instead of keeping references to the objects, which leads to memory leaks
153143
prev_trace_info=None,
154144
disable_sampling=False,
155145
dump_loc=None,
156146
):
157147

158-
if Proxy.var_dict.get(self.__dict__["var_name"]) is None:
159-
# create
160-
self.__dict__["last_update_timestamp"] = 0
161-
Proxy.var_dict[self.__dict__["var_name"]] = self
148+
var_name = self.__dict__["var_name"]
149+
assert (
150+
var_name in Proxy.var_dict
151+
), f"var_name {var_name} is not in var_dict, it has not been proxied yet, check Proxy.__init__() for existence of assignment into Proxy.var_dict"
152+
var_proxy_info = Proxy.var_dict[var_name]
162153

163154
if (
164-
get_timestamp_ns()
165-
- Proxy.var_dict[self.__dict__["var_name"]].__dict__[
166-
"last_update_timestamp"
167-
]
155+
get_timestamp_ns() - var_proxy_info.last_update_timestamp
168156
> proxy_config.proxy_update_limit
169157
or disable_sampling
170158
):
@@ -182,6 +170,7 @@ def dump_trace(
182170
if isinstance(prev_obj._obj, torch.Tensor) and isinstance(
183171
self._obj, torch.Tensor
184172
):
173+
# DEPRECATED: this is necessary for delta dump, but we should do the compare using hashes instead of keeping references to the objects, which leads to memory leaks
185174
if not torch.equal(prev_obj._obj, self._obj):
186175
dump_pre_and_post_trace = True
187176
else:
@@ -195,19 +184,10 @@ def dump_trace(
195184
self.__dict__["last_update_timestamp"] = current_time
196185
self.dump_to_trace(prev_obj, prev_trace_info, dump_loc)
197186

198-
# record the trace info
199-
if proxy_config.debug_mode:
200-
frame = inspect.currentframe()
201-
frame_array = self.get_frame_array(frame)
202-
dumped_frame_array = json.dumps(frame_array)
203-
else:
204-
dumped_frame_array = None
205-
206187
current_time = get_timestamp_ns()
207188
trace_info = {
208189
"time": current_time,
209190
"status": status,
210-
"frame_array": dumped_frame_array,
211191
}
212192

213193
if only_record and status == "pre_observe":
@@ -252,8 +232,6 @@ def dump_to_trace(self, obj, trace_info, dump_loc=None):
252232
status = trace_info["status"]
253233
else:
254234
status = "update"
255-
if "frame_array" not in trace_info:
256-
raise ValueError("frame_array is not provided in trace_info")
257235

258236
var_name = self.__dict__["var_name"]
259237
assert (
@@ -265,15 +243,8 @@ def dump_to_trace(self, obj, trace_info, dump_loc=None):
265243
]
266244
if filter_by_tensor_version and status == "update":
267245
if hasattr(obj, "_version"):
268-
if (
269-
obj._version
270-
== Proxy.var_dict[self.__dict__["var_name"]]._obj._version
271-
):
246+
if obj._version == Proxy.var_dict[self.__dict__["var_name"]].version:
272247
return
273-
# Strong assertion: the previous type and current type of the object should be the same
274-
# assert typename(obj) == typename(
275-
# self._obj
276-
# ), f"Type of the object is changed from {typename(self._obj)} to {typename(obj)}, needs careful check"
277248

278249
if not issubclass(type(obj), torch.nn.Module):
279250
self.jsondumper.dump_json(
@@ -339,12 +310,7 @@ def __init__(
339310
self.__dict__["old_value"] = obj.__dict__["old_value"]
340311
self.__dict__["old_meta_vars"] = obj.__dict__["old_meta_vars"]
341312
return
342-
if proxy_config.debug_mode:
343-
frame = inspect.currentframe()
344-
frame_array = self.get_frame_array(frame)
345-
dumped_frame_array = json.dumps(frame_array)
346-
else:
347-
dumped_frame_array = None
313+
348314
# inherit the var_name from the parent object
349315
if self.__dict__["var_name"] is not None:
350316
current_var_name_list = self.__dict__["var_name"]
@@ -394,22 +360,26 @@ def __init__(
394360
)
395361

396362
current_var_name_list = current_var_name_list
363+
current_time = get_timestamp_ns()
397364
if (
398-
Proxy.var_dict.get(current_var_name_list) is None
365+
current_var_name_list not in Proxy.var_dict
399366
): # if the object is not proxied yet
400-
401367
self.__dict__["_obj"] = obj
368+
369+
self.__dict__["last_update_timestamp"] = current_time
370+
Proxy.var_dict[current_var_name_list] = (
371+
ProxyObjInfo.construct_from_proxy_obj(self)
372+
)
373+
402374
dump_call_return = proxy_config.dump_info_config["dump_call_return"]
403375
dump_iter = proxy_config.dump_info_config["dump_iter"]
404376
if not dump_call_return and from_call:
405377
return
406378
if not dump_iter and from_iter:
407379
return
408380

409-
current_time = get_timestamp_ns()
410381
trace_info = {
411382
"time": current_time,
412-
"frame_array": dumped_frame_array,
413383
}
414384
if dump_trace_info:
415385
if from_call:
@@ -421,8 +391,6 @@ def __init__(
421391
else:
422392
trace_info["status"] = "update"
423393
self.dump_to_trace(obj, trace_info, dump_loc="initing")
424-
self.__dict__["last_update_timestamp"] = current_time
425-
Proxy.var_dict[current_var_name_list] = self
426394

427395
else: # if the object is proxied already
428396
if type(obj) not in [int, float, str, bool] and obj is not None:
@@ -431,14 +399,12 @@ def __init__(
431399
)
432400

433401
print_debug(
434-
lambda: f'Time elapse: {get_timestamp_ns() - Proxy.var_dict[current_var_name_list].__dict__["last_update_timestamp"]}'
402+
lambda: f"Time elapse: {get_timestamp_ns() - Proxy.var_dict[current_var_name_list].last_update_timestamp}"
435403
)
436404
self.__dict__["_obj"] = obj
437405
if (
438406
get_timestamp_ns()
439-
- Proxy.var_dict[current_var_name_list].__dict__[
440-
"last_update_timestamp"
441-
]
407+
- Proxy.var_dict[current_var_name_list].last_update_timestamp
442408
< proxy_config.proxy_update_limit
443409
):
444410
return
@@ -454,7 +420,6 @@ def __init__(
454420

455421
trace_info = {
456422
"time": current_time,
457-
"frame_array": dumped_frame_array,
458423
}
459424
if dump_trace_info:
460425
if from_call:
@@ -468,7 +433,9 @@ def __init__(
468433

469434
del Proxy.var_dict[current_var_name_list]
470435
self.__dict__["last_update_timestamp"] = current_time
471-
Proxy.var_dict[current_var_name_list] = self
436+
Proxy.var_dict[current_var_name_list] = (
437+
ProxyObjInfo.construct_from_proxy_obj(self)
438+
)
472439

473440
@property # type: ignore
474441
def __class__(self): # type: ignore[misc]
@@ -515,23 +482,28 @@ def __setattr__(self, name, value):
515482
if name == "_obj":
516483
self.__dict__[name] = value # Set the attribute directly
517484
else:
518-
if Proxy.var_dict.get(self.__dict__["var_name"]) is None:
519-
self.__dict__["last_update_timestamp"] = 0
520-
Proxy.var_dict[self.__dict__["var_name"]] = self
521-
485+
var_name = self.__dict__["var_name"]
486+
assert (
487+
var_name in Proxy.var_dict
488+
), f"var_name {var_name} is not in var_dict, it has not been proxied yet, check Proxy.__init__() for existence of assignment into Proxy.var_dict"
489+
current_time = get_timestamp_ns()
522490
print_debug(
523-
lambda: f"Time elapse: {get_timestamp_ns() - Proxy.var_dict[self.__dict__['var_name']].__dict__['last_update_timestamp']}"
491+
lambda: f"Time elapse: {get_timestamp_ns() - self.__dict__['last_update_timestamp']}"
524492
)
493+
self.__dict__["last_update_timestamp"] = current_time
494+
Proxy.var_dict[var_name].last_update_timestamp = current_time
525495

526496
# update the timestamp of the current object
527497
self.register_object()
528498

529499
if self.__dict__["var_name"] == "":
530-
var_name = name
500+
global_name = name
531501
else:
532-
var_name = self.__dict__["var_name"] + "." + name
502+
global_name = self.__dict__["var_name"] + "." + name
533503

534-
print_debug(lambda: f"Setting attribute '{name}' to '{value}'")
504+
print_debug(
505+
lambda: f"Setting attribute '{name}' to '{value}', with global name '{global_name}'"
506+
)
535507

536508
# if self._obj is a tensor already, then deproxify the value
537509
if issubclass(type(self._obj), torch.Tensor):
@@ -544,10 +516,10 @@ def __setattr__(self, name, value):
544516
value,
545517
logdir=self.logdir,
546518
log_level=self.log_level,
547-
var_name=var_name,
519+
var_name=global_name,
548520
),
549521
)
550-
# dump frame array
522+
551523
if general_config.should_disable_proxy_dumping():
552524
# do not dump update traces
553525
return None
@@ -604,19 +576,6 @@ def __iter__(self):
604576
__sub__ = proxy_methods.__sub__
605577
__truediv__ = proxy_methods.__truediv__
606578

607-
# max = proxy_methods.max
608-
# min = proxy_methods.min
609-
# size = proxy_methods.size
610-
611-
def print_proxy_dict(self, proxy_dict):
612-
# for debugging purpose: print the var_dict of the proxy object
613-
print_debug(lambda: "logger_proxy: Dump Proxy Dict: ")
614-
for k, value in proxy_dict.items():
615-
if isinstance(value, torch.Tensor):
616-
self.print_tensor(value)
617-
else:
618-
print_debug(lambda: f"logger_proxy: {k}: {value}")
619-
620579
@classmethod
621580
def __torch_function__(cls, func, types, args=(), kwargs=None):
622581
# 🚨 Ensure Proxy does not interfere with PyTorch dispatch

0 commit comments

Comments
 (0)