Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
177 commits
Select commit Hold shift + click to select a range
6605ac9
Backend optimized, naive search setup
mattteochen Jul 2, 2024
7cbecfd
Merge branch 'Lightning-AI:main' into looking-around
mattteochen Jul 2, 2024
4ec2fa0
Single trace region placement impl
mattteochen Jul 4, 2024
8d084a8
Serial exaustive search
mattteochen Jul 9, 2024
7781d75
Serial greedy search
mattteochen Jul 10, 2024
99eb302
Extended incremental strat with a fusion try after the greedy search …
mattteochen Jul 11, 2024
2dbef84
Fixed key error after fusion_pass (symbol deleted by CSE)
mattteochen Jul 12, 2024
8c4fd7d
Removed from tracking
mattteochen Jul 12, 2024
4651f26
Removed from tracking and updated tests
mattteochen Jul 12, 2024
83a1bea
Merge branch 'main' into automatic-backends-placement
mattteochen Jul 12, 2024
f09f9b3
Removed import
mattteochen Jul 12, 2024
17df099
Change timing function
mattteochen Jul 12, 2024
bdca10b
Cleaned impl and updated debug files generation
mattteochen Jul 12, 2024
a9f0dbd
Moved already present fusion regions count at FusionOperator class le…
mattteochen Jul 12, 2024
3bfb690
Wip
mattteochen Jul 15, 2024
6f93cba
Added torch empty cache during benchmark execution / added trace prov…
mattteochen Jul 16, 2024
0945159
Wip fuser ex focus
mattteochen Jul 16, 2024
24ad509
Fusion executor placements / runtime vs memory placements options
mattteochen Jul 19, 2024
d94e841
Updated test model
mattteochen Jul 19, 2024
77b7871
Merge branch 'main' into automatic-backends-placement
mattteochen Jul 22, 2024
15f51bf
Enhanced bw trace placement seach / support for more models (#3)
mattteochen Jul 26, 2024
0f7ced3
Updated test
mattteochen Jul 26, 2024
579888d
Fixed bad list index / removed print
mattteochen Jul 28, 2024
3528e6c
Disabled graphviz / modified test runner
mattteochen Jul 28, 2024
f67bb1c
Using user defined executor list or default as unique executors ref i…
mattteochen Jul 28, 2024
5f76bcf
Computing bw traces taking in consideration every fw traces options (…
mattteochen Jul 29, 2024
5dedbff
Before `transform_for_execution` executors placement autotune / nvsig…
mattteochen Aug 1, 2024
15f914e
Fixed remat visual timing / Switched towards `memory` strat / General…
mattteochen Aug 2, 2024
89bba8d
Testing different compile options for `nvfuser` (#7)
mattteochen Aug 2, 2024
5be42dc
Benchmarking different compile options for nvFuser (#8)
mattteochen Aug 5, 2024
fceed7e
Moved `runtime` or `memory` selection at the end of seach when all ca…
mattteochen Aug 6, 2024
53da20c
Refactoring autotune code (#10)
mattteochen Aug 6, 2024
80a11e6
Enlarged benchmark iters and use no remat traces a chance to be the o…
mattteochen Aug 7, 2024
04b3139
Fixed nv fuser compile options, now good traces will be generated (#12)
mattteochen Aug 7, 2024
fd44dc2
Allowing duplicates during eval trace write / Added missing fn call d…
mattteochen Aug 7, 2024
670a582
Updated test runner
mattteochen Aug 7, 2024
dfc7fdc
Prev commit
mattteochen Aug 7, 2024
9b8eb4d
Added comment
mattteochen Aug 7, 2024
823398a
Updated log
mattteochen Aug 9, 2024
afed225
Merge branch 'main' into automatic-backends-placement
mattteochen Aug 12, 2024
a0f0d2b
Defaults empty executors list to all executors
mattteochen Aug 12, 2024
9b32158
Fixed formatting
mattteochen Aug 13, 2024
fdeddc7
Cuda graphs integration / minors con benchmarks and profiler / made i…
mattteochen Aug 13, 2024
31c781e
Restore nanogpt config
mattteochen Aug 13, 2024
03188f8
Added nvsight to bench
mattteochen Aug 14, 2024
74c5760
Restored old value
mattteochen Aug 14, 2024
51f4592
Updated nsight iter / updated test models
mattteochen Aug 14, 2024
0a101b1
Using cuda graphs only is not disabled
mattteochen Aug 14, 2024
940a704
Updated bench fn
mattteochen Aug 14, 2024
9ef9d0f
Fixed nanogpt test
mattteochen Aug 14, 2024
20992d6
Updated log
mattteochen Aug 14, 2024
0ce7921
Updated log
mattteochen Aug 14, 2024
dcb6326
Added gitignore
mattteochen Aug 14, 2024
9f6ccb1
Added fa3 to autotuner
mattteochen Aug 14, 2024
57e6163
Fixed log
mattteochen Aug 14, 2024
cc96535
Disabled te for now / updated exceptions log
mattteochen Aug 14, 2024
89f8d3f
Refactored and fixed autotuning that requires fw and bw split handlin…
mattteochen Aug 19, 2024
30ef6ca
Disabling reverse search
mattteochen Aug 19, 2024
c43e66a
Transformer Engine support (#15)
mattteochen Aug 20, 2024
79dd4d2
Unified args building fn
mattteochen Aug 20, 2024
e91e0ea
Beam search for fw bw split operators
mattteochen Aug 20, 2024
c0f73c5
Beam search for fw bw split operators
mattteochen Aug 20, 2024
9b1d0cb
Fixed issues about remat and solved sdpa backward pass benchmark issu…
mattteochen Aug 21, 2024
2eeb6fc
Updated test
mattteochen Aug 21, 2024
090b595
Removed print
mattteochen Aug 21, 2024
d030b1e
Updated tests
mattteochen Aug 21, 2024
bb53a5e
Supporting tuples
mattteochen Aug 21, 2024
74ae689
Updated example
mattteochen Aug 21, 2024
7581651
Enabled log and modified tests
mattteochen Aug 21, 2024
fc090c8
Fixed executors list for gradfn picking
mattteochen Aug 21, 2024
df17770
Adding fusion ex to executors list if not present
mattteochen Aug 22, 2024
f7a8b16
Benchmark bw fn with runtime args for all traces
mattteochen Aug 22, 2024
a0cc9ee
Restore input
mattteochen Aug 22, 2024
96a4f93
Updated litgpt runner
mattteochen Aug 22, 2024
2e796a3
Updated litgpt runner
mattteochen Aug 22, 2024
a4d5fa5
Enhanced logs
mattteochen Aug 22, 2024
0782407
Unpacking sequences during search of not used proxies
mattteochen Aug 22, 2024
ed1a2e2
Updated comment
mattteochen Aug 22, 2024
c1560b0
Updated model tests
mattteochen Aug 22, 2024
b890ac6
Updated model tests
mattteochen Aug 22, 2024
4228ff8
Updated log and comment
mattteochen Aug 23, 2024
3682c82
Fixed comment
mattteochen Aug 23, 2024
f6b8d16
Removed file
mattteochen Aug 23, 2024
14074c4
Fixed nvsight bench when args need to be cloned as in TE
mattteochen Aug 23, 2024
706ff91
Benchmarking TE on llama
mattteochen Aug 23, 2024
4c32776
nvmath ex, integrated matmul
mattteochen Aug 23, 2024
0b9c5fe
Removed old counter
mattteochen Aug 24, 2024
511fe4b
Restored imports
mattteochen Aug 24, 2024
8cce6df
Restored line order
mattteochen Aug 24, 2024
cc42db2
Updated torch_compile_ex to synch with main
mattteochen Aug 24, 2024
7b55a71
Merge branch 'main' into nvmath
mattteochen Aug 24, 2024
c6bcdcc
Fixed print
mattteochen Aug 24, 2024
67ee055
Skipping single trace region candidate
mattteochen Aug 24, 2024
494fa73
Debug for single trace regions
mattteochen Aug 26, 2024
ed6013a
Print if no nvsight
mattteochen Aug 26, 2024
c6270fb
Updated litgpt runner
mattteochen Aug 26, 2024
656e4e3
Fixed cd assertion and print
mattteochen Aug 26, 2024
9468802
Fixed cached update and restore missing args check
mattteochen Aug 26, 2024
f3713b3
Removed visualizer
mattteochen Aug 26, 2024
1d0224e
Prev commit
mattteochen Aug 26, 2024
f93c950
Updated test runner
mattteochen Aug 26, 2024
5d7bd9d
Fixed cache
mattteochen Aug 26, 2024
9317561
Removed print
mattteochen Aug 26, 2024
3aaf44d
Disabled debug
mattteochen Aug 26, 2024
5b6deb0
Updated litgpt
mattteochen Aug 26, 2024
550b639
Unit tests and / minors changes for compilation to gain flexibility /…
mattteochen Aug 26, 2024
dfc70b4
Fix appended label
mattteochen Aug 27, 2024
3b3009b
Updated comments and removed import
mattteochen Aug 27, 2024
4ef329b
New tests and linter
mattteochen Aug 27, 2024
e33f29c
Formatter
mattteochen Aug 27, 2024
5ce0002
Fixed tensor device
mattteochen Aug 27, 2024
388d7d0
Added cuda guard
mattteochen Aug 27, 2024
19953cd
Formatter
mattteochen Aug 27, 2024
04124be
Changed file name
mattteochen Aug 27, 2024
44a8939
Restored old value
mattteochen Aug 27, 2024
2bc7546
Restored flag
mattteochen Aug 27, 2024
0fbcbb2
Torch compiler reset
mattteochen Aug 28, 2024
fbf94e0
Updated litgpt runner
mattteochen Aug 28, 2024
be3912a
Added guard for args cloning
mattteochen Aug 28, 2024
7a1cf14
Added guard for name attribute
mattteochen Aug 28, 2024
e746c58
Fixed var overwritten
mattteochen Aug 28, 2024
de67643
Tests
mattteochen Aug 28, 2024
56c8eaf
Updated litgpt
mattteochen Aug 28, 2024
ff06125
Removed not used
mattteochen Aug 28, 2024
637d5ce
Wip on common transformer block replacement
mattteochen Aug 28, 2024
fec126a
Transformer block optimization
mattteochen Aug 30, 2024
cbc4bb6
Fixed bad def value
mattteochen Aug 30, 2024
a14a155
Fixed comment
mattteochen Aug 30, 2024
5e0c379
Fixed comment
mattteochen Aug 30, 2024
dc9d181
Changed log level
mattteochen Aug 30, 2024
b80ffbe
Formatted comment
mattteochen Aug 30, 2024
62b75ef
Updated runner
mattteochen Aug 30, 2024
a9a9a3f
Enabled te and nvFuser compile options from thunder jit / updated tests
mattteochen Aug 30, 2024
448fd8a
Disabled cudagraphs
mattteochen Aug 31, 2024
b08d9e1
Restricting the same executor in vjp pass if common trace block opt i…
mattteochen Aug 31, 2024
4b05a39
Docs
mattteochen Aug 31, 2024
701b36c
Docs and cleaning
mattteochen Aug 31, 2024
cce3ea3
Docs and reorganization
mattteochen Aug 31, 2024
42e4b6b
Using python logger
mattteochen Sep 1, 2024
e494514
Moved cache to compile data
mattteochen Sep 1, 2024
00bff88
Updated logger
mattteochen Sep 1, 2024
2d8eace
Trace configuration dumps and restore (#20)
mattteochen Sep 5, 2024
625b473
Doc
mattteochen Sep 5, 2024
ed12755
Removed trace print
mattteochen Sep 17, 2024
519855a
Enhanced timing measurement / autotuned nvmath matmul
mattteochen Sep 17, 2024
f09a813
Jit doc
mattteochen Sep 17, 2024
a86630c
Renamed class
mattteochen Sep 17, 2024
9383c5e
Removed print
mattteochen Sep 17, 2024
e5f4e7d
Updated doc
mattteochen Sep 17, 2024
0f98689
Updated doc and removed unused imports
mattteochen Sep 17, 2024
38f176a
Restored partial trace benchmark options
mattteochen Sep 17, 2024
a3c037c
Fixed function name typo and renamed dir
mattteochen Sep 17, 2024
46e3a71
Changed optimizaton type
mattteochen Sep 17, 2024
ff9f661
Formatter
mattteochen Sep 17, 2024
cdc9157
Enhanced logs description
mattteochen Sep 18, 2024
3990619
Merge branch 'main' into develop
mattteochen Sep 18, 2024
6954062
Small fixes to align main
mattteochen Sep 18, 2024
87a79fd
Fixed OOM errors during trace benchmarks leading to a premature end o…
mattteochen Sep 19, 2024
75b856d
Handled nvmath missing installation
mattteochen Sep 19, 2024
f2b934b
Prev commit
mattteochen Sep 19, 2024
42401b1
Prev commit
mattteochen Sep 19, 2024
02f5038
Added comment
mattteochen Sep 19, 2024
fadafe4
Updated test runner
mattteochen Sep 19, 2024
e4707fa
Log applied executors
mattteochen Sep 19, 2024
b728ab6
Torch timer for benchmarks
mattteochen Sep 19, 2024
bb30805
Doc
mattteochen Sep 19, 2024
191f403
Updated Anyproxy hash / updated test runner file
mattteochen Sep 19, 2024
7f64dbf
Formatter
mattteochen Sep 19, 2024
350c451
Restored manual benchmark configuration / added env var for test runner
mattteochen Sep 19, 2024
580050a
Disabled flag
mattteochen Sep 20, 2024
8bbe1ec
Updated doc
mattteochen Sep 20, 2024
df0469e
Autotuner for jit with no autograd
mattteochen Sep 20, 2024
7f57584
Added CUDA barrier for unit test
mattteochen Sep 20, 2024
82017d2
Integrated autotuner in benchmark script
mattteochen Sep 20, 2024
26044a3
Added missing flag
mattteochen Sep 20, 2024
0a17128
Fixed comments
mattteochen Sep 20, 2024
afb5637
Removed comment
mattteochen Sep 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/autotuner/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*.log
*.txt
*.pickle
*.nsys-rep
55 changes: 55 additions & 0 deletions examples/autotuner/LLaMAMLP.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
This benchmark script is intended to demonstrate the autotuner on a generic model.
No executor are given leaving full responsibility to Thunder.
"""

import torch
import thunder
from thunder.benchmarks.utils import torch_timer_total_benchmark, torch_total_benchmark


class LLaMAMLP(torch.nn.Module):
def __init__(self, n_embd, intermediate_size) -> None:
super().__init__()
self.fc_1 = torch.nn.Linear(n_embd, intermediate_size, bias=False)
self.fc_2 = torch.nn.Linear(n_embd, intermediate_size, bias=False)
self.proj = torch.nn.Linear(intermediate_size, n_embd, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = torch.nn.functional.silu(x_fc_1) * x_fc_2
return self.proj(x)


with torch.device("cuda"):
mult = 2
a = 4096 * mult
b = 11008 * mult
x = torch.randn(4, 2048, a, requires_grad=True)

model = LLaMAMLP(a, b)

eager = model
torchcompile = torch.compile(model)
jmodel_def = thunder.jit(model)
jmodel_auto = thunder.jit(
model,
autotune_type="runtime",
autotune_enable_te=True,
autotune_nv_enable_options=True,
model_name="LLaMAMLP",
autotune_save_configuration=True,
)

print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item())
print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item())

iters = 100
callables = [eager, torchcompile, jmodel_def, jmodel_auto]
labels = ["eager", "torchcompile", "Thunder", "Thunder Autotuned"]
inputs = [x, x, x, x]
print("\nResults with torch total benchmark:")
torch_total_benchmark(callables, labels, inputs, iters)
print("\nResults with torch timer benchmark:")
torch_timer_total_benchmark(callables, labels, inputs, "LlamaMLP")
101 changes: 101 additions & 0 deletions examples/autotuner/litGPT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
This script benchmarks litGPT models in a easier way wrt thunder.benchmarks.benchmark_litgpt.py with a fake training loop with no optimizers.
"""

from litgpt import GPT
from thunder.benchmarks.utils import torch_total_benchmark, torch_timer_total_benchmark
from thunder.tests.litgpt_model import Config
import thunder
import torch
import time
from pprint import pprint

torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn

# import os
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

class LitGPTModelThunderConfig:
def __init__(
self,
layers: int,
autotune_type: str,
batch_size: int,
seq_len: int = -1,
model_name: str = "Llama-3-8B",
executors=None,
optimize_transformer_blocks=True,
optimize_transformer_min_block_size=60, # for llama3
) -> None:
self.layers = layers
self.autotune_type = autotune_type
self.batch_size = batch_size
self.seq_len = seq_len
self.model_name = model_name
self.executors = executors
self.optimize_transformer_blocks = optimize_transformer_blocks
self.optimize_transformer_min_block_size = optimize_transformer_min_block_size


to_run = [
LitGPTModelThunderConfig(
1,
"runtime",
2,
executors=[
"cudnn",
"sdpa",
"fa3",
"nvfuser",
"nvmath",
"torchcompile",
],
),
]

for test in to_run:
try:
cfg = Config.from_name(test.model_name)
cfg.n_layer = test.layers
if test.seq_len != -1:
cfg.block_size = test.seq_len
torch.set_default_dtype(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16)
pprint(cfg)
print("Batch size:", test.batch_size)
with torch.device("cuda"):
model = GPT(cfg)
x = torch.randint(1, model.config.vocab_size, (test.batch_size, cfg.block_size))
target = torch.ones_like(x)

eager = model
torch_compile = torch.compile(model)
jmodel_def = thunder.jit(model)
jmodel_auto = thunder.jit(
model,
autotune_type=test.autotune_type,
executors=test.executors,
autotune_optimize_common_blocks=test.optimize_transformer_blocks,
autotune_optimize_common_blocks_min_size=test.optimize_transformer_min_block_size,
)
print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item())
s = time.time_ns()
print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item())
e = time.time_ns()
print("Compilation time:", {(e - s) / 1000000000}, "s")

iters = 100
callables = [eager, torch_compile, jmodel_def, jmodel_auto]
labels = ["eager", "torch.compile", "Thunder", "Thunder Autotuner"]
inputs = [x, x, x, x]
print(f"\nResults torch total benchmark ({iters} iters):")
torch_total_benchmark(callables, labels, inputs, iters, torch.nn.functional.cross_entropy)
print(f"\nResults torch timer benchmark ({iters} iters):")
torch_timer_total_benchmark(callables, labels, inputs, test.model_name, torch.nn.functional.cross_entropy)

print(f'Executors employed: {thunder.executors_applied(jmodel_auto)}')
except Exception as e:
print(f"Benchmark failed:\n{e}")
import traceback

traceback.print_exc()
95 changes: 90 additions & 5 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def jit(
disable_torch_autograd: bool = False, # TODO Revisit this UX for RC1
transforms: list[Transform] | None = None,
record_history: bool = False,
# autotune_type: Any | None = None,
**compile_options, # TODO RC1 Make this explicit -- dict of options
) -> Callable:
"""Just-in-time compile a callable (function or model).
Expand All @@ -292,7 +293,18 @@ def jit(
- ``"same input"`` - don't check, but just assume that a cached function works if it exists.

transforms: List of transforms to be applied. It should be an instance :class:`thunder.core.transforms.Transform`. Default: ``None``

autotune_type: string representing the required autotuner performance target (``"runtime"`` or ``"memory"``).
autotune_nv_enable_options: boolean to enable nvFuser compilation options autotuning. Currently at most one option will be used. Default: ``"False"``
autotune_enable_te: boolean to enable TransformerEngineFP8 executor autotuning. Default: ``"False"``
autotune_optimize_common_blocks: boolean to enable trace's common block optimization during the compilation (for example transformer layers). This optimization can be used if you are working with a model with repeated block structures as transformer based models. You don't need to know
where a block starts or ends as it's handled automatically. Default: ``"False"``
autotune_optimize_common_blocks_min_size: integer to control the minimum block length to trigger the common block optimization. Default: ``-1``
autotune_save_configuration: boolean to produce a configuration file for the current model. This configuration can be loaded afterwards with ``"autotune_restore_configuration"``. Default ``"False"``
autotune_restore_configuration: string containing the cached configuration file name with the relative path to the script invocation.
model_name: string containing the current model name used during the configuration file creation in ``"autotune_save_configuration"``. A default one is used if this is not provided.
"""
from thunder.backend_optimizer.optimizer import OptimizerType

if "executors_list" in compile_options:
warnings.warn("outdated argument executors_list= in call, please use executors=")
Expand All @@ -308,6 +320,41 @@ def jit(
if transforms is None:
transforms = []

required_autotune = compile_options.get("autotune_type", None)
if required_autotune is not None:
if required_autotune not in ["runtime", "memory"]:
raise AssertionError(f"Not supported optimization: {required_autotune}")

compile_options |= {
"autotune_type": OptimizerType.RUNTIME if required_autotune == "runtime" else OptimizerType.MEMORY,
"autotune_executors_placed_by_fw_bw_split": set(),
}

# Default the executors list to all_executors if no options are given
# Otherwise the user restricted choice will be used
from thunder.executors.transformer_engineex import transformer_engine_ex
from thunder.executors.pythonex import ex as python_ex
if not executors:
executors = get_all_executors()
# Remove pythonex
executors = [ex for ex in executors if ex != python_ex]
# Remove transformer_engine if not requested
executors = [
ex
for ex in executors
if ex != transformer_engine_ex
or (ex == transformer_engine_ex and compile_options.get("autotune_enable_te", False))
]
else:
# If TE is in executors list we have to enable the compilation option
if transformer_engine_ex in executors:
compile_options['autotune_enable_te'] = True

from thunder.backend_optimizer.utils import reorder_executors_list
executors = reorder_executors_list(
executors, autotune_enable_te=compile_options.get("autotune_enable_te", False)
)

# Resolve names of executors
executors = resolve_executors(executors)

Expand Down Expand Up @@ -450,6 +497,7 @@ def get_computation_and_inputs(*args, **kwargs):
cs.last_traces = comp_traces
cs.last_interpreted_instructions = None
cs.last_interpreter_log = None
cs.last_executors = cd.executors_list
cs.last_prologue_traces = pro_traces
cs.last_prologue = pro
cs.last_prologue_transformation_start = 0
Expand Down Expand Up @@ -485,6 +533,7 @@ def get_computation_and_inputs(*args, **kwargs):
cs.last_traces = comp_traces
cs.last_interpreted_instructions = None
cs.last_interpreter_log = None
cs.last_executors = cd.executors_list
cs.last_prologue_traces = pro_traces
cs.last_prologue = pro

Expand Down Expand Up @@ -605,6 +654,7 @@ def get_computation_and_inputs(*args, **kwargs):
cs.last_prologue_traces = prologue_traces
cs.last_prologue = pro
cs.last_traces = computation_traces
cs.last_executors = cd.executors_list
backward_traces = []
cs.last_backward_traces = backward_traces
cs.last_interpreter_log = last_interpreter_log
Expand All @@ -631,22 +681,44 @@ def get_computation_and_inputs(*args, **kwargs):
# Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces
# by split_forward_backward

# Reset the cache for the next compilation
cd.autotuner_bsym_with_gradfn_executor_cache = {}

if backward_trc is None:
from thunder.executors.passes import transform_for_execution as transform_for_execution_pass
from thunder.executors.passes import autotune_transform_for_execution
from thunder.executors.passes import _transform_for_operator_executor_execution
from thunder.distributed.utils import maybe_sort_waits
from thunder.backend_optimizer.optimizer import BackendOptimizer, TraceType

tmp_comp_trc = _transform_for_operator_executor_execution(computation_trc, cd.executors_list)
is_transformed, tmp_comp_trc = maybe_sort_waits(tmp_comp_trc)
if is_transformed:
computation_trc = tmp_comp_trc
computation_traces.append(computation_trc)

extraces = transform_for_execution(
computation_trc,
executors_list=cd.executors_list,
use_del_last_used=False,
)
autotune = cd.compile_options.get('autotune_type', None)
if autotune is None:
extraces = transform_for_execution(
computation_trc,
executors_list=cd.executors_list,
use_del_last_used=False,
)
else:
optimizer_ctx = BackendOptimizer(
priority_executors=cd.executors_list,
apply_bucketing_bw_trace=False,
produce_log=False,
optimizer_type=autotune,
compile_data=cd,
)
extrace = autotune_transform_for_execution(
optimizer_context=optimizer_ctx,
trace=computation_trc,
trace_type=TraceType.FW,
is_computational=True
)
extraces = [extrace]
computation_traces.extend(extraces)
computation_trc = computation_traces[-1]

Expand Down Expand Up @@ -834,6 +906,19 @@ def last_prologue_traces(fn) -> TraceCtx:
return cs.last_prologue_traces


def executors_applied(fn) -> Sequence[Executor]:
"""Obtains the list of executors that have been applied to the computational trace.
If the backward trace is not None, the list will include also executors used in the backward trace.

"""
cs = compile_stats(fn)
if cs is None:
raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.")
if cs.last_executors is None:
raise TypeError(f"{fn} doesn't seem to have been called yet.")
return cs.last_executors


def cache_option(fn) -> CACHE_OPTIONS:
"""Returns the cache options set when JITting the function."""
cd = compile_data(fn)
Expand Down
Loading
Loading