diff --git a/examples/autotuner/.gitignore b/examples/autotuner/.gitignore new file mode 100644 index 0000000000..4ec10c2e86 --- /dev/null +++ b/examples/autotuner/.gitignore @@ -0,0 +1,4 @@ +*.log +*.txt +*.pickle +*.nsys-rep diff --git a/examples/autotuner/LLaMAMLP.py b/examples/autotuner/LLaMAMLP.py new file mode 100644 index 0000000000..9ec217347c --- /dev/null +++ b/examples/autotuner/LLaMAMLP.py @@ -0,0 +1,55 @@ +""" +This benchmark script is intended to demonstrate the autotuner on a generic model. +No executor are given leaving full responsibility to Thunder. +""" + +import torch +import thunder +from thunder.benchmarks.utils import torch_timer_total_benchmark, torch_total_benchmark + + +class LLaMAMLP(torch.nn.Module): + def __init__(self, n_embd, intermediate_size) -> None: + super().__init__() + self.fc_1 = torch.nn.Linear(n_embd, intermediate_size, bias=False) + self.fc_2 = torch.nn.Linear(n_embd, intermediate_size, bias=False) + self.proj = torch.nn.Linear(intermediate_size, n_embd, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_fc_1 = self.fc_1(x) + x_fc_2 = self.fc_2(x) + x = torch.nn.functional.silu(x_fc_1) * x_fc_2 + return self.proj(x) + + +with torch.device("cuda"): + mult = 2 + a = 4096 * mult + b = 11008 * mult + x = torch.randn(4, 2048, a, requires_grad=True) + + model = LLaMAMLP(a, b) + + eager = model + torchcompile = torch.compile(model) + jmodel_def = thunder.jit(model) + jmodel_auto = thunder.jit( + model, + autotune_type="runtime", + autotune_enable_te=True, + autotune_nv_enable_options=True, + model_name="LLaMAMLP", + autotune_save_configuration=True, + ) + + print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) + print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item()) + + iters = 100 + callables = [eager, torchcompile, jmodel_def, jmodel_auto] + labels = ["eager", "torchcompile", "Thunder", "Thunder Autotuned"] + inputs = [x, x, x, x] + print("\nResults with torch total benchmark:") + torch_total_benchmark(callables, labels, inputs, iters) + print("\nResults with torch timer benchmark:") + torch_timer_total_benchmark(callables, labels, inputs, "LlamaMLP") diff --git a/examples/autotuner/litGPT.py b/examples/autotuner/litGPT.py new file mode 100644 index 0000000000..714b84a5fc --- /dev/null +++ b/examples/autotuner/litGPT.py @@ -0,0 +1,101 @@ +""" +This script benchmarks litGPT models in a easier way wrt thunder.benchmarks.benchmark_litgpt.py with a fake training loop with no optimizers. +""" + +from litgpt import GPT +from thunder.benchmarks.utils import torch_total_benchmark, torch_timer_total_benchmark +from thunder.tests.litgpt_model import Config +import thunder +import torch +import time +from pprint import pprint + +torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul +torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + +# import os +# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + +class LitGPTModelThunderConfig: + def __init__( + self, + layers: int, + autotune_type: str, + batch_size: int, + seq_len: int = -1, + model_name: str = "Llama-3-8B", + executors=None, + optimize_transformer_blocks=True, + optimize_transformer_min_block_size=60, # for llama3 + ) -> None: + self.layers = layers + self.autotune_type = autotune_type + self.batch_size = batch_size + self.seq_len = seq_len + self.model_name = model_name + self.executors = executors + self.optimize_transformer_blocks = optimize_transformer_blocks + self.optimize_transformer_min_block_size = optimize_transformer_min_block_size + + +to_run = [ + LitGPTModelThunderConfig( + 1, + "runtime", + 2, + executors=[ + "cudnn", + "sdpa", + "fa3", + "nvfuser", + "nvmath", + "torchcompile", + ], + ), +] + +for test in to_run: + try: + cfg = Config.from_name(test.model_name) + cfg.n_layer = test.layers + if test.seq_len != -1: + cfg.block_size = test.seq_len + torch.set_default_dtype(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16) + pprint(cfg) + print("Batch size:", test.batch_size) + with torch.device("cuda"): + model = GPT(cfg) + x = torch.randint(1, model.config.vocab_size, (test.batch_size, cfg.block_size)) + target = torch.ones_like(x) + + eager = model + torch_compile = torch.compile(model) + jmodel_def = thunder.jit(model) + jmodel_auto = thunder.jit( + model, + autotune_type=test.autotune_type, + executors=test.executors, + autotune_optimize_common_blocks=test.optimize_transformer_blocks, + autotune_optimize_common_blocks_min_size=test.optimize_transformer_min_block_size, + ) + print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) + s = time.time_ns() + print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item()) + e = time.time_ns() + print("Compilation time:", {(e - s) / 1000000000}, "s") + + iters = 100 + callables = [eager, torch_compile, jmodel_def, jmodel_auto] + labels = ["eager", "torch.compile", "Thunder", "Thunder Autotuner"] + inputs = [x, x, x, x] + print(f"\nResults torch total benchmark ({iters} iters):") + torch_total_benchmark(callables, labels, inputs, iters, torch.nn.functional.cross_entropy) + print(f"\nResults torch timer benchmark ({iters} iters):") + torch_timer_total_benchmark(callables, labels, inputs, test.model_name, torch.nn.functional.cross_entropy) + + print(f'Executors employed: {thunder.executors_applied(jmodel_auto)}') + except Exception as e: + print(f"Benchmark failed:\n{e}") + import traceback + + traceback.print_exc() diff --git a/thunder/__init__.py b/thunder/__init__.py index d793cdc52d..86962aaba1 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -266,6 +266,7 @@ def jit( disable_torch_autograd: bool = False, # TODO Revisit this UX for RC1 transforms: list[Transform] | None = None, record_history: bool = False, + # autotune_type: Any | None = None, **compile_options, # TODO RC1 Make this explicit -- dict of options ) -> Callable: """Just-in-time compile a callable (function or model). @@ -292,7 +293,18 @@ def jit( - ``"same input"`` - don't check, but just assume that a cached function works if it exists. transforms: List of transforms to be applied. It should be an instance :class:`thunder.core.transforms.Transform`. Default: ``None`` + + autotune_type: string representing the required autotuner performance target (``"runtime"`` or ``"memory"``). + autotune_nv_enable_options: boolean to enable nvFuser compilation options autotuning. Currently at most one option will be used. Default: ``"False"`` + autotune_enable_te: boolean to enable TransformerEngineFP8 executor autotuning. Default: ``"False"`` + autotune_optimize_common_blocks: boolean to enable trace's common block optimization during the compilation (for example transformer layers). This optimization can be used if you are working with a model with repeated block structures as transformer based models. You don't need to know + where a block starts or ends as it's handled automatically. Default: ``"False"`` + autotune_optimize_common_blocks_min_size: integer to control the minimum block length to trigger the common block optimization. Default: ``-1`` + autotune_save_configuration: boolean to produce a configuration file for the current model. This configuration can be loaded afterwards with ``"autotune_restore_configuration"``. Default ``"False"`` + autotune_restore_configuration: string containing the cached configuration file name with the relative path to the script invocation. + model_name: string containing the current model name used during the configuration file creation in ``"autotune_save_configuration"``. A default one is used if this is not provided. """ + from thunder.backend_optimizer.optimizer import OptimizerType if "executors_list" in compile_options: warnings.warn("outdated argument executors_list= in call, please use executors=") @@ -308,6 +320,41 @@ def jit( if transforms is None: transforms = [] + required_autotune = compile_options.get("autotune_type", None) + if required_autotune is not None: + if required_autotune not in ["runtime", "memory"]: + raise AssertionError(f"Not supported optimization: {required_autotune}") + + compile_options |= { + "autotune_type": OptimizerType.RUNTIME if required_autotune == "runtime" else OptimizerType.MEMORY, + "autotune_executors_placed_by_fw_bw_split": set(), + } + + # Default the executors list to all_executors if no options are given + # Otherwise the user restricted choice will be used + from thunder.executors.transformer_engineex import transformer_engine_ex + from thunder.executors.pythonex import ex as python_ex + if not executors: + executors = get_all_executors() + # Remove pythonex + executors = [ex for ex in executors if ex != python_ex] + # Remove transformer_engine if not requested + executors = [ + ex + for ex in executors + if ex != transformer_engine_ex + or (ex == transformer_engine_ex and compile_options.get("autotune_enable_te", False)) + ] + else: + # If TE is in executors list we have to enable the compilation option + if transformer_engine_ex in executors: + compile_options['autotune_enable_te'] = True + + from thunder.backend_optimizer.utils import reorder_executors_list + executors = reorder_executors_list( + executors, autotune_enable_te=compile_options.get("autotune_enable_te", False) + ) + # Resolve names of executors executors = resolve_executors(executors) @@ -450,6 +497,7 @@ def get_computation_and_inputs(*args, **kwargs): cs.last_traces = comp_traces cs.last_interpreted_instructions = None cs.last_interpreter_log = None + cs.last_executors = cd.executors_list cs.last_prologue_traces = pro_traces cs.last_prologue = pro cs.last_prologue_transformation_start = 0 @@ -485,6 +533,7 @@ def get_computation_and_inputs(*args, **kwargs): cs.last_traces = comp_traces cs.last_interpreted_instructions = None cs.last_interpreter_log = None + cs.last_executors = cd.executors_list cs.last_prologue_traces = pro_traces cs.last_prologue = pro @@ -605,6 +654,7 @@ def get_computation_and_inputs(*args, **kwargs): cs.last_prologue_traces = prologue_traces cs.last_prologue = pro cs.last_traces = computation_traces + cs.last_executors = cd.executors_list backward_traces = [] cs.last_backward_traces = backward_traces cs.last_interpreter_log = last_interpreter_log @@ -631,10 +681,15 @@ def get_computation_and_inputs(*args, **kwargs): # Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces # by split_forward_backward + # Reset the cache for the next compilation + cd.autotuner_bsym_with_gradfn_executor_cache = {} + if backward_trc is None: from thunder.executors.passes import transform_for_execution as transform_for_execution_pass + from thunder.executors.passes import autotune_transform_for_execution from thunder.executors.passes import _transform_for_operator_executor_execution from thunder.distributed.utils import maybe_sort_waits + from thunder.backend_optimizer.optimizer import BackendOptimizer, TraceType tmp_comp_trc = _transform_for_operator_executor_execution(computation_trc, cd.executors_list) is_transformed, tmp_comp_trc = maybe_sort_waits(tmp_comp_trc) @@ -642,11 +697,28 @@ def get_computation_and_inputs(*args, **kwargs): computation_trc = tmp_comp_trc computation_traces.append(computation_trc) - extraces = transform_for_execution( - computation_trc, - executors_list=cd.executors_list, - use_del_last_used=False, - ) + autotune = cd.compile_options.get('autotune_type', None) + if autotune is None: + extraces = transform_for_execution( + computation_trc, + executors_list=cd.executors_list, + use_del_last_used=False, + ) + else: + optimizer_ctx = BackendOptimizer( + priority_executors=cd.executors_list, + apply_bucketing_bw_trace=False, + produce_log=False, + optimizer_type=autotune, + compile_data=cd, + ) + extrace = autotune_transform_for_execution( + optimizer_context=optimizer_ctx, + trace=computation_trc, + trace_type=TraceType.FW, + is_computational=True + ) + extraces = [extrace] computation_traces.extend(extraces) computation_trc = computation_traces[-1] @@ -834,6 +906,19 @@ def last_prologue_traces(fn) -> TraceCtx: return cs.last_prologue_traces +def executors_applied(fn) -> Sequence[Executor]: + """Obtains the list of executors that have been applied to the computational trace. + If the backward trace is not None, the list will include also executors used in the backward trace. + + """ + cs = compile_stats(fn) + if cs is None: + raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.") + if cs.last_executors is None: + raise TypeError(f"{fn} doesn't seem to have been called yet.") + return cs.last_executors + + def cache_option(fn) -> CACHE_OPTIONS: """Returns the cache options set when JITting the function.""" cd = compile_data(fn) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py new file mode 100644 index 0000000000..37a5076e20 --- /dev/null +++ b/thunder/backend_optimizer/optimizer.py @@ -0,0 +1,1384 @@ +from collections.abc import Callable, Sequence +from enum import Enum +from thunder.backend_optimizer.utils import ( + dump_traces_placement, + map_executors_from_reduced_trace_to_complete_trace, + operation_in_trace, + wrap_fn_with_exeuctor_compile_option, + apply_results_from_file, +) +from thunder.core.compile_data import get_compile_data +from thunder.core.prims import PrimIDs +from thunder.core.proxies import CollectionProxy, FloatProxy, IntegerProxy, TensorProxy +from thunder.core.symbol import BoundSymbol +from thunder.core.trace import from_trace, TraceCtx +from thunder.core.transforms import construct_trace +from thunder.extend import Executor, FusionExecutor, get_always_executors +from typing import Hashable +from thunder.backend_optimizer.utils import benchmark_trace, BenchmarkResult, OptimizerType, TraceType +import logging + +logging.basicConfig(level=logging.INFO, format="[{name}]: {message}", style="{") +logger = logging.getLogger("Thunder Autotuner") + +# Control if single trace regions or partial traces are benchmarked during OperatorExecutor tuning +_benchmark_single_trace_region = False + + +class OptimizationAlgorithm(Enum): + """ + Represents the optimization technique used by the autotuner. + """ + + BEST_FUSER = 0 + + +class FusionCompileOptionsHelper: + """ + Represents compile options for a fusion executor. + + Attributes: + fusion_tag (str): A label representing the fusion ops regarding a compile option (e.g. nv_linear). + symbol_tag (str): The symbol name + id (PrimIDs): The symbol id. + impl (Callable): A callable implementation. + checker (Callable): A callable checker. + """ + + def __init__(self, fusion_tag: str, symbol_tag: str, id: PrimIDs, impl: Callable, checker: Callable) -> None: + self.fusion_tag: str = fusion_tag + self.symbol_tag: str = symbol_tag + self.id: PrimIDs = id + self.impl: Callable = impl + self.checker: Callable = checker + + +class FusionExecutorsPlacementCtx: + """ + Represents a executor placement context. + + Attributes: + placement (list): A list of executors. + compile_options (FusionExecutorsPlacementCtx | None): Any compile options being used for the fusion executor contained in the placement. + """ + + def __init__(self, *, placement: list, compile_options: FusionCompileOptionsHelper | None = None) -> None: + self.placement: list = placement + self.compile_options: FusionCompileOptionsHelper | None = compile_options + + +class TraceCandidate: + """ + Represents an optimal trace candidate. + + Attributes: + trace (TraceCtx): The candidate trace. + ctx (FusionExecutorsPlacementCtx): Trace's placement context. + label (str): A generic label to identify this candidate. + """ + + def __init__( + self, + *, + trace: TraceCtx, + ctx: FusionExecutorsPlacementCtx, + label: str, + ) -> None: + self.trace: TraceCtx = trace + self.ctx: FusionExecutorsPlacementCtx = ctx + self.label: str = label + + +class TraceCandidates: + """ + Represents an optimal pair of trace candidates (compute time and memory consumption). + + Attributes: + best_time (TraceCtx): The trace with the optimal runtime. + best_mem (TraceCtx): The trace with the optimal peak memory consumption. + placement_ctx_time (FusionExecutorsPlacementCtx): Trace placement context with exeuctors and any applied fusion compile options. + placement_ctx_mem (FusionExecutorsPlacementCtx): Trace placement context with exeuctors and any applied fusion compile options. + """ + + def __init__( + self, + best_time: TraceCtx | None = None, + best_mem: TraceCtx | None = None, + placement_ctx_time: FusionExecutorsPlacementCtx | None = None, + placement_ctx_mem: FusionExecutorsPlacementCtx | None = None, + ) -> None: + self.best_time: TraceCtx | None = best_time + self.best_mem: TraceCtx | None = best_mem + self.placement_ctx_time: FusionExecutorsPlacementCtx | None = placement_ctx_time + self.placement_ctx_mem: FusionExecutorsPlacementCtx | None = placement_ctx_mem + + def __repr__(self) -> str: + """ + Give a representation for the current object. + + Returns: + str: A string as the representation of the current object + """ + return f"\nBest runtime candidate:\n{self.best_time}\nBest memory candidate:\n{self.best_mem}" + + def is_set(self) -> bool: + """ + Check that the optimal trace pair has been set. + + Returns: + bool: A flag indicating if the optimal trace is not None. + """ + return False if self.best_time is None or self.best_mem is None else True + + def attach_best_time_candidate(self, trace: TraceCtx, ctx: FusionExecutorsPlacementCtx | None = None): + """ + Attach a new best time trace result. + + Args: + trace (TraceCtx): The trace to assign. + ctx (FusionExecutorsPlacementCtx | None): The trace placement context. + """ + self.best_time = trace + self.placement_ctx_time = ctx + + def attach_best_mem_candidate(self, trace: TraceCtx, ctx: FusionExecutorsPlacementCtx | None = None): + """ + Attach a new best memory trace result. + + Args: + trace (TraceCtx): The trace to assign. + ctx (FusionExecutorsPlacementCtx | None): The trace placement context. + """ + self.best_mem = trace + self.placement_ctx_mem = ctx + + def iterable(self) -> tuple[tuple, tuple]: + """ + Returns an iterable object over the traces paired with their contexts. + + Returns: + tuple: A tuple with paired values of performance metric and its context. + """ + return (self.best_time, self.placement_ctx_time), (self.best_mem, self.placement_ctx_mem) + + def trace_ctx_iterable(self) -> tuple[TraceCtx | None, TraceCtx | None]: + """ + Returns an iterable object over the traces. + + Returns: + tuple: A tuple of traces with time and memory consumption targets. + """ + return self.best_time, self.best_mem + + def placement_ctx_iterable(self) -> tuple[FusionExecutorsPlacementCtx | None, FusionExecutorsPlacementCtx | None]: + """ + Returns an iterable object over the placement contexts. + + Returns: + tuple: A tuple of contexes referring to traces targetting compute time and peak memory consumption. + """ + return self.placement_ctx_time, self.placement_ctx_mem + + +class OutputCandidate: + """ + Represents a final output candidate: forward and backward trace pair. + + Attributes: + fw (TraceCtx): The forward trace. + bw (TraceCtx): The backward trace. + executors_fw (list): The forward trace regions' executors + executors_bw (list): The backward trace regions' executors + compile_opt (FusionExecutorsPlacementCtx | None): Any compile options being used for a fusion executor in the forward trace. + tot_cost (float): The total cost to execute the pair (ms for a time strategy and GB for a memory strategy). + apply_remat (bool): If rematerialization has been applied. + """ + + def __init__( + self, + *, + fw: TraceCtx, + bw: TraceCtx, + executors_fw: list[Executor], + executors_bw: list[Executor], + compile_opt: FusionCompileOptionsHelper | None = None, + cost: float = 0.0, + apply_remat: bool = False, + ) -> None: + self.fw: TraceCtx = fw + self.bw: TraceCtx = bw + self.executors_fw: list[Executor] = executors_fw + self.executors_bw: list[Executor] = executors_bw + self.compile_opt: FusionCompileOptionsHelper | None = compile_opt + self.tot_cost: float = cost + self.apply_remat: bool = apply_remat + + def __repr__(self) -> str: + """ + Give a representation of the current object. + + Returns: + str: A string representing the current object. + """ + return f"Final output candidate: forward trace:\n{self.fw.__repr__()}\nFinal output candidate: backward trace:\n{self.bw.__repr__()}" + + +class FusionStratHelper: + """ + Represents a helper structure for the fusion strategy. + + Attributes: + supported_executors (set): A list of supported fusion executors. + optimized_traces_mem (list): a list of dictionaries containing informations regarding the optimized traces for peak memory consumption. + optimized_traces_mem_benchmark_only (list): a list of dictionaries containing informations regarding the optimized traces for peak memory consumption (used only for internal benchmarking). + optimized_traces_time (list): a list of dictionaries containing informations regarding the optimized traces for total compute time. + optimized_traces_time_benchmark_only (list): a list of dictionaries containing informations regarding the optimized traces for total compute time (used only for internal benchmarking). + """ + + def __init__(self) -> None: + self.supported_executors: set = set(["nvfuser", "torchcompile"]) + self.optimized_traces_mem: list[dict[str | Hashable, tuple[TraceCtx, FusionExecutorsPlacementCtx | None]]] = [] + self.optimized_traces_mem_benchmark_only: list[dict[str | Hashable, TraceCtx]] = [] + self.optimized_traces_time: list[dict[str | Hashable, tuple[TraceCtx, FusionExecutorsPlacementCtx | None]]] = [] + self.optimized_traces_time_benchmark_only: list[dict[str | Hashable, TraceCtx]] = [] + + +class ExecutorPlacementOptions: + """ + Represents an aggregate placement options for executors combining those that targets peak memory consumption and those for total compute time. + + Attributes: + placement_options_mem (list): A list of placement contexts. + placement_options_time (list): A list of placement contexts. + """ + + def __init__(self) -> None: + self.placement_options_mem: list[FusionExecutorsPlacementCtx] = [] + self.placement_options_time: list[FusionExecutorsPlacementCtx] = [] + + +class PlacerBase: + """ + Represents a base (interface) class for a placement class. + + Attributes: + always_executors (tuple): A list of always present executors. + empty_executor_hashable_placeholder (str): A label representing en empty executor. + executors (Sequence): A list of executors to use. + fusion_executors (Sequence): A list of fusion executors to use. + fusion_executors_saved_for_later (Sequence): A helper list containing maybe repeated fusion executors. + debug_msg (str): A dynamic filled log message. + log_file_name (str): The output log file name if generated. + produce_log (bool): A tuning parameter to control log file generation. + optimizer_type (OptimizerType): The optimization target. + active_fw_trace_ctx (tuple): An active forward trace set to optimize backward. + cached_fw_traces (list): Cached optimized forward traces. + best_comp_trace (TraceCtx): The optimized computational trace. + cached_computational_trace (TraceCtx): Original computational trace + cached_computational_backward_trace (TraceCtx): Original computational backward trace + bw_trace_candidates (TraceCandidate): An instance of trace candidates. + best_pair_runtime (OutputCandidate): A final trace pair targetting the compute time. + best_pair_memory (OutputCandidate): A final trace pair targetting the peak memory consumption. + apply_bucketing_bw_trace (bool): A distributed flag. + benchmark_iters (int): Benchmark iteration steps. + compile_data (Any): Thunder compilation data. + """ + + def __init__( + self, + *, + priority_executors: Sequence[Executor], + produce_log: bool = False, + apply_bucketing_bw_trace: bool, + log_file_name: str, + optimizer_type: OptimizerType = OptimizerType.RUNTIME, + compile_data, + ) -> None: + self.always_executors: tuple[Executor, ...] = get_always_executors() + self.empty_executor_hashable_placeholder: str = "empty" + self.executors: Sequence[Executor] = priority_executors + self.fusion_executors: Sequence[FusionExecutor] = [ + ex for ex in self.executors if isinstance(ex, FusionExecutor) + ] + # Helper needed for later + self.fusion_executors_saved_for_later: Sequence[FusionExecutor] = [] + + self.debug_msg: str = "" + self.log_file_name: str = log_file_name + self.produce_log: bool = produce_log + + self.optimizer_type: OptimizerType = optimizer_type + + self.active_fw_trace_ctx: tuple[TraceCtx | None, FusionExecutorsPlacementCtx | None] = None, None + self.cached_fw_traces: list[TraceCandidate] = [] + self.best_comp_trace: TraceCtx = TraceCtx() + self.cached_computational_trace: TraceCtx = TraceCtx() + self.cached_computational_backward_trace: TraceCtx = TraceCtx() + self.bw_trace_candidates: TraceCandidates = TraceCandidates() + self.out_traces_candidates: list[OutputCandidate] = [] + self.best_pair_runtime: OutputCandidate + self.best_pair_memory: OutputCandidate + + self.apply_bucketing_bw_trace: bool = apply_bucketing_bw_trace + + self.benchmark_iters: int = 5 + + self.compile_data = compile_data + + def optimize(self): + """ + Optimize the executor placement for the current trace. + """ + pass + + def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType, apply_dce=True): + """ + Attach a new trace for executors optimization. + + Args: + trace: The trace to attach. + trace_type: Forward or backward trace refrence. + """ + pass + + def get_optimal_fw_traces(self, is_computational=False) -> Sequence[TraceCtx] | TraceCtx: + """ + Retrive the optimal forward traces that the object has tuned. + + Args: + is_computational: The requested forward trace is a computational trace (autograd is disabled). + """ + return [] + + def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: + """ + Retrive the optimal forward and backward trace pair. + """ + return (TraceCtx(), TraceCtx()) + + +class FusionPlacer_BeamSearch(PlacerBase): + """ + Represents a placer targetting the fusion regions. + + Attributes: + fusion_strat_helper: A helper structures to save intermediate values. + executor_placement_options: A helper structures to save different intemediate executor placement. + is_reduced: A flag indicating if the current trace under optimization is a reduced version of a bigger trace (by common blocks reduction). + cached_original_trace: A reference to the original trace if the optmization is performed on a reduced version. + """ + + def __init__( + self, + *, + priority_executors: Sequence[Executor], + produce_log: bool = False, + apply_bucketing_bw_trace: bool, + log_file_name: str, + optimizer_type: OptimizerType = OptimizerType.RUNTIME, + compile_data, + ) -> None: + super().__init__( + priority_executors=priority_executors, + produce_log=produce_log, + apply_bucketing_bw_trace=apply_bucketing_bw_trace, + log_file_name=log_file_name, + optimizer_type=optimizer_type, + compile_data=compile_data, + ) + + # Strat fusion + self.fusion_strat_helper: FusionStratHelper = FusionStratHelper() + self.executor_placement_options: ExecutorPlacementOptions = ExecutorPlacementOptions() + + # nvFuser compile options + if compile_data.compile_options.get("autotune_enable_nvfuser_all", False): + from thunder.executors.nvfuserex_impl import linear, _linear_check + from thunder.executors.nvfuserex_impl import matmul, _matmul_check + + self.known_fusion_ex_compile_options: dict[str | Hashable, list[FusionCompileOptionsHelper]] = { + "nvfuser": [ + FusionCompileOptionsHelper("nv_enable_linear", "linear", PrimIDs.LINEAR, linear, _linear_check), + FusionCompileOptionsHelper("nv_enable_matmul", "matmul", PrimIDs.MATMUL, matmul, _matmul_check), + ] + } + else: + self.known_fusion_ex_compile_options: dict[str | Hashable, list[FusionCompileOptionsHelper]] = { + "nvfuser": [] + } + + # Transformer based models optimization + # For models based on layers of transformer blocks we can optimize the tuning by researching the best placement + # on the model with a single layer and then mirror the configuration to the other layers. + self.is_reduced: bool = False + self.cached_original_trace: TraceCtx | None = None + + """ + ################################################## Internal methods ################################################## + """ + + def _best_runtime_and_memory_candidates(self, candidates: Sequence[OutputCandidate]): + """ + Retrive the best compute time and peak memory consumption trace pairs. + + Args: + candidates: A sequence of possible candidates. + """ + from thunder.core.rematerialization import rematerialize_forward_and_backward + from thunder.backend_optimizer.utils import benchmark_trace + + min_value_time: float = float("inf") + min_value_mem: float = float("inf") + best_pair_runtime: OutputCandidate + best_pair_memory: OutputCandidate + pair: OutputCandidate + for pair in candidates: + if pair.compile_opt: + remat_fw, remat_bw = wrap_fn_with_exeuctor_compile_option( + pair.compile_opt, rematerialize_forward_and_backward, pair.fw, pair.bw + ) + else: + remat_fw, remat_bw = rematerialize_forward_and_backward(pair.fw, pair.bw) + # Create pair final options by applying final optimizations: cudagraphs and rematerialization + pair_options: list[ + tuple[TraceCtx, TraceCtx, FusionCompileOptionsHelper | None, list[Executor], list[Executor], bool] + ] = [ + (pair.fw, pair.bw, pair.compile_opt, pair.executors_fw, pair.executors_bw, False), + (remat_fw, remat_bw, pair.compile_opt, pair.executors_fw, pair.executors_bw, True), + ] + # Select the best options + for pair_option in pair_options: + fw, bw, compile_opt, executors_fw, executors_bw, remat_applied = pair_option + + pair_cost_time = 0 + pair_cost_mem = 0 + t, m, _ = benchmark_trace(fw, iters=self.benchmark_iters) + logger.debug(f"Pair fw time: {t} ms, mem: {m/(2**30)} GB") + pair_cost_time = pair_cost_time + t + pair_cost_mem = pair_cost_mem + m + t, m, _ = benchmark_trace(bw, iters=self.benchmark_iters, fw_trace=fw) + logger.debug(f"Pair bw time: {t} ms, mem: {m/(2**30)} GB") + pair_cost_time = pair_cost_time + t + pair_cost_mem = pair_cost_mem + m + + if pair_cost_time < min_value_time: + best_pair_runtime = OutputCandidate( + fw=fw, + bw=bw, + compile_opt=compile_opt, + executors_fw=executors_fw, + executors_bw=executors_bw, + cost=pair_cost_time, + apply_remat=remat_applied, + ) + logger.debug(f"New best runtime pair (no remat):\n{best_pair_runtime}") + min_value_time = pair_cost_time + + if pair_cost_mem < min_value_mem: + best_pair_memory = OutputCandidate( + fw=fw, + bw=bw, + compile_opt=compile_opt, + executors_fw=executors_fw, + executors_bw=executors_bw, + cost=pair_cost_mem, + apply_remat=remat_applied, + ) + logger.debug(f"New best memory pair (no remat):\n{best_pair_memory}") + min_value_mem = pair_cost_mem + + return best_pair_runtime, best_pair_memory + + def _filter_candidates(self): + """ + Reduce the solutions count by comparing different options across different fusion executors. + + For forward traces all the options are cached. + """ + self.debug_msg += "Traces benchmarks:\n\n" + + # We cache every optimized fw traces as they might impact differently on the bw trace + # Number of fw traces to cached are: #fusion_executors * 2 + def fw_benchmark(): + # The optimizer builds the results in order following the self.fusion_executors list order + best_time = BenchmarkResult() + best_mem = BenchmarkResult() + pair_time: dict + pair_mem: dict + for pair_time, pair_mem in zip( + self.fusion_strat_helper.optimized_traces_time, self.fusion_strat_helper.optimized_traces_mem + ): + placement_ctx_time: FusionExecutorsPlacementCtx + placement_ctx_mem: FusionExecutorsPlacementCtx + trc_time: TraceCtx + trc_mem: TraceCtx + trc_time, placement_ctx_time = list(pair_time.values())[0] + trc_mem, placement_ctx_mem = list(pair_mem.values())[0] + label = list(pair_time.keys())[0] + c, m, _ = benchmark_trace(trc_time, self.benchmark_iters) + if c < best_time.runtime: + best_time = BenchmarkResult(time=c, trace=trc_time) + self.debug_msg += ( + f"Trace name = [{label}] - Target: TIME - Time = {c} ms - Mem = {m / (2**30)} GB\n{trc_time}\n\n" + ) + c, m, _ = benchmark_trace(trc_mem, self.benchmark_iters) + if m < best_mem.memory: + best_mem = BenchmarkResult(memory=m, trace=trc_mem) + self.debug_msg += ( + f"Trace name = [{label}] - Target: MEM - Mem = {m / (2**30)} GB - Time = {c} ms\n{trc_mem}\n\n" + ) + # For forward trace we cache the best placement for both runtime and memory for the current Fusion executor (represented by label) + for t, ctx in zip([trc_time, trc_mem], [placement_ctx_time, placement_ctx_mem]): + logger.info( + f"Caching fw candidate [compile option: {ctx.compile_options.fusion_tag if ctx.compile_options else 'None'}]" + ) + self.cached_fw_traces.append( + TraceCandidate( + trace=t, + ctx=ctx, + label=(label + "_enabled_" + ctx.compile_options.fusion_tag) + if ctx.compile_options is not None + else label, + ) + ) + # Assign best computational trace + self.best_comp_trace = best_time.trace if self.optimizer_type == OptimizerType.RUNTIME else best_mem.trace + + # Cache the original fw trace + self.cached_computational_trace = self.trace + + def bw_benchmark(): + time_result = BenchmarkResult() + memory_result = BenchmarkResult() + + # Find best trace for runtime + for i, pair in enumerate(self.fusion_strat_helper.optimized_traces_time_benchmark_only): + # Unpack the dict + label = list(pair.keys())[0] + trace = list(pair.values())[0] + trace_time, trace_mem, _ = benchmark_trace( + trace, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0] + ) + self.debug_msg += f"Trace name = [{label}] - Target: TIME - Time = {trace_time} ms - Mem = {trace_mem / (2**30)} GB\n{trace}\n\n" + if trace_time < time_result.runtime: + time_result = BenchmarkResult(time=trace_time, memory=trace_mem, trace=trace, label=label, index=i) + + # Find best trace for memory + for i, pair in enumerate(self.fusion_strat_helper.optimized_traces_mem_benchmark_only): + # Unpack the dict + label = list(pair.keys())[0] + trace = list(pair.values())[0] + + trace_time, trace_mem, _ = benchmark_trace( + trace, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0] + ) + self.debug_msg += f"Trace name = [{label}] - Target: MEM - Mem = {trace_mem / (2**30)} GB - Time = {trace_time} ms\n{trace}\n\n" + if trace_mem < memory_result.memory: + memory_result = BenchmarkResult( + time=trace_time, memory=trace_mem, trace=trace, label=label, index=i + ) + + # Here we have to recover the traces without the pass through remat in order to be compliant + # with thunder flow as we might have request for no remat. + trc, placement_ctx = list(self.fusion_strat_helper.optimized_traces_time[time_result.index].values())[0] + self.bw_trace_candidates.attach_best_time_candidate(trc, placement_ctx) + trc, placement_ctx = list(self.fusion_strat_helper.optimized_traces_mem[memory_result.index].values())[0] + self.bw_trace_candidates.attach_best_mem_candidate(trc, placement_ctx) + + # Now, finally build the pair fw and bw traces + # The current fw trace is set by the caller and we take it as is. All current bw traces optimizations are made with the fw trace set by the caller. + + assert self.active_fw_trace_ctx[0] is not None and self.active_fw_trace_ctx[1] is not None + + for bw in self.bw_trace_candidates.iterable(): + self.out_traces_candidates.append( + OutputCandidate( + fw=self.active_fw_trace_ctx[0], + bw=bw[0], + executors_fw=self.active_fw_trace_ctx[1].placement, + executors_bw=bw[1].placement, + compile_opt=self.active_fw_trace_ctx[1].compile_options, + ) + ) + + # Cache original backward trace + self.cached_computational_backward_trace = self.trace + + match self.trace_type: + case TraceType.FW: + fw_benchmark() + case TraceType.BW: + bw_benchmark() + + if self.produce_log: + import time + + timestamp: str = str(time.time()) + with open(f"{timestamp}-{self.log_file_name}", "w") as file: + file.write(self.debug_msg) + file.close() + + self.debug_msg = "" + + def _search_candidates(self, increment_factor: int = 1): + """ + For the current trace generate all the placement candidates. + + For each fusion executor the time-memory pair candidates will be generated and cached. + If any compile options for an executor is available, it will be take under consideration. + + Args: + increment_factor: An integer controlling the increment step during the fusion exclusion to speed up the compilation. + """ + from thunder.executors.data_dependent_partition import Node, fuse_bound_symbols + from thunder.backend_optimizer.utils import ( + get_not_used_intermediate_outsputs, + sequence_hash, + can_executor_execute, + get_first_available_operator_executor, + assign_executors, + ) + + def get_placed_trace(mapping: dict[str, Executor], bound_symbols_in: Sequence[BoundSymbol]): + """ + Generates a trace with the requested executors. + + Args: + mapping: a dictionary pointing to the assigned executor for a trace region. + bound_symbols_in: Input trace regions. + """ + trc = from_trace(self.trace) + trc.bound_symbols = list(bound_symbols_in) + + # For this partial trace we have to return all not used tensors otherwise the dce remove them + tensors = get_not_used_intermediate_outsputs(trc) + + forced_return_bsym = self.trace.bound_symbols[-1].from_bsym(args=tensors) + + executor_configuration = [] + empty_executor = Executor(name=self.empty_executor_hashable_placeholder) + keys = [] + for bsym in trc.bound_symbols: + if bsym.sym.name == "return": + raise AssertionError("Return statement should not be here") + elif isinstance(bsym.output, Sequence): + seq_hash = sequence_hash(bsym.output) + executor_configuration.append(mapping.get(seq_hash, empty_executor)) + keys.append(seq_hash) + elif ( + isinstance(bsym.output, CollectionProxy) + or isinstance(bsym.output, TensorProxy) + or isinstance(bsym.output, IntegerProxy) + or isinstance(bsym.output, FloatProxy) + ): + if bsym.output.name not in mapping: + raise AssertionError(f"Expected key {bsym.output.name} in mapping {mapping}") + executor_configuration.append(mapping[bsym.output.name]) + keys.append(bsym.output.name) + else: + raise AssertionError(f"Type not handled: {type(bsym.output)}") + + if trc.bound_symbols[-1].sym.name != "return": + trc.bound_symbols.append(forced_return_bsym) + executor_configuration.append(Executor(name=self.empty_executor_hashable_placeholder)) + keys.append("return") + + if len(trc.bound_symbols) != len(executor_configuration) or len(keys) != len(executor_configuration): + raise AssertionError( + f"len trc.bound_symbols ({len(trc.bound_symbols)}) != len executor_configuration ({len(executor_configuration)}) != len keys ({len(keys)})" + ) + + placed_trace = assign_executors( + in_trace=trc, + executors_list=executor_configuration, + always_executors=self.always_executors, + empty_str=self.empty_executor_hashable_placeholder, + ) + return placed_trace, keys, executor_configuration + + def _search(ex: FusionExecutor, executor_compile_option: FusionCompileOptionsHelper | None = None): + """ + For the given executor search and cached the best placements. + + Args: + ex: A fusion executor. + executor_placement_options: Any compile option this executor might activate. + """ + + def _should_fuse_nvfuser(a: Node, b: Node): + """ + Fusable fn definition for nvFuser. + + Args: + a: First node. + b: Second node. + """ + + def _can_fuse_node(n: Node): + # if already merged, then node can be fused + if len(n.group_bsyms) > 1: + return True + bsym: BoundSymbol = n.group_bsyms[0] + can_fuse: bool = ex.can_fuse(bsym) + cuda_in_or_out: bool = ex.has_cuda_input_or_output(bsym) + return can_fuse and cuda_in_or_out + + return _can_fuse_node(a) and _can_fuse_node(b) + + def _should_fuse_torchcompile(a: Node, b: Node): + """ + Fusable fn definition for torch.compile. + + Args: + a: First node. + b: Second node. + """ + + def _can_fuse_node(n: Node): + if len(n.group_bsyms) > 1: + return True + bsym: BoundSymbol = n.group_bsyms[0] + return ex.can_fuse(bsym) + + return _can_fuse_node(a) and _can_fuse_node(b) + + def match_bsym_executor(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): + """ + Match a bound symbol to its executor. + + Args: + bsym_in: The bound symbol to match. + dicts: The matrching destination. + ex_in: The executor to assign. + """ + if isinstance(bsym_in.output, Sequence): + for d in dicts: + d[sequence_hash(bsym_in.output)] = ex_in + elif ( + isinstance(bsym_in.output, CollectionProxy) + or isinstance(bsym_in.output, TensorProxy) + or isinstance(bsym_in.output, IntegerProxy) + or isinstance(bsym_in.output, FloatProxy) + ): + for d in dicts: + d[bsym_in.output.name] = ex_in + else: + raise AssertionError(f"Type not handled: {type(bsym_in.output)}") + + merge_fn: Callable + match ex.name: + case "nvfuser": + merge_fn = _should_fuse_nvfuser + case "torchcompile": + merge_fn = _should_fuse_torchcompile + bound_symbol_groups = fuse_bound_symbols(self.trace, merge_fn) + logger.debug(f"Number of Fusion groups = {len(bound_symbol_groups)}") + + # Print fusion groups if requested + # for id, group in enumerate(bound_symbol_groups): + # log(f"Group id: {id}", level=LogLevel.DEBUG) + # for sub in group: + # log(f"{sub.sym.name} -> out: {sub.output}", level=LogLevel.DEBUG) + # if log_level == LogLevel.DEBUG and len(group) > 0: + # print("\n") + + dict_time_strat: dict[str, Executor] = {} + dict_mem_strat: dict[str, Executor] = {} + increasing_symbols = [] + # Tuning starting point: iterate over all the groups. + for group_id, group in enumerate(bound_symbol_groups): + logger.debug(f"Fusion group id: {group_id}") + logger.debug( + f"Fusion group start = [{group[0].output.name if hasattr(group[0].output, 'name') else 'unknown'} = {group[0].sym.name}]" + ) + logger.debug( + f"Fusion group end = [{group[-1].output.name if hasattr(group[-1].output, 'name') else 'unknown'} = {group[-1].sym.name}]" + ) + + if group[0].sym.name != "return": + increasing_symbols += group + + # We assign to a Fusion executor only region with at least 2 elements. Otherwise let the best OperatorExecutor pick the symbol up + if len(group) < 2: + current_bsym = group[0] + logger.debug( + f"--> Single group: [{current_bsym.output.name if hasattr(current_bsym.output, 'name') else 'unknown'} = {current_bsym.sym.name}]" + ) + # Filter out all possible candidates for the current symbol + candidate_executors = [ + ex + for ex in self.executors + if can_executor_execute(ex, current_bsym) and not isinstance(ex, FusionExecutor) + ] + + if current_bsym.sym.id == PrimIDs.RETURN: + dict_time_strat["return"] = Executor(name=self.empty_executor_hashable_placeholder) + dict_mem_strat["return"] = Executor(name=self.empty_executor_hashable_placeholder) + # Add the modified return statement at the end of the for loop + break + + # Not executors available + if not candidate_executors: + match_bsym_executor( + current_bsym, + [dict_time_strat, dict_mem_strat], + Executor(name=self.empty_executor_hashable_placeholder), + ) + continue + else: + logger.debug(f"Available executors for single region:\n{candidate_executors}") + + # Helpers + candidate_best_time = BenchmarkResult() + candidate_best_mem = BenchmarkResult() + + # No choices + if len(candidate_executors) == 1: + candidate_best_time = BenchmarkResult(index=0) + candidate_best_mem = BenchmarkResult(index=0) + else: + if _benchmark_single_trace_region: + # Define the standalone trace in order to benchmark this symbol + subtrace = construct_trace()(current_bsym.sym, *current_bsym.args, **current_bsym.kwargs) + + # Search for best candidate + for i, candidate in enumerate(candidate_executors): + if _benchmark_single_trace_region: + from thunder.common import transform_for_execution + + subtrace_placed = transform_for_execution(subtrace, executors_list=[candidate])[-1] + logger.debug(f"Subtrace to benchmark single symbol:\n{subtrace_placed}") + t, m, _ = benchmark_trace( + subtrace_placed, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0] + ) + else: + # Match the current candidate into helper dicts to benchmark partial trace + match_bsym_executor(current_bsym, [dict_time_strat, dict_mem_strat], candidate) + # Retrieve partial trace and benchmark, apply remat if possible + trc, _, _ = get_placed_trace(dict_time_strat, increasing_symbols) + t, m, _ = benchmark_trace( + trc, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0] + ) + logger.info( + f"Operator excutor [{candidate.name}] candidate perf (is single trace region: {_benchmark_single_trace_region}): {t} ms {m/(2**30)} GB" + ) + # Update results + if t < candidate_best_time.runtime: + candidate_best_time = BenchmarkResult(time=t, index=i) + if m < candidate_best_mem.memory: + candidate_best_mem = BenchmarkResult(memory=m, index=i) + + if candidate_best_time.index == -1 or candidate_best_mem.index == -1: + raise AssertionError( + f"Failed to get optimal single trace region candidate. Available candidates for {current_bsym.sym.name}:\n{candidate_executors}" + ) + + logger.debug( + f"Best time OperatorExecutor for single {current_bsym.sym.name}: {candidate_executors[candidate_best_time.index].name}" + ) + logger.debug( + f"Best mem OperatorExecutor for single {current_bsym.sym.name}: {candidate_executors[candidate_best_mem.index].name}" + ) + + match_bsym_executor(current_bsym, [dict_time_strat], candidate_executors[candidate_best_time.index]) + match_bsym_executor(current_bsym, [dict_mem_strat], candidate_executors[candidate_best_mem.index]) + # Go to next bsym group + continue + + # Inside groups we should have alwasy tensors as out + best_res_time = BenchmarkResult() + best_res_mem = BenchmarkResult() + + best_placement_time = None + best_keys_time = None + best_placement_mem = None + best_keys_mem = None + + def measure_and_update_result(): + nonlocal best_res_time + nonlocal best_placement_time + nonlocal best_keys_time + nonlocal best_res_mem + nonlocal best_placement_mem + nonlocal best_keys_mem + trc, keys, placements = get_placed_trace(dict_time_strat, increasing_symbols) + cost, mem, _ = benchmark_trace(trc, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0]) + logger.debug(f"Placed trace (cost = {cost} ms, mem = {mem/(2**30)} GB)\n{trc}") + if cost < best_res_time.runtime or (cost == best_res_time.runtime and mem < best_res_time.memory): + best_res_time = BenchmarkResult(time=cost, memory=mem, trace=trc) + best_placement_time = placements + best_keys_time = keys + if mem < best_res_mem.memory or (mem == best_res_mem.memory and cost < best_res_mem.runtime): + best_res_mem = BenchmarkResult(time=cost, memory=mem, trace=trc) + best_placement_mem = placements + best_keys_mem = keys + + start_idx = 0 + # This is to accomodate the following TODO + # TODO: investigate why is failing with torchcompile if left alone + if ex.name == "torchcompile": + last_embedding_idx = -1 + for idx in range(0, len(group)): + if group[idx].sym.name == "embedding_backward": + last_embedding_idx = idx + logger.debug(f"last embedding idx: {last_embedding_idx}") + if last_embedding_idx != -1: + # Until last_embedding_idx (included) assigned to current fusion ex + for i in range(0, last_embedding_idx + 1, 1): + match_bsym_executor(group[i], [dict_time_strat, dict_mem_strat], ex) + + if last_embedding_idx == len(group) - 1: + # Benchmark + measure_and_update_result() + + start_idx = last_embedding_idx + 1 + + n_missing_bsyms = len(group) - start_idx + # Tune a single fusion group. + # NOTE: currently this is disabled for backward traces + for i in range(0, n_missing_bsyms, n_missing_bsyms - 1 if self.trace_type == TraceType.BW else 1): + if ex.name == "torchcompile": + import torch + + torch.compiler.reset() + + # for i in range(0, n_missing_bsyms): + # From top to bottom (this will include the whole region) + # -> First iteration is the one with fusion region with single element + # -> Last iteration gives the complete fusion region + for j in range(start_idx, start_idx + i + 1, increment_factor): + match_bsym_executor(group[j], [dict_time_strat, dict_mem_strat], ex) + for k in range(start_idx + i + 1, len(group), increment_factor): + match_bsym_executor( + group[k], + [dict_time_strat, dict_mem_strat], + # In order to benchmark the fusion placecement, we can use any executor for the excluded bsym from the fusion region + # TODO: consider tuning the single trace regions removed from the fusion one + get_first_available_operator_executor( + bsym=group[k], + executors=self.executors, + empty_hash=self.empty_executor_hashable_placeholder, + ), + ) + # Benchmark + measure_and_update_result() + + if best_placement_time is None or best_keys_time is None: + raise AssertionError("Failed to get best time placement") + if best_placement_mem is None or best_keys_mem is None: + raise AssertionError("Failed to get best placement") + + logger.debug( + f"For group {group_id} best placement with time cost = {best_res_time.runtime} ms:\n{best_res_time.trace}" + ) + logger.debug( + f"For group {group_id} best placement with mem cost = {best_res_mem.memory / (2**30)} GB:\n{best_res_mem.trace}" + ) + + # Update our dict + for n, p in zip(best_keys_time, best_placement_time): + dict_time_strat |= {n: p} + # Update our dict + for n, p in zip(best_keys_mem, best_placement_mem): + dict_mem_strat |= {n: p} + + # Generate the placement + executors_time = [] + executors_mem = [] + for bsym in self.trace.bound_symbols: + if bsym.sym.id == PrimIDs.RETURN: + # TODO (matteochen): Aggregate them + if "return" not in dict_time_strat or "return" not in dict_mem_strat: + raise AssertionError(f"Expected key return in mapping {dict_time_strat} and {dict_mem_strat}") + executors_time.append(dict_time_strat["return"]) + executors_mem.append(dict_mem_strat["return"]) + elif isinstance(bsym.output, Sequence): + seq_hash = sequence_hash(bsym.output) + if seq_hash not in dict_time_strat or seq_hash not in dict_mem_strat: + raise AssertionError( + f"Expected key {seq_hash} in mapping {dict_time_strat} and {dict_mem_strat}" + ) + executors_time.append(dict_time_strat[seq_hash]) + executors_mem.append(dict_mem_strat[seq_hash]) + elif ( + isinstance(bsym.output, CollectionProxy) + or isinstance(bsym.output, TensorProxy) + or isinstance(bsym.output, IntegerProxy) + or isinstance(bsym.output, FloatProxy) + ): + if bsym.output.name not in dict_time_strat or bsym.output.name not in dict_mem_strat: + raise AssertionError( + f"Expected key {bsym.output.name} in mapping {dict_time_strat} and {dict_mem_strat}" + ) + executors_time.append(dict_time_strat[bsym.output.name]) + executors_mem.append(dict_mem_strat[bsym.output.name]) + else: + raise AssertionError(f"Type not handled: {type(bsym.output)}") + + # For the forward trace we benchmark (memory) the mocked return statement as we don't know which + # tensor will be returned after the rematerialize_forward_and_backward call in order to do not under/over-estimate the memory consumption + trace = self.trace + if self.trace_type == TraceType.FW: + trace = from_trace(self.trace) + trace.bound_symbols = list(self.trace.bound_symbols) + trace.bound_symbols.pop() + trace.bound_symbols.append( + self.trace.bound_symbols[-1].from_bsym(args=get_not_used_intermediate_outsputs(trace)) + ) + # Save the optimal traces (both for runtime and memory consumption) that we have found + for executors, container in zip( + [executors_mem, executors_time], + [ + self.fusion_strat_helper.optimized_traces_mem_benchmark_only, + self.fusion_strat_helper.optimized_traces_time_benchmark_only, + ], + ): + trc = assign_executors( + in_trace=trace, + executors_list=executors, + always_executors=self.always_executors, + empty_str=self.empty_executor_hashable_placeholder, + ) + container.append({ex.name: trc}) + + # We add any provided compile option reference + self.executor_placement_options.placement_options_time.append( + FusionExecutorsPlacementCtx(placement=executors_time, compile_options=executor_compile_option) + ) + self.executor_placement_options.placement_options_mem.append( + FusionExecutorsPlacementCtx(placement=executors_mem, compile_options=executor_compile_option) + ) + + # If any compile options is used we will need to have duplicated executors inside the executors list to maintain the matching. + # TODO: integrate torchcompile_cat alongside with nvFuser. This should speed up the autotuner too. + self.fusion_executors_saved_for_later = [] + ex: FusionExecutor + for ex in self.fusion_executors: + if ex.name not in self.fusion_strat_helper.supported_executors: + continue + + logger.info(f"Searching best placement for fusion executor = {ex.name}") + + # We try to enable fusion specific compile options only for fw traces + # Backward traces will follow fw traces options + ex_compile_opts = ( + self.known_fusion_ex_compile_options.get(ex.name, []) if self.trace_type == TraceType.FW else [] + ) + self.fusion_executors_saved_for_later.append(ex) + + # Always search with option disabled (standard flow) + _search(ex) + + # Currently we are enabling one compile option at the time as testing all the permutations might need too much time. + # TODO: Consider implementing patterns based on the executor under investingation + if ex_compile_opts: + logger.info(f"{ex.name} compile options: {[option.fusion_tag for option in ex_compile_opts]}") + for opt in ex_compile_opts: + # Search only if we have an instruction related to the compile option + op_in_trace: bool = operation_in_trace(trace=self.trace, op=opt.symbol_tag) + if op_in_trace: + self.fusion_executors_saved_for_later.append(ex) + wrap_fn_with_exeuctor_compile_option(opt, _search, ex, opt) + + logger.info(f"Searching best placement for fusion executor = {ex.name} ended.") + + """ + ################################################## Public methods ################################################## + """ + + def get_optimal_fw_traces(self, is_computational=False) -> Sequence[TraceCtx] | TraceCtx: + if not self.cached_fw_traces: + raise AssertionError("Failed to obtain optimal fw traces") + if not is_computational: + return [candidate.trace for candidate in self.cached_fw_traces] + return self.best_comp_trace + + def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: + restore_file = self.compile_data.compile_options.get("autotune_restore_configuration", "") + + # We apply the dce transform as it will be applied to the cached traces during the past optimization + # (dce has been applied to the traces saved in the configuration). + if restore_file: + from thunder.core.transforms import dce + + fw_extrace, bw_extrace = apply_results_from_file( + fw_trace=dce(self.cached_computational_trace), + bw_trace=dce(self.cached_computational_backward_trace), + file=restore_file, + ) + return fw_extrace, bw_extrace + return ( + (self.best_pair_runtime.fw, self.best_pair_runtime.bw) + if self.optimizer_type == OptimizerType.RUNTIME + else (self.best_pair_memory.fw, self.best_pair_memory.bw) + ) + + def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType, apply_dce: bool = True): + from thunder.core.transform_common import dce + + self.trace_type = trace_type + # dce for the backward trace will be passed afterwards as we might modify it before + self.trace: TraceCtx = dce(trace) if apply_dce else trace + + match self.trace_type: + case TraceType.FW: + logger.info(f"New forward trace to optimize (strat = {self.optimizer_type})") + case TraceType.BW: + if not self.compile_data.compile_options.get("autotune_restore_configuration", ""): + if not self.cached_fw_traces: + raise AssertionError("Can not optimize backward traces before forward traces") + logger.info(f"New backward trace to optimize (strat = {self.optimizer_type})") + + def optimize(self): + from thunder.core.transform_common import dce + from thunder.executors.torch_autograd import update_bw_from_forward_optimization + from thunder.backend_optimizer.utils import assign_executors + from thunder.backend_optimizer.utils import repetead_trace_blocks, reduce_common_trace_blocks + + def _optimize(): + # Reset fusion helpers + self.fusion_strat_helper = FusionStratHelper() + # Reset helpers data structures + self.executor_placement_options = ExecutorPlacementOptions() + + cd = get_compile_data() + # Check if common blocks optimization is requested + optimize_common_blocks = ( + False if cd is None else cd.compile_options.get("autotune_optimize_common_blocks", False) + ) + optimize_common_blocks_min_size = ( + -1 if cd is None else cd.compile_options.get("autotune_optimize_common_blocks_min_size", -1) + ) + + # Cut the compilation time if possible + common_trace_blocks = repetead_trace_blocks( + trace=self.trace, min_block_size=optimize_common_blocks_min_size if optimize_common_blocks else -1 + ) + # A valid block is defined with at least 2 trace regions + if len(common_trace_blocks) >= 2 and optimize_common_blocks: + logger.info( + f"Running optimization with common blocks reduction. Found block indices in trace: {common_trace_blocks}" + ) + reduced_trace = reduce_common_trace_blocks(trace=self.trace, common_blocks_in=common_trace_blocks) + logger.info("Operating on reduced trace (by cutting common transformer blocks)") + self.is_reduced = True + self.cached_original_trace = self.trace + self.trace = reduced_trace + else: + logger.info( + "Optimizing the whole trace directly. No common transformer block optimization will be applied." + ) + + # This performs executor tuning + self._search_candidates() + + # From now on we have the optimized executors for each trace region. Apply them... + if len(self.executor_placement_options.placement_options_time) != len( + self.fusion_executors_saved_for_later + ): + raise AssertionError( + f"Unexpected time placement options size: {len(self.executor_placement_options.placement_options_time)}. Expected: {len(self.fusion_executors_saved_for_later)}" + ) + if len(self.executor_placement_options.placement_options_mem) != len(self.fusion_executors_saved_for_later): + raise AssertionError( + f"Unexpected mem placement options size: {len(self.executor_placement_options.placement_options_mem)}. Expected: {len(self.fusion_executors_saved_for_later)}" + ) + + # If we optimized the reduced trace we now can share the placing with other blocks + if self.is_reduced and self.cached_original_trace is not None: + for placement_ctx in self.executor_placement_options.placement_options_time: + placement = map_executors_from_reduced_trace_to_complete_trace( + self.cached_original_trace, common_trace_blocks, placement_ctx.placement + ) + placement_ctx.placement = placement + + for placement_ctx in self.executor_placement_options.placement_options_mem: + placement = map_executors_from_reduced_trace_to_complete_trace( + self.cached_original_trace, common_trace_blocks, placement_ctx.placement + ) + placement_ctx.placement = placement + + # Reset original trace + self.trace = self.cached_original_trace + # We will create the best compute time and peak memory consumption placement for each fusion executor + for placement_ctx, ex in zip( + self.executor_placement_options.placement_options_time, self.fusion_executors_saved_for_later + ): + trc = assign_executors( + in_trace=self.trace, + executors_list=placement_ctx.placement, + always_executors=self.always_executors, + empty_str=self.empty_executor_hashable_placeholder, + compile_data=self.compile_data, + fusion_executor_compile_options_to_activate=placement_ctx.compile_options, + ) + self.fusion_strat_helper.optimized_traces_time.append({ex.name: tuple([trc, placement_ctx])}) + for placement_ctx, ex in zip( + self.executor_placement_options.placement_options_mem, self.fusion_executors_saved_for_later + ): + trc = assign_executors( + in_trace=self.trace, + executors_list=placement_ctx.placement, + always_executors=self.always_executors, + empty_str=self.empty_executor_hashable_placeholder, + compile_data=self.compile_data, + fusion_executor_compile_options_to_activate=placement_ctx.compile_options, + ) + self.fusion_strat_helper.optimized_traces_mem.append({ex.name: tuple([trc, placement_ctx])}) + + # Filter out the optimal candidates for the current serach iteration + self._filter_candidates() + + restore_file_name = self.compile_data.compile_options.get("autotune_restore_configuration", "") + + match self.trace_type: + case TraceType.FW: + # Perform optimization only if we don't restore it from a past configuration + if restore_file_name: + self.cached_computational_trace = self.trace + logger.info("Skipping forward trace optimization as it will be restored from a configuration file.") + return + + # Clear any previous results + self.cached_fw_traces = [] + _optimize() + # We have multiple cached optimized fw traces, this iteration will create a fw-bw pair for + # every cached forward trace. At the end the best one will be picked up. + case TraceType.BW: + # Perform optimization only if we don't restore it from a past configuration + if restore_file_name: + logger.info( + "Skipping backward trace optimization as it will be restored from a configuration file." + ) + self.cached_computational_backward_trace = self.trace + return + + # Clear any previous results + self.out_traces_candidates = [] + + # Cached the bw trace as we need to modify the self.trace during the loop + cached_self_trace = from_trace(self.trace) + cached_self_trace.bound_symbols = list(self.trace.bound_symbols) + + # Now we can generate backward solutions from the cached fw traces + for fw_trace_candidate in self.cached_fw_traces: + logger.info(f"Backward optimization with fw from {fw_trace_candidate.label}") + # Restore the original bw trace + self.trace = from_trace(cached_self_trace) + self.trace.bound_symbols = list(cached_self_trace.bound_symbols) + # Set the current active cached forward trace context + # logger.info( + # f"Current fw cached ctx:\n{fw_trace_candidate.trace}\nOptions: {fw_trace_candidate.ctx.compile_options.fusion_tag if fw_trace_candidate.ctx.compile_options is not None else 'None'}" + # ) + self.active_fw_trace_ctx = fw_trace_candidate.trace, fw_trace_candidate.ctx + + logger.debug(f"Input bw trace:\n{self.trace}") + + self.trace = update_bw_from_forward_optimization(fw=fw_trace_candidate.trace, bw=self.trace) + + # Taken from: https://github.com/Lightning-AI/lightning-thunder/blob/339a782e3d75061a065a3d2e47b5206f23aea7c3/thunder/executors/torch_autograd.py#L222 + if self.apply_bucketing_bw_trace: + from thunder.distributed.transforms import FSDPCommBucketing + + self.trace = FSDPCommBucketing.apply_bucketing_to_backward_trace(self.trace) + + # Not called in the constructor for bw traces + self.trace = dce(self.trace) + + # Enable any forward active compilation flag + if fw_trace_candidate.ctx.compile_options: + wrap_fn_with_exeuctor_compile_option(fw_trace_candidate.ctx.compile_options, _optimize) + else: + _optimize() + + # For every pair being generated filter out the best choice. + self.best_pair_runtime, self.best_pair_memory = self._best_runtime_and_memory_candidates( + self.out_traces_candidates + ) + + # Save the tuning if requested + do_save = self.compile_data.compile_options.get("autotune_save_configuration", False) + if do_save: + model_name = self.compile_data.compile_options.get("model_name", "unknown") + file_name = f"{model_name}_runtime.json" + dump_traces_placement( + fw_trace=self.cached_computational_trace, + bw_trace=self.cached_computational_backward_trace, + file_name=file_name, + apply_remat=self.best_pair_runtime.apply_remat, + exs_fw=self.best_pair_runtime.executors_fw, + exs_bw=self.best_pair_runtime.executors_bw, + ) + file_name = f"{model_name}_memory.json" + dump_traces_placement( + fw_trace=self.cached_computational_trace, + bw_trace=self.cached_computational_backward_trace, + file_name=file_name, + apply_remat=self.best_pair_memory.apply_remat, + exs_fw=self.best_pair_memory.executors_fw, + exs_bw=self.best_pair_memory.executors_bw, + ) + + +class BackendOptimizer: + """ + Represents a generic backend optimizer. + + Attributes: + optimizer: An optimizer instance based on the configurations. + """ + + def __init__( + self, + *, + priority_executors: Sequence[Executor], + produce_log=False, + apply_bucketing_bw_trace: bool, + log_file_name="autotune_debug.log", + optimizer_type: OptimizerType = OptimizerType.RUNTIME, + optimizer_algorithm: OptimizationAlgorithm = OptimizationAlgorithm.BEST_FUSER, + compile_data, + ) -> None: + if optimizer_algorithm != OptimizationAlgorithm.BEST_FUSER: + raise AssertionError(f"Optimization {optimizer_algorithm} not implemented") + self.optimizer: PlacerBase = FusionPlacer_BeamSearch( + priority_executors=priority_executors, + produce_log=produce_log, + apply_bucketing_bw_trace=apply_bucketing_bw_trace, + log_file_name=log_file_name, + optimizer_type=optimizer_type, + compile_data=compile_data, + ) + + logger.info(f"Executors: {[ex.name for ex in priority_executors]}") + + def optimize(self): + """ + Optimize the executor placement for the current trace. + """ + self.optimizer.optimize() + + def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType, apply_dce=True): + """ + Attach a new trace for executors optimization. + + Args: + trace: The trace to attach. + trace_type: Forward or backward trace refrence. + """ + self.optimizer.attach_trace(trace=trace, trace_type=trace_type) + + def get_optimal_fw_traces(self, is_computational=False) -> Sequence[TraceCtx] | TraceCtx: + """ + Retrive the optimal forward traces that the object has tuned. + + Args: + is_computational: The requested forward trace is a computational trace (autograd is disabled). + """ + return self.optimizer.get_optimal_fw_traces(is_computational) + + def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: + """ + Retrive the optimal forward and backward trace pair. + """ + return self.optimizer.get_optimal_fw_bw_traces() diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py new file mode 100644 index 0000000000..4a163aa588 --- /dev/null +++ b/thunder/backend_optimizer/utils.py @@ -0,0 +1,1709 @@ +from collections.abc import Callable, Hashable, Sequence +from typing import Any + +from thunder.core.compile_data import get_compile_data +from thunder.core.dtypes import to_torch_dtype +from thunder.core.prims import PrimIDs +from thunder.core.proxies import ( + AnyProxy, + CollectionProxy, + FloatProxy, + IntegerProxy, + NumberProxy, + Proxy, + TensorProxy, + Variable, + variableify, +) +from thunder.core.symbol import BoundSymbol, Symbol +from thunder.core.trace import TraceCtx, from_trace, get_tracectx, reset_tracectx, set_tracectx +from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_always_executors +from thunder.core.utils import check, safe_map_flat +import thunder.core.transforms as transforms +from itertools import chain +import torch +from torch.utils.benchmark import Timer, Compare +from thunder.core.dtypes import dtype +from enum import Enum + + +class TraceType(Enum): + """ + Represents the nature of a trace, if forward (computational) or backward. + """ + + FW = 0 + BW = 1 + + +class BenchmarkResult: + """ + Represents a trace benchmark result information. + + Attributes: + time: Benchmark computation time. + memory: Benchmark peak memory usage. + trace: Computaiton trace. + label: A generic label. + index: A generic index in a sequence. + """ + + def __init__( + self, + *, + time: float = float("inf"), + memory: float = float("inf"), + trace: TraceCtx = TraceCtx(), + label: str | Hashable = "", + index: int = -1, + ) -> None: + self.runtime: float = time + self.memory: float = memory + self.trace: TraceCtx = trace + self.label: str | Hashable = label + self.index: int = index + + +class OptimizerType(Enum): + """ + Represents the autotuner target. + """ + + MEMORY = 0 + RUNTIME = 1 + + +# Maybe we can use id(s) +def sequence_hash(s: Sequence) -> str: + """ + Create a fake hash for a sequence of elements. + A fake hash is created because it relies on the elements metadata and not on a specific hash function. + + Args: + s: A sequence to hash. + """ + + def rec(s) -> str: + name = "[" + for e in s: + if e is None: + name += "None#" + elif hasattr(e, "name"): + name += e.name + "#" + elif isinstance(e, Sequence) and not isinstance(e, str): + name += rec(e) + elif isinstance(e, int): + name += "int" + str(e) + "#" + else: + raise AssertionError(f"Unsupported type = {type(e)}") + name += "]" + return name + + return rec(s) + + +def can_executor_execute(ex: Executor, bsym: BoundSymbol) -> bool: + """ + Wrap the `can_execute` call of the `Executor`. + + Args: + ex: The executor to test. + bsym: The bound symbol to test. + """ + try: + return ex.can_execute(bsym) + except Exception: + return False + + +def get_first_available_operator_executor( + *, bsym: BoundSymbol, executors: Sequence[Executor], empty_hash: str = "empty" +): + """ + Returns the first available executor which can execute the given bound symbol. + + Args: + bsym: The bound symbol to execute. + executors: A list of possible executors. + empty_hash: A label representing an empty executor if none will be found. + """ + for ex in executors: + if isinstance(ex, FusionExecutor): + continue + if can_executor_execute(ex, bsym): + return ex + return Executor(name=empty_hash) + + +def flatten_sequence(sequence: Sequence) -> list: + """ + Flat a sequence containing sub sequences with a dfs search. + By default None elements will be skipped. + + Args: + sequence: The sequence to flatten. + """ + res = [] + for e in sequence: + if isinstance(e, Sequence): + res.extend(flatten_sequence(e)) + # Skip Nones as they are not useful + elif e is not None: + res.append(e) + return res + + +def get_not_used_intermediate_outsputs(trace_in: TraceCtx) -> list[Proxy]: + """ + Returns all the intermediate outputs that are not used or returned in the input trace. + This can be usefull if we want to force a specific TensorProxy to be returned in a modfied trace to avoid the dce. + + Args: + trace_in: A generic trace. + """ + + def is_in_sequence(seq: Sequence[Any], t: Proxy): + for e in seq: + if hasattr(e, "name") and hasattr(t, "name") and e.name == t.name: + return True + return False + + def unpack_output(out) -> Sequence[Proxy]: + if issubclass(type(out), Proxy): + return [out] + elif isinstance(out, Sequence): + return flatten_sequence(out) + else: + raise RuntimeError(f"Unpack operation not defined for {type(out)}") + + ans: list[Proxy] = [] + for a in trace_in.bound_symbols: + f = False + unpacked_out = unpack_output(a.output) + for e in unpacked_out: + # None values are checked inside the unpack_output fn + for b in trace_in.bound_symbols: + if b.args is not None and isinstance(b.args, Sequence) and is_in_sequence(b.args, e): + f = True + break + if not f: + ans.append(e) + from thunder.backend_optimizer.optimizer import logger + + logger.debug(f"Returning not used proxies: {[p.name if hasattr(p, 'name') else p for p in ans ]}") + return ans + + +def assign_executors( + *, + in_trace: TraceCtx, + executors_list: list[Executor | FusionExecutor | OperatorExecutor] + | tuple[Executor | FusionExecutor | OperatorExecutor, ...], + always_executors: list[Executor] | tuple[Executor, ...], + empty_str: str | Hashable, + compile_data=None, + fusion_executor_compile_options_to_activate: Any | None = None, +) -> TraceCtx: + """ + Given a not optimized trace (original computation trace) generate a transformed trace with the requested executors. + + Args: + in_trace: The computation trace. + executors_list: A list of executors, one for each trace region. The size of this list is expected to be equal to the number of bound symbols inside the trace. + always_executors: A list of always executors to pick up symbols not picked up by any specific executor. + empty_str: A label representing an empty executor in the executors_list. + compile_data: A reference to the current compilation data. + fusion_executor_compile_options_to_activate: Any fusion exeuctor compilation options that can be enabled during the trace generation (for example nvFuser). + """ + + from thunder.executors.passes import _transform_for_operator_executor_execution + + def _assign_executors(): + swapmap: dict[Variable, Proxy] = {} + + def restore_correct_args(trace_in: TraceCtx): + def args_eq(a, b) -> bool: + if len(a) != len(b): + return False + for obj_a, obj_b in zip(a, b): + if type(obj_a) == type(obj_b) and isinstance(obj_a, TensorProxy): + if obj_a.name != obj_b.name: + return False + elif type(obj_a) == type(obj_b) and not isinstance(obj_a, TensorProxy): + if obj_a != obj_b: + raise AssertionError(f"What do you want to do here:\nobj_a:\n{obj_a}\nobj_b:{obj_b}") + return True + + def clear(bsym: BoundSymbol, input): + size = len(bsym.subsymbols) + if size > 0: + for subsym in bsym.subsymbols: + if not args_eq(subsym.args, input): + subsym.args = tuple(list(input)) + clear(subsym, input) + + for bsym in trace_in.bound_symbols: + if isinstance(bsym.sym.executor, OperatorExecutor): + clear(bsym, bsym.args) + + def update_swapmap(o: Any, no: Any) -> None: + if isinstance(o, Proxy): + check( + isinstance(no, Proxy), + lambda: f"Expected an execution transform to produce outputs with the same type, but found {type(o)} and {type(no)}", + ) + + vo = variableify(o) + vno = variableify(no) + if vo == vno: + return + swapmap[vno] = o + + def preserve_bsym(bsym: BoundSymbol) -> Any: + trace: TraceCtx | None = get_tracectx() + if trace is None: + raise AssertionError("None trace context") + trace.scopes[-1].append(bsym) + for p in chain(bsym.flat_proxy_outs, bsym.flat_proxy_args): + trace.names.add(p.name) + return bsym.output + + def visit_helper(bsym: BoundSymbol, ex: Executor) -> None | bool: + if bsym.sym.python_impl is not None: + return None + + # We have mapped this at previous stages + if ex.name == empty_str: + return None + + execution_transform: None | Callable = ex.get_execution_transform(bsym.sym) + out: Any + if execution_transform is not None: + out = execution_transform(*bsym.args, **bsym.kwargs) + elif isinstance(ex, OperatorExecutor): + # Calls the operator executor's operation + op: Symbol | None = ex.implmap[bsym.sym.id].symbol + if op is None: + raise AssertionError("op is None") + out = op(*bsym.args, **bsym.kwargs) + elif isinstance(ex, FusionExecutor): + # Preserves the symbol as is (it will be handled in the fusion pass) + out = preserve_bsym(bsym) + else: + raise AssertionError("Unknown executor") + + safe_map_flat(update_swapmap, bsym.output, out) + + return True + + def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: + return transforms.VISIT_TYPE.NO_OP if visit_helper(bsym, ex) is None else transforms.VISIT_TYPE.REPLACE + + if len(executors_list) != len(in_trace.bound_symbols): + raise AssertionError("len(executors_list) != len(in_trace.bound_symbols)") + + cached_subsymbols: dict[str, Sequence[BoundSymbol]] = {} + executor_mapping: dict[str, Executor] = {} + unique_fusion_executors = set() + + # Input should have equal length + if len(executors_list) != len(in_trace.bound_symbols): + raise AssertionError("len(executors_list) != len(extrace.bound_symbols)") + + for b, e in zip(in_trace.bound_symbols, executors_list): + if isinstance(e, FusionExecutor): + unique_fusion_executors.add(e) + if isinstance(b.output, TensorProxy): + executor_mapping[b.output.name] = e + + extrace = transforms.visitor_transform_paired(in_trace, visit, zip(in_trace.bound_symbols, executors_list)) + + # Restores original variables + bound_symbols: list[BoundSymbol] = [] + for bsym in extrace.bound_symbols: + nbsym: BoundSymbol = bsym.from_bsym_swap_proxies(swapmap) + bound_symbols.append(nbsym) + extrace.bound_symbols = bound_symbols + + for bsym in extrace.bound_symbols: + if isinstance(bsym.output, TensorProxy): + t_name = bsym.output.name + if t_name not in executor_mapping: + # Symbol added by the visitor + continue + saved_ex = executor_mapping[t_name] + if isinstance(saved_ex, OperatorExecutor): + cached_subsymbols[t_name] = list(bsym.subsymbols) + # This will leave out these symbols from the fusion pass + bsym.subsymbols = [] + + # Perform fusion pass + for ex in unique_fusion_executors: + extrace = ex.fusion_pass(extrace) + + # Restore subsymbols + # TODO (matteochen): Improve this search + for k, v in cached_subsymbols.items(): + # NOTE: Some symbols may be cut out by the fusion pass -> CSE + # For example: + # a = 1 + 1 + # b = 1 + 1 + # c = a + b + # being replaced by c = a + a + for bsym in extrace.bound_symbols: + if isinstance(bsym.output, TensorProxy) and bsym.output.name == k: + bsym.subsymbols = v + + restore_correct_args(extrace) + + # Apply always executors + extrace = _transform_for_operator_executor_execution(extrace, always_executors) + + return extrace + + if fusion_executor_compile_options_to_activate: + return wrap_fn_with_exeuctor_compile_option(fusion_executor_compile_options_to_activate, _assign_executors) + return _assign_executors() + + +def operation_in_trace(*, trace: TraceCtx, op: str, prefix: bool = False) -> bool: + """ + Test if an operation is being used inside a trace. + + Args: + trace: A computation trace. + op: The operation name to be tested. + prefix: Test only the prefix label. + """ + + # This is to query nv_enable_bookend (https://github.com/Lightning-AI/lightning-thunder/blob/339a782e3d75061a065a3d2e47b5206f23aea7c3/thunder/executors/nvfuserex_impl.py#L807) + # as there won't be any references about this in a trace. + always_true = set(["bookend"]) + + if op in always_true: + return True + for b in trace.bound_symbols: + if prefix: + if b.sym.name.startswith(op): + return True + else: + if b.sym.name == op: + return True + return False + + +def is_te_used(trace: TraceCtx) -> bool: + """ + Test if transformer engine is being used inside a trace. + + Args: + trace: A computation trace. + """ + from thunder.executors.transformer_engineex import linear_bound_symbol_name_prefix + from thunder.executors.transformer_engineex import te_functional_linear_backward_name + + if operation_in_trace(trace=trace, op=te_functional_linear_backward_name) or operation_in_trace( + trace=trace, op=linear_bound_symbol_name_prefix, prefix=True + ): + return True + + return False + + +def is_backward_trace(trace: TraceCtx) -> bool: + """ + Test if a trace is a backward trace from its signature. + + Args: + trace: A computation trace. + """ + sig = trace.signature_with_no_ctx() + return sig.find("backward") >= 0 + + +def benchmark_trace( + trace: TraceCtx, + iters: int = 1, + show_func=False, + apply_del_last_used=True, + snapshot=False, + snapshot_name="", + nsight: bool = False, + nsight_fn_name: str = "", + **kwargs, +) -> tuple[float, float, Any]: + """ + Benchmark a generic computation trace compute time and peak memory usage. + nsight profiles can be generated if requested. + + If a backward trace is benchmarked, its paired forward trace is requested (with kwargs) as we don't generate inputs + for the backward call from the static args but with the dynamic arguments returned by the forward trace. + + Args: + trace: A computation trace. + iters: Benchmark iterations. + show_func: Print the executed trace if True. + apply_del_last_used: A flag to control if the trace should be executed after a deletion of not used vars call. + snapshot: A flag controlling if memory usage snapshots should be created (https://pytorch.org/docs/stable/torch_cuda_memory.html). + snapshot_name: A label for the generated snapshot. + nsight: A flag contolling if nvsigh profiles should be generated or not. + nsight_fn_name: A label for the nsight iteration name during benchmark loop. + """ + from thunder.executors.passes import del_last_used + import inspect + + warm_up_iters = 10 + + torch.compiler.reset() + + # TODO: If TE is used inside the trace we have to clone the input arguments as + # we are currently seeing benchmarking issues at the iteration i > 0 + def clone_args_if_needed(args): + te_used = is_te_used(trace) + if not te_used: + return args + res = [] + # Detatching the tensors as for standalone trace benchmarks we are not interested in the gradients + for arg in args: + if isinstance(arg, Sequence): + res.append(clone_args_if_needed(arg)) + else: + if isinstance(arg, torch.Tensor): + res.append(arg.clone().detach()) + else: + res.append(arg) + return tuple(res) + + def warm_up(fn: Callable, args: Sequence): + for _ in range(warm_up_iters): + new_args = clone_args_if_needed(args) + fn(*new_args) + + def memory_snapshot(fn: Callable, args: Sequence, file_name: str): + new_args = clone_args_if_needed(args) + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.memory._record_memory_history() + fn(*new_args) + torch.cuda.memory._dump_snapshot(file_name + "_benchmark.pickle") + torch.cuda.memory._record_memory_history(enabled=None) + + def compute_time_cost_nsight(fn: Callable, iters: int, *args) -> tuple[float, float, Any]: + try: + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Warm up cycles + warm_up(fn, args) + + # Benchmark + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.cudart().cudaProfilerStart() + for i in range(iters): + new_args = clone_args_if_needed(args) + torch.cuda.empty_cache() + torch.cuda.nvtx.range_push(f"thunder benchmark fn:{nsight_fn_name}, iter{i}") + fn(*new_args) + torch.cuda.nvtx.range_pop() + torch.cuda.cudart().cudaProfilerStop() + + return float("inf"), float("inf"), None + except Exception as e: + import inspect + + trc = inspect.getsource(fn) + print(f"Trace execution failed for nsight (error: {e}):\n\nTrace executed:\n{trc}") + raise e + + def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[float, float, Any]: + try: + current_iter = 0 + out = None + + # Warm up cycles + warm_up(fn, args) + + # Snapshot request + if snapshot: + memory_snapshot(fn, args, snapshot_name) + + # Save output + new_args = clone_args_if_needed(args) + out = fn(*new_args) + + # Benchmark + stream = torch.cuda.current_stream() + max_allocated_bytes = 0 + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + torch.cuda.synchronize() + for i in range(iters): + current_iter = i + new_args = clone_args_if_needed(args) + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + start_events[i].record(stream) + fn(*new_args) + end_events[i].record(stream) + max_allocated_bytes = max( + max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device()) + ) + + torch.cuda.synchronize() + times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + tot_time = sum(times) / iters + return tot_time, max_allocated_bytes, out + except Exception as e: + print(f"Trace execution failed at iter {current_iter} (error: {e})\n\nTrace executed:\n{repr}") + raise e + + def compute_time_cost_ms_torchtimer(fn: Callable, repr: str, *args) -> tuple[float, float, Any]: + try: + out = None + + # Warm up cycles + warm_up(fn, args) + + # Snapshot request + if snapshot: + memory_snapshot(fn, args, snapshot_name) + + # Measure memory consumption + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) + new_args = clone_args_if_needed(args) + # Cache the output + out = fn(*new_args) + max_allocated_bytes = torch.cuda.max_memory_allocated(torch.cuda.current_device()) + + # Benchmark + new_args = clone_args_if_needed(args) + # Omit any labels as we are not going to print the Timer result + t = Timer( + stmt=""" + fn(*new_args) + """, + globals={"fn": fn, "new_args": new_args}, + ) + t = t.blocked_autorange(min_run_time=1) + return t.median, max_allocated_bytes, out + except Exception as e: + print(f"Trace execution failed (error: {e})\n\nTrace executed:\n{repr}") + raise e + + def build_static_args(sequence: Sequence, **kwargs) -> list: + return transform_proxies_to_real(sequence, level=0, **kwargs) + + def backward_trace_args_preprocess() -> list | None: + if "fw_trace" not in kwargs: + raise RuntimeError( + "Set the associated forward trace in order to benchmark backward pass with sdpa executor" + ) + fw_trace = kwargs.get("fw_trace", None) + if not isinstance(fw_trace, TraceCtx): + raise AssertionError(f"forward trace is not a TraceCtx. Received: {type(fw_trace)}") + # Run the fw trace and get the outputs + fw_output = benchmark_trace(fw_trace, apply_del_last_used=False)[2] + # If any issue with the forward trace benchmark we have to stop this backward benchmark too (usually OOM errors) + if fw_output is None: + return None + + # Check if the fw trace is a final trace or an intermediate one (used for single trace region benchmarks) + sig = fw_trace.signature_with_no_ctx() + is_fw_final_trace = sig.startswith("def augmented") + + # Filter the C0 tuple + # These location might change if the implementation of the automatic + # differentiation transform changes. The saved tensors are the second output + # of the return statement. There's a prototype changing the saved tensors to + # be part of the output of a special symbol + # https://github.com/Lightning-AI/lightning-thunder/pull/214 + saved_for_bw_C0 = fw_output[1] if not is_fw_final_trace else fw_output[1][0] + + # The underlying API will generate TE.Float8 tensors also, hence it must know if TE executor is used or not + input_args = build_static_args(trace.args, te_used=te_used) + + # Now, we expected that if the fw trace is a final trace also the bw trace is a final one. And vice versa + if is_fw_final_trace: + # Swap saved_for_backward_traces + saved_for_bw = saved_for_bw_C0, fw_output[1][1] # Saved for backward tuple unpacks in (C0, _) + # Subsitute the static inputs for saved_for_backward with the runtime ones + input_args.pop(0) + input_args.insert(0, saved_for_bw) + else: + # Currently single trace region backward trace receives as input the saved_for_bw tensors plus some others. + # They are indexed like [saved_for_bw, others...]. + # NOTE: This may change in the future. + """ + Example: + @torch.no_grad() + @no_autocast + def _cudnn_sdpa_bwd_wrapper(query, key, value, attn_mask, dropout_p=0.0, is_causal=False, *, scale=None): + # query: "cuda:0 bf16[32, 8, 128, 64]" + # key: "cuda:0 bf16[32, 8, 128, 64]" + # value: "cuda:0 bf16[32, 8, 128, 64]" + # dropout_p: "float 0.0" + # is_causal: "bool False" + (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, dropout_p, is_causal, scale=None) + return (t0, [query, key, value, dropout_p, is_causal, t0, t1, t2, t3]) + + @torch.no_grad() + @no_autocast + def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causal, t0, t1, t2, t3, t4): + (t5, t6, t7) = cudnn_sdpa_bwd(t4, query, key, value, None, dropout_p, is_causal, t0, t1, t2, t3, scale=None, cat_grad_qkv=False) + return {'query': t5, 'key': t6, 'value': t7, 'attn_mask': None, 'dropout_p': None, 'is_causal': None, 'scale': None} + + See how the backward trace needs t4 as argument recoveered from the static args + """ + updated_input_args = [t for t in saved_for_bw_C0] + updated_input_args.extend( + input_args[len(updated_input_args) :] + ) # Should be only one variable but leave this dyanamic + input_args = updated_input_args + + return input_args + + # Check for correctness + if trace.bound_symbols[-1].sym.id != PrimIDs.RETURN: + raise AssertionError("Missing return statement") + + if apply_del_last_used: + trace = del_last_used(trace) + + # Handle TE traces + cd = get_compile_data() + # We might benchmarking a partial trace where the TE symbol is not included yet, in this case rely on the compile option which tells us + # that afterwards at least one TE symbol will be included + # NOTE: compile data could be None if this benchmark util is used outside the compilation process. If this is the case we are benchmarking + # a whole trace (in theory) and is_te_used API will return the needed result. + te_used = (cd.compile_options.get("te_used", False) if cd else False) or is_te_used(trace) + if te_used: + cached_te_fp8_autocast_value = trace._include_te_fp8_autocast + trace._include_te_fp8_autocast = True + + # Build trace arguments: forward trace will receive compile time tensors while + # backward trace will receive dynamic inputs (runtime) to match real training env. + if is_backward_trace(trace): + input_args = backward_trace_args_preprocess() + # Forward or computational trace, parse the compile time input args... + else: + input_args = build_static_args(trace.args, te_used=te_used) + + # Can not parse input args (usually due to OOM errors in upstream calls) + if input_args is None: + return float("inf"), float("inf"), None + + # Obtain the python executable string + executable_str = trace.python() + executable = trace.python_callable() + if show_func: + print(inspect.getsource(executable)) + + trace_tok = set_tracectx(trace) + + t = float("inf") + m = float("inf") + answer = None + try: + if nsight: + t, m, answer = compute_time_cost_nsight(executable, iters, *input_args) + else: + # # By default torch.utils.benchmark.Timer is employed for measurement but if TE FP8 is being used we have to used our custom measurer. + # # https://github.com/mattteochen/lightning-thunder/blob/b728ab6416aca9a6fd621101a4fc68842b3ed60e/thunder/backend_optimizer/utils.py#L459 + # if is_te_used(trace): + # t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) + # else: + # t, m, answer = compute_time_cost_ms_torchtimer(executable, executable_str, *input_args) + + t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) + except Exception: + import traceback + + traceback.print_exc() + finally: + reset_tracectx(trace_tok) + + # Restore the autocast value to not mess up the input trace + if te_used: + trace._include_te_fp8_autocast = cached_te_fp8_autocast_value + return t, m, answer + + +def _register_impl_executor(ex: Executor, id: PrimIDs, fn: Callable, checker: Callable) -> None: + if ex.name == "nvfuser": + from thunder.executors.nvfuserex_impl import register_supported + + register_supported(id, fn, checker) + + +def _recover_ex_from_compile_option(option: str) -> Executor: + if option.startswith("nv"): + from thunder.executors.nvfuserex_impl import ex + + return ex + else: + raise AssertionError(f"Compile option not recognized: {option}") + + +def wrap_fn_with_exeuctor_compile_option(option, fn: Callable | None = None, *args): + """ + Wraps a function call enabling a compile option for a specific executor. + The compile option will be restored after the function completes. + This can be usefull if we want to benchmark a specific compile option. + + Args: + option: The option to be enabled. + fn: A callable function. + args: Function arguments. + """ + from thunder.core import compile_data + + cd = compile_data.get_compile_data() + if option is not None: + # Update compile option context + if cd is None: + raise AssertionError("compile_data is None") + old_opt: bool | None = cd.compile_options.get(option.fusion_tag, None) + new_opt = True if old_opt is None or old_opt is False else False + cd.compile_options[option.fusion_tag] = new_opt + # Register the impl for the executor in order to be able to execute the id + _register_impl_executor( + _recover_ex_from_compile_option(option.fusion_tag), + option.id, + option.impl, + option.checker, + ) + # Call fn and return output + if fn: + out = fn(*args) + else: + out = None + # Restore compile option + if option is not None: + cd.compile_options[option.fusion_tag] = old_opt + + return out + + +def print_trace_args(trace: TraceCtx): + """ + Utility to display a trace arguments. + + Args: + trace: A computation trace. + """ + print_nested_sequence(trace.args) + + +def print_nested_sequence(args, show_dicts=False): + """ + Utility to display a sequence of elements with possible nested sequences. + Elements will be retrieved in a dfs manner. + + Args: + args: The input sequence. + show_dicts: Control if dict types should be printed. + """ + + import pprint + + def is_tensor(t): + return isinstance(t, torch.Tensor) or isinstance(t, TensorProxy) + + if not isinstance(args, Sequence): + return + print("###################################### Sequence start") + + def _print(args, level): + tabs = "\t" * level + print(f"Level {level} start") + for arg in args: + if isinstance(arg, Sequence): + _print(arg, level + 1) + else: + tensor_shape = arg.shape if is_tensor(arg) else None + dtype = arg.dtype if is_tensor(arg) else None + name = arg.name if isinstance(arg, TensorProxy) else "" + print( + f'{tabs}{name + ": " if name else ""}{type(arg)}{pprint.pformat(arg) if isinstance(arg, dict) and show_dicts else ""} {tensor_shape if tensor_shape else ""} {dtype if dtype else ""}' + ) + print(f"Level {level} end") + + _print(args, 0) + print("###################################### Debug args\n") + + +def update_compile_options_executor_list_after_fw_bw_split() -> None: + """ + Updates the compile options with the executors that have been placed by the forward-backward split pass. + This utility can be used to save all the executors that have been effectively placed in a trace. + """ + + cd = get_compile_data() + assert cd + + # Get all the possible options that the vjp_optimization pass will use + options: dict = get_fw_bw_split_backends_options( + autotune_enable_te=cd.compile_options.get("autotune_enable_te", False) + ) + executors_list = list(cd.executors_list) + + # Remove all the initial options + for _, v in options.items(): + for ex in v: + if ex in executors_list: + executors_list.remove(ex) + + # Putting at the front even though order does not matter + for ex in cd.compile_options["autotune_executors_placed_by_fw_bw_split"]: + executors_list.insert(0, ex) + + # Assign new compilation executors options + cd.executors_list = executors_list + + +def transform_tensor(arg: TensorProxy, **kwargs) -> torch.Tensor: + """ + Retrive the associated torch.Tensor from a proxy tensor by reading its metadata. + This will allocate the real tensor in memory. + This utility can read transformer engine compilation requests and generate the associated FP8 tensor if needed. + + Args: + arg: The proxy tensor. + """ + from thunder.core.dtypes import is_float_dtype, is_signedinteger_dtype, is_boolean_dtype + + dtype = arg.dtype + shape = arg.shape + device = arg.device + requires_grad = arg.requires_grad + torch_dtype = to_torch_dtype(dtype) + if torch_dtype is None: + raise AssertionError(f"Unrecognized thunder dtype: {dtype}") + if is_float_dtype(dtype): + # Use TE Float8 if TE is enabled, it has float32 torch dtype + te_used = kwargs.get("te_used", False) + if te_used: + tensor: torch.Tensor = torch.randn( + shape, + dtype=torch_dtype if dtype.bytes > 1 else torch.float32, + device=device.device_str(), + requires_grad=requires_grad, + ) + if dtype.bytes == 1: + import transformer_engine.pytorch as te + + tensor = te.float8_tensor.Float8Tensor.to_float8(tensor) + # Support standard float tensors + else: + tensor: torch.Tensor = torch.randn( + shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad + ) + elif is_signedinteger_dtype(dtype): + tensor: torch.Tensor = torch.randint( + 0, 8, shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad + ) + elif is_boolean_dtype(dtype): + # TODO (matteochen): maybe random? + tensor: torch.Tensor = torch.zeros( + *shape, dtype=torch.bool, device=device.device_str(), requires_grad=requires_grad + ) + else: + raise AssertionError(f"dtype {dtype} not supported yet") + + return tensor + + +def transform_proxies_to_real(sequence: Sequence, level=0, **kwargs) -> tuple | list: + """ + Retrieve a sequence of real arguments relative to a sequence of proxy arguments. + This supports also nested sequences in a recursive way. + + Args: + sequence: The input proxy sequence. + level: An utility integer representing the search dept. + """ + from thunder.executors.transformer_engineex import Context as C + + res = [] + for e in sequence: + if isinstance(e, Sequence): + res.append(transform_proxies_to_real(e, level + 1, **kwargs)) + else: + if isinstance(e, TensorProxy): + res.append(transform_tensor(e, **kwargs)) + elif isinstance(e, IntegerProxy): + if e.python_type is bool: + res.append(False if e.value is None else e.value) + else: + res.append(0 if e.value is None else e.value) + elif isinstance(e, FloatProxy): + res.append(0.0 if e.value is None else e.value) + # Transformer engine Context object + # + # This instruction will populate the args with a dummy context which is not correct in theory. + # For the benchmark purpose (where this fn is currently used) this error will not impact on the runtime correctness as at the end we + # will use the cached runtime contexts from the forward pass. + # We need this only to generate a context for the static inputs (which are discarded afterwards). + # + # Backward args: (saved_for_backward, cotangents) + # saved_for_backward -> replaced by the runtime tuple + # cotangents -> static inputs will be used + # If the static input generator will be capable to generate only the cotangents then branch will not be used anymore + # + # Currently an option to fill a custom maybe real context is left. + elif hasattr(e, "name") and isinstance(e, AnyProxy) and e.name.startswith("ctx_te"): + required_context = kwargs.get("cached_fw_te_ctx_out", None) + res.append(required_context if required_context is not None else C()) + elif e is None: + res.append(None) + else: + raise AssertionError( + f'Input arg type not recognized: {type(e)} with name: {e.name if hasattr(e, "name") else "unknown"} with value: {e}' + ) + # Outer container must be a list + return tuple(res) if level > 0 else res + + +def reorder_executors_list(executors: Sequence, **kwargs): + """ + Reorders a random executors list to be compatible with the autotuner compilation flow. + This will put in the front of the returned list all the executors with a grad fn. + All the other executors will be appended afterwards. + + If no fusion executors is present inside the input list, a default one will be added in order to trigger the autotuning process. + + Args: + executors: The executors to be reordered. + """ + from thunder.executors.torch_compile import torch_compile_ex + from thunder.executors.nvfuserex_impl import ex as nvfuser_ex + + reordered = [] + options = get_fw_bw_split_backends_options(**kwargs) + + are_inputs_names = isinstance(executors[0], str) + + # Put these in front to be picked up by _get_gradfn_and_executor + for _, v in options.items(): + for ex in v: + if are_inputs_names: + if ex.name in executors: + reordered.append(ex.name) + elif ex in executors: + reordered.append(ex) + + # Add others + for ex in executors: + if ex not in reordered: + reordered.append(ex) + + # NOTE: Currently the autotuner expects at least one Fusion executor otherwise it won't work. + # If other techniques will be added then this constraint will not be necessary + found = False + for ex in reordered: + if are_inputs_names and (ex == nvfuser_ex.name or ex == torch_compile_ex.name): + found = True + elif ex == nvfuser_ex or ex == torch_compile_ex: + found = True + if not found: + reordered.insert(0, nvfuser_ex.name if are_inputs_names else nvfuser_ex) + + return reordered + + +def symbol_hash( + *, + bsym: BoundSymbol, + ignore_returns_meta: bool = False, + ignore_unpacks_meta: bool = False, + ignore_unpacks: bool = False, +): + """ + Hash a bound symbol relying on its metadata (symbol name, bound symbol inputsa and outputs). + No hash functions will be applied in order to leave the output readable. + + Args: + bsym: A bound symbol. + ignore_returns_meta: If True, return statement metadata will be ignored + ignore_unpacks_meta: If True, unpack statements metadata will be ignored + ignore_unpacks: If True, unpack symbols will not be included. + """ + + def _tensor_hash(t: TensorProxy) -> str: + assert t.dtype + shapes = [str(s) for s in t.shape] + return "{" + "-".join(shapes) + "," + str(t.device) + "," + t.dtype.full_name + "," + str(t.requires_grad) + "}" + + def _collection_hash(c: CollectionProxy) -> str: + return "{Collection," + c.name + "," + str((type(c.collection()))) + "," + str(len(c.collection())) + "}" + + def _number_hash(t: NumberProxy) -> str: + return "{" + str(t.value) + "}" + + def _any_proxy_hash(p: AnyProxy) -> str: + # We are not using class' __repr__ as it might contain memory addresses and those could change during different iterations + return "{AnyProxy}" + + def _sequence_hash(s: Sequence | None) -> str: + if s is None: + return "None" + + ret = "[" + for e in s: + if e is None: + ret += "{None}," + elif isinstance(e, TensorProxy): + ret += _tensor_hash(e) + "," + elif isinstance(e, NumberProxy): + ret += _number_hash(e) + "," + elif isinstance(e, Sequence): + ret += _sequence_hash(e) + "," + elif isinstance(e, AnyProxy): + ret += _any_proxy_hash(e) + "," + elif isinstance(e, CollectionProxy): + ret += _collection_hash(e) + "," + elif isinstance(e, int) or isinstance(e, float) or isinstance(e, bool): + ret += "{" + f"{(type(e))}" + "}," + elif isinstance(e, dtype): + ret += "{" + f"{(type(e))}" + "}," + else: + raise RuntimeError(f"Not implemented {type(e)}. Failed bsym: {bsym}") + return ret + "]" + + def _hash(bsym: BoundSymbol) -> str: + match = { + TensorProxy: _tensor_hash, + tuple: _sequence_hash, + list: _sequence_hash, + Sequence: _sequence_hash, + CollectionProxy: _collection_hash, + } + + if ignore_returns_meta and bsym.sym.id == PrimIDs.RETURN: + return "{return}" + + if ignore_unpacks and bsym.sym.name.startswith("unpack"): + return "" + elif ignore_unpacks_meta and bsym.sym.name.startswith("unpack"): + if bsym is not None and bsym.output is not None: + if isinstance(bsym.output, Sequence) and len(bsym.output) < 1: + return "" + return "{general_unpack}" + + h = bsym.sym.name + # Handle tensor as output or sequences + if type(bsym.output) not in match.keys(): + raise RuntimeError(f"type {type(bsym.output)} not implemented") + h += ( + "#out:" + + match[type(bsym.output)](bsym.output) + + "#in:" + # Args is always a tuple + + _sequence_hash(bsym.args) + ) + return h + + h = _hash(bsym) + return ("{" + h + "}") if h else h + + +# Both lhs and rhs are included in the range +# TODO: known_points can be used to detect start and end of a block sequence +def repetead_trace_blocks( + *, trace: TraceCtx, min_block_size=2, known_points: tuple[BoundSymbol, BoundSymbol] | None = None +) -> list[tuple[int, int]]: + """ + Detects if are there repeated sections inside a given trace. + This utility can be employed on traces referring to transformer based models where the layers are repeated N times. + + The return list will contain a tuple of two elements pointing to the index (in the computation trace) of where a block starts and ends (both included). + + The variable min_block_size can be tuned in order to not allucinate this function by capturing unwanted sections (small sections) if no repeated transformer layers can be found. + + Args: + trace: A computation trace. + min_block_size: The minimum block lenght, by default 2. + known_points: If a practitioner already knows where a transformer layer starts and ends inside a given trace, these points can be supplied in order to speed up the search. Currently not implemented. + """ + if min_block_size < 2: + return [] + + if known_points is not None: + raise RuntimeError("known_points research is not supported.") + + symbols = [ + s + for s in trace.bound_symbols + if not s.sym.name.startswith("python_del") and not s.sym.name.startswith("unpack") + ] + + def _tuple_name(tup: Sequence): + ret = "(" + for e in tup: + if e is None: + ret += "None, " + elif hasattr(e, "name"): + ret += e.name + ", " + elif isinstance(e, Sequence): + ret += _tuple_name(e) + ", " + elif isinstance(e, int) or isinstance(e, float) or isinstance(e, bool): + ret += str(e) + ", " + else: + raise RuntimeError(f"Not implemented {type(e)}") + return ret + ")" + + # Only bsym that have inputs and outputs + original_map_indexes = { + str(bsym.output.name) if isinstance(bsym.output, TensorProxy) else _tuple_name(bsym.output): i + for i, bsym in enumerate(trace.bound_symbols) + if not (bsym.output is None or not bsym.args) and bsym.sym.id != PrimIDs.RETURN + } + + def _lcs(start_indexes) -> int: + max_last_len = len(symbols) - 1 + max_first_len = start_indexes[1] + + lcs = 0 + while start_indexes[0] < max_first_len and start_indexes[-1] < max_last_len: + # Get all the hashes + hashes = [symbol_hash(bsym=symbols[i]) for i in start_indexes] + # Advance if all the hashes coincides + uniques = set(hashes) + if len(uniques) == 1: + start_indexes = [i + 1 for i in start_indexes] + lcs += 1 + else: + return lcs + return max(lcs, 1) + + def _skip(bsym: BoundSymbol) -> bool: + return bsym.output is None or not bsym.args + + bsym_indexes: dict[str, list[int]] = {} + for i, bsym in enumerate(symbols): + if i == len(symbols) - 1: + break + if _skip(bsym): + continue + h = symbol_hash(bsym=bsym) + if h in bsym_indexes: + bsym_indexes[h].append(i) + else: + bsym_indexes[h] = [i] + + def _range_seen(index: int, s: set): + for r in s: + if index >= r[0] and index <= r[1]: + return True + return False + + seen_hashes = set() + seen_ranges = set() + max_lcs = 0 + res = [] + for i, bsym in enumerate(symbols): + if i == len(symbols) - 1: + break + if _skip(bsym): + continue + + h = symbol_hash(bsym=bsym) + # Normally, bsym are expected to output a TensorProxy + if not isinstance(bsym.output, Proxy) or h in seen_hashes or _range_seen(i, seen_ranges): + continue + + indexes = bsym_indexes.get(h, []) + seen_hashes.add(h) + if len(indexes) < 2: + continue + + # Now we can find the longest common sequence between all the occurences + lcs = _lcs(indexes) + # print('\n####################') + # for index in indexes: + # print(f'For index {index} lcs: {lcs}') + # print(f'Starting bsym: {symbols[index]}') + # print(f'Ending bsym: {symbols[index + lcs - 1]}') + # print('\n####################') + if lcs > 1: + # Push every seen ranges to ignore all the subranges + for i in indexes: + seen_ranges.add((i, i + lcs - 1)) + + # Set result + if lcs > max_lcs: + max_lcs = lcs + res = [(i, i + lcs - 1) for i in indexes] + + if max_lcs < min_block_size: + return [] + + from thunder.backend_optimizer.optimizer import logger + + logger.debug(f"Max block lcs fouund: {max_lcs}") + logger.debug(f"{[(symbols[r[0]].output.name, symbols[r[1]].output.name) for r in res]}") + + return [ + (original_map_indexes[symbols[t[0]].output.name], original_map_indexes[symbols[t[1]].output.name]) for t in res + ] + + +def _regions_between_blocks(trace: TraceCtx, common_blocks: list[tuple]) -> int: + """ + Retrieve the size of a gap region between common blocks. + + What is regions_between_blocks? + They are trace regions between one transformer block and the next one (usually found in the backward trace) and given that these regions are not present + at the end of the last transformer block it means that they are needed in order to prepare shapes or strides + for the block at i+1 from the output of block i. + For example if common blocks looks like: [(32, 155), (157, 280)] + the symbol at index 156 (the gap) could generally be: (for torch.float32 dtype, if another dtype is used the trace may contain other ops in this region leading to a larger gap). + In the forward trace we have not these gaps (so far). + + In the example above the returned value will be 1. + + Args: + trace: A computation trace. + common_blocks: A list containig the common blocks for the given trace. + + """ + + def _assert_args(seq_a: Sequence, seq_b: Sequence): + assert len(seq_a) == len(seq_b) + for a, b in zip(seq_a, seq_b): + assert type(a) == type(b) + if isinstance(a, TensorProxy): + assert a.shape == b.shape + assert a.dtype == b.dtype + elif isinstance(a, Sequence): + _assert_args(a, b) + + regions_between_blocks = common_blocks[1][0] - common_blocks[0][1] - 1 + trace_region_between_common_blocks = trace.bound_symbols[common_blocks[0][1] + 1 : common_blocks[1][0]] + for i in range(1, len(common_blocks)): + if not common_blocks[i][0] - common_blocks[i - 1][1] - 1 == regions_between_blocks: + raise AssertionError( + "Trace configuration not supported. All the trace regions between common blocks are expected to have the same number of instructions." + ) + + # Check that the trace regions are equal + test_trace_regions = trace.bound_symbols[common_blocks[i - 1][1] + 1 : common_blocks[i][0]] + assert len(test_trace_regions) == len(trace_region_between_common_blocks) + for a, b in zip(test_trace_regions, trace_region_between_common_blocks): + assert a.sym.name == b.sym.name + _assert_args(a.args, b.args) + + return regions_between_blocks + + +def _indices_to_exclude_between_common_blocks(common_blocks: list[tuple]) -> list: + """ + Retrive the indicies referring to the gaps between one common block and the next one. + + Args: + common_blocks: A computed common block list for a given trace. + """ + if len(common_blocks) < 2: + return [] + + ret = [] + for i in range(1, len(common_blocks)): + start_gap_index = common_blocks[i - 1][1] + 1 + end_gap_index = common_blocks[i][0] - 1 + ret.extend([j for j in range(start_gap_index, end_gap_index + 1)]) + return ret + + +def reduce_common_trace_blocks( + *, trace: TraceCtx, common_blocks_in: list[tuple], skip_between_blocks: bool = True +) -> TraceCtx: + """ + Generate a reduced trace (shorter computation nodes) given a common block pattern. + + This can be useful to speed up the executor tuning for models with repeated layers. + + Args: + trace: A computation trace. + common_blocks_in: A previously computed common block pattern. + skip_between_blocks: A flag to control if gaps between common blocks should be included in the output trace or not. See _regions_between_blocks. + """ + + def _exclude(blocks: list[tuple[int, int]], index: int, black_list: set): + # Exclude if the index is in a repeated block + for block in blocks: + if index >= block[0] and index <= block[1]: + return True + + # Exclude if it marked as to remove + if index in black_list and skip_between_blocks: + return True + return False + + def _find_bsym_index(out_name: str, space: Sequence[BoundSymbol]) -> int: + for i, b in enumerate(space): + if b.output is not None and hasattr(b.output, "name") and b.output.name == out_name: + return i + raise RuntimeError(f"Can not found bsym with output {out_name} in the search space.") + + common_blocks = list(common_blocks_in) + if len(common_blocks) < 2: + trc = from_trace(trace) + trc.bound_symbols = list(trace.bound_symbols) + return trc + + # Create a mapping where we can easily find to which block a specific output belongs + output_to_block: dict[str, tuple[int, int]] = {} + for n_block, block in enumerate(common_blocks): + for i in range(block[0], block[1] + 1): + bsym = trace.bound_symbols[i] + if not hasattr(bsym.output, "name"): + continue + output_to_block[bsym.output.name] = (n_block, i - block[0]) + + # Check that we maintain the pattern + regions_between_blocks = _regions_between_blocks(trace, common_blocks) + + # We have to exlude these gaps indices from the reduce trace + index_gaps_to_exclude = [] + if regions_between_blocks: + index_gaps_to_exclude = _indices_to_exclude_between_common_blocks(common_blocks) + # Make it fast to search in + index_gaps_to_exclude = set(index_gaps_to_exclude) + + # Create reduced trace regions + bound_symbols: list[BoundSymbol] = [ + b for i, b in enumerate(trace.bound_symbols) if not _exclude(common_blocks[1:], i, index_gaps_to_exclude) + ] + + # Now, we have to update the trace region inputs after the last block to accepts the outputs of the first block, if it's not the return statement. + if trace.bound_symbols[common_blocks[-1][1] + 1].sym.id != PrimIDs.RETURN: + symbol_to_correct_index = _find_bsym_index( + trace.bound_symbols[common_blocks[-1][1] + 1].output.name, bound_symbols + ) + symbol_to_correct = bound_symbols[symbol_to_correct_index] + + def _correct_args(target: BoundSymbol): + args = [] + for arg in target.args: + if arg is None: + args.append(None) + elif hasattr(arg, "name") and arg.name in output_to_block: + _, index_in_block = output_to_block[arg.name] + # Recover the argument from the first block + args.append(trace.bound_symbols[common_blocks[0][0] + index_in_block].output) + elif isinstance(arg, Sequence): + raise RuntimeError("Not implemented") + else: + args.append(arg) + return args + + def _correct_bsym(bsym: BoundSymbol) -> BoundSymbol: + bsym = bsym.from_bsym(args=_correct_args(bsym)) + return bsym + + new_subsymbols = [] + for sub in symbol_to_correct.subsymbols: + new_sub = _correct_bsym(sub) + new_subsymbols.append(new_sub) + + bound_symbols[symbol_to_correct_index] = symbol_to_correct.from_bsym( + args=_correct_args(symbol_to_correct), subsymbols=new_subsymbols + ) + + # We need to check also the return statements as we have fewer args now + flatten_bsyms = flatten_sequence([b.output for b in bound_symbols]) + args_remained = set([b.name for b in flatten_bsyms if b is not None and hasattr(b, "name")]) + # Fw trace + if isinstance(bound_symbols[-1].args[0], dict): + saved_for_backward = tuple( + [e for e in bound_symbols[-1].args[1][0] if hasattr(e, "name") and e.name in args_remained] + ) + if isinstance(bound_symbols[-1].args[0]["output"], Sequence): + output = tuple( + [o for o in bound_symbols[-1].args[0]["output"] if hasattr(o, "name") and o.name in args_remained] + ) + else: + output = bound_symbols[-1].args[0]["output"] + flat_output = tuple( + [o for o in bound_symbols[-1].args[0]["flat_output"] if hasattr(o, "name") and o.name in args_remained] + ) + new_dict = {"output": output, "flat_output": flat_output, "flat_args": bound_symbols[-1].args[0]["flat_args"]} + + # Create the new args and substitute return symbol + bsym = bound_symbols[-1].from_bsym(args=(new_dict, (saved_for_backward, bound_symbols[-1].args[1][1]))) + bound_symbols[-1] = bsym + # Bw trace + else: + + def _returned(seq: Sequence) -> tuple: + ret = [] + for e in seq: + if e is None: + ret.append(None) + elif isinstance(e, Sequence): + ret.append(_returned(e)) + elif isinstance(e, Proxy) and e.name in args_remained: + ret.append(e) + elif not isinstance(e, Proxy): + raise RuntimeError(f"type not recognized: {type(e)}") + + return tuple(ret) + + # Backward output is a tuple, and generally a tuple of tuple (()) + original_returned = bound_symbols[-1].args + returned = _returned(original_returned) + bound_symbols[-1] = bound_symbols[-1].from_bsym(args=returned) + + extrace: TraceCtx = from_trace(trace) + extrace.bound_symbols = bound_symbols + return extrace + + +def map_executors_from_reduced_trace_to_complete_trace( + complete_trace: TraceCtx, common_blocks: list[tuple], ex_mappings: list[Executor] +) -> list[Executor]: + """ + Generate executors mappings (trace region -> executor) for the complete trace once the optimization has been performed on a reduced trace. + + This implementation currently relies on the fact that transformer blocks are contiguous in trace + or they have a common gap region between them (in case for bw trace). + + The output executor list has size equal to the complete trace regions size. + + Args: + complete_trace: A computation trace. + common_blocks: A previously computed common block pattern. + ex_mappings: The executor mappings for the reduce trace. + """ + from thunder.executors.torchex import ex as torch_ex + + if len(common_blocks) < 2: + raise AssertionError("No common block found") + + # Check that we maintain the pattern + regions_between_blocks = _regions_between_blocks(complete_trace, common_blocks) + + # These are the trace region indices (referred to the complete trace) that we have excluded from the reduced trace optimization. + # We have also to integrate their executors. + # By default torchex will be used as currently no complex (optimizable) ops are present so far (they are usually reshape ops). + indices_excluded: list = _indices_to_exclude_between_common_blocks(common_blocks) + + # Correctness assertion + if regions_between_blocks: + assert len(indices_excluded) % regions_between_blocks == 0 + assert len(indices_excluded) // regions_between_blocks == len(common_blocks) - 1 + + # Solution starting point: copy up to the end of the first common block + complete_trace_executors: list[Executor] = ex_mappings[: common_blocks[0][1] + 1] + # Get the executors sequence to share from the first block to all the other equal blocks. + to_share: list[Executor] = [] + for i in range(len(common_blocks) - 1): + # First region bewteen block, adding here as this was not present in the reduce trace (not found in the ex_mappings structure) + if i == 0: + to_share.extend([torch_ex] * regions_between_blocks) + + to_share.extend(ex_mappings[common_blocks[0][0] : common_blocks[0][1] + 1]) + + # We have to add back the excluded regions (see comment 15 lines above). + if i < len(common_blocks) - 2: + to_share.extend([torch_ex] * regions_between_blocks) + + # Extend by sharing mappings of transformer blocks + complete_trace_executors.extend(to_share) + # Extend with the remained bsyms + complete_trace_executors.extend(ex_mappings[common_blocks[0][1] + 1 :]) + + # Check that we have all the executors needed + len_got = len(complete_trace_executors) + len_expected = len(complete_trace.bound_symbols) + if len_got != len_expected: + raise AssertionError( + f"Trace regions size is different from the obtained executors lenght: {len_expected} - {len_got}" + ) + + return complete_trace_executors + + +# This fn is used before compile data being set, rely on kwargs +def get_fw_bw_split_backends_options(bsym: BoundSymbol | None = None, **kwargs) -> list | dict: + """ + Retrieves the executors tuning options for the vector jacobian product pass. + These executors must be tuned at the vjp stage as we have to choose the correspective backward grad function. + + For new executors support the followig lists can be expanded. + + A guard is put for the transformer_engine_ex as its usage should not be tuned if not requested in a explicit way. + + Args: + bsym: The query bound symbol. + """ + from thunder.executors.sdpaex import sdpa_ex + from thunder.executors.cudnnex import cudnn_ex + from thunder.executors.fa3ex import fa3_ex + from thunder.executors.transformer_engineex import transformer_engine_ex + + if kwargs is None or not kwargs.get("autotune_enable_te", False): + options: dict[str, list] = { + "scaled_dot_product_attention": [sdpa_ex, cudnn_ex, fa3_ex], + } + else: + options: dict[str, list] = { + "linear": [transformer_engine_ex], + "scaled_dot_product_attention": [sdpa_ex, cudnn_ex, fa3_ex], + } + + return options.get(bsym.sym.name, []) if bsym else options + + +def trace_symbolic_hash(trace: TraceCtx) -> str: + res = "" + for b in trace.bound_symbols: + # Ignoring unpacks as when tuple has size zero, there are cases when None is given as static args/output and cases where a zero sized tuple is returned. + res += symbol_hash(bsym=b, ignore_returns_meta=True, ignore_unpacks_meta=True) + return res + + +supported_file_modes = set(["json"]) + + +def dump_traces_placement( + *, + fw_trace: TraceCtx, + bw_trace: TraceCtx, + exs_fw: list[Executor], + exs_bw: list[Executor], + apply_remat: bool, + file_name: str, + output_mode: str = "json", +) -> str: + """ + Creates an output configuration file where the current forward and backward trace optimization are saved. + + Args: + fw_trace: A forward trace. + bw_trace: A backward trace. + exs_fw: Forward trace region executors. + exs_bw: Backward trace region executors. + apply_remat: If forward and backward traces are output of rematerialize_forward_and_backward + file_name: The output file name. + output_mode: The output file format. Must be one of ['json']. + """ + assert output_mode in supported_file_modes + + if output_mode == "json": + # We defined an unique trace by reading its bsym metadata, the proxies name are ignored as they may + # change but the overall computation can remain the same. + fw_hash = trace_symbolic_hash(fw_trace) + bw_hash = trace_symbolic_hash(bw_trace) + + executors_fw_name = [ex.name if (ex and ex.name != "empty") else "None" for ex in exs_fw] + executors_bw_name = [ex.name if (ex and ex.name != "empty") else "None" for ex in exs_bw] + + assert len(fw_trace.bound_symbols) == len(executors_fw_name) + assert len(bw_trace.bound_symbols) == len(executors_bw_name) + + from thunder.backend_optimizer.optimizer import logger + + logger.info( + f"Size match between len(fw_trace.bound_symbols)[{len(fw_trace.bound_symbols)}] and len(executors_fw_name)[{len(executors_fw_name)}]" + ) + logger.info( + f"Size match between len(bw_trace.bound_symbols)[{len(bw_trace.bound_symbols)}] and len(executors_bw_name)[{len(executors_bw_name)}]" + ) + logger.info(f"Saving configuration in {file_name}") + + data = { + "forward": { + "hash": fw_hash, + "executors": executors_fw_name, + }, + "backward": { + "hash": bw_hash, + "executors": executors_bw_name, + }, + "rematerialize": apply_remat, + } + try: + with open(file_name, "w") as file: + import json + + json.dump(data, file) + except Exception: + from thunder.backend_optimizer.optimizer import logger + import traceback + + err = traceback.format_exc() + logger.error(f"Can not dump {file_name} file:\n{err}") + return "" + return file_name + return "" + + +def apply_results_from_file( + *, fw_trace: TraceCtx, bw_trace: TraceCtx, file: str, input_mode: str = "json" +) -> tuple[TraceCtx, TraceCtx]: + """ + Generate a transformed forward and backward trace from a configuration file. + Compatibility check is performed on both traces. + + Args: + fw_trace: The original augmented forward trace. + bw_trace: The original backward trace. + file: The configuration file. + input_mode: The configuration structure. Must be one of ['json']. + """ + import json + from thunder.executors.torchex import ex as torch_ex + from thunder.executors.pythonex import ex as python_ex + from thunder.executors.sdpaex import sdpa_ex + from thunder.executors.cudnnex import cudnn_ex + from thunder.executors.fa3ex import fa3_ex + from thunder.executors.nvfuserex_impl import ex as nvfuser_ex + from thunder.executors.torch_compile import torch_compile_ex + from thunder.executors.torch_autograd import update_bw_from_forward_optimization + + assert input_mode in supported_file_modes + + # Extend this if more executors will be added + conversion_map: dict[str | Hashable, Executor] = { + "None": Executor("empty"), + torch_ex.name: torch_ex, + python_ex.name: python_ex, + nvfuser_ex.name: nvfuser_ex, + torch_compile_ex.name: torch_compile_ex, + sdpa_ex.name: sdpa_ex, + cudnn_ex.name: cudnn_ex, + fa3_ex.name: fa3_ex, + } + + if input_mode == "json": + data = json.load(open(file, "r")) + + fw_hash = trace_symbolic_hash(fw_trace) + bw_hash = trace_symbolic_hash(bw_trace) + assert fw_hash == data["forward"]["hash"] + assert bw_hash == data["backward"]["hash"] + + fw_executors_recovered: list[str] = data["forward"]["executors"] + extrace_fw = assign_executors( + in_trace=fw_trace, + executors_list=[conversion_map[ex] for ex in fw_executors_recovered], + empty_str="empty", + always_executors=get_always_executors(), + ) + bw_executors_recovered: list[str] = data["backward"]["executors"] + bw_trace = update_bw_from_forward_optimization(fw=extrace_fw, bw=bw_trace) + extrace_bw = assign_executors( + in_trace=bw_trace, + executors_list=[conversion_map[ex] for ex in bw_executors_recovered], + empty_str="empty", + always_executors=get_always_executors(), + ) + + if data["rematerialize"]: + from thunder.core.rematerialization import rematerialize_forward_and_backward + + return rematerialize_forward_and_backward(extrace_fw, extrace_bw) + return extrace_fw, extrace_bw diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index f63a66f5ce..bc97c3f0c9 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -238,6 +238,8 @@ def __init__( use_torchao_fp8_linear: bool = False, use_torchao_fp8_allgather: bool = False, use_torchao_fp8_precompute_scale_for_fsdp: bool = False, + autotune: str = "", + save_autotune_cfg: bool = False ): seed = 1337 torch.manual_seed(seed) @@ -358,6 +360,13 @@ def __init__( self.profiler_start = profiler_start self.profiler_stop = profiler_stop + # Autotuner + supported_autotuning = set(['runtime', 'memory', '']) + if autotune not in supported_autotuning: + raise AssertionError(f"Autotuning configuration not supported. Available ones are: {[a for a in supported_autotuning if a]}") + self.autotune_type = autotune + self.save_autotune_cfg = save_autotune_cfg + if n_layers is not None: self.config.n_layer = n_layers @@ -569,6 +578,16 @@ def setup_compile(self, model): executors.insert(0, transformer_engine_ex) + if "fa3" in self.compile: + from thunder.executors.fa3ex import fa3_ex + + executors.insert(0, fa3_ex) + + if "nvmath" in self.compile: + from thunder.executors.nvmathex import nvmath_ex + + executors.insert(0, nvmath_ex) + if "dynamo" in self.compile: if self.distributed_mode == "fsdp2": print("Resetting cache size for when fsdp2 and using thunder as backend torch.compile") @@ -595,7 +614,19 @@ def setup_compile(self, model): # so we are using the lower level torch._dynamo.optimize function model = torch._dynamo.optimize(backend=backend)(model) else: - model = thunder.jit(model, executors=executors) + if self.autotune_type: + # nvFuser compile options to be enabled if wanted with: autotune_nv_enable_options=True + model = thunder.jit( + model, + autotune_type=self.autotune_type, + executors=executors, + autotune_optimize_common_blocks=True, + autotune_optimize_common_blocks_min_size=20, # This is quite low for a traced transformer block but will do the job + autotune_save_configuration=self.save_autotune_cfg, + autotune_enable_te="transformerengine" in self.compile + ) + else: + model = thunder.jit(model, executors=executors) elif self.compile != "eager": raise ValueError(f"Invalid compile option: {self.compile}") diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py new file mode 100644 index 0000000000..b6e165cd66 --- /dev/null +++ b/thunder/benchmarks/utils.py @@ -0,0 +1,295 @@ +from collections.abc import Callable +import torch +from thunder.backend_optimizer.utils import benchmark_trace +from torch.utils.benchmark import Timer, Compare + +warm_up_iters = 50 + + +class SplitFwBwBenchmarkUtils: + """ + Represents a benchmark result container. + It should be used when a single trace region is benchmarked as it can store an optimal executor (referred to the bsym under investigation). + + Attributes: + cost: The benchmark result. Can be compute time or peak memory usage. + fw_fn: Storage for a forward trace. + bw_fn: Storage for a backward trace. + executor: An OperatorExecutor. + """ + + def __init__( + self, *, cost: float = float("inf"), fw_fn: Callable | None = None, bw_fn: Callable | None = None, executor=None + ) -> None: + self.cost: float = cost + self.fw_fn: Callable | None = fw_fn + self.bw_fn: Callable | None = bw_fn + self.executor = executor + + +def _run_loss(model, input, target, loss_fn): + logits = model(input) + logits = logits.reshape(-1, logits.size(-1)) + target = target.reshape(-1) + loss = loss_fn(logits, target) + loss.backward() + + +def _run_autograd(model, input): + y = model(input) + torch.autograd.grad(y, input, grad_outputs=torch.ones_like(y)) + + +def torch_fw_bw_benchmark_nvsight(models: list, labels: list, inputs: list, iters: int, loss) -> None: + """ + Benchmark a mock trainig loop of the given models. The loss function is defined as a naive torch.sum(). + This util will generate nvsight system profiles. + + Args: + models: a list of Callable models to benchmark. + labels: a list of labels (names) referring to the models. + inputs: a list of inputs to give to models' forward pass. + iters: benchmark iterations. + """ + + for m, input, label in zip(models, inputs, labels): + # Warm up + for _ in range(warm_up_iters): + y = m(input) + y.sum().backward() + + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.cudart().cudaProfilerStart() + for i in range(iters): + torch.cuda.empty_cache() + torch.cuda.nvtx.range_push(f"torch training {label}, iter {i}") + torch.cuda.nvtx.range_push("forward") + y = m(input) + torch.cuda.nvtx.range_pop() + loss = y.sum() + torch.cuda.nvtx.range_push("backward") + loss.backward() + torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_pop() + torch.cuda.cudart().cudaProfilerStop() + + +def torch_fw_bw_benchmark( + models: list, labels: list, inputs: list, iters: int, loss_fn: Callable | None = None +) -> None: + """ + Benchmark a mock trainig loop of the given models. Time measurements will be performed by using cuda events. + A loss function is applied to trigger backward if provided. Otherwise torch.autograd will be used. + Forward and backward pass will be recorded separately. + + Args: + models: a list of Callable models to benchmark. + labels: a list of labels (names) referring to the models. + inputs: a list of inputs to give to models' forward pass. + iters: benchmark iterations. + loss_fn: a Pytorch loss function. + """ + for m, input, label in zip(models, inputs, labels): + # Warm up + target = torch.ones_like(input) + for _ in range(warm_up_iters): + if loss_fn is not None: + _run_loss(m, input, target, loss_fn) + else: + _run_autograd(m, input) + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + stream = torch.cuda.current_stream() + max_allocated_bytes = 0 + torch.cuda.synchronize() + for i in range(iters): + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) + + start_events[i].record(stream) + y = m(input) + end_events[i].record(stream) + + max_allocated_bytes = max(max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device())) + + torch.cuda.synchronize() + tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + tot_time = sum(tot) / iters + print(f"{label} forward mean time: {tot_time} ms") + print(f"{label} peak forward allocated memory: {max_allocated_bytes / (2**30)} GB") + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + stream = torch.cuda.current_stream() + max_allocated_bytes = 0 + torch.cuda.synchronize() + for i in range(iters): + target = torch.ones_like(input) + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + + y = m(input) + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) + start_events[i].record(stream) + if loss_fn is not None: + y = y.reshape(-1, y.size(-1)) + target = target.reshape(-1) + loss = loss_fn(y, target) + loss.backward() + else: + torch.autograd.grad(y, input, grad_outputs=torch.ones_like(y)) + end_events[i].record(stream) + + max_allocated_bytes = max(max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device())) + + torch.cuda.synchronize() + tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + tot_time = sum(tot) / iters + print(f"{label} backward mean time: {tot_time} ms") + print(f"{label} peak backward allocated memory: {max_allocated_bytes / (2**30)} GB") + + +def torch_timer_total_benchmark( + models: list, labels: list, inputs: list, name: str = "Model", loss_fn: Callable | None = None +) -> None: + """ + Benchmark a mock trainig loop time of the given models. Measurements will be computed by using torch.utils.benchmark.Timer, median times will be provided. + A loss function is applied to trigger backward if provided. Otherwise torch.autograd will be used. + + Args: + models: a list of Callable models to benchmark. + labels: a list of labels (names) referring to the models. + inputs: a list of inputs to give to models' forward pass. + name: the model name + loss_fn: a Pytorch loss function + """ + results = [] + for m, l, i in zip(models, labels, inputs): + t = Timer( + stmt=""" + _run_loss(m, i, target, loss_fn) + """ + if loss_fn is not None + else """ + _run_autograd(m, i) + """, + globals={ + "i": i, + "m": m, + "target": torch.zeros_like(i), + "_run_loss": _run_loss, + "_run_autograd": _run_autograd, + "loss_fn": loss_fn, + }, + label=name, + description=l, + ) + results.append(t.blocked_autorange(min_run_time=1)) + compare = Compare(results) + compare.colorize(rowwise=True) + compare.print() + + +def torch_total_benchmark( + models: list, labels: list, inputs: list, iters: int, loss_fn: Callable | None = None +) -> None: + """ + Benchmark a mock trainig loop of the given models. Time measurements will be performed by using cuda events. + A loss function is applied to trigger backward if provided. Otherwise torch.autograd will be used. + The complete time will be recorded with no split between forward pass and backward pass. + + Args: + models: a list of Callable models to benchmark. + labels: a list of labels (names) referring to the models. + inputs: a list of inputs to give to models' forward pass. + iters: benchmark iterations. + loss_fn: a Pytorch loss function. + """ + for m, input, label in zip(models, inputs, labels): + # Warm up + target = torch.ones_like(input) + for _ in range(warm_up_iters): + if loss_fn is not None: + _run_loss(m, input, target, loss_fn) + else: + _run_autograd(m, input) + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + stream = torch.cuda.current_stream() + max_allocated_bytes = 0 + torch.cuda.synchronize() + for i in range(iters): + target = torch.ones_like(input) + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) + + start_events[i].record(stream) + y = m(input) + if loss_fn is not None: + y = y.reshape(-1, y.size(-1)) + target = target.reshape(-1) + loss = loss_fn(y, target) + loss.backward() + else: + torch.autograd.grad(y, input, grad_outputs=torch.ones_like(y)) + end_events[i].record(stream) + + max_allocated_bytes = max(max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device())) + + torch.cuda.synchronize() + tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + tot_time = sum(tot) / iters + print(f"{label} forward+backward mean time: {tot_time} ms") + print(f"{label} peak forward+backward allocated memory: {max_allocated_bytes / (2**30)} GB") + + +def thunder_fw_bw_benchmark( + fw_traces: list, bw_traces: list, fw_labels: list, bw_labels: list, iters: int, nvsight: bool = False +) -> None: + """ + Benchmark a foward and backward trace pair. + The requested inputs are TraceCtx objects. + A nvsight profile can be generate if requested. + + Args: + fw_traces: a list of TraceCtx. + bw_traces: a list of TraceCtx. + fw_labels: a list of labels (names) referring to the forward traces. + bw_labels: a list of labels (names) referring to the backward traces. + iters: benchmark iterations. + nvsight: flag to control nvsight profile generation. + """ + assert len(fw_traces) == len(bw_traces) == len(fw_labels) == len(bw_labels) + for trc, label in zip(fw_traces, fw_labels): + c, m, _ = benchmark_trace( + trc, + apply_del_last_used=False, + snapshot=True, + snapshot_name=label, + iters=iters, + nvsight=nvsight, + nvsight_fn_name=label, + ) + if not nvsight: + print(f"Executing {label} trace:\n{c} ms, {m / (2**30)} GB") + + i = 0 + for trc, label in zip(bw_traces, bw_labels): + c, m, _ = benchmark_trace( + trc, + apply_del_last_used=False, + snapshot=True, + snapshot_name=label, + iters=iters, + nvsight=nvsight, + nvsight_fn_name=label, + fw_trace=fw_traces[i], + ) + if not nvsight: + print(f"Executing {label} trace:\n{c} ms, {m / (2**30)} GB") + i += 1 diff --git a/thunder/common.py b/thunder/common.py index 168c415fbb..3a4468757d 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -71,6 +71,7 @@ class CompileStats: last_prologue_traces (Sequence[TraceCtx]): last_interpreted_instructions (Generator[dist.Instruction, None, None] | None): last_interpreter_log (list[InterpreterLogItem] | None): + last_executors (Sequence[Executor] | None): last_backward_traces (Sequence[TraceCtx]): last_trace_host_start (int): last_trace_host_stop (int): @@ -102,6 +103,7 @@ def __init__(self): self.last_prologue_traces = None self.last_interpreted_instructions: Generator[dis.Instruction, None, None] | None = None self.last_interpreter_log: list[InterpreterLogItem] | None = None + self.last_executors: Sequence[Executor] | None = None # torch.autograd.Function specific data self.last_backward_traces = None @@ -268,6 +270,8 @@ def __init__( self.additional_return_names = None self.num_constant_args = 0 + self.autotuner_bsym_with_gradfn_executor_cache: dict = {} + assert disable_preprocessing, "please use thunder.compile if you need preprocessing" diff --git a/thunder/core/codeutils.py b/thunder/core/codeutils.py index 66e674ab77..a95e242f3c 100644 --- a/thunder/core/codeutils.py +++ b/thunder/core/codeutils.py @@ -461,3 +461,17 @@ class NamedBindings: si.defaultdict = default_dict si.unwrapped_fn = unwrapped return si + + +def get_siginfo_name(trace) -> str: + try: + name = "" + if trace.fn is not None: + siginfo: SigInfo = get_siginfo(trace.fn, trace.args, trace.kwargs) + name = siginfo.name + else: + name = "unknown" + + return name + except Exception as e: + raise AssertionError(f"Is input trace an instance of TraceCtx?\n{e}") diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index d0524c87a2..c91c2f032c 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -1577,6 +1577,7 @@ def thunder_general_jit( process_group_for_ddp=process_group_for_ddp, executor_lookasides=executor_lookasides, ) + jfn = interpret( fn, fn_lookaside=general_jit_lookaside, diff --git a/thunder/core/trace.py b/thunder/core/trace.py index dd179e479a..4611060f90 100644 --- a/thunder/core/trace.py +++ b/thunder/core/trace.py @@ -340,6 +340,11 @@ def set_current_source_location(self, filename: str | None, positions: Positions self._current_source_filename = filename self._current_source_positions = positions + def signature_with_no_ctx(self) -> str: + si = self.siginfo() + signature_str = si.prettyprint(trace=self) + return signature_str + # TODO Account for multi-line signatures # TODO issue "Add type annotations to Python function produced by traces" # Consider extending the signature with type information, in particular the diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 6731d72bb7..314254563c 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -7,6 +7,7 @@ from functools import partial import thunder +from thunder.core.compile_data import get_compile_data import thunder.core.prims as prims from thunder.core.baseutils import BoundSymbolInterface from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy, unvariableify @@ -96,65 +97,76 @@ def check(inp, log_str): def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace: start_time_ns = time.perf_counter_ns() - producer_map: ProxyDict = producers(trace) + cd = get_compile_data() + disabled = not (not cd or (cd and not cd.compile_options.get("disable_dce", None))) + if not disabled: + producer_map: ProxyDict = producers(trace) - flat_trace_outputs, _ = tree_flatten(trace.output) - if needed_proxies is None: - needed_proxies: set[Variable] = set(tuple(variableify(x) for x in flat_trace_outputs if isinstance(x, Proxy))) - else: - needed_proxies.update(tuple(variableify(x) for x in flat_trace_outputs if isinstance(x, Proxy))) - dced = [] - - bsym: BoundSymbol - for bsym in reversed(trace.bound_symbols): - # Preserves symbols that should never be collected - if has_tags(bsym, {prims.OpTags.DONT_DCE}): - needed = True + flat_trace_outputs, _ = tree_flatten(trace.output) + if needed_proxies is None: + needed_proxies: set[Variable] = set( + tuple(variableify(x) for x in flat_trace_outputs if isinstance(x, Proxy)) + ) else: - needed = False + needed_proxies.update(tuple(variableify(x) for x in flat_trace_outputs if isinstance(x, Proxy))) + dced = [] - # NOTE This block is run even if we know we're preserving the operation, because it - # may mark some of the operation's outputs as unused - some_unused = False - for out in bsym.flat_proxy_outs: - if variableify(out) in needed_proxies and producer_map[out] == bsym: + bsym: BoundSymbol + for bsym in reversed(trace.bound_symbols): + # Preserves symbols that should never be collected + if has_tags(bsym, {prims.OpTags.DONT_DCE}): needed = True else: - some_unused = True - - if needed: - nbsym: BoundSymbol = bsym - - # Replaces unused Proxy outputs with None - if some_unused: - - def _helper(x): - if isinstance(x, Proxy) and (variableify(x) not in needed_proxies or producer_map[x] != bsym): - return None - return x - - nbsym_output = tree_map(_helper, bsym.output) - nbsym = bsym.from_bsym(output=nbsym_output) - - # Eliminates no-op subsymbols - # NOTE In general editing subsymbols doesn't do anything, but no-op subsymbols are a pain - # for transforms to deal with. Transforms typically look for a "flattened" version of an - # operator for which they can apply their rules, and no-op subsymbols have no - # flattening, requiring each transform handle them explicitly or DCE them themselves - # while flattening. - _remove_noop_subsymbols(nbsym) - - dced.append(nbsym) - for x in nbsym.flat_proxy_args: - needed_proxies.add(variableify(x)) - - dcetrace = from_trace(trace) - dcetrace.bound_symbols = list(reversed(dced)) + needed = False + + # NOTE This block is run even if we know we're preserving the operation, because it + # may mark some of the operation's outputs as unused + some_unused = False + for out in bsym.flat_proxy_outs: + if variableify(out) in needed_proxies and producer_map[out] == bsym: + needed = True + else: + some_unused = True + + if needed: + nbsym: BoundSymbol = bsym + + # Replaces unused Proxy outputs with None + if some_unused: + + def _helper(x): + if isinstance(x, Proxy) and (variableify(x) not in needed_proxies or producer_map[x] != bsym): + return None + return x + + nbsym_output = tree_map(_helper, bsym.output) + nbsym = bsym.from_bsym(output=nbsym_output) + + # Eliminates no-op subsymbols + # NOTE In general editing subsymbols doesn't do anything, but no-op subsymbols are a pain + # for transforms to deal with. Transforms typically look for a "flattened" version of an + # operator for which they can apply their rules, and no-op subsymbols have no + # flattening, requiring each transform handle them explicitly or DCE them themselves + # while flattening. + _remove_noop_subsymbols(nbsym) + + dced.append(nbsym) + for x in nbsym.flat_proxy_args: + needed_proxies.add(variableify(x)) + + dcetrace = from_trace(trace) + dcetrace.bound_symbols = list(reversed(dced)) + else: + dcetrace = trace end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns elapsed_time_millis = elapsed_time_ns // 1000000 - dcetrace.set_provenance(TraceProvenance(f"Dead Code Elimination (took {elapsed_time_millis} milliseconds)")) + dcetrace.set_provenance( + TraceProvenance( + f"Dead Code Elimination{' Skipped Per Compile Options' if disabled else ''} (took {elapsed_time_millis} milliseconds)" + ) + ) return dcetrace diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index dafd882d0b..82e7c808c9 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -352,6 +352,66 @@ class VISIT_TYPE(Enum): NO_OP = auto() +# Creates a new trace from "trace_from" by calling "visit" on its bound symbols ("bsyms") paired with an assigned executor. +# visit(bsym: BoundSymbolInterface, ex: Executor) -> VISIT_TYPE should call operations +# as if executing a program, and those operations will be recorded into the +# new trace. +# If visit() returns INSERT_AFTER for a bsym then that bsym will be copied +# to the new trace before visit() is called. This is useful when augmenting the bound +# symbols in an existing trace. +# If visit() returns INSERT_BEFORE for a bsym then that bsym will be copied to the new trace +# after visit() is called. This is also useful when augmenting the bound symbols in an existing +# trace. +# If visit() returns REPLACE for a bsym then that bsym will not be copied to the new trace. +# TODO Suggest a mechanism to preserve the original bound symbol with operations +# recorded both before and after it. This could be done by passing the (sub)scope to visit() for +# direct modification, acquiring the trace's current scope through the trace ctx and modifying it +# directly (this can be done today), or adding a record() function that is a sugar for the previous +# approach. Perhaps both passing the scope directly to visit() and adding record() would be helpful. +# TODO(crcrpar): Think about providing a guide how to let thunder "claim" if this is called after +# `thunder.executors.transform_for_execution`. +def visitor_transform_paired(trace_from: Trace, visit: Callable, zipped: zip, *, provenance: None | str = None): + trc: Trace = from_trace(trace_from) + + try: + tracectx_tok = set_tracectx(trc) + + for bsym, ex in zipped: + try: + # Creates a temporary scope to support copying the original bsym BEFORE + # the operations performed by visit(), even though this doesn't know whether to + # copy the original bsym until after visit() completes + old_scope = trc.scopes + scope = [] + trc.scopes = [scope] + + # This can be simpler? We currently trigger all the flow for the substitution + visit_type = visit(bsym, ex) + + if visit_type is VISIT_TYPE.INSERT_AFTER: + trc.bound_symbols.append(bsym) + + if visit_type is not VISIT_TYPE.NO_OP: + trc.bound_symbols.extend(scope) + else: + trc.bound_symbols.append(bsym) + + if visit_type is VISIT_TYPE.INSERT_BEFORE: + trc.bound_symbols.append(bsym) + + finally: + # Restores the trc's scope + trc.scopes = old_scope + + if provenance is not None: + trc.set_provenance(TraceProvenance(provenance)) + + return trc + + finally: + reset_tracectx(tracectx_tok) + + # Creates a new trace from "trace_from" by calling "visit" on its bound symbols ("bsyms"). # visit(bsym: BoundSymbolInterface) -> VISIT_TYPE should call operations # as if executing a program, and those operations will be recorded into the diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index ad44218028..1131a69bf9 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -4,13 +4,16 @@ from inspect import Parameter, Signature from itertools import chain +from thunder.backend_optimizer.utils import symbol_hash from thunder.core import prims, utils +from thunder.core.compile_data import get_compile_data from thunder.core.prims import PrimIDs from thunder.core.proxies import Proxy, variableify, TensorProxy from thunder.core.pytree import tree_flatten, tree_map from thunder.core.symbol import BoundSymbol from thunder.core.trace import from_trace, TraceCtx from thunder.core.transform_common import dce +from thunder.extend import Executor _cache = {} @@ -49,146 +52,258 @@ def make_aug_forward_and_backward(bsym: BoundSymbol) -> tuple[Callable, Callable from thunder.common import _make_cache_key from thunder.core.transforms import _get_gradfn_and_executor, eval_trace - joint_forward_backward, executor = _get_gradfn_and_executor(bsym) - utils.check( - joint_forward_backward is not None, - lambda: f"Cannot generate forward and backward functions for {bsym.sym.name}", - ) - key = (bsym.sym, executor, subkey := _make_cache_key(bsym.args, bsym.kwargs)) - cached_result = _cache.get(key, None) if subkey is not None else None - if cached_result is not None and not getattr(joint_forward_backward, "_disable_caching", False): - return cached_result - - joint_trace = thunder.trace(inline_trace=False, use_dce=False)(joint_forward_backward, *bsym.args, **bsym.kwargs) - consumers = utils.consumers(joint_trace) - - def find_backward_input(forward_output): - output_consumers = consumers.get(forward_output, None) - if output_consumers is None or not output_consumers: - return None - get_grad_bsym = next( - filter(lambda bsym: bsym.sym.id == PrimIDs.GET_GRAD, output_consumers), - None, + def _make_aug_forward_and_backward(*, return_traces=False, update_cache=True) -> tuple: + joint_forward_backward, executor = _get_gradfn_and_executor(bsym) + utils.check( + joint_forward_backward is not None, + lambda: f"Cannot generate forward and backward functions for {bsym.sym.name}", ) - return get_grad_bsym.output if get_grad_bsym is not None else None - - def find_backward_output(forward_input): - forward_input_consumers = consumers.get(forward_input, None) - if forward_input_consumers is None or not forward_input_consumers: - return None - put_grad_bsym = next( - filter(lambda bsym: bsym.sym.id == PrimIDs.PUT_GRAD, forward_input_consumers), - None, + key = (bsym.sym, executor, subkey := _make_cache_key(bsym.args, bsym.kwargs)) + # If we update the cache we are not using the autotuner hence cache values for the key entry generated above is valid. + # If autotuner is used, each bsym has an unique key id hence this cache entry is not valid anymore. + if update_cache: + cached_result = _cache.get(key, None) if subkey is not None else None + if cached_result is not None and not getattr(joint_forward_backward, "_disable_caching", False): + return cached_result + + joint_trace = thunder.trace(inline_trace=False, use_dce=False)( + joint_forward_backward, *bsym.args, **bsym.kwargs ) - return put_grad_bsym.args[1] if put_grad_bsym is not None else None - - bw_inputs = tree_map(find_backward_input, utils.sequencify(joint_trace.output)) - bw_outputs_args = tree_map(find_backward_output, joint_trace.args) - bw_outputs_kwargs = tree_map(find_backward_output, joint_trace.kwargs) - meta_parameters = inspect.signature(bsym.sym.meta).parameters - meta_parameters = { - name: param - for name, param in meta_parameters.items() - if param.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.POSITIONAL_ONLY) - } - bw_outputs = {name: bw_output for name, bw_output in utils.safe_zip(meta_parameters, bw_outputs_args)} - bw_outputs = bw_outputs | bw_outputs_kwargs - flat_bw_outputs, _ = tree_flatten(bw_outputs) - - backward_bsyms = utils.find_producer_symbols(joint_trace, flat_bw_outputs, tree_flatten(bw_inputs)[0]) - skip = ( - prims.PrimIDs.UNPACK_EMPTY_DICT, - prims.PrimIDs.UNPACK_KEY, - prims.PrimIDs.UNPACK_SEQUENCE, - prims.PrimIDs.UNPACK_TRIVIAL, - prims.PrimIDs.GET_GRAD, - ) - backward_bsyms = [bsym for bsym in backward_bsyms if bsym.sym.id not in skip] - backward_bsyms.append(prims.python_return.bind(bw_outputs, output=())) - - forward_input_proxies = tree_flatten((joint_trace.args, joint_trace.kwargs))[0] - forward_input_proxies = [arg for arg in forward_input_proxies if isinstance(arg, Proxy)] - forward_bsyms = utils.find_producer_symbols(joint_trace, tree_flatten(joint_trace.output)[0], forward_input_proxies) - backward_bsyms = [bsym for bsym in backward_bsyms if bsym not in forward_bsyms] - - # Find required info from forward trace for backward trace - backward_producers = utils.producers(backward_bsyms) - saved_for_backward = [] - for backward_bsym in backward_bsyms: - for arg in backward_bsym.flat_args: - if not isinstance(arg, Proxy): - continue - if arg not in backward_producers and variableify(arg) not in map(variableify, tree_flatten(bw_inputs)[0]): - saved_for_backward.append(arg) - - saved_for_backward = list({variableify(arg): arg for arg in saved_for_backward}.values()) - - # Augment forward trace to include saved_for_backward as output - augmented_forward_trace = from_trace(joint_trace) - augmented_forward_trace.bound_symbols = [ - b for b in joint_trace.bound_symbols if b.sym.id not in (PrimIDs.PUT_GRAD, PrimIDs.GET_GRAD) - ] - return_bsym = augmented_forward_trace.bound_symbols[-1] - assert return_bsym.sym.id == PrimIDs.RETURN - augmented_forward_trace.bound_symbols[-1] = prims.python_return.bind( - (joint_trace.output, saved_for_backward), output=() - ) - # Remove put/get grad and backward symbols from augmented forward trace - augmented_forward_trace = dce(augmented_forward_trace) - - # Check that the number of outputs of the original forward function is the - # same as the number of primal outputs of the augmented forward trace - utils.check( - len(utils.sequencify(bsym.output)) == len(utils.sequencify(augmented_forward_trace.output[0])), - lambda: f"While generating forward and backward functions for {bsym.sym.name}, encountered an error.\n" - "The number of outputs of the original forward function must be the same as the number of primal outputs of the augmented forward trace.\n" - f"Number of outputs of the original forward function: {len(utils.sequencify(bsym.output))}\n" - f"Number of primal outputs of the augmented forward trace: {len(utils.sequencify(augmented_forward_trace.output[0]))}\n" - "Please check the forward function and the augmented forward trace to ensure that they have the same number of outputs.", - ) + consumers = utils.consumers(joint_trace) + + def find_backward_input(forward_output): + output_consumers = consumers.get(forward_output, None) + if output_consumers is None or not output_consumers: + return None + get_grad_bsym = next( + filter(lambda bsym: bsym.sym.id == PrimIDs.GET_GRAD, output_consumers), + None, + ) + return get_grad_bsym.output if get_grad_bsym is not None else None + + def find_backward_output(forward_input): + forward_input_consumers = consumers.get(forward_input, None) + if forward_input_consumers is None or not forward_input_consumers: + return None + put_grad_bsym = next( + filter(lambda bsym: bsym.sym.id == PrimIDs.PUT_GRAD, forward_input_consumers), + None, + ) + return put_grad_bsym.args[1] if put_grad_bsym is not None else None + + bw_inputs = tree_map(find_backward_input, utils.sequencify(joint_trace.output)) + bw_outputs_args = tree_map(find_backward_output, joint_trace.args) + bw_outputs_kwargs = tree_map(find_backward_output, joint_trace.kwargs) + meta_parameters = inspect.signature(bsym.sym.meta).parameters + meta_parameters = { + name: param + for name, param in meta_parameters.items() + if param.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.POSITIONAL_ONLY) + } + bw_outputs = {name: bw_output for name, bw_output in utils.safe_zip(meta_parameters, bw_outputs_args)} + bw_outputs = bw_outputs | bw_outputs_kwargs + flat_bw_outputs, _ = tree_flatten(bw_outputs) + + backward_bsyms = utils.find_producer_symbols(joint_trace, flat_bw_outputs, tree_flatten(bw_inputs)[0]) + skip = ( + prims.PrimIDs.UNPACK_EMPTY_DICT, + prims.PrimIDs.UNPACK_KEY, + prims.PrimIDs.UNPACK_SEQUENCE, + prims.PrimIDs.UNPACK_TRIVIAL, + prims.PrimIDs.GET_GRAD, + ) + backward_bsyms = [bsym for bsym in backward_bsyms if bsym.sym.id not in skip] + backward_bsyms.append(prims.python_return.bind(bw_outputs, output=())) - # Check if any of the bound symbols in the backward trace are also in the - # augmented forward trace - # If so, remove them from the backward trace - same_bsyms = set(augmented_forward_trace.bound_symbols) & set(backward_bsyms) - if same_bsyms: - backward_bsyms = [bsym for bsym in backward_bsyms if bsym not in same_bsyms] - additional_saved = [o for bsym in same_bsyms for o in bsym.flat_proxy_outs] - saved_for_backward += list({variableify(arg): arg for arg in additional_saved}.values()) + forward_input_proxies = tree_flatten((joint_trace.args, joint_trace.kwargs))[0] + forward_input_proxies = [arg for arg in forward_input_proxies if isinstance(arg, Proxy)] + forward_bsyms = utils.find_producer_symbols( + joint_trace, tree_flatten(joint_trace.output)[0], forward_input_proxies + ) + backward_bsyms = [bsym for bsym in backward_bsyms if bsym not in forward_bsyms] + + # Find required info from forward trace for backward trace + backward_producers = utils.producers(backward_bsyms) + saved_for_backward = [] + for backward_bsym in backward_bsyms: + for arg in backward_bsym.flat_args: + if not isinstance(arg, Proxy): + continue + if arg not in backward_producers and variableify(arg) not in map( + variableify, tree_flatten(bw_inputs)[0] + ): + saved_for_backward.append(arg) + + saved_for_backward = list({variableify(arg): arg for arg in saved_for_backward}.values()) + + # Augment forward trace to include saved_for_backward as output + augmented_forward_trace = from_trace(joint_trace) + augmented_forward_trace.bound_symbols = [ + b for b in joint_trace.bound_symbols if b.sym.id not in (PrimIDs.PUT_GRAD, PrimIDs.GET_GRAD) + ] + return_bsym = augmented_forward_trace.bound_symbols[-1] + assert return_bsym.sym.id == PrimIDs.RETURN augmented_forward_trace.bound_symbols[-1] = prims.python_return.bind( (joint_trace.output, saved_for_backward), output=() ) + # Remove put/get grad and backward symbols from augmented forward trace + augmented_forward_trace = dce(augmented_forward_trace) + + # Check that the number of outputs of the original forward function is the + # same as the number of primal outputs of the augmented forward trace + utils.check( + len(utils.sequencify(bsym.output)) == len(utils.sequencify(augmented_forward_trace.output[0])), + lambda: f"While generating forward and backward functions for {bsym.sym.name}, encountered an error.\n" + "The number of outputs of the original forward function must be the same as the number of primal outputs of the augmented forward trace.\n" + f"Number of outputs of the original forward function: {len(utils.sequencify(bsym.output))}\n" + f"Number of primal outputs of the augmented forward trace: {len(utils.sequencify(augmented_forward_trace.output[0]))}\n" + "Please check the forward function and the augmented forward trace to ensure that they have the same number of outputs.", + ) - backward_params = [ - Parameter(getattr(x, "name", f"arg{i}"), Parameter.POSITIONAL_OR_KEYWORD) - for i, x in enumerate(chain(saved_for_backward, bw_inputs)) - ] - backward_signature = Signature(backward_params) - - def backward_fn(): - pass - - backward_fn.__signature__ = backward_signature - backward_fn.__name__ = bsym.sym.name + "_backward" - - # Finally, build the backward trace - backward_trace = TraceCtx(backward_fn) - backward_trace.args = (*saved_for_backward, *bw_inputs) - backward_trace.kwargs = {} - backward_trace.bound_symbols = backward_bsyms - - # Creating new functions instead of using partial to avoid limitations in - # codeutils.get_siginfo - # https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/core/codeutils.py#L349-L353 - def fw_fn(*args, **kwargs): - return eval_trace(augmented_forward_trace, *args, **kwargs) - - def bw_fn(*args, **kwargs): - return eval_trace(backward_trace, *args, **kwargs) - - _cache[key] = fw_fn, bw_fn - - return fw_fn, bw_fn + # Check if any of the bound symbols in the backward trace are also in the + # augmented forward trace + # If so, remove them from the backward trace + same_bsyms = set(augmented_forward_trace.bound_symbols) & set(backward_bsyms) + if same_bsyms: + backward_bsyms = [bsym for bsym in backward_bsyms if bsym not in same_bsyms] + additional_saved = [o for bsym in same_bsyms for o in bsym.flat_proxy_outs] + saved_for_backward += list({variableify(arg): arg for arg in additional_saved}.values()) + augmented_forward_trace.bound_symbols[-1] = prims.python_return.bind( + (joint_trace.output, saved_for_backward), output=() + ) + + backward_params = [ + Parameter(getattr(x, "name", f"arg{i}"), Parameter.POSITIONAL_OR_KEYWORD) + for i, x in enumerate(chain(saved_for_backward, bw_inputs)) + ] + backward_signature = Signature(backward_params) + + def backward_fn(): + pass + + backward_fn.__signature__ = backward_signature + backward_fn.__name__ = bsym.sym.name + "_backward" + + # Finally, build the backward trace + backward_trace = TraceCtx(backward_fn) + backward_trace.args = (*saved_for_backward, *bw_inputs) + backward_trace.kwargs = {} + backward_trace.bound_symbols = backward_bsyms + + # Creating new functions instead of using partial to avoid limitations in + # codeutils.get_siginfo + # https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/core/codeutils.py#L349-L353 + def fw_fn(*args, **kwargs): + return eval_trace(augmented_forward_trace, *args, **kwargs) + + def bw_fn(*args, **kwargs): + return eval_trace(backward_trace, *args, **kwargs) + + if update_cache: + _cache[key] = fw_fn, bw_fn + + if not return_traces: + return fw_fn, bw_fn + return fw_fn, bw_fn, augmented_forward_trace, backward_trace + + cd = get_compile_data() + # No autotuning + if not cd or not cd.compile_options.get("autotune_type", None): + return _make_aug_forward_and_backward() + + # This search will be performed on the requested executors list + is_backend_available: bool = _get_gradfn_and_executor(bsym)[1] is not None + if not is_backend_available: + key = (bsym.sym, None, subkey := _make_cache_key(bsym.args, bsym.kwargs)) + # Cached will be checked in the inner fn if not miss + fw_fn, bw_fn = _make_aug_forward_and_backward() + return fw_fn, bw_fn + # We have a backend + else: + from thunder.backend_optimizer.optimizer import OptimizerType + from thunder.backend_optimizer.optimizer import logger + from thunder.backend_optimizer.utils import benchmark_trace + from thunder.backend_optimizer.utils import get_fw_bw_split_backends_options + from thunder.benchmarks.utils import SplitFwBwBenchmarkUtils + + # In order define this unique trace region we need an unique id + key = (bsym.sym, Executor(f"{id(bsym)}-autotuned"), subkey := _make_cache_key(bsym.args, bsym.kwargs)) + # We do check the cache here as the key in the inner fn does not know about this special id + cached_result = _cache.get(key, None) if subkey is not None else None + # NOTE: cache is always enabled here + if cached_result is not None: + return cached_result + + # Get the possible backends for the current bsym + backends = get_fw_bw_split_backends_options( + bsym, autotune_enable_te=cd.compile_options.get("autotune_enable_te", False) + ) + if not backends: + raise AssertionError( + f"No enabled backends found for {bsym.sym.name} but an executor for that symbol it is present in the executors list. Either remove that from the executors list or enable at least one backend for {bsym.sym.name} inside 'get_fw_bw_split_backends_options'." + ) + + cached_executors_list = list(cd.executors_list) + # Retrieve all the executors which are requested to be used + requested_executors_list_for_bsym = [ex for ex in cached_executors_list if ex in backends] + + best = SplitFwBwBenchmarkUtils() + + # Do we have a common transformer block optimization enabled? + # If yes we have to restrict the same executor on every bsym + # in the transformer block (e.g. every scaled_dot_product in every transformer block will have the same executor + # as they are expected to work on same input size, shape and dtype). + optmimizer_common_transformer_block = cd.compile_options.get('autotune_optimize_common_blocks', False) + # The generated hash will rely on the operation, input args metadata and output metadata + h = symbol_hash(bsym=bsym) + # Recover the cache stored in the compile data + autotuner_bsym_with_gradfn_executor_cache = cd.autotuner_bsym_with_gradfn_executor_cache + + # Run the search only if not already visited before + if h in autotuner_bsym_with_gradfn_executor_cache and optmimizer_common_transformer_block: + best = autotuner_bsym_with_gradfn_executor_cache[h] + else: + # Restrict the search space + backends = list(requested_executors_list_for_bsym) + + logger.info(f"Search space for bsym {bsym.sym.name}: {backends}") + for b in backends: + logger.info(f"Benchmarking executor {b.name} for {bsym.sym.name}") + # Let downstream fn to pick up this + requested_executors_list_for_bsym.remove(b) + requested_executors_list_for_bsym.insert(0, b) + cd.executors_list = requested_executors_list_for_bsym + fw_fn, bw_fn, fw_trace, bw_trace = _make_aug_forward_and_backward(return_traces=True, update_cache=False) + # What should be the optimal iter? + # TODO: make benchmark info taken from an autotuner config + fw_time, fw_mem, _ = benchmark_trace(fw_trace, iters=100, apply_del_last_used=False) + bw_time, bw_mem, _ = benchmark_trace(bw_trace, iters=100, apply_del_last_used=False, fw_trace=fw_trace) + cost = ( + fw_time + bw_time if cd.compile_options["autotune_type"] == OptimizerType.RUNTIME else fw_mem + bw_mem + ) + if cost < best.cost: + best = SplitFwBwBenchmarkUtils(cost=cost, fw_fn=fw_fn, bw_fn=bw_fn, executor=b) + + assert best.cost != float("inf") + + logger.info(f"Best executor for symbol [{bsym.output.name} = {bsym.sym.name}]: {best.executor.name}") + + # Cache the bsym result for common trace's common block reductions + # At this stage we are tuning trace regions for these symbols name: linear and scaled_dot_product_attention + if bsym.sym.name in get_fw_bw_split_backends_options().keys() and optmimizer_common_transformer_block: + autotuner_bsym_with_gradfn_executor_cache[h] = best + + # Update the compile options + cd.compile_options["autotune_executors_placed_by_fw_bw_split"].add(best.executor) + from thunder.executors.transformer_engineex import transformer_engine_ex + + cd.compile_options |= {"te_used": True if best.executor == transformer_engine_ex else False} + # Restore executor list for downstream optimizations + cd.executors_list = cached_executors_list + # The executors used in this pass will be updated after the termination of the forward_and_backward_from_trace call + + _cache[key] = best.fw_fn, best.bw_fn + return best.fw_fn, best.bw_fn def get_saved_for_backward_tensors(trace: TraceCtx) -> tuple[TensorProxy]: diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 4e43a0b57a..55dcdd2c96 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -823,9 +823,7 @@ def _can_fuse_node(n: Node): bound_symbol_groups = fuse_bound_symbols(trace, _should_fuse) - # Counts how many fusions (per executor) have been constructed - # (Used to name fusions like nvFusion0, nvFusion1, ...) - fusion_counter: int = 0 + fusion_counter = 0 for bsyms in bound_symbol_groups: # TODO The following allows generating single node fusions, which # may be suboptimal for real-world performance. diff --git a/thunder/executors/nvmathex.py b/thunder/executors/nvmathex.py new file mode 100644 index 0000000000..388d716736 --- /dev/null +++ b/thunder/executors/nvmathex.py @@ -0,0 +1,68 @@ +from importlib.metadata import version +from thunder.core.prims import PrimIDs +import logging +import thunder +import thunder.torch as ltorch +import torch + +try: + import nvmath + HAS_NVMATH = True + version = version('nvmath-python') +except: + pass + HAS_NVMATH = False + version = None + +logger = logging.getLogger("Thunder nvmath_ex") +# Disable nvmath logs +logger.disabled = True + +nvmath_ex = thunder.extend.OperatorExecutor("nvmath", version=version) +thunder.extend.register_executor(nvmath_ex) + +_cache = {} + +def _cache_key(a: torch.Tensor, b: torch.Tensor) -> str: + def _get_shape_str(t: tuple): + return '_'.join(str(num) for num in t) + + return f'{_get_shape_str(a.size())}-{_get_shape_str(b.size())}' + +def _nvmath_linalg_advanced_matmul_impl(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + options = nvmath.linalg.advanced.MatmulOptions(logger=logger) + # Check if these shapes have been cached + k = _cache_key(a, b) + if k in _cache: + algo = _cache[k] + with nvmath.linalg.advanced.Matmul(a, b, options=options) as mm: + # Provide the optimized algorithms directly to plan. + mm.plan(algorithms=algo) + # Execute the multiplication + return mm.execute() + + # Compute a new shape and cache the result + with nvmath.linalg.advanced.Matmul(a, b, options=options) as mm: + preferences = nvmath.linalg.advanced.MatmulPlanPreferences(limit=25) + mm.plan(preferences=preferences) + mm.autotune(iterations=10) + # Execute the multiplication + result = mm.execute() + _cache[k] = mm.algorithms + return result + +def _nvmath_linalg_advanced_matmul_checker(*args, **kwargs) -> bool: + return HAS_NVMATH + +nvmath_linalg_advanced_matmul = nvmath_ex.register_operator( + "nvmath_linalg_advanced_matmul", + like=ltorch.matmul, + fn=_nvmath_linalg_advanced_matmul_impl, +) +nvmath_ex.register_implementation( + ltorch.matmul, nvmath_linalg_advanced_matmul, checker=_nvmath_linalg_advanced_matmul_checker +) + +nvmath_ex.register_implementation( + PrimIDs.MATMUL, nvmath_linalg_advanced_matmul, checker=_nvmath_linalg_advanced_matmul_checker +) diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index 77c6a3d8f4..20633f9928 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -1,3 +1,4 @@ +from pprint import pprint from typing import Dict, Any, List, Tuple, Optional from collections.abc import Callable from collections.abc import Sequence @@ -7,20 +8,23 @@ from functools import partial import time -from thunder.core.trace import TraceCtx, from_trace, TraceProvenance, VariableInterface +from thunder.core.compile_data import get_compile_data +from thunder.core.trace import TraceCtx, from_trace, TraceProvenance, VariableInterface, reset_tracectx, set_tracectx +from thunder.core.codeutils import SigInfo import thunder.core.dtypes as dtypes import thunder.core.utils as cutils from thunder.core.utils import ProxyDict, check, safe_map_flat from thunder.core.symbol import BoundSymbol from thunder.core.pytree import tree_flatten, tree_unflatten, tree_map import thunder.core.prims as prims -from thunder.core.proxies import Proxy, variableify, unvariableify, Variable, CollectionProxy +from thunder.core.proxies import Proxy, TensorProxy, variableify, unvariableify, Variable, CollectionProxy import thunder.core.transforms as transforms from thunder.core.transform_common import dce from thunder.core.trace import get_tracectx from thunder.executors.pythonex import clear_mutable_collection -from thunder.extend import Executor, get_always_executors, OperatorExecutor, FusionExecutor +from thunder.extend import Executor, get_all_executors, get_always_executors, OperatorExecutor, FusionExecutor +from thunder.backend_optimizer.optimizer import BackendOptimizer, OptimizerType, TraceCandidates, TraceType comment_symbols = {prims.PrimIDs.COMMENT, prims.PrimIDs.UNPACK_TRIVIAL} @@ -69,6 +73,7 @@ def visit_helper_(bsym: BoundSymbol) -> None | bool: for ex in executors_list: # TODO Consider allowing operator executors to claim portions of operations # TODO Should FusionExecutors be allowed to claim bsym with bsym.sym.executor? + if (isinstance(ex, OperatorExecutor) and ex.can_execute(bsym)) or ( isinstance(ex, FusionExecutor) and ex.can_fuse(bsym) ): @@ -91,6 +96,7 @@ def visit_helper_(bsym: BoundSymbol) -> None | bool: raise AssertionError("Unknown executor") safe_map_flat(update_swapmap, bsym.output, out) + return True if bsym.sym.executor is not None: @@ -133,6 +139,71 @@ def visit_(bsym: BoundSymbol) -> transforms.VISIT_TYPE: return extrace +# Autotuned transform_for_execution version +def autotune_transform_for_execution( + *, optimizer_context: BackendOptimizer, trace: TraceCtx, trace_type: TraceType, is_computational: bool = False +) -> tuple[TraceCtx, TraceCtx] | TraceCtx | None: + import torch + + start_time_ns = time.perf_counter_ns() + + # Recover the function name + sig_name = cutils.get_siginfo_name(trace) + + if torch.distributed.is_available(): + # Apply AllReduce bucketing if possible & needed + from thunder.distributed.transforms.ddp import apply_bucketing_to_grad_allreduce + + trace = apply_bucketing_to_grad_allreduce(trace) + + # Attach new trace and set the debug file name + optimizer_context.attach_trace(trace=trace, trace_type=trace_type, apply_dce=trace_type == TraceType.FW) + optimizer_context.log_file_name = f"autotune_transform_for_execution_{sig_name}.log" + # Forward traces are cached inside the context + optimizer_context.optimize() + + # Retrive the optimized traces. If backward trace is requested then the forward trace will be given only together with the backward one. + # This is because the optimal forward does not always lead to an optimal backward. + # If this is a computational trace (no autograd) then the forward (computational) trace will be ready and returned. + match trace_type: + case TraceType.FW: + if not is_computational: + pass + else: + fw_trace: TraceCtx = optimizer_context.get_optimal_fw_traces(is_computational) + # When optimizing the backward pass, the optimizer will return the best fw and bw traces based on the requested autotune_type, no need to choose the fw pass manually + case TraceType.BW: + fw_extrace, bw_extrace = optimizer_context.get_optimal_fw_bw_traces() + + end_time_ns = time.perf_counter_ns() + elapsed_time_ns = end_time_ns - start_time_ns + elapsed_time_millis = elapsed_time_ns // 1000000 + + # Assign the trace provenance + match trace_type: + case TraceType.FW: + if not is_computational: + cd = get_compile_data() + # Only for fresh tuning + if not cd or not cd.compile_options.get('autotune_restore_configuration', ""): + # We are assigning the provenance to all the possible candidates as at this stage we + # don't know which trace will be returned at the end of the optimization + fw_traces: list = optimizer_context.get_optimal_fw_traces() + for trc in fw_traces: + trc.set_provenance( + TraceProvenance(f"Autotuned transform for execution (took {elapsed_time_millis} milliseconds)") + ) + return None + else: + fw_trace.set_provenance(TraceProvenance(f"Autotuned transform for execution (took {elapsed_time_millis} milliseconds)")) + return fw_trace + case TraceType.BW: + bw_extrace.set_provenance( + TraceProvenance(f"Autotuned transform for execution (took {elapsed_time_millis} milliseconds)") + ) + return fw_extrace, bw_extrace + + def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor]) -> TraceCtx: import torch @@ -297,21 +368,32 @@ def del_last_used(trace: TraceCtx, *, clear_mutable_collections=False) -> TraceC Returns: list: transformed trace """ - start_time_ns = time.perf_counter_ns() - del_trace = from_trace(trace) + # If dce is disabled, we have to disable this pass also + cd = get_compile_data() + disabled = not (not cd or (cd and not cd.compile_options.get("disable_dce", None))) - outs = cutils.sequencify(trace.output) - flat_outs, _ = tree_flatten(outs) + start_time_ns = time.perf_counter_ns() - del_trace.bound_symbols = _del_last_used( - trace.bound_symbols, flat_outs, clear_mutable_collections=clear_mutable_collections - ) + if not disabled: + del_trace = from_trace(trace) + outs = cutils.sequencify(trace.output) + flat_outs, _ = tree_flatten(outs) + + del_trace.bound_symbols = _del_last_used( + trace.bound_symbols, flat_outs, clear_mutable_collections=clear_mutable_collections + ) + else: + del_trace = trace end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns elapsed_time_millis = elapsed_time_ns // 1000000 - del_trace.set_provenance(TraceProvenance(f"Delete Last Used (took {elapsed_time_millis} milliseconds)")) + del_trace.set_provenance( + TraceProvenance( + f"Delete Last Used{' Skipped Per Compile Options' if disabled else ''} (took {elapsed_time_millis} milliseconds)" + ) + ) return del_trace diff --git a/thunder/executors/sdpaex.py b/thunder/executors/sdpaex.py index b912acaed9..dc34db7483 100644 --- a/thunder/executors/sdpaex.py +++ b/thunder/executors/sdpaex.py @@ -321,6 +321,7 @@ def _scaled_dot_product_efficient_attention_backward_impl( ) +sdpaex_scaled_dot_product_efficient_attention_backward_name = "sdpaex_scaled_dot_product_efficient_attention_backward" sdpea_bwd = sdpa_ex.register_operator( "sdpaex_scaled_dot_product_efficient_attention_backward", meta=_scaled_dot_product_efficient_attention_backward_meta, @@ -401,8 +402,9 @@ def _scaled_dot_product_flash_attention_backward_impl( return (_sdpa_slice_head_dimension(g, value.shape[-1]) for g in grads) +sdpafx_scaled_dot_product_efficient_attention_backward_name = "sdpafx_scaled_dot_product_efficient_attention_backward" sdpfa_bwd = sdpa_ex.register_operator( - "sdpafx_scaled_dot_product_efficient_attention_backward", + sdpafx_scaled_dot_product_efficient_attention_backward_name, meta=_scaled_dot_product_flash_attention_backward_meta, fn=_scaled_dot_product_flash_attention_backward_impl, ) diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 336c946515..756df1477b 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -107,12 +107,44 @@ def backward(ctx, *args): return (None, None, None, None, None, *([None] * n_grads)) +def update_bw_from_forward_optimization(*, fw: TraceCtx, bw: TraceCtx) -> TraceCtx: + # Some of the optimization passes change proxies in the trace and + # any change in the forward trace must be reflected in the backward + # trace. + original_bw_saved_tensors_for_backward = bw.args[0][0] + new_fw_saved_tensors_for_backward = get_saved_for_backward_tensors(fw) + swap_map = { + variableify(x): y + for x, y in zip(original_bw_saved_tensors_for_backward, new_fw_saved_tensors_for_backward) + if variableify(x) != variableify(y) + } + new_bsyms = replace_redundant_inputs(swap_map, bw.bound_symbols) + # replace_redundant_inputs doesn't replace the output of + # UNPACK_SEQUENCE so we do it manually. Here we have certain + # assumptions about the structure of the backward trace. + assert bw.bound_symbols[0].sym.id == PrimIDs.UNPACK_TRIVIAL + assert bw.bound_symbols[0].kwargs["name"] == "saved_for_backward" + assert bw.bound_symbols[4].sym.id == PrimIDs.UNPACK_SEQUENCE + assert bw.bound_symbols[4].args[0].name == "C0" + new_bsyms[4] = new_bsyms[4].from_bsym_swap_proxies( + swap_map, + skip_inputs=False, + skip_output=False, + skip_subsymbols=False, + ) + bw.bound_symbols = new_bsyms + + return bw + + def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stats, /, *flat_args): + from thunder.backend_optimizer.optimizer import TraceType, BackendOptimizer + from thunder.backend_optimizer.utils import update_compile_options_executor_list_after_fw_bw_split from thunder.core.rematerialization import rematerialize_all_gather, rematerialize_forward_and_backward from thunder.core.transforms import forward_and_backward_from_trace from thunder.distributed.transforms import FSDPCommBucketing from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_communication_ops - from thunder.executors.passes import del_last_used, transform_for_execution + from thunder.executors.passes import del_last_used, transform_for_execution, autotune_transform_for_execution utils.check(compile_data is not None, lambda: "`compile_data` is required") # NOTE: This function is rather slow, so it's intended to be used @@ -123,9 +155,12 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat if not any(requires_grad_mask): raise RuntimeError("PyTorch's Autograd interface requires at least one tensor input with requires_grad=True") + autotune_type = compile_data.compile_options.get("autotune_type", None) + primal_trace = computation_trc primal_trace = sort_data_parallel_syncs(primal_trace) + # Handled by the caller if autotune is not None if compile_stats is not None: compile_stats.last_traces.append(primal_trace) @@ -136,6 +171,10 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # the forward trace and inputs of the backward trace. fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True) + # Update the autotuned executors list + if autotune_type: + update_compile_options_executor_list_after_fw_bw_split() + fw_traces = [fw_trace] bw_traces = [bw_trace] @@ -164,73 +203,60 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat _fsdp_comm_bucketing = FSDPCommBucketing(compile_data, computation_trc) fw_trace = _fsdp_comm_bucketing.apply_bucketing_to_forward_trace(fw_trace) + do_apply_bucketing_bw_trace: bool = getattr(compile_data.fn, "use_fsdp", False) + # Now we can run the optimization passes on the forward trace - # TODO Restore request for no rematerialization - fw_extrace = transform_for_execution( - fw_trace, - executors_list=compile_data.executors_list, + backend_optimizer_ctx: BackendOptimizer | None = ( + None + if autotune_type is None + else BackendOptimizer( + priority_executors=compile_data.executors_list, + apply_bucketing_bw_trace=do_apply_bucketing_bw_trace, + produce_log=False, + optimizer_type=autotune_type, + compile_data=compile_data, + ) ) - fw_traces.append(fw_extrace) - - # Some of the optimization passes change proxies in the trace and - # any change in the forward trace must be reflected in the backward - # trace. - original_bw_saved_tensors_for_backward = bw_trace.args[0][0] - new_fw_saved_tensors_for_backward = get_saved_for_backward_tensors(fw_extrace) - # saved meta data (this could also contain proxies) - original_bw_saved_meta_for_backward = bw_trace.args[0][1] - new_fw_saved_meta_for_backward = fw_extrace.output[1][1] + # Get optimzied fw trace + fw_extrace = ( + transform_for_execution(fw_trace, executors_list=compile_data.executors_list) + if autotune_type is None + else autotune_transform_for_execution( + optimizer_context=backend_optimizer_ctx, trace=fw_trace, trace_type=TraceType.FW + ) + ) - saved_tensors_swap_map = { - variableify(x): y - for x, y in zip(original_bw_saved_tensors_for_backward, new_fw_saved_tensors_for_backward) - if variableify(x) != variableify(y) - } + # If in default mode, otherwise the best fw will be returned only at the end + if autotune_type is None: + # Here fw_extrace is not None - saved_metadata_swap_map = { - variableify(x): y - for x, y in zip(original_bw_saved_meta_for_backward, new_fw_saved_meta_for_backward) - if variableify(x) != variableify(y) - } - swap_map = saved_tensors_swap_map | saved_metadata_swap_map - - new_bsyms = replace_redundant_inputs(swap_map, bw_trace.bound_symbols) - # replace_redundant_inputs doesn't replace the output of - # UNPACK_SEQUENCE so we do it manually. Here we have certain - # assumptions about the structure of the backward trace. - assert bw_trace.bound_symbols[0].sym.id == PrimIDs.UNPACK_TRIVIAL - assert bw_trace.bound_symbols[0].kwargs["name"] == "saved_for_backward" - assert bw_trace.bound_symbols[4].sym.id == PrimIDs.UNPACK_SEQUENCE - assert bw_trace.bound_symbols[4].args[0].name == "C0" - assert bw_trace.bound_symbols[5].sym.id == PrimIDs.UNPACK_SEQUENCE - assert bw_trace.bound_symbols[5].args[0].name == "C1" - new_bsyms[4] = new_bsyms[4].from_bsym_swap_proxies( - swap_map, - skip_inputs=False, - skip_output=False, - skip_subsymbols=False, - ) - new_bsyms[5] = new_bsyms[5].from_bsym_swap_proxies( - swap_map, - skip_inputs=False, - skip_output=False, - skip_subsymbols=False, - ) - bw_trace.bound_symbols = new_bsyms + fw_traces.append(fw_extrace) - if getattr(compile_data.fn, "use_fsdp", False): - bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace) + # If autotuning is activated, it will take care of the following 2 calls + bw_trace = update_bw_from_forward_optimization(fw=fw_extrace, bw=bw_trace) + if do_apply_bucketing_bw_trace: + bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace) # Now we can run the optimization passes on the backward trace - # TODO Restore request for no rematerialization - bw_extrace = transform_for_execution( - bw_trace, - executors_list=compile_data.executors_list, - ) + if autotune_type is not None: + fw_extrace, bw_extrace = autotune_transform_for_execution( + optimizer_context=backend_optimizer_ctx, trace=bw_trace, trace_type=TraceType.BW + ) + fw_traces.append(fw_extrace) + else: + bw_extrace = transform_for_execution( + bw_trace, + executors_list=compile_data.executors_list, + ) bw_traces.append(bw_extrace) - fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) + if autotune_type is None: + # TODO Restore request for no rematerialization + fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) + # Autotuner has been taken care of remat + else: + pass fw_traces.append(fw_extrace) bw_traces.append(bw_extrace) @@ -296,9 +322,11 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat bw_trace = rename_bwd_trace_outputs(bw_extrace, fw_extrace) + # This is moved to the caller if autotune is enabled if compile_stats is not None: compile_stats.last_traces += fw_traces compile_stats.last_backward_traces += bw_traces + compile_stats.last_executors = compile_data.executors_list # Enable wrapping with `te.fp8_autocast`. fw_extrace._include_te_fp8_autocast = True diff --git a/thunder/executors/transformer_engineex.py b/thunder/executors/transformer_engineex.py index df620c08c1..bdb5472ead 100644 --- a/thunder/executors/transformer_engineex.py +++ b/thunder/executors/transformer_engineex.py @@ -404,8 +404,10 @@ def _te_functional_linear_backward_meta( ) +te_functional_linear_backward_name: str = "te_functional_linear_backward" + te_functional_linear_backward = transformer_engine_ex.register_operator( - "te_functional_linear_backward", meta=_te_functional_linear_backward_meta, fn=_te_functional_linear_backward_impl + te_functional_linear_backward_name, meta=_te_functional_linear_backward_meta, fn=_te_functional_linear_backward_impl ) LINEAR_CALLS_COUNTER = 0 @@ -417,13 +419,16 @@ def _te_functional_linear_backward_meta( FP8_RECIPE_KEY = "te_fp8_recipe" +linear_bound_symbol_name_prefix: str = "te_linear" + + # Creates a new stateful operator for each invocation of `linear`. def _create_fp8_linear_bound_symbol( a: TensorProxy, w: TensorProxy, b: TensorProxy, is_grad_enabled=False ) -> tuple[torch.Tensor, AnyProxy | None]: linear_fn = partial(TELinear(w.shape[1], w.shape[0]), is_grad_enabled=is_grad_enabled) global LINEAR_CALLS_COUNTER - name = f"te_linear_{LINEAR_CALLS_COUNTER}" + name = f"{linear_bound_symbol_name_prefix}_{LINEAR_CALLS_COUNTER}" desc = "transformer_engine_ex: Optional fp8_recipe for `fp8_autocast` context manager." if (fp8_recipe := get_compile_option(FP8_RECIPE_KEY, desc)) is None: diff --git a/thunder/extend/__init__.py b/thunder/extend/__init__.py index 1551bfb880..4c18f167fd 100644 --- a/thunder/extend/__init__.py +++ b/thunder/extend/__init__.py @@ -362,6 +362,7 @@ def get_all_executors() -> tuple[Executor, ...]: torchex, transformer_engineex, triton_crossentropy, + nvmathex ) return tuple(_executor_map.values()) diff --git a/thunder/tests/test_autotuner.py b/thunder/tests/test_autotuner.py new file mode 100644 index 0000000000..fa60343037 --- /dev/null +++ b/thunder/tests/test_autotuner.py @@ -0,0 +1,716 @@ +from thunder.backend_optimizer.utils import get_fw_bw_split_backends_options +from thunder.core.dtypes import to_torch_dtype +from thunder.core.prims import PrimIDs +from thunder.core.proxies import FloatProxy, IntegerProxy, TensorProxy +from thunder.core.symbol import BoundSymbol, Symbol +from thunder.core.trace import TraceCtx +from thunder.executors.cudnnex import cudnn_ex +from thunder.executors.fa3ex import fa3_ex +from thunder.executors.nvfuserex import nvfuserex +from thunder.executors.pythonex import ex as pythonex +from thunder.executors.sdpaex import sdpa_ex +from thunder.executors.torch_compile import torch_compile_ex +from thunder.executors.torchex import ex as torchex +from thunder.executors.transformer_engineex import transformer_engine_ex +from thunder.extend import Executor, get_always_executors +from thunder.tests.framework import requiresCUDA +from typing import Callable, Sequence +import pytest +import thunder +import thunder.backend_optimizer.utils as aut_utils +import torch + + +class DummyProxy: + def __init__(self, name) -> None: + self.name = name + + +@pytest.mark.parametrize( + "data,expected", + [ + ([DummyProxy("a"), DummyProxy("b")], "[a#b#]"), + ([DummyProxy("a"), DummyProxy("b"), 90], "[a#b#int90#]"), + ([DummyProxy("a"), DummyProxy("b"), 90, None], "[a#b#int90#None#]"), + ([DummyProxy("a"), DummyProxy("b"), 90, [DummyProxy("c"), [DummyProxy("d")]]], "[a#b#int90#[c#[d#]]]"), + ], +) +def test_sequence_hash(data, expected): + assert aut_utils.sequence_hash(data) == expected + + +@pytest.mark.parametrize( + "data,expected", + [ + ([DummyProxy("a"), "b"], "[a#b#]"), + ], +) +def test_sequence_hash_bad_input(data, expected): + with pytest.raises(AssertionError): + assert aut_utils.sequence_hash(data) == expected + + +@pytest.mark.parametrize( + "data,expected_sum,expected_others", + [ + ([nvfuserex, torch_compile_ex], Executor(name="empty"), Executor(name="empty")), + ([nvfuserex, torchex], torchex, Executor(name="empty")), + ], +) +def test_first_available_operator_executor(data, expected_sum, expected_others): + def fn(a: torch.Tensor, b: torch.Tensor): + return a + b + + a = torch.randn(1, 1) + b = torch.randn(1, 1) + jitted = thunder.jit(fn) + jitted(a, b) + trace = thunder.last_traces(jitted)[-1] + for bsym in trace.bound_symbols: + if bsym.sym.id == PrimIDs.ADD: + assert ( + aut_utils.get_first_available_operator_executor(bsym=bsym, executors=data, empty_hash="empty") + == expected_sum + ) + else: + assert ( + aut_utils.get_first_available_operator_executor(bsym=bsym, executors=data, empty_hash="empty") + == expected_others + ) + + +@pytest.mark.parametrize( + "test,expected", + [ + ([1, 2, 3], [1, 2, 3]), + ([1, 2, [3, 4]], [1, 2, 3, 4]), + ([1, 2, [3, 4, [None]]], [1, 2, 3, 4]), + ], +) +def test_flatten_sequence(test, expected): + assert aut_utils.flatten_sequence(test) == expected + + +def test_get_not_used_intermediate_outputs(): + # Flat outputs + def fn(a: torch.Tensor, b: torch.Tensor): + t1 = a - b + t2 = a * b + t3 = a / b + return (a + b) + t2 + + a = torch.randn(1, 1) + b = torch.randn(1, 1) + jitted = thunder.jit(fn, disable_dce=True) + jitted(a, b) + trace = thunder.last_traces(jitted)[-1] + + not_used = aut_utils.get_not_used_intermediate_outsputs(trace) + # We have not used t1, t3 in trace + not_used_labels = ["t1", "t3"] + assert len(not_used) == 2 + for t in not_used: + assert t.name in not_used_labels + not_used_labels.remove(t.name) + + +def _assign_executors_fn(a: torch.Tensor): + t0 = a * 2 + t1 = a * a + t3 = t0 + t1 + return t3 + + +@pytest.mark.parametrize( + "fn, args, executors", + [ + ( + _assign_executors_fn, + torch.randn(1, 1), + [Executor("empty"), torchex, torchex, torchex, Executor("empty")], + ), + ( + _assign_executors_fn, + torch.randn(1, 1), + [Executor("empty"), torch_compile_ex, torch_compile_ex, torch_compile_ex, Executor("empty")], + ), + ( + _assign_executors_fn, + torch.randn(1, 1), + [Executor("empty"), torch_compile_ex, torch_compile_ex, torchex, Executor("empty")], + ), + ( + _assign_executors_fn, + torch.randn(1, 1), + [Executor("empty"), torchex, torch_compile_ex, torch_compile_ex, Executor("empty")], + ), + ], +) +def test_assign_executors(fn, args, executors): + trace: TraceCtx = thunder.trace(inline_trace=True)(fn, args) + placed: TraceCtx = aut_utils.assign_executors( + in_trace=trace, executors_list=executors, always_executors=get_always_executors(), empty_str="empty" + ) + + def _id(bsym: BoundSymbol): + res = bsym.sym.name + if isinstance(bsym.output, Sequence): + res += "#" + aut_utils.sequence_hash(bsym.output) + else: + res += "#" + bsym.output.name + + return res + + # Unapacks and return symbols are filtered out + executor_map = { + _id(b): e if e.name != "empty" else None + for b, e in zip(trace.bound_symbols, executors) + if b.output is not None and b.sym.id != PrimIDs.RETURN and b.args is not None + } + + for b in placed.bound_symbols: + # print(b) + if b.sym.is_fusion: + # Search in every subsymbol + for sub in b.subsymbols: + identif = _id(sub) + assert b.sym.executor == executor_map[identif] + elif b.sym.id != PrimIDs.RETURN and b.args: + identif = _id(b) + assert b.sym.executor == executor_map[identif] + + +class Linear(torch.nn.Module): + def __init__(self, a, b) -> None: + super().__init__() + self.linear = torch.nn.Linear(a, b) + + def forward(self, x): + return self.linear(x) + + +class Matmul(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + return x @ x + + +@pytest.mark.parametrize( + "model, x, op, expected", + [ + (Linear(8, 8), torch.randn(8, 8), "linear", True), + (Linear(8, 8), torch.randn(8, 8), "add", False), + (Matmul(), torch.randn(8, 8), "matmul", True), + ], +) +def test_operation_in_trace(model, x, op, expected): + jitted = thunder.jit(model) + jitted(x) + # jitted(args if not isinstance(args, Sequence) else *args) + trace = thunder.last_traces(jitted)[-1] + + assert aut_utils.operation_in_trace(trace=trace, op=op) == expected + + +class Sdpa(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, q, k, v): + return torch.nn.functional.scaled_dot_product_attention(q, k, v) + + +@pytest.mark.parametrize( + "model, q, k, v, executors, expected", + [ + ( + Sdpa(), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + ["cudnn", "sdpa", "fa3"], + 1, + ), + ( + Sdpa(), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + ["cudnn", "sdpa"], + 1, + ), + ( + Sdpa(), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + [ + "cudnn", + ], + 1, + ), + ( + Sdpa(), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + [], + 0, + ), + ], +) +@requiresCUDA +# Currently these executors are: cudnn, spda, fa3, TE +def test_update_compile_options_executor_list_after_fw_bw_split(model, q, k, v, executors, expected): + jitted = thunder.jit(model, autotune_type="runtime", executors=executors) + jitted(q, k, v) + + assigned: Sequence[Executor] = thunder.executors_applied(jitted) + + count = 0 + for ex in assigned: + count += 1 if ex.name in executors else 0 + + assert count == expected + + +def _test_transform_proxies_to_real_fn_1(a: torch.Tensor, b: torch.Tensor, k: int): + t0 = a * b + return t0 * k + + +def _test_transform_proxies_to_real_fn_2( + a: torch.Tensor, b: torch.Tensor, c: tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]] +): + t0 = c[0] + c[1][0] + t1 = t0 * c[1][1] + return t1 - a + b + + +def _test_transform_proxies_to_real_common( + fn: Callable, torch_args: tuple, executors: list, has_backward: bool, **kwargs +): + jitted = thunder.jit(fn, executors=executors) + jitted(*torch_args) + + trace_static_args = thunder.last_traces(jitted)[-1].args + assert trace_static_args + + transformed_args = aut_utils.transform_proxies_to_real(trace_static_args, **kwargs) + + assert isinstance(transformed_args, list) + + def _comp(thunder_seq: Sequence, torch_seq: Sequence): + assert len(thunder_seq) == len(torch_seq) + + for a, b in zip(thunder_seq, torch_seq): + if isinstance(a, TensorProxy): + # handle TE fp32 + # Static type for fp8 is torch.float8 but the runtime is TE Float8 if TE is being used + if a.dtype.bytes == 1 and kwargs.get("te_used"): + assert b.dtype == torch.float32 + else: + assert b.dtype == to_torch_dtype(a.dtype) + assert a.device.device_str() == str(b.device) + assert a.shape == b.shape + assert a.requires_grad == b.requires_grad + elif isinstance(a, IntegerProxy) or isinstance(a, FloatProxy): + assert a.value == b + + if isinstance(a, Sequence): + assert isinstance(b, Sequence) + _comp(a, b) + + _comp(trace_static_args, transformed_args) + + if has_backward: + trace_static_args = thunder.last_backward_traces(jitted)[-1].args + assert trace_static_args + + transformed_args = aut_utils.transform_proxies_to_real(trace_static_args, **kwargs) + print(trace_static_args) + # print(transformed_args) + + _comp(trace_static_args, transformed_args) + + +@pytest.mark.parametrize( + "fn, torch_args, executors, has_backward", + [ + (_test_transform_proxies_to_real_fn_1, tuple([torch.randn(1, 1), torch.randn(1, 1), 10]), [], False), + ( + _test_transform_proxies_to_real_fn_2, + tuple([torch.randn(1, 1), torch.randn(1, 1), (torch.randn(1, 1), (torch.randn(1, 1), torch.rand(1, 1)))]), + [], + False, + ), + ( + Sdpa(), + ( + torch.randn([10, 128, 4, 32], dtype=torch.float16, requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, requires_grad=True), + ), + [], + True, + ), + ], +) +def test_transform_proxies_to_real(fn: Callable, torch_args: tuple, executors: list, has_backward: bool): + _test_transform_proxies_to_real_common(fn, torch_args, executors, has_backward) + + +@requiresCUDA +def test_transform_proxies_to_real_TE(): + class Model(torch.nn.Module): + def __init__(self, in_features, out_features) -> None: + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + model = Model(4096, 4096) + model.to("cuda") + + _test_transform_proxies_to_real_common( + model, + tuple([torch.randn(4096, 4096, requires_grad=True, device="cuda")]), + ["transformer_engine"], + True, + te_used=True, + ) + + +@pytest.mark.parametrize( + "executors, expected, use_te", + [ + (["python"], ["nvfuser", "python"], False), + (["nvfuser", "cudnn"], ["cudnn", "nvfuser"], False), + (["torch", "nvfuser", "sdpa"], ["sdpa", "torch", "nvfuser"], False), + (["transformer_engine", "nvfuser", "sdpa"], ["transformer_engine", "sdpa", "nvfuser"], True), + ], +) +# We might not have nvfuser in non cuda envs +@requiresCUDA +def test_reorder_executors_list(executors, expected, use_te): + assert aut_utils.reorder_executors_list(executors, autotune_enable_te=use_te) == expected + + +@pytest.mark.parametrize( + "name, expected", + [("linear", [transformer_engine_ex]), ("scaled_dot_product_attention", [sdpa_ex, cudnn_ex, fa3_ex])], +) +def test_get_fw_bw_split_backends_options(name: str, expected): + symbol = Symbol(name=name) + bsym = BoundSymbol(symbol, (), {}, None) + options = get_fw_bw_split_backends_options(bsym, autotune_enable_te=True) + assert all(map(lambda v: v in options, expected)) + + +class Model_1(torch.nn.Module): + def __init__(self, in_f, out_f) -> None: + super().__init__() + self.linear = torch.nn.Linear(in_f, out_f) + + def forward(self, x): + t0 = self.linear(x) + return torch.nn.functional.silu(t0) + + +class Model_2(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.n_head = 12 + self.n_embd = 3072 + self.c_attn = torch.nn.Linear(self.n_embd, 3 * self.n_embd, bias=False) + + def forward(self, x): + B, T, C = x.size() + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + return torch.nn.functional.scaled_dot_product_attention(q, k, v) + + +@pytest.mark.parametrize( + "model, tensor_shape, dtype, autotune_type, executors, expected_executors, use_te", + [ + ( + Model_1(32, 32), + (32, 32), + torch.float32, + "runtime", + [nvfuserex], + [[nvfuserex, torchex, pythonex]], + False, + ), + ( + Model_1(32, 32), + (32, 32), + torch.float32, + "memory", + [torch_compile_ex], + [[torch_compile_ex, torchex, pythonex]], + False, + ), + ( + Model_1(4096, 4096), + (128, 4096), + torch.float32, + "runtime", + [transformer_engine_ex], + [[transformer_engine_ex, nvfuserex, torchex, pythonex]], + True, + ), + ( + Model_2(), + (16, 1024, 3072), + torch.float16, + "runtime", + [sdpa_ex, cudnn_ex], + [[sdpa_ex, nvfuserex, torchex, pythonex], [cudnn_ex, nvfuserex, torchex, pythonex]], + False, + ), + ( + Model_2(), + (16, 1024, 3072), + torch.float32, + "runtime", + [sdpa_ex, transformer_engine_ex], + [ + [sdpa_ex, transformer_engine_ex, nvfuserex, torchex, pythonex], + [transformer_engine_ex, sdpa_ex, nvfuserex, torchex, pythonex], + ], + True, + ), + ], +) +@requiresCUDA +def test_autotuner( + model: torch.nn.Module, + tensor_shape: tuple, + dtype: torch.dtype, + autotune_type: str, + executors: list, + expected_executors: list[list], + use_te: bool, +): + def _run(): + model.to("cuda") + x = torch.randn(tensor_shape, dtype=dtype, device="cuda") + jitted_def = thunder.jit(model, executors=executors) + jitted_auto = thunder.jit( + model, + autotune_type=autotune_type, + executors=executors, + autotune_enable_te=use_te, + ) + y_def = jitted_def(x) + y_auto = jitted_auto(x) + + te_used = aut_utils.is_te_used(thunder.last_traces(jitted_auto)[-1]) + got = thunder.executors_applied(jitted_auto) + print("got", got) + print("expected", expected_executors) + assert any([t == got for t in expected_executors]) + # With TE enabled deviation ((y_def - y_auto).abs().max().item()) is between tensors are ~0.2 + # For the else branch: https://pytorch.org/docs/stable/testing.html + torch.testing.assert_close(y_def, y_auto, atol=2 * 1e-1 if te_used else 1e-5, rtol=1e-1 if te_used else 1.3e-6) + + if dtype != torch.get_default_dtype(): + with torch.autocast(device_type="cuda"): + _run() + else: + _run() + + +""" +The longest repeated block is: + t2 = x @ y + t3 = t0 + t0 + t4 = t1 * t1 +""" + + +def _test_repetead_transformer_blocks_fn(x: torch.Tensor, y: torch.Tensor): + t0 = x + x + t1 = y * y + t2 = x @ y + t3 = t0 + t0 + t4 = t1 * t1 + t5 = t2 @ t2 + t6 = t3 + t3 + t7 = t4 * t4 + t8 = t6 - t7 + return t8, t5 + + +def test_repetead_transformer_blocks(): + device = "cpu" + + a = torch.randn(2, 2, device=device) + b = torch.randn(2, 2, device=device) + + jitted = thunder.jit(_test_repetead_transformer_blocks_fn, disable_dce=True) + jitted(a, b) + + trace = thunder.last_traces(jitted)[-1] + print(trace) + blocks = aut_utils.repetead_trace_blocks(trace=trace) + assert len(blocks) == 2 + assert blocks[0][1] - blocks[0][0] + 1 == 3 + + +def test_reduce_common_trace_blocks(): + device = "cpu" + + a = torch.randn(2, 2, device=device) + b = torch.randn(2, 2, device=device) + + jitted = thunder.jit(_test_repetead_transformer_blocks_fn, disable_dce=True) + jitted(a, b) + + trace = thunder.last_traces(jitted)[-1] + blocks = aut_utils.repetead_trace_blocks(trace=trace) + reduced_trace = aut_utils.reduce_common_trace_blocks( + trace=trace, common_blocks_in=blocks, skip_between_blocks=False + ) + + # We expect that t5, t6, t7 have been removed + should_remove = set(["t5", "t6", "t7"]) + for b in reduced_trace.bound_symbols: + if hasattr(b.output, "name"): + assert b.output.name not in should_remove + +@requiresCUDA +def test_save_configuration_cuda(): + class _LLaMAMLP(torch.nn.Module): + def __init__(self, n_embd, intermediate_size) -> None: + super().__init__() + self.fc_1 = torch.nn.Linear(n_embd, intermediate_size, bias=False) + self.fc_2 = torch.nn.Linear(n_embd, intermediate_size, bias=False) + self.proj = torch.nn.Linear(intermediate_size, n_embd, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_fc_1 = self.fc_1(x) + x_fc_2 = self.fc_2(x) + x = torch.nn.functional.silu(x_fc_1) * x_fc_2 + return self.proj(x) + + with torch.device("cuda"): + model = _LLaMAMLP(4, 4) + jitted = thunder.jit( + model, + autotune_type="memory", + model_name="llamamlp", + autotune_save_configuration=True, + ) + jitted_recovered = thunder.jit( + model, + autotune_type="runtime", + autotune_restore_configuration="llamamlp_memory.json", + ) + + x = torch.randn(4, 4) + a = jitted(x) + b = jitted_recovered(x) + + torch.testing.assert_close(a, b) + + for bsym_a, bsym_b in zip( + thunder.last_traces(jitted)[-1].bound_symbols, thunder.last_traces(jitted_recovered)[-1].bound_symbols + ): + assert bsym_a.sym.executor == bsym_b.sym.executor + + +@requiresCUDA +# Currently inside the autotuner flow nvfuser will be imported which will lead to import errors +def test_no_autograd_trace_autotuning(): + def _fn(a, b): + t0 = a + b + t1 = a + t0 + t2 = t1 * t1 + t3 = b - t2 + return b @ t3 + + executors = ['torch', 'torchcompile'] + jfn_def = thunder.jit(_fn, executors=executors) + jfn_auto = thunder.jit(_fn, autotune_type='runtime', disable_torch_autograd=True, exeuctors=executors) + a = torch.randn(4,4) + b = torch.randn(4,4) + + y_def = jfn_def(a, b) + y_auto = jfn_auto(a, b) + + applied = set() + trace = thunder.last_traces(jfn_auto)[-1] + for b in trace.bound_symbols: + if b.sym.executor is not None: + applied.add(b.sym.executor.name) + + assert (applied == set(['torch']) or applied == set(['torchcompile', 'torch'])) + torch.testing.assert_close(y_def, y_auto) diff --git a/thunder/tests/test_extend.py b/thunder/tests/test_extend.py index 30015e76f4..58250a8a53 100644 --- a/thunder/tests/test_extend.py +++ b/thunder/tests/test_extend.py @@ -127,6 +127,7 @@ def test_get_all_executors_includes_all_native_executors(): "apex", "cudnn", "fa3", + "nvmath", "torch", "cudnn_layernorm", "sdpa",