Skip to content

Commit 1521cb6

Browse files
ZoroZhaoEssoz
andauthored
Proxy: torch.nn.Parameter subclass (#134)
* feat: basic parameter subclass * fix: avoid proxy or copy the unsuitable object * fix: deepcopy & avoid proxy during dynamo * feat: print the observation trace * feat: dump trace * fix: information collected * fix: update time when setattr * feat: instrument proxyparameter * fix: scan_proxy_in_args * fix: remove the used import * fix: torch.compile compatibility * fix: error message includes proxyparameter * feat: compile mode * fix: checker trace path parsing * fix: make sure all python-level states are copied over when subclassing tensors * add: refine setattr log for better debugging * add: use consistent trace dumping logic for subclass wrapper * Fix proxyparameter deepcopy and dump gating * Route wrappers by tracker style and rename subclass mode * Document subclass model tracker option --------- Co-authored-by: Yuxuan <lessoxx@gmail.com> Co-authored-by: Yuxuan Jiang <jyuxuan@umich.edu>
1 parent 2ab2a3f commit 1521cb6

File tree

15 files changed

+708
-52
lines changed

15 files changed

+708
-52
lines changed

docs/assets/examples/traincheck-collect/gpt2-pretrain-config/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ shscript: run.sh
77
copy_all_files: true
88
models_to_track:
99
- model
10-
model_tracker_style: proxy
10+
model_tracker_style: proxy # [Optional] "proxy" (default), "subclass", or "sampler".

docs/assets/examples/traincheck-collect/mnist-config/config.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ pyscript: mnist.py # The Python entry point of your training program.
44
shscript: run.sh # [Optional] Shell script to launch the program with custom arguments or environment setup.
55
models_to_track: # [Optional] List of variable names for models you want to track. If omitted, model tracking is disabled.
66
- model
7-
model_tracker_style: proxy # [Optional] Method for model tracking. Choose between "proxy" (default) or "sampler".
7+
model_tracker_style: proxy # [Optional] Method for model tracking. Choose between "proxy" (default), "subclass", or "sampler".
88
copy_all_files: false # [Optional] Set to true if your code uses relative paths (e.g., loading local datasets or configs).
99
# This ensures TrainCheck copies the entire working directory before execution.
10-
# Note: TrainCheck automatically handles PYTHONPATH. Default is false.
10+
# Note: TrainCheck automatically handles PYTHONPATH. Default is false.

docs/instr.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ modules_to_instr: # Libraries to instrument. Defaults to ['torch'] if
4848
- torch
4949
models_to_track: # [Optional] Variable names of models to track. Leave empty to disable model tracking.
5050
- model
51-
model_tracker_style: proxy # [Optional] Tracking method: "proxy" (default) or "sampler".
51+
model_tracker_style: proxy # [Optional] Tracking method: "proxy" (default), "subclass", or "sampler".
5252
copy_all_files: false # [Optional] Set true if your code relies on relative paths (e.g., local datasets/configs).
5353
```
5454
@@ -140,4 +140,4 @@ Instructions for defining and injecting meta variables into traces will be provi
140140

141141
## Instrumentation Mechanisms
142142
📌 **[To Be Documented]**
143-
Details about TrainCheck’s instrumentation strategies, supported APIs, and limitations will be covered here later.
143+
Details about TrainCheck’s instrumentation strategies, supported APIs, and limitations will be covered here later.

traincheck/checker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def main():
153153
trace_parent_folders = []
154154
if args.traces is not None:
155155
logger.info("Reading traces from %s", "\n".join(args.traces))
156-
trace_parent_folders = [os.path.basename(os.path.commonpath(args.traces[0]))]
156+
trace_parent_folders = [os.path.basename(os.path.commonpath(args.traces))]
157157
traces.append(read_trace_file(args.traces))
158158
if args.trace_folders is not None:
159159
for trace_folder in args.trace_folders:

traincheck/collect_trace.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def main():
350350
parser.add_argument(
351351
"--model-tracker-style",
352352
type=str,
353-
choices=["sampler", "proxy"],
353+
choices=["sampler", "proxy", "subclass"],
354354
default="proxy",
355355
)
356356
parser.add_argument(
@@ -371,6 +371,11 @@ def main():
371371
action="store_true",
372372
help="Disable automatic variable instrumentation, necessary when the default behavior of the instrumentor is not desired (e.g. cause segmentation fault)",
373373
)
374+
parser.add_argument(
375+
"--use-torch-compile",
376+
action="store_true",
377+
help="Indicate wthether use torch.compile to speed the model, necessary to realize compatibility",
378+
)
374379

375380
args = parser.parse_args()
376381

@@ -444,7 +449,7 @@ def main():
444449
scan_proxy_in_args = not args.disable_scan_proxy_in_args
445450

446451
# if no proxy tracking specified in the arguments, disable the scan_proxy_in_args
447-
if not args.models_to_track or args.model_tracker_style != "proxy":
452+
if not args.models_to_track or args.model_tracker_style == "sampler":
448453
scan_proxy_in_args = False
449454

450455
if args.invariants:
@@ -481,6 +486,7 @@ def main():
481486
output_dir=output_dir,
482487
instr_descriptors=args.instr_descriptors,
483488
no_auto_var_instr=args.no_auto_var_instr,
489+
use_torch_compile=args.use_torch_compile,
484490
)
485491
else:
486492
source_code = instrumentor.instrument_file(
@@ -496,6 +502,7 @@ def main():
496502
output_dir=output_dir,
497503
instr_descriptors=args.instr_descriptors,
498504
no_auto_var_instr=args.no_auto_var_instr,
505+
use_torch_compile=args.use_torch_compile,
499506
)
500507

501508
if args.copy_all_files:

traincheck/config/config.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
]
9090

9191
INSTR_OPTS = None # TODO: set defaults for this variable
92+
MODEL_TRACKER_STYLE: str | None = None
9293

9394
# var dumper related error-backoff configs
9495
TYPE_ERR_THRESHOLD = 3
@@ -105,8 +106,9 @@ def __init__(
105106
assert model_tracker_style in [
106107
"sampler",
107108
"proxy",
109+
"subclass",
108110
None,
109-
], "model_tracker_style should be one of ['sampler', 'proxy', None]"
111+
], "model_tracker_style should be one of ['sampler', 'proxy', 'subclass', None]"
110112

111113
self.funcs_instr_opts: dict[str, dict[str, bool | dict]] = func_instr_opts
112114
self.model_tracker_style = model_tracker_style
@@ -238,6 +240,7 @@ def should_disable_proxy_dumping() -> bool:
238240

239241

240242
INSTR_DESCRIPTORS = False
243+
USE_TORCH_COMPILE = False
241244

242245
ALL_STAGE_NAMES = {
243246
"init",
@@ -249,3 +252,12 @@ def should_disable_proxy_dumping() -> bool:
249252
"preprocessing",
250253
"postprocessing",
251254
}
255+
256+
COMPILE_INTERNAL_MODULE = (
257+
"torch.fx",
258+
# "torch._dynamo",
259+
"torch._inductor",
260+
"torch._subclasses",
261+
"torch._higher_order_ops",
262+
"torch.utils._sympy",
263+
)

traincheck/instrumentor/dumper.py

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818

1919
# if torch.cuda.is_available():
2020
from traincheck.proxy_wrapper.hash import tensor_hash
21+
from traincheck.proxy_wrapper.proxy_basics import is_fake_tensor
2122
from traincheck.proxy_wrapper.proxy_config import (
2223
attribute_black_list,
2324
primitive_types,
25+
proxy_attribute,
2426
tensor_dump_format,
2527
)
26-
from traincheck.utils import get_timestamp_ns, typename
28+
from traincheck.utils import get_timestamp_ns, typename, typename_compile
2729

2830
DEBUG = os.environ.get("ML_DAIKON_DEBUG", False)
2931
THREAD_DATA = threading.local()
@@ -44,12 +46,48 @@
4446
logger = logging.getLogger(__name__)
4547

4648

49+
def _json_default(o):
50+
try:
51+
if type(o).__name__ in ("SymInt", "SymFloat", "SymBool"):
52+
return str(o)
53+
54+
if isinstance(o, torch.device):
55+
return str(o)
56+
if isinstance(o, torch.dtype):
57+
return str(o)
58+
if isinstance(o, torch.Size):
59+
out = []
60+
for d in o:
61+
try:
62+
out.append(int(d))
63+
except Exception:
64+
out.append(str(d))
65+
return out
66+
except Exception:
67+
pass
68+
69+
if isinstance(o, set):
70+
return list(o)
71+
if isinstance(o, tuple):
72+
return list(o)
73+
74+
try:
75+
import numpy as np
76+
77+
if isinstance(o, (np.generic,)):
78+
return o.item()
79+
except Exception:
80+
pass
81+
82+
return repr(o)
83+
84+
4785
def serialize(obj_dict: dict[str, object | str]) -> str:
4886
try:
49-
return orjson.dumps(obj_dict).decode("utf-8")
87+
return orjson.dumps(obj_dict, default=_json_default).decode("utf-8")
5088
except Exception:
5189
# if orjson fails (e.g. cannot handle ints larger than 64-bit), fallback to json
52-
return json.dumps(obj_dict)
90+
return json.dumps(obj_dict, default=_json_default)
5391

5492

5593
def monitor_main_thread(main_thread, stop_event):
@@ -335,6 +373,9 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict
335373
):
336374
continue
337375

376+
if attr_name in proxy_attribute:
377+
continue
378+
338379
if attr_name in attribute_black_list:
339380
continue
340381

@@ -346,12 +387,17 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict
346387

347388
attr = safe_getattr(var, attr_name)
348389
if attr is NOT_FOUND:
349-
logger.warning(
350-
f"Failed to get attribute {attr_name} of object type {type(var)}, skipping it for all following dumps for this attribute."
351-
)
352-
if var_type not in skip_attrs_due_to_errs:
353-
skip_attrs_due_to_errs[var_type] = set()
354-
skip_attrs_due_to_errs[var_type].add(attr_name)
390+
if not (
391+
attr_name == "data"
392+
and isinstance(var, torch.Tensor)
393+
and not include_tensor_data
394+
):
395+
logger.warning(
396+
f"Failed to get attribute {attr_name} of object type {type(var)}, skipping it for all following dumps for this attribute."
397+
)
398+
if var_type not in skip_attrs_due_to_errs:
399+
skip_attrs_due_to_errs[var_type] = set()
400+
skip_attrs_due_to_errs[var_type].add(attr_name)
355401
continue
356402

357403
attr_name = str(attr_name)
@@ -395,7 +441,25 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict
395441
return result
396442

397443

444+
def convert_fake_tensor_to_dict(var):
445+
try:
446+
shape = tuple(var.shape)
447+
except Exception:
448+
shape = None
449+
try:
450+
dtype = str(var.dtype)
451+
except Exception:
452+
dtype = None
453+
return {
454+
"fake": True,
455+
"shape": shape,
456+
"dtype": dtype,
457+
}
458+
459+
398460
def obj_to_serializable(obj, dump_config=None) -> dict[str, object]:
461+
if is_fake_tensor(obj):
462+
return {typename_compile(obj): convert_fake_tensor_to_dict(obj)}
399463
if (
400464
type(obj) in skip_type_due_to_recursion
401465
and skip_type_due_to_recursion[type(obj)] > RECURSION_ERR_THRESHOLD
@@ -429,6 +493,9 @@ def var_to_serializable(obj, dump_config=None) -> dict[str, object]:
429493
If you want to dump the `data` attribute of a tensor, use `convert_var_to_dict` and set `include_tensor_data=True`.
430494
"""
431495

496+
if is_fake_tensor(obj):
497+
return {typename_compile(obj): convert_fake_tensor_to_dict(obj)}
498+
432499
if issubclass(type(obj), dict) and type(obj) != dict: # noqa E721
433500
return obj_to_serializable(obj, dump_config=dump_config)
434501

traincheck/instrumentor/source_file.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def get_child_parent_map(root) -> dict[ast.AST, ast.AST]:
271271

272272

273273
def instrument_all_model_assignments(
274-
source_code: str, model_name: str, mode: str
274+
source_code: str, model_name: str, mode: str | None
275275
) -> str:
276276
"""
277277
Finds all assignment statements to `model` and inserts a Proxy statement or a VarSampler statement
@@ -292,8 +292,15 @@ def instrument_all_model_assignments(
292292
instr_statement = ast.parse(
293293
f"{model_name}_sampler = VarSampler({model_name}, var_name='{model_name}')"
294294
)
295+
elif mode == "subclass":
296+
instr_statement = ast.parse(
297+
f"proxy_parameter({model_name}, logdir=proxy_config.proxy_log_dir, parent_name='{model_name}')"
298+
)
299+
295300
else:
296-
raise ValueError(f"Invalid mode: {mode}. Must be one of ['proxy', 'sampler']")
301+
raise ValueError(
302+
f"Invalid mode: {mode}. Must be one of ['proxy', 'sampler', 'subclass']"
303+
)
297304

298305
# find all assignment statements to `model`
299306
assignments = []
@@ -348,6 +355,7 @@ def instrument_model_tracker_proxy(
348355
models_to_track: list[str],
349356
adjusted_proxy_config: list[dict[str, int | bool | str]],
350357
no_auto_var_instr: bool,
358+
model_tracker_style: str | None,
351359
):
352360
auto_observer_config: dict[str, int | bool | str] = adjusted_proxy_config[0]
353361
proxy_basic_config: dict[str, int | bool | str] = adjusted_proxy_config[1]
@@ -373,8 +381,13 @@ def instrument_model_tracker_proxy(
373381
tensor_dump_format.update({tensor_dump_format})
374382
"""
375383

376-
proxy_start_code += """
384+
if model_tracker_style == "proxy":
385+
proxy_start_code += """
377386
from traincheck.proxy_wrapper.proxy import Proxy
387+
"""
388+
else:
389+
proxy_start_code += """
390+
from traincheck.proxy_wrapper.subclass import proxy_parameter
378391
"""
379392

380393
if auto_observer_config["enable_auto_observer"]:
@@ -435,7 +448,7 @@ def instrument_model_tracker_proxy(
435448
if not no_auto_var_instr:
436449
for model in models_to_track:
437450
instrumented_source = instrument_all_model_assignments(
438-
instrumented_source, model, "proxy"
451+
instrumented_source, model, model_tracker_style
439452
)
440453

441454
code_head, code_tail = get_code_head_and_tail(instrumented_source)
@@ -797,6 +810,7 @@ def instrument_file(
797810
output_dir: str,
798811
instr_descriptors: bool,
799812
no_auto_var_instr: bool,
813+
use_torch_compile: bool,
800814
) -> str:
801815
"""
802816
Instruments the given file and returns the instrumented source code.
@@ -833,20 +847,28 @@ def instrument_file(
833847
general_config_update = f"""
834848
import traincheck.config.config as general_config
835849
general_config.INSTR_DESCRIPTORS = {instr_descriptors}
850+
general_config.MODEL_TRACKER_STYLE = {model_tracker_style!r}
851+
"""
852+
if use_torch_compile:
853+
torch_compile_config_update = """
854+
general_config.USE_TORCH_COMPILE = True
836855
"""
856+
general_config_update = general_config_update + torch_compile_config_update
837857
# TODO: move the INSTR_DESCRIPTORS to the instr_opts file
838858

839859
if models_to_track:
840860
assert model_tracker_style in [
841861
"proxy",
842862
"sampler",
843-
], f"Invalid model tracker style: {model_tracker_style}, must be one of ['proxy', 'sampler']"
844-
if model_tracker_style == "proxy":
863+
"subclass",
864+
], f"Invalid model tracker style: {model_tracker_style}, must be one of ['proxy', 'sampler', 'subclass']"
865+
if model_tracker_style == "proxy" or model_tracker_style == "subclass":
845866
instrumented_source = instrument_model_tracker_proxy(
846867
instrumented_source,
847868
models_to_track,
848869
adjusted_proxy_config,
849870
no_auto_var_instr,
871+
model_tracker_style,
850872
)
851873
else:
852874
instrumented_source = instrument_model_tracker_sampler(

0 commit comments

Comments
 (0)