Skip to content

Commit e67a568

Browse files
committed
fix: torch.compile compatibility
1 parent 88d2ab7 commit e67a568

File tree

7 files changed

+139
-34
lines changed

7 files changed

+139
-34
lines changed

traincheck/config/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,12 @@ def should_disable_proxy_dumping() -> bool:
249249
"preprocessing",
250250
"postprocessing",
251251
}
252+
253+
COMPILE_INTERNAL_MODULE = (
254+
"torch.fx",
255+
# "torch._dynamo",
256+
"torch._inductor",
257+
"torch._subclasses",
258+
"torch._higher_order_ops",
259+
"torch.utils._sympy",
260+
)

traincheck/instrumentor/dumper.py

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818

1919
# if torch.cuda.is_available():
2020
from traincheck.proxy_wrapper.hash import tensor_hash
21+
from traincheck.proxy_wrapper.proxy_basics import is_fake_tensor
2122
from traincheck.proxy_wrapper.proxy_config import (
2223
attribute_black_list,
2324
primitive_types,
2425
proxy_attribute,
2526
tensor_dump_format,
2627
)
27-
from traincheck.utils import get_timestamp_ns, typename
28+
from traincheck.utils import get_timestamp_ns, typename, typename_compile
2829

2930
DEBUG = os.environ.get("ML_DAIKON_DEBUG", False)
3031
THREAD_DATA = threading.local()
@@ -45,12 +46,48 @@
4546
logger = logging.getLogger(__name__)
4647

4748

49+
def _json_default(o):
50+
try:
51+
if type(o).__name__ in ("SymInt", "SymFloat", "SymBool"):
52+
return str(o)
53+
54+
if isinstance(o, torch.device):
55+
return str(o)
56+
if isinstance(o, torch.dtype):
57+
return str(o)
58+
if isinstance(o, torch.Size):
59+
out = []
60+
for d in o:
61+
try:
62+
out.append(int(d))
63+
except Exception:
64+
out.append(str(d))
65+
return out
66+
except Exception:
67+
pass
68+
69+
if isinstance(o, set):
70+
return list(o)
71+
if isinstance(o, tuple):
72+
return list(o)
73+
74+
try:
75+
import numpy as np
76+
77+
if isinstance(o, (np.generic,)):
78+
return o.item()
79+
except Exception:
80+
pass
81+
82+
return repr(o)
83+
84+
4885
def serialize(obj_dict: dict[str, object | str]) -> str:
4986
try:
50-
return orjson.dumps(obj_dict).decode("utf-8")
87+
return orjson.dumps(obj_dict, default=_json_default).decode("utf-8")
5188
except Exception:
5289
# if orjson fails (e.g. cannot handle ints larger than 64-bit), fallback to json
53-
return json.dumps(obj_dict)
90+
return json.dumps(obj_dict, default=_json_default)
5491

5592

5693
def monitor_main_thread(main_thread, stop_event):
@@ -350,12 +387,17 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict
350387

351388
attr = safe_getattr(var, attr_name)
352389
if attr is NOT_FOUND:
353-
logger.warning(
354-
f"Failed to get attribute {attr_name} of object type {type(var)}, skipping it for all following dumps for this attribute."
355-
)
356-
if var_type not in skip_attrs_due_to_errs:
357-
skip_attrs_due_to_errs[var_type] = set()
358-
skip_attrs_due_to_errs[var_type].add(attr_name)
390+
if not (
391+
attr_name == "data"
392+
and isinstance(var, torch.Tensor)
393+
and not include_tensor_data
394+
):
395+
logger.warning(
396+
f"Failed to get attribute {attr_name} of object type {type(var)}, skipping it for all following dumps for this attribute."
397+
)
398+
if var_type not in skip_attrs_due_to_errs:
399+
skip_attrs_due_to_errs[var_type] = set()
400+
skip_attrs_due_to_errs[var_type].add(attr_name)
359401
continue
360402

361403
attr_name = str(attr_name)
@@ -399,7 +441,25 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict
399441
return result
400442

401443

