Skip to content

Commit dedcbd6

Browse files
authored
run model debugging with forward arg (#39905)
* run model debugging a lot simpler * fixup * Update src/transformers/utils/generic.py * fixup * mode syle? * guard a bit
1 parent 20ce210 commit dedcbd6

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

src/transformers/model_debugging_utils.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,23 @@
2121
from io import StringIO
2222
from typing import Optional
2323

24-
from safetensors.torch import save_file
25-
26-
from transformers.utils.import_utils import requires
27-
28-
from .utils import is_torch_available
24+
from .utils.import_utils import is_torch_available, requires
2925

3026

3127
if is_torch_available():
3228
import torch
3329
import torch.distributed.tensor
30+
from safetensors.torch import save_file
3431

35-
32+
# Note to code inspectors: this toolbox is intended for people who add models to `transformers`.
33+
_torch_distributed_available = torch.distributed.is_available()
34+
else:
35+
_torch_distributed_available = False
3636
from .utils import logging
3737

3838

3939
logger = logging.get_logger(__name__)
4040

41-
# Note to code inspectors: this toolbox is intended for people who add models to `transformers`.
42-
_torch_distributed_available = torch.distributed.is_available()
43-
4441

4542
def _is_rank_zero():
4643
"""Return True if rank=0 or we aren't running distributed."""

src/transformers/utils/generic.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
# required for @can_return_tuple decorator to work with torchdynamo
5353
import torch # noqa: F401
5454

55+
from ..model_debugging_utils import model_addition_debugger_context
56+
5557

5658
class cached_property(property):
5759
"""
@@ -1032,7 +1034,13 @@ def make_capture_wrapper(module, orig_forward, key, index):
10321034
def wrapped_forward(*args, **kwargs):
10331035
if key == "hidden_states" and len(collected_outputs[key]) == 0:
10341036
collected_outputs[key] += (args[0],)
1035-
output = orig_forward(*args, **kwargs)
1037+
if kwargs.get("debug_io", False):
1038+
with model_addition_debugger_context(
1039+
module, kwargs.get("debug_io_dir", "~/model_debug"), kwargs.get("prune_layers")
1040+
):
1041+
output = orig_forward(*args, **kwargs)
1042+
else:
1043+
output = orig_forward(*args, **kwargs)
10361044
if not isinstance(output, tuple):
10371045
collected_outputs[key] += (output,)
10381046
elif output[index] is not None:

0 commit comments

Comments
 (0)