|
18 | 18 |
|
19 | 19 | # if torch.cuda.is_available(): |
20 | 20 | from traincheck.proxy_wrapper.hash import tensor_hash |
| 21 | +from traincheck.proxy_wrapper.proxy_basics import is_fake_tensor |
21 | 22 | from traincheck.proxy_wrapper.proxy_config import ( |
22 | 23 | attribute_black_list, |
23 | 24 | primitive_types, |
24 | 25 | proxy_attribute, |
25 | 26 | tensor_dump_format, |
26 | 27 | ) |
27 | | -from traincheck.utils import get_timestamp_ns, typename |
| 28 | +from traincheck.utils import get_timestamp_ns, typename, typename_compile |
28 | 29 |
|
29 | 30 | DEBUG = os.environ.get("ML_DAIKON_DEBUG", False) |
30 | 31 | THREAD_DATA = threading.local() |
|
45 | 46 | logger = logging.getLogger(__name__) |
46 | 47 |
|
47 | 48 |
|
| 49 | +def _json_default(o): |
| 50 | + try: |
| 51 | + if type(o).__name__ in ("SymInt", "SymFloat", "SymBool"): |
| 52 | + return str(o) |
| 53 | + |
| 54 | + if isinstance(o, torch.device): |
| 55 | + return str(o) |
| 56 | + if isinstance(o, torch.dtype): |
| 57 | + return str(o) |
| 58 | + if isinstance(o, torch.Size): |
| 59 | + out = [] |
| 60 | + for d in o: |
| 61 | + try: |
| 62 | + out.append(int(d)) |
| 63 | + except Exception: |
| 64 | + out.append(str(d)) |
| 65 | + return out |
| 66 | + except Exception: |
| 67 | + pass |
| 68 | + |
| 69 | + if isinstance(o, set): |
| 70 | + return list(o) |
| 71 | + if isinstance(o, tuple): |
| 72 | + return list(o) |
| 73 | + |
| 74 | + try: |
| 75 | + import numpy as np |
| 76 | + |
| 77 | + if isinstance(o, (np.generic,)): |
| 78 | + return o.item() |
| 79 | + except Exception: |
| 80 | + pass |
| 81 | + |
| 82 | + return repr(o) |
| 83 | + |
| 84 | + |
48 | 85 | def serialize(obj_dict: dict[str, object | str]) -> str: |
49 | 86 | try: |
50 | | - return orjson.dumps(obj_dict).decode("utf-8") |
| 87 | + return orjson.dumps(obj_dict, default=_json_default).decode("utf-8") |
51 | 88 | except Exception: |
52 | 89 | # if orjson fails (e.g. cannot handle ints larger than 64-bit), fallback to json |
53 | | - return json.dumps(obj_dict) |
| 90 | + return json.dumps(obj_dict, default=_json_default) |
54 | 91 |
|
55 | 92 |
|
56 | 93 | def monitor_main_thread(main_thread, stop_event): |
@@ -350,12 +387,17 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict |
350 | 387 |
|
351 | 388 | attr = safe_getattr(var, attr_name) |
352 | 389 | if attr is NOT_FOUND: |
353 | | - logger.warning( |
354 | | - f"Failed to get attribute {attr_name} of object type {type(var)}, skipping it for all following dumps for this attribute." |
355 | | - ) |
356 | | - if var_type not in skip_attrs_due_to_errs: |
357 | | - skip_attrs_due_to_errs[var_type] = set() |
358 | | - skip_attrs_due_to_errs[var_type].add(attr_name) |
| 390 | + if not ( |
| 391 | + attr_name == "data" |
| 392 | + and isinstance(var, torch.Tensor) |
| 393 | + and not include_tensor_data |
| 394 | + ): |
| 395 | + logger.warning( |
| 396 | + f"Failed to get attribute {attr_name} of object type {type(var)}, skipping it for all following dumps for this attribute." |
| 397 | + ) |
| 398 | + if var_type not in skip_attrs_due_to_errs: |
| 399 | + skip_attrs_due_to_errs[var_type] = set() |
| 400 | + skip_attrs_due_to_errs[var_type].add(attr_name) |
359 | 401 | continue |
360 | 402 |
|
361 | 403 | attr_name = str(attr_name) |
@@ -399,7 +441,25 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict |
399 | 441 | return result |
400 | 442 |
|
401 | 443 |
|
| 444 | +def convert_fake_tensor_to_dict(var): |
| 445 | + try: |
| 446 | + shape = tuple(var.shape) |
| 447 | + except Exception: |
| 448 | + shape = None |
| 449 | + try: |
| 450 | + dtype = str(var.dtype) |
| 451 | + except Exception: |
| 452 | + dtype = None |
| 453 | + return { |
| 454 | + "fake": True, |
| 455 | + "shape": shape, |
| 456 | + "dtype": dtype, |
| 457 | + } |
| 458 | + |
| 459 | + |
402 | 460 | def obj_to_serializable(obj, dump_config=None) -> dict[str, object]: |
| 461 | + if is_fake_tensor(obj): |
| 462 | + return {typename_compile(obj): convert_fake_tensor_to_dict(obj)} |
403 | 463 | if ( |
404 | 464 | type(obj) in skip_type_due_to_recursion |
405 | 465 | and skip_type_due_to_recursion[type(obj)] > RECURSION_ERR_THRESHOLD |
@@ -433,6 +493,9 @@ def var_to_serializable(obj, dump_config=None) -> dict[str, object]: |
433 | 493 | If you want to dump the `data` attribute of a tensor, use `convert_var_to_dict` and set `include_tensor_data=True`. |
434 | 494 | """ |
435 | 495 |
|
| 496 | + if is_fake_tensor(obj): |
| 497 | + return {typename_compile(obj): convert_fake_tensor_to_dict(obj)} |
| 498 | + |
436 | 499 | if issubclass(type(obj), dict) and type(obj) != dict: # noqa E721 |
437 | 500 | return obj_to_serializable(obj, dump_config=dump_config) |
438 | 501 |
|
|
0 commit comments