Skip to content

Commit 7fdd754

Browse files
anijain2305pytorchmergebot
authored andcommitted
[compile-time traces] Profile large missing gaps in compile time (pytorch#151256)
Pull Request resolved: pytorch#151256 Approved by: https://github.com/bdhirsh, https://github.com/masnesral, https://github.com/zou3519, https://github.com/jansel
1 parent ee096b8 commit 7fdd754

File tree

8 files changed

+81
-29
lines changed

8 files changed

+81
-29
lines changed

test/dynamo/test_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,14 +230,21 @@ def test_dynamo_timed(self, mock_time, mock_time_ns):
230230
'_recursive_joint_graph_passes': [0.0],
231231
'_recursive_post_grad_passes': [0.0, 0.0],
232232
'_recursive_pre_grad_passes': [0.0],
233+
'additional_fake_tensor_prop': [0.0, 0.0],
234+
'aot_collect_metadata': [0.0],
235+
'aot_trace_joint_graph': [0.0],
233236
'async_compile.wait': [0.0, 0.0],
234237
'backward._backward_impl': [0.0],
238+
'build_guards': [0.0],
239+
'bytecode_tracing': [0.0],
240+
'compile_attempt_0': [0.0],
235241
'compile_file': [0.0, 0.0],
236242
'compile_fx.<locals>.bw_compiler': [0.0],
237243
'compile_fx.<locals>.fw_compiler_base': [0.0],
238244
'compile_fx_inner': [0.0, 0.0],
239245
'create_aot_dispatcher_function': [0.0],
240-
'gc': [0.0]}""", # noqa: B950
246+
'gc': [0.0],
247+
'min_cut_rematerialization_partition': [0.0]}""", # noqa: B950
241248
)
242249

243250
# Now validate utils.calculate_time_spent(). Formatting the return

torch/_dynamo/backends/inductor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
"""
1414

1515
from torch._dynamo import register_backend
16+
from torch._dynamo.utils import dynamo_timed
1617

1718

1819
@register_backend
1920
def inductor(*args, **kwargs):
20-
# do import here to avoid loading inductor into memory when it is not used
21-
from torch._inductor.compile_fx import compile_fx
21+
with dynamo_timed("inductor_import", log_pt2_compile_event=True):
22+
# do import here to avoid loading inductor into memory when it is not used
23+
from torch._inductor.compile_fx import compile_fx
2224

2325
return compile_fx(*args, **kwargs)

torch/_dynamo/convert_frame.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,7 @@ def transform(
737737
)
738738

739739
try:
740+
tracer.output.mark_bytecode_tracing_start()
740741
with tracing(tracer.output.tracing_context), tracer.set_current_tx():
741742
tracer.run()
742743
except exc.UnspecializeRestartAnalysis:
@@ -810,7 +811,10 @@ def log_bytecode(
810811
for attempt in itertools.count():
811812
CompileContext.get().attempt = attempt
812813
try:
813-
out_code = transform_code_object(code, transform)
814+
with dynamo_timed(
815+
f"compile_attempt_{attempt}", log_pt2_compile_event=True
816+
):
817+
out_code = transform_code_object(code, transform)
814818
break
815819
except exc.RestartAnalysis as e:
816820
if not isinstance(e, exc.TensorifyScalarRestartAnalysis):
@@ -919,13 +923,14 @@ def count_args(code: CodeType) -> int:
919923
assert output.guards is not None
920924
CleanupManager.instance[out_code] = output.cleanups
921925
nonlocal cache_entry
922-
check_fn = CheckFunctionManager(
923-
code,
924-
output,
925-
cache_entry,
926-
hooks.guard_fail_fn if hooks else None,
927-
hooks.guard_filter_fn if hooks else None,
928-
)
926+
with dynamo_timed("build_guards", log_pt2_compile_event=True):
927+
check_fn = CheckFunctionManager(
928+
code,
929+
output,
930+
cache_entry,
931+
hooks.guard_fail_fn if hooks else None,
932+
hooks.guard_filter_fn if hooks else None,
933+
)
929934

930935
compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
931936
annotation_str = "Torch-Compiled Region: " + compile_id_str

torch/_dynamo/output_graph.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,19 @@ def __init__(
517517
self.install_builtins_dict_in_fglobals()
518518
)
519519

520+
self.compiler_trace_stack = contextlib.ExitStack()
521+
522+
def mark_bytecode_tracing_start(self):
523+
self.compiler_trace_stack.enter_context(
524+
dynamo_timed(
525+
"bytecode_tracing",
526+
log_pt2_compile_event=True,
527+
)
528+
)
529+
530+
def mark_bytecode_tracing_stop(self):
531+
self.compiler_trace_stack.close()
532+
520533
def install_builtins_dict_in_fglobals(self):
521534
# f_globals["__builtins__"] can be a dict or a module. This is an
522535
# implemenation detail -
@@ -1068,6 +1081,8 @@ def compile_subgraph(
10681081
Generate a subgraph to continue execution on user code.
10691082
Automatically restore live variables.
10701083
"""
1084+
# bytecode tracing has finished. Pop the context manager for dynamo_timed
1085+
self.mark_bytecode_tracing_stop()
10711086
assert reason is not None
10721087

