11import copy
2- import inspect
3- import json
4- import json .encoder
5- import linecache
62import logging
73import os
84import threading
95import time
106import types
11- import typing
127from typing import Dict
138
149import torch
3126from .utils import print_debug
3227
3328
34- def get_line (filename , lineno ):
35- return linecache .getline (filename , lineno ).strip ()
29+ class ProxyObjInfo :
30+ def __init__ (self , var_name : str , last_update_timestamp : int , version : int | None ):
31+ self .var_name = var_name
32+ self .last_update_timestamp = last_update_timestamp
33+ self .version = version
34+
35+ @staticmethod
36+ def construct_from_proxy_obj (proxy_obj ) -> "ProxyObjInfo" :
37+ return ProxyObjInfo (
38+ proxy_obj .__dict__ ["var_name" ],
39+ proxy_obj .__dict__ ["last_update_timestamp" ],
40+ proxy_obj ._obj ._version if hasattr (proxy_obj ._obj , "_version" ) else None ,
41+ )
42+
43+ def __repr__ (self ):
44+ return f"ProxyObjInfo(var_name={ self .var_name } , last_update_timestamp={ self .last_update_timestamp } , version={ self .version } )"
3645
3746
3847def proxy_handler (
@@ -94,7 +103,7 @@ def generator_proxy_handler():
94103
95104
96105class Proxy :
97- var_dict : Dict [str , typing . Any ] = {}
106+ var_dict : Dict [str , ProxyObjInfo ] = {}
98107 loglevel = logging .INFO
99108 jsondumper = dumper (
100109 os .path .join (os .getenv ("ML_DAIKON_OUTPUT_DIR" , "." ), "proxy_log.json" ) # type: ignore
@@ -122,25 +131,6 @@ def proxy_parameters(module: torch.nn.Module, parent_name="", from_iter=False):
122131 + f"Proxied { num_params } parameters of '{ parent_name + module .__class__ .__name__ } ', duration: { time_end - start_time } seconds"
123132 )
124133
125- @staticmethod
126- def get_frame_array (frame ):
127- frame_array = []
128- while frame :
129- if "mldaikon" in frame .f_code .co_filename :
130- frame = frame .f_back
131- continue
132-
133- # fetch the frame info
134- frame_array .append (
135- (
136- frame .f_code .co_filename ,
137- frame .f_lineno ,
138- get_line (frame .f_code .co_filename , frame .f_lineno ),
139- )
140- )
141- frame = frame .f_back
142- return frame_array
143-
144134 def register_object (self ):
145135 # get_global_registry().add_var(self, self.__dict__["var_name"])
146136 pass
@@ -149,22 +139,20 @@ def dump_trace(
149139 self ,
150140 status ,
151141 only_record = False ,
152- prev_obj = None ,
142+ prev_obj = None , # DEPRECATED: this is necessary for delta dump, but we should do the compare using hashes instead of keeping references to the objects, which leads to memory leaks
153143 prev_trace_info = None ,
154144 disable_sampling = False ,
155145 dump_loc = None ,
156146 ):
157147
158- if Proxy .var_dict .get (self .__dict__ ["var_name" ]) is None :
159- # create
160- self .__dict__ ["last_update_timestamp" ] = 0
161- Proxy .var_dict [self .__dict__ ["var_name" ]] = self
148+ var_name = self .__dict__ ["var_name" ]
149+ assert (
150+ var_name in Proxy .var_dict
151+ ), f"var_name { var_name } is not in var_dict, it has not been proxied yet, check Proxy.__init__() for existence of assignment into Proxy.var_dict"
152+ var_proxy_info = Proxy .var_dict [var_name ]
162153
163154 if (
164- get_timestamp_ns ()
165- - Proxy .var_dict [self .__dict__ ["var_name" ]].__dict__ [
166- "last_update_timestamp"
167- ]
155+ get_timestamp_ns () - var_proxy_info .last_update_timestamp
168156 > proxy_config .proxy_update_limit
169157 or disable_sampling
170158 ):
@@ -182,6 +170,7 @@ def dump_trace(
182170 if isinstance (prev_obj ._obj , torch .Tensor ) and isinstance (
183171 self ._obj , torch .Tensor
184172 ):
173+ # DEPRECATED: this is necessary for delta dump, but we should do the compare using hashes instead of keeping references to the objects, which leads to memory leaks
185174 if not torch .equal (prev_obj ._obj , self ._obj ):
186175 dump_pre_and_post_trace = True
187176 else :
@@ -195,19 +184,10 @@ def dump_trace(
195184 self .__dict__ ["last_update_timestamp" ] = current_time
196185 self .dump_to_trace (prev_obj , prev_trace_info , dump_loc )
197186
198- # record the trace info
199- if proxy_config .debug_mode :
200- frame = inspect .currentframe ()
201- frame_array = self .get_frame_array (frame )
202- dumped_frame_array = json .dumps (frame_array )
203- else :
204- dumped_frame_array = None
205-
206187 current_time = get_timestamp_ns ()
207188 trace_info = {
208189 "time" : current_time ,
209190 "status" : status ,
210- "frame_array" : dumped_frame_array ,
211191 }
212192
213193 if only_record and status == "pre_observe" :
@@ -252,8 +232,6 @@ def dump_to_trace(self, obj, trace_info, dump_loc=None):
252232 status = trace_info ["status" ]
253233 else :
254234 status = "update"
255- if "frame_array" not in trace_info :
256- raise ValueError ("frame_array is not provided in trace_info" )
257235
258236 var_name = self .__dict__ ["var_name" ]
259237 assert (
@@ -265,15 +243,8 @@ def dump_to_trace(self, obj, trace_info, dump_loc=None):
265243 ]
266244 if filter_by_tensor_version and status == "update" :
267245 if hasattr (obj , "_version" ):
268- if (
269- obj ._version
270- == Proxy .var_dict [self .__dict__ ["var_name" ]]._obj ._version
271- ):
246+ if obj ._version == Proxy .var_dict [self .__dict__ ["var_name" ]].version :
272247 return
273- # Strong assertion: the previous type and current type of the object should be the same
274- # assert typename(obj) == typename(
275- # self._obj
276- # ), f"Type of the object is changed from {typename(self._obj)} to {typename(obj)}, needs careful check"
277248
278249 if not issubclass (type (obj ), torch .nn .Module ):
279250 self .jsondumper .dump_json (
@@ -339,12 +310,7 @@ def __init__(
339310 self .__dict__ ["old_value" ] = obj .__dict__ ["old_value" ]
340311 self .__dict__ ["old_meta_vars" ] = obj .__dict__ ["old_meta_vars" ]
341312 return
342- if proxy_config .debug_mode :
343- frame = inspect .currentframe ()
344- frame_array = self .get_frame_array (frame )
345- dumped_frame_array = json .dumps (frame_array )
346- else :
347- dumped_frame_array = None
313+
348314 # inherit the var_name from the parent object
349315 if self .__dict__ ["var_name" ] is not None :
350316 current_var_name_list = self .__dict__ ["var_name" ]
@@ -394,22 +360,26 @@ def __init__(
394360 )
395361
396362 current_var_name_list = current_var_name_list
363+ current_time = get_timestamp_ns ()
397364 if (
398- Proxy . var_dict . get ( current_var_name_list ) is None
365+ current_var_name_list not in Proxy . var_dict
399366 ): # if the object is not proxied yet
400-
401367 self .__dict__ ["_obj" ] = obj
368+
369+ self .__dict__ ["last_update_timestamp" ] = current_time
370+ Proxy .var_dict [current_var_name_list ] = (
371+ ProxyObjInfo .construct_from_proxy_obj (self )
372+ )
373+
402374 dump_call_return = proxy_config .dump_info_config ["dump_call_return" ]
403375 dump_iter = proxy_config .dump_info_config ["dump_iter" ]
404376 if not dump_call_return and from_call :
405377 return
406378 if not dump_iter and from_iter :
407379 return
408380
409- current_time = get_timestamp_ns ()
410381 trace_info = {
411382 "time" : current_time ,
412- "frame_array" : dumped_frame_array ,
413383 }
414384 if dump_trace_info :
415385 if from_call :
@@ -421,8 +391,6 @@ def __init__(
421391 else :
422392 trace_info ["status" ] = "update"
423393 self .dump_to_trace (obj , trace_info , dump_loc = "initing" )
424- self .__dict__ ["last_update_timestamp" ] = current_time
425- Proxy .var_dict [current_var_name_list ] = self
426394
427395 else : # if the object is proxied already
428396 if type (obj ) not in [int , float , str , bool ] and obj is not None :
@@ -431,14 +399,12 @@ def __init__(
431399 )
432400
433401 print_debug (
434- lambda : f' Time elapse: { get_timestamp_ns () - Proxy .var_dict [current_var_name_list ].__dict__ [ " last_update_timestamp" ] } '
402+ lambda : f" Time elapse: { get_timestamp_ns () - Proxy .var_dict [current_var_name_list ].last_update_timestamp } "
435403 )
436404 self .__dict__ ["_obj" ] = obj
437405 if (
438406 get_timestamp_ns ()
439- - Proxy .var_dict [current_var_name_list ].__dict__ [
440- "last_update_timestamp"
441- ]
407+ - Proxy .var_dict [current_var_name_list ].last_update_timestamp
442408 < proxy_config .proxy_update_limit
443409 ):
444410 return
@@ -454,7 +420,6 @@ def __init__(
454420
455421 trace_info = {
456422 "time" : current_time ,
457- "frame_array" : dumped_frame_array ,
458423 }
459424 if dump_trace_info :
460425 if from_call :
@@ -468,7 +433,9 @@ def __init__(
468433
469434 del Proxy .var_dict [current_var_name_list ]
470435 self .__dict__ ["last_update_timestamp" ] = current_time
471- Proxy .var_dict [current_var_name_list ] = self
436+ Proxy .var_dict [current_var_name_list ] = (
437+ ProxyObjInfo .construct_from_proxy_obj (self )
438+ )
472439
473440 @property # type: ignore
474441 def __class__ (self ): # type: ignore[misc]
@@ -515,23 +482,28 @@ def __setattr__(self, name, value):
515482 if name == "_obj" :
516483 self .__dict__ [name ] = value # Set the attribute directly
517484 else :
518- if Proxy .var_dict .get (self .__dict__ ["var_name" ]) is None :
519- self .__dict__ ["last_update_timestamp" ] = 0
520- Proxy .var_dict [self .__dict__ ["var_name" ]] = self
521-
485+ var_name = self .__dict__ ["var_name" ]
486+ assert (
487+ var_name in Proxy .var_dict
488+ ), f"var_name { var_name } is not in var_dict, it has not been proxied yet, check Proxy.__init__() for existence of assignment into Proxy.var_dict"
489+ current_time = get_timestamp_ns ()
522490 print_debug (
523- lambda : f"Time elapse: { get_timestamp_ns () - Proxy . var_dict [ self . __dict__ [ 'var_name' ]] .__dict__ ['last_update_timestamp' ]} "
491+ lambda : f"Time elapse: { get_timestamp_ns () - self .__dict__ ['last_update_timestamp' ]} "
524492 )
493+ self .__dict__ ["last_update_timestamp" ] = current_time
494+ Proxy .var_dict [var_name ].last_update_timestamp = current_time
525495
526496 # update the timestamp of the current object
527497 self .register_object ()
528498
529499 if self .__dict__ ["var_name" ] == "" :
530- var_name = name
500+ global_name = name
531501 else :
532- var_name = self .__dict__ ["var_name" ] + "." + name
502+ global_name = self .__dict__ ["var_name" ] + "." + name
533503
534- print_debug (lambda : f"Setting attribute '{ name } ' to '{ value } '" )
504+ print_debug (
505+ lambda : f"Setting attribute '{ name } ' to '{ value } ', with global name '{ global_name } '"
506+ )
535507
536508 # if self._obj is a tensor already, then deproxify the value
537509 if issubclass (type (self ._obj ), torch .Tensor ):
@@ -544,10 +516,10 @@ def __setattr__(self, name, value):
544516 value ,
545517 logdir = self .logdir ,
546518 log_level = self .log_level ,
547- var_name = var_name ,
519+ var_name = global_name ,
548520 ),
549521 )
550- # dump frame array
522+
551523 if general_config .should_disable_proxy_dumping ():
552524 # do not dump update traces
553525 return None
@@ -604,19 +576,6 @@ def __iter__(self):
604576 __sub__ = proxy_methods .__sub__
605577 __truediv__ = proxy_methods .__truediv__
606578
607- # max = proxy_methods.max
608- # min = proxy_methods.min
609- # size = proxy_methods.size
610-
611- def print_proxy_dict (self , proxy_dict ):
612- # for debugging purpose: print the var_dict of the proxy object
613- print_debug (lambda : "logger_proxy: Dump Proxy Dict: " )
614- for k , value in proxy_dict .items ():
615- if isinstance (value , torch .Tensor ):
616- self .print_tensor (value )
617- else :
618- print_debug (lambda : f"logger_proxy: { k } : { value } " )
619-
620579 @classmethod
621580 def __torch_function__ (cls , func , types , args = (), kwargs = None ):
622581 # 🚨 Ensure Proxy does not interfere with PyTorch dispatch
0 commit comments