Skip to content

Commit 0210651

Browse files
authored
Merge pull request #121 from OrderLab/fix_proxy_changing_model
Fix Proxy Changing Wrapped Model #120
2 parents 7d2ecbe + 73245cc commit 0210651

File tree

10 files changed

+139
-366
lines changed

10 files changed

+139
-366
lines changed

mldaikon/collect_trace.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -352,12 +352,6 @@ def is_path_md_output_dir(output_dir: str) -> bool:
352352
choices=["sampler", "proxy"],
353353
default="proxy",
354354
)
355-
parser.add_argument(
356-
"--proxy-update-limit",
357-
type=float,
358-
default=proxy_config.proxy_update_limit,
359-
help="The threshold for updating the proxy object",
360-
)
361355
parser.add_argument(
362356
"--tensor-dump-format",
363357
choices=["hash", "stats", "full"],
@@ -431,7 +425,6 @@ def is_path_md_output_dir(output_dir: str) -> bool:
431425
# set up adjusted proxy_config
432426
proxy_basic_config: dict[str, int | bool | str] = {}
433427
for configs in [
434-
"proxy_update_limit",
435428
"debug_mode",
436429
"enable_C_level_observer",
437430
]:

mldaikon/config/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def should_disable_proxy_dumping() -> bool:
198198
"__doc__",
199199
"logdir",
200200
"log_level",
201-
"is_root",
201+
"recurse",
202202
"var_name",
203203
"mode",
204204
"process_id",

mldaikon/instrumentor/source_file.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def visit_Assign(self, node):
178178
func=ast.Name(id="Proxy", ctx=ast.Load()),
179179
args=[node.value],
180180
keywords=[
181-
ast.keyword(arg="is_root", value=ast.Constant(value=True)),
181+
ast.keyword(arg="recurse", value=ast.Constant(value=True)),
182182
ast.keyword(
183183
arg="logdir",
184184
value=ast.Attribute(
@@ -187,6 +187,10 @@ def visit_Assign(self, node):
187187
ctx=ast.Load(),
188188
),
189189
),
190+
ast.keyword(
191+
arg="var_name",
192+
value=ast.Constant(value=self.model_name),
193+
),
190194
],
191195
)
192196
return node

mldaikon/proxy_wrapper/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,19 @@
66

77
`import src.proxy_wrapper.proxy as Proxy`
88

9-
- wrap the machine learning model with `is_root = True`
9+
- wrap the machine learning model with `recurse = True`
1010

11-
`model = Proxy(model, is_root = True)`
11+
`model = Proxy(model, recurse = True)`
1212

1313
- Examples:
1414

1515
As shown in line 140 in `./proxyclass_tracer_result/instrumented_mnist.py`
1616

17-
`model=Proxy.Proxy(model, is_root = True, logdir='log-model-proxy-example.log', log_level=logging.INFO)`
17+
`model=Proxy.Proxy(model, recurse = True, logdir='log-model-proxy-example.log', log_level=logging.INFO)`
1818

1919
line 99 in `./proxyclass_tracer_result/instrumented_84911.py`
2020

21-
`model_transfer=Proxy.Proxy(model_transfer, is_root = True, "model_transfer-example.log", log_level = logging.INFO)`
21+
`model_transfer=Proxy.Proxy(model_transfer, recurse = True, "model_transfer-example.log", log_level = logging.INFO)`
2222

2323
Note: Initially, we want to achieve automatic instrumentation via `__new__ wrapper`, wrapping all `torch.nn` modules and add a proxy to the targeted modules inherently. However, it is discovered that this instrumentation would interfere with inherent torch functionality, such as `torch.autograd` behavior.
2424

mldaikon/proxy_wrapper/dumper.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,6 @@ def __call__(cls, *args, **kwargs):
2222
return cls._instances[cls]
2323

2424

25-
class SkippedDumpingObj:
26-
def __init__(self, obj):
27-
self._obj = obj
28-
29-
def __repr__(self):
30-
return f"Skipped Dumping Object: ({self._obj})"
31-
32-
3325
class json_dumper(metaclass=Singleton):
3426
# singleton pattern for shared state
3527
_shared_state = False
@@ -58,12 +50,10 @@ def dump_json(
5850
return
5951

6052
data = {
61-
# "value": var_value,
6253
"var_name": var_name,
6354
"var_type": var_type,
6455
"mode": change_type, # "new", "update"
6556
"dump_loc": dump_loc,
66-
# "stack_trace": stack_trace,
6757
"process_id": process_id,
6858
"thread_id": thread_id,
6959
"time": time,

0 commit comments

Comments
 (0)