444+
def convert_fake_tensor_to_dict(var):
445+
try:
446+
shape = tuple(var.shape)
447+
except Exception:
448+
shape = None
449+
try:
450+
dtype = str(var.dtype)
451+
except Exception:
452+
dtype = None
453+
return {
454+
"fake": True,
455+
"shape": shape,
456+
"dtype": dtype,
457+
}
458+
459+
402460
def obj_to_serializable(obj, dump_config=None) -> dict[str, object]:
461+
if is_fake_tensor(obj):
462+
return {typename_compile(obj): convert_fake_tensor_to_dict(obj)}
403463
if (
404464
type(obj) in skip_type_due_to_recursion
405465
and skip_type_due_to_recursion[type(obj)] > RECURSION_ERR_THRESHOLD
@@ -433,6 +493,9 @@ def var_to_serializable(obj, dump_config=None) -> dict[str, object]:
433493
If you want to dump the `data` attribute of a tensor, use `convert_var_to_dict` and set `include_tensor_data=True`.
434494
"""
435495

496+
if is_fake_tensor(obj):
497+
return {typename_compile(obj): convert_fake_tensor_to_dict(obj)}
498+
436499
if issubclass(type(obj), dict) and type(obj) != dict: # noqa E721
437500
return obj_to_serializable(obj, dump_config=dump_config)
438501

traincheck/instrumentor/tracer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
)
3232
from traincheck.proxy_wrapper.proxy_basics import (
3333
is_proxied,
34-
is_proxyparamtetr,
34+
is_proxyparameter,
3535
unproxy_func,
3636
)
3737
from traincheck.proxy_wrapper.proxy_config import enable_C_level_observer
@@ -219,7 +219,7 @@ def global_wrapper(
219219

220220
def find_proxy_in_args(args):
221221
for i, arg in enumerate(args):
222-
if is_proxied(arg) or is_proxyparamtetr(arg):
222+
if is_proxied(arg) or is_proxyparameter(arg):
223223
proxy_in_args.append(arg)
224224
elif type(arg) in [list, tuple]:
225225
find_proxy_in_args(arg)
@@ -238,7 +238,7 @@ def find_proxy_in_args(args):
238238
if "proxy_obj_names" not in pre_record:
239239
pre_record["proxy_obj_names"] = []
240240
for proxy in proxy_in_args:
241-
if is_proxyparamtetr(proxy):
241+
if is_proxyparameter(proxy):
242242
pre_record["proxy_obj_names"].append(
243243
[proxy.__dict__["var_name"], "Parameter"]
244244
)

traincheck/proxy_wrapper/proxy_basics.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,58 @@
44

55
import astor
66

7+
from traincheck.config.config import COMPILE_INTERNAL_MODULE
8+
9+
10+
def is_compile_internal_module(obj):
11+
mod = getattr(type(obj), "__module__", "") or ""
12+
if any(mod.startswith(p) for p in COMPILE_INTERNAL_MODULE):
13+
return True
14+
name = type(obj).__name__
15+
if mod.startswith("torch._dynamo") and name != "OptimizedModule":
16+
return True
17+
return False
18+
19+
20+
def is_fake_tensor(x) -> bool:
21+
try:
22+
from torch._subclasses.fake_tensor import FakeTensor
23+
from torch.fx import Proxy as FxProxy
24+
25+
if isinstance(x, FakeTensor):
26+
return True
27+
if isinstance(x, FxProxy):
28+
return True
29+
except Exception:
30+
pass
31+
32+
try:
33+
if is_compile_internal_module(x):
34+
return True
35+
except Exception:
36+
return True
37+
38+
try:
39+
return x.device.type == "meta"
40+
except Exception:
41+
return True
42+
743

844
def is_proxied(obj):
945
try:
46+
if is_fake_tensor(obj):
47+
return False
1048
if obj is not None and "is_traincheck_proxied_obj" in obj.__dict__:
1149
return True
1250
except Exception:
1351
return False
1452
return False
1553

1654

17-
def is_proxyparamtetr(obj):
55+
def is_proxyparameter(obj):
1856
try:
57+
if is_fake_tensor(obj):
58+
return False
1959
if obj is not None and "is_traincheck_proxyparameter" in obj.__dict__:
2060
return True
2161
except Exception:

traincheck/proxy_wrapper/proxy_observer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from traincheck.proxy_wrapper.proxy import Proxy
1010
from traincheck.proxy_wrapper.subclass import ProxyParameter
1111

12-
from .proxy_basics import is_proxied, unproxy_func
12+
from .proxy_basics import is_proxied, is_proxyparameter, unproxy_func
1313

1414

1515
def observe_proxy_var(
@@ -41,9 +41,9 @@ def wrapper(*args, **kwargs):
4141
# if the arg is list or tuple, check if it contains proxied object
4242
if type(arg) in [list, tuple]:
4343
for element in arg:
44-
if is_proxied(element) or isinstance(element, ProxyParameter):
44+
if is_proxied(element) or is_proxyparameter(element):
4545
proxied_vars.append(element)
46-
if is_proxied(arg) or isinstance(arg, ProxyParameter):
46+
if is_proxied(arg) or is_proxyparameter(arg):
4747
proxied_vars.append(arg)
4848

4949
# pre observe

traincheck/proxy_wrapper/subclass.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from traincheck.utils import get_timestamp_ns
1010

1111
from .dumper import json_dumper as dumper
12+
from .proxy_basics import is_fake_tensor
1213

1314
# from .proxy_registry import get_global_registry
1415
# from .utils import print_debug
@@ -23,22 +24,6 @@ def in_dynamo() -> bool:
2324
return False
2425

2526

26-
def is_fake_tensor(x: torch.Tensor) -> bool:
27-
try:
28-
from torch._subclasses.fake_tensor import FakeTensor # 2.x
29-
30-
if isinstance(x, FakeTensor):
31-
return True
32-
except Exception:
33-
pass
34-
if getattr(x, "fake_mode", None) is not None:
35-
return True
36-
if getattr(x, "_is_fake", False):
37-
return True
38-
39-
return isinstance(x, torch.Tensor) and x.device.type == "meta"
40-
41-
4227
class ProxyParameter(torch.nn.Parameter):
4328
loglevel = logging.INFO
4429
jsondumper = dumper(
@@ -59,13 +44,13 @@ def __new__(
5944
# TODO
6045
# from_copy=False,
6146
):
47+
if isinstance(data, ProxyParameter):
48+
return data
6249
if in_dynamo() or is_fake_tensor(data):
6350
if isinstance(data, nn.Parameter):
6451
return data
6552
return nn.Parameter(data, requires_grad=data.requires_grad)
6653
# TODO: verify
67-
if isinstance(data, ProxyParameter):
68-
return data
6954

7055
return torch.Tensor._make_subclass(cls, data.detach(), data.requires_grad)
7156

traincheck/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ def safe_getattr(obj, attr, default=None):
3535
raise
3636

3737

38+
def typename_compile(o):
39+
try:
40+
mod = getattr(type(o), "__module__", "") or ""
41+
return f"{mod}.{type(o).__name__}"
42+
except Exception:
43+
return "compile_stage"
44+
45+
3846
def typename(o, is_runtime=False):
3947
if isinstance(o, torch.nn.Parameter):
4048
return "torch.nn.Parameter"

0 commit comments

Comments
 (0)