10731088
from .decorators import disable

torch/_dynamo/symbolic_convert.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,6 +1374,11 @@ def run(self):
13741374
if isinstance(self, InstructionTranslator):
13751375
self.output.cleanup()
13761376

1377+
# Note that this call maybe redundant if compile_subgraph is
1378+
# called. This is ok, because calling exit stack close()
1379+
# twice is not an issue (second stop is a no op).
1380+
self.output.mark_bytecode_tracing_stop()
1381+
13771382
def push(self, val: Optional[VariableTracker]):
13781383
assert val is None or isinstance(val, VariableTracker), (
13791384
f"push expects VariableTracker, got {typestr(val)}"

torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch
2424
import torch.utils.dlpack
2525
from torch import Tensor
26-
from torch._dynamo.utils import detect_fake_mode, lazy_format_graph_code
26+
from torch._dynamo.utils import detect_fake_mode, dynamo_timed, lazy_format_graph_code
2727
from torch._guards import CompileContext, TracingContext
2828
from torch._logging import getArtifactLogger, trace_structured
2929
from torch._subclasses import FakeTensor
@@ -792,9 +792,10 @@ def aot_dispatch_autograd(
792792
)
793793

794794
fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled()
795-
fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(
796-
flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
797-
)
795+
with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True):
796+
fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(
797+
flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
798+
)
798799

799800
# Copied from aot_dispatch_autograd_graph.
800801
disable_amp = torch._C._is_any_autocast_enabled()

torch/_functorch/aot_autograd.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,17 @@ def _dup_fake_script_obj(fake_flat_args):
673673
ctx = _detect_attribute_assignment(mod)
674674
else:
675675
ctx = nullcontext()
676-
with ctx:
676+
677+
if torch._functorch.config.fake_tensor_propagate_real_tensors:
678+
# Running dynamo_timed causes fake tensor issues when
679+
# propagate real tensor is switched on.
680+
dynamo_timed_ctx = nullcontext()
681+
else:
682+
dynamo_timed_ctx = dynamo_timed(
683+
"aot_collect_metadata", log_pt2_compile_event=True
684+
)
685+
686+
with dynamo_timed_ctx, ctx:
677687
fw_metadata = run_functionalized_fw_and_collect_metadata(
678688
flat_fn,
679689
static_input_indices=aot_config.static_input_indices,

torch/_inductor/compile_fx.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,12 +1140,15 @@ def codegen_and_compile(
11401140
# .view() call.
11411141
view_to_reshape(gm)
11421142

1143-
# It is safe to run FakeTensorProp under no_grad because by the time
1144-
# we're in inductor, we assume that AOTAutograd has already "taken care"
1145-
# of autograd, so there should be no more autograd-related API's in the
1146-
# graph.
1147-
with torch.no_grad():
1148-
fake_mode = fake_tensor_prop(gm, example_inputs)
1143+
with dynamo_timed(
1144+
"additional_fake_tensor_prop", log_pt2_compile_event=True
1145+
):
1146+
# It is safe to run FakeTensorProp under no_grad because by the time
1147+
# we're in inductor, we assume that AOTAutograd has already "taken care"
1148+
# of autograd, so there should be no more autograd-related API's in the
1149+
# graph.
1150+
with torch.no_grad():
1151+
fake_mode = fake_tensor_prop(gm, example_inputs)
11491152

11501153
record_original_output_strides(gm)
11511154

@@ -2196,13 +2199,17 @@ def partition_fn(
21962199
static_lifetime_input_indices: Optional[list[int]] = kwargs.pop( # type: ignore[assignment]
21972200
"static_lifetime_input_indices", None
21982201
)
2199-
return min_cut_rematerialization_partition(
2200-
gm,
2201-
joint_inputs,
2202-
compiler="inductor",
2203-
static_lifetime_input_indices=static_lifetime_input_indices,
2204-
**kwargs,
2205-
)
2202+
2203+
with dynamo_utils.dynamo_timed(
2204+
"min_cut_rematerialization_partition", log_pt2_compile_event=True
2205+
):
2206+
return min_cut_rematerialization_partition(
2207+
gm,
2208+
joint_inputs,
2209+
compiler="inductor",
2210+
static_lifetime_input_indices=static_lifetime_input_indices,
2211+
**kwargs,
2212+
)
22062213

22072214
@compile_time_strobelight_meta(phase_name="backward")
22082215
def bw_compiler(

0 commit comments

Comments
 (0)