From a5dbd8a53a144f8e2f201f80284987cf58531c65 Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Fri, 19 Sep 2025 14:42:07 +0200 Subject: [PATCH 01/13] Exported TE states to ThunderModule --- thunder/dev_utils/te_states_reporter.py | 412 ++++++++++++++++++ .../executors/transformer_engineex_impl.py | 67 +++ thunder/tests/distributed/test_ddp.py | 108 +++++ thunder/tests/distributed/test_fsdp.py | 135 ++++++ ...st_transformer_engine_executor_reporter.py | 297 +++++++++++++ 5 files changed, 1019 insertions(+) create mode 100644 thunder/dev_utils/te_states_reporter.py create mode 100644 thunder/tests/test_transformer_engine_executor_reporter.py diff --git a/thunder/dev_utils/te_states_reporter.py b/thunder/dev_utils/te_states_reporter.py new file mode 100644 index 0000000000..b8ba4754ec --- /dev/null +++ b/thunder/dev_utils/te_states_reporter.py @@ -0,0 +1,412 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple + +import thunder +import torch + +import transformer_engine as te + +import transformer_engine_torch +from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch.constants import MXFP8_BLOCK_SCALING_SIZE +from transformer_engine.pytorch.fp8 import ( + FP8GlobalStateManager, + RecipeState, + get_fp8_torch_dtype, +) + + +def summarize_recipe(recipe: Recipe) -> Dict[str, Any]: + """Create a compact, serializable summary of a TE FP8 recipe. + + The summary captures the recipe class name and a small set of key + configuration fields depending on the recipe family (DelayedScaling, + Float8CurrentScaling, MXFP8BlockScaling, Float8BlockScaling). For delayed + and current-scaling variants, the effective FP8 torch dtypes for forward + and backward are also included. + + Args: + recipe: A TransformerEngine `Recipe` instance from `transformer_engine.common.recipe`. + + Returns: + A dictionary with fields describing the recipe. + """ + summary: Dict[str, Any] = { + "type": recipe.__class__.__name__, + "fp8_format": getattr(recipe, "fp8_format", None), + } + + if recipe.delayed(): + summary.update( + { + "margin": getattr(recipe, "margin", None), + "amax_history_len": getattr(recipe, "amax_history_len", None), + "amax_compute_algo": getattr(recipe, "amax_compute_algo", None), + "scaling_factor_compute_algo": getattr(recipe, "scaling_factor_compute_algo", None), + "reduce_amax": getattr(recipe, "reduce_amax", None), + "fp8_dpa": getattr(recipe, "fp8_dpa", None), + "fp8_mha": getattr(recipe, "fp8_mha", None), + # Effective FP8 dtypes per pass + "fwd_fp8_torch_dtype": str(get_fp8_torch_dtype(recipe, True)), + "bwd_fp8_torch_dtype": str(get_fp8_torch_dtype(recipe, False)), + } + ) + elif recipe.float8_current_scaling(): + summary.update( + { + "fp8_quant_fwd_inp": getattr(recipe, "fp8_quant_fwd_inp", None), + "fp8_quant_fwd_weight": getattr(recipe, "fp8_quant_fwd_weight", None), + "fp8_quant_bwd_grad": getattr(recipe, "fp8_quant_bwd_grad", None), + "fp8_gemm_fprop": getattr(recipe, "fp8_gemm_fprop", None), + "fp8_gemm_dgrad": getattr(recipe, "fp8_gemm_dgrad", None), + "fp8_gemm_wgrad": getattr(recipe, "fp8_gemm_wgrad", None), + "fp8_dpa": getattr(recipe, "fp8_dpa", None), + "fp8_mha": getattr(recipe, "fp8_mha", None), + "fwd_fp8_torch_dtype": str(get_fp8_torch_dtype(recipe, True)), + "bwd_fp8_torch_dtype": str(get_fp8_torch_dtype(recipe, False)), + } + ) + elif recipe.mxfp8(): + summary.update( + { + "margin": getattr(recipe, "margin", None), + "fwd_fp8_torch_dtype": str(get_fp8_torch_dtype(recipe, True)), + "bwd_fp8_torch_dtype": str(get_fp8_torch_dtype(recipe, False)), + } + ) + elif recipe.float8_block_scaling(): + summary.update( + { + "x_block_scaling_dim": getattr(recipe, "x_block_scaling_dim", None), + "w_block_scaling_dim": getattr(recipe, "w_block_scaling_dim", None), + "grad_block_scaling_dim": getattr(recipe, "grad_block_scaling_dim", None), + "fp8_quant_fwd_inp": getattr(recipe, "fp8_quant_fwd_inp", None), + "fp8_quant_fwd_weight": getattr(recipe, "fp8_quant_fwd_weight", None), + "fp8_quant_bwd_grad": getattr(recipe, "fp8_quant_bwd_grad", None), + "fp8_gemm_fprop": getattr(recipe, "fp8_gemm_fprop", None), + "fp8_gemm_dgrad": getattr(recipe, "fp8_gemm_dgrad", None), + "fp8_gemm_wgrad": getattr(recipe, "fp8_gemm_wgrad", None), + "fwd_fp8_torch_dtype": str(get_fp8_torch_dtype(recipe, True)), + "bwd_fp8_torch_dtype": str(get_fp8_torch_dtype(recipe, False)), + } + ) + + return summary + + +def summarize_state(state: RecipeState) -> Dict[str, Any]: + """Summarize a runtime FP8 `RecipeState` object. + + Captures the state class, mode (forward/backward/None), dtype, number of + quantizers, and optionally basic tensor shape/device information for scale + and amax history tensors when present. + + Args: + state: A `RecipeState` produced by TransformerEngine during execution. + + Returns: + A dictionary with essential metadata about the state for reporting. + """ + out: Dict[str, Any] = { + "cls": state.__class__.__name__, + "mode": getattr(state, "mode", None), + "dtype": str(getattr(state, "dtype", None)), + "num_quantizers": getattr(state, "num_quantizers", None), + } + scale = getattr(state, "scale", None) + if isinstance(scale, torch.Tensor): + out["scale_shape"] = tuple(scale.shape) + out["scale_device"] = str(scale.device) + amax_hist = getattr(state, "amax_history", None) + if isinstance(amax_hist, torch.Tensor): + out["amax_history_shape"] = tuple(amax_hist.shape) + out["amax_history_device"] = str(amax_hist.device) + return out + + +def summarize_quantizer(quantizer: Any) -> Dict[str, Any]: + """Summarize an FP8 quantizer instance. + + Extracts commonly useful fields across different quantizer implementations + (rowwise/columnwise usage, internal flag, dtype) and, when available, + additional configuration such as amax reduction info. Tensor shape/device + metadata for `scale` and `amax` is included if present. + + Args: + quantizer: A quantizer-like object from TransformerEngine runtime. + + Returns: + A dictionary describing the quantizer in a compact, readable form. + """ + base: Dict[str, Any] = { + "cls": quantizer.__class__.__name__, + "rowwise_usage": getattr(quantizer, "rowwise_usage", None), + "columnwise_usage": getattr(quantizer, "columnwise_usage", None), + "internal": getattr(quantizer, "internal", None), + "dtype": str(getattr(quantizer, "dtype", None)), + } + # Optional attributes by quantizer class + if hasattr(quantizer, "with_amax_reduction"): + base["with_amax_reduction"] = getattr(quantizer, "with_amax_reduction") + base["amax_reduction_group"] = str(getattr(quantizer, "amax_reduction_group", None)) + if hasattr(quantizer, "force_pow_2_scales"): + base["force_pow_2_scales"] = getattr(quantizer, "force_pow_2_scales") + if hasattr(quantizer, "amax_epsilon"): + base["amax_epsilon"] = getattr(quantizer, "amax_epsilon") + # Shapes (when available) + for attr in ("scale", "amax"): + tensor = getattr(quantizer, attr, None) + if isinstance(tensor, torch.Tensor): + base[f"{attr}_shape"] = tuple(tensor.shape) + base[f"{attr}_device"] = str(tensor.device) + return base + + +def build_global_context() -> Dict[str, Any]: + """Collect global FP8 runtime context and environment details. + + Queries `FP8GlobalStateManager` and related sources to produce a stable + snapshot of the current FP8 configuration and availability, along with + environment metadata such as CUDA/cuBLASLt versions, device compute + capability, world size, and package versions. The result is intended to be + recorded once per session for reporting correlation. + + Returns: + A dictionary of global context fields suitable for rendering in reports. + """ + fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() + fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() + high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val() + fp8_group = FP8GlobalStateManager.get_fp8_group() + autocast_depth = FP8GlobalStateManager.FP8_AUTOCAST_DEPTH + graph_capturing = FP8GlobalStateManager.fp8_graph_capturing() + + # Availability and reasons + fp8_avail, reason_no_fp8 = FP8GlobalStateManager.is_fp8_available() + mxfp8_avail, reason_no_mx = FP8GlobalStateManager.is_mxfp8_available() + fp8blk_avail, reason_no_blk = FP8GlobalStateManager.is_fp8_block_scaling_available() + + # Versions / device + cuda_version = getattr(torch.version, "cuda", None) + cublaslt_version = transformer_engine_torch.get_cublasLt_version() + device_cc = torch.cuda.get_device_capability() if torch.cuda.is_available() else None + + # Dist info + world_size = None + if torch.distributed.is_available() and torch.distributed.is_initialized(): + try: + world_size = torch.distributed.get_world_size(group=fp8_group) + except (RuntimeError, ValueError, TypeError): + world_size = torch.distributed.get_world_size() + + # Package versions + te_version = getattr(te, "__version__", None) if te is not None else None + thunder_version = None + thunder_version = getattr(thunder, "__version__", None) + + return { + "fp8_enabled": fp8_enabled, + "fp8_calibration": fp8_calibration, + "with_fp8_parameters": with_fp8_params, + "high_precision_init_val": high_precision_init_val, + "fp8_group": str(fp8_group), + "world_size": world_size, + "autocast_depth": autocast_depth, + "graph_capturing": graph_capturing, + "fp8_available": fp8_avail, + "reason_no_fp8": reason_no_fp8, + "mxfp8_available": mxfp8_avail, + "reason_no_mxfp8": reason_no_mx, + "fp8_block_scaling_available": fp8blk_avail, + "reason_no_fp8_block_scaling": reason_no_blk, + "cuda_version": cuda_version, + "cublaslt_version": cublaslt_version, + "device_compute_capability": device_cc, + "te_version": te_version, + "thunder_version": thunder_version, + } + + +@dataclass +class TEStateReporter: + """Accumulates TE runtime summaries and renders a report.""" + + global_ctx: Optional[Dict[str, Any]] = None + recipe_summaries: List[Dict[str, Any]] = field(default_factory=list) + state_summaries_forward: List[Dict[str, Any]] = field(default_factory=list) + state_summaries_backward: List[Dict[str, Any]] = field(default_factory=list) + quantizer_summaries: List[Dict[str, Any]] = field(default_factory=list) + shape_policy: Dict[str, Any] = field(default_factory=lambda: {"mxfp8_block": MXFP8_BLOCK_SCALING_SIZE}) + seen_fw_states: Set[Tuple[int, int]] = field(default_factory=set) + seen_bw_states: Set[Tuple[int, int]] = field(default_factory=set) + seen_quantizers: Set[Tuple[int, int]] = field(default_factory=set) + + def update_from_runtime( + self, + *, + holder, + recipe: Optional[Recipe] = None, + states: Sequence[RecipeState] | None = None, + mode: Optional[str] = None, + quantizers: Sequence[Any] | None = None, + ) -> None: + """Update the reporter with data observed during runtime. + + This method is called one or more times during forward/backward passes + to incrementally collect summaries. The first invocation also captures + the global context snapshot. + + Args: + holder: The holder object (TERecipe, TERecipeState, or TEQuantizerState) + that owns the runtime data being reported. + recipe: Optional recipe active for the current autocast session. + states: Optional sequence of `RecipeState` objects observed. + mode: Optional mode string ("forward" or "backward") indicating the + execution phase when states are captured. + quantizers: Optional sequence of quantizer objects observed. + """ + if self.global_ctx is None: + self.global_ctx = build_global_context() + + # Collect recipe summaries only when called from the main recipe holder (TERecipe class). + # This avoids duplicate recipe entries when the same recipe is referenced by quantizer + # or state holders, ensuring we track each unique recipe configuration exactly once. + if recipe is not None and not quantizers: + summary = summarize_recipe(recipe) + if summary not in self.recipe_summaries: + self.recipe_summaries.append(summary) + + # Each trace execution can contain multiple forward and backward states for different recipes. + # We track unique combinations of (holder_id, recipe_id) to avoid duplicate state summaries + # while ensuring we capture all distinct recipe configurations used during runtime. + if states: + if mode == "forward": + if (id(holder), id(recipe)) not in self.seen_fw_states: + self.seen_fw_states.add((id(holder), id(recipe))) + self.state_summaries_forward.extend(summarize_state(s) for s in states) + elif mode == "backward": + if (id(holder), id(recipe)) not in self.seen_bw_states: + self.seen_bw_states.add((id(holder), id(recipe))) + self.state_summaries_backward.extend(summarize_state(s) for s in states) + + # Quantizers are reused across multiple trace executions but their behavior depends on the active recipe. + # While the quantizer object instances remain the same, different recipes can affect their configuration + # and internal state. We track unique combinations of (holder_id, recipe_id) to ensure we capture + # quantizer summaries for each distinct recipe configuration, avoiding both duplicates and missed + # configurations when recipes change during runtime. + if quantizers: + if (id(holder), id(recipe)) not in self.seen_quantizers: + self.seen_quantizers.add((id(holder), id(recipe))) + self.quantizer_summaries.extend(summarize_quantizer(q) for q in quantizers) + + def render_report(self) -> str: + """Render a human-readable multi-section report of collected data. + + The report includes global context, recipes, forward/backward state + summaries, quantizer summaries, and shape policy information. + + Returns: + A formatted string suitable for console logging or test output. + """ + lines: List[str] = [] + + def add(line: str = "") -> None: + lines.append(line) + + # Global Context + ctx = self.global_ctx or {} + add("Global Context:") + add(f" • FP8 Enabled: {ctx.get('fp8_enabled')}") + add(f" • FP8 Calibration: {ctx.get('fp8_calibration')}") + add(f" • FP8 Parameters: {ctx.get('with_fp8_parameters')}") + add(f" • High Precision Init: {ctx.get('high_precision_init_val')}") + add(f" • FP8 Group: {ctx.get('fp8_group')}") + add(f" • World Size: {ctx.get('world_size')}") + add(f" • Autocast Depth: {ctx.get('autocast_depth')}") + add(f" • Graph Capturing: {ctx.get('graph_capturing')}") + add("") + add(" Availability:") + add(f" - FP8: {ctx.get('fp8_available')}") + add(f" - MXFP8: {ctx.get('mxfp8_available')}") + add(f" - FP8 Block Scaling: {ctx.get('fp8_block_scaling_available')}") + if not ctx.get("fp8_block_scaling_available", True): + add(f" Reason: {ctx.get('reason_no_fp8_block_scaling')}") + add("") + add(" Versions:") + add(f" - CUDA: {ctx.get('cuda_version')} cuBLASLt: {ctx.get('cublaslt_version')}") + add(f" - Compute Capability: {ctx.get('device_compute_capability')}") + add(f" - TransformerEngine: {ctx.get('te_version')} Thunder: {ctx.get('thunder_version')}") + add("") + + # Recipes + add(f"Recipes ({len(self.recipe_summaries)}):") + for idx, rs in enumerate(self.recipe_summaries, 1): + add(f" [{idx}] {rs.get('type')} - {rs.get('fp8_format')}") + # Print a compact subset + for key in ( + "margin", + "amax_history_len", + "amax_compute_algo", + "reduce_amax", + "fp8_dpa", + "fp8_mha", + "fwd_fp8_torch_dtype", + "bwd_fp8_torch_dtype", + "x_block_scaling_dim", + "w_block_scaling_dim", + "grad_block_scaling_dim", + ): + if rs.get(key) is not None: + add(f" {key}: {rs.get(key)}") + add("") + + # States + add(f"Forward States ({len(self.state_summaries_forward)}):") + for idx, ss in enumerate(self.state_summaries_forward, 1): + add(f" [{idx}] Mode: {ss.get('mode')} DType: {ss.get('dtype')} Quantizers: {ss.get('num_quantizers')}") + if ss.get("scale_shape") is not None: + add(f" Scale: {ss.get('scale_shape')} on {ss.get('scale_device')}") + else: + add(" Note: no per-tensor scale (likely MXFP8/blockwise)") + if ss.get("amax_history_shape") is not None: + add(f" Amax History: {ss.get('amax_history_shape')} on {ss.get('amax_history_device')}") + add("") + + # Backward States (if any) + if self.state_summaries_backward: + add(f"Backward States ({len(self.state_summaries_backward)}):") + for idx, ss in enumerate(self.state_summaries_backward, 1): + add( + f" [{idx}] Mode: {ss.get('mode')} DType: {ss.get('dtype')} Quantizers: {ss.get('num_quantizers')}" + ) + if ss.get("scale_shape") is not None: + add(f" Scale: {ss.get('scale_shape')} on {ss.get('scale_device')}") + if ss.get("amax_history_shape") is not None: + add(f" Amax History: {ss.get('amax_history_shape')} on {ss.get('amax_history_device')}") + add("") + + # Quantizers + add(f"Quantizers ({len(self.quantizer_summaries)}):") + for idx, qs in enumerate(self.quantizer_summaries, 1): + add( + f" [{idx}] {qs.get('cls')} - {qs.get('dtype')}\n" + f" Rowwise: {qs.get('rowwise_usage')}\n" + f" Columnwise: {qs.get('columnwise_usage')}\n" + f" Internal: {qs.get('internal')}" + ) + if qs.get("with_amax_reduction") is not None: + add(f" Amax Reduction: {qs.get('with_amax_reduction')} group={qs.get('amax_reduction_group')}") + for attr in ("scale", "amax"): + if qs.get(f"{attr}_shape") is not None: + add(f" {attr.capitalize()}: {qs.get(f'{attr}_shape')} on {qs.get(f'{attr}_device')}") + add("") + + # Shape policy + add("Shape Policy:") + add(f" • mxfp8_block: {self.shape_policy.get('mxfp8_block')}") + + return "\n".join(lines) + + +__all__ = ["TEStateReporter"] diff --git a/thunder/executors/transformer_engineex_impl.py b/thunder/executors/transformer_engineex_impl.py index 9d913446d7..b8e1802458 100644 --- a/thunder/executors/transformer_engineex_impl.py +++ b/thunder/executors/transformer_engineex_impl.py @@ -1,8 +1,10 @@ import time from typing import TYPE_CHECKING +import weakref import torch.distributed as torch_dist +from thunder.core.compile_data import get_compile_data from thunder.core.prims import linear as linear_prim from thunder.core.prims import get_grad, put_grad from thunder.core.proxies import AnyProxy, TensorProxy @@ -20,6 +22,7 @@ from thunder.core.transform_common import cse_single_bsym from thunder.executors.passes import del_last_used import thunder.core.utils as utils +from thunder.dev_utils.te_states_reporter import TEStateReporter if TYPE_CHECKING: from thunder.core.trace import VariableInterface @@ -35,14 +38,39 @@ RecipeState, FP8GlobalStateManager, ) + from transformer_engine.pytorch.ops import BasicLinear from transformer_engine.pytorch.tensor import Quantizer from transformer_engine.pytorch.utils import check_dim_for_fp8_exec +from thunder.core.module import get_thunder_module transformer_engine_ex = StatefulExecutor("transformer_engine") register_executor(transformer_engine_ex) +def _export_te_states(*, recipe=None, states=None, mode = None, quantizers=None, holder=None): + tm = None + if holder is not None: + tm_ref = getattr(holder, "_tm_ref", None) + if tm_ref is not None and isinstance(tm_ref, weakref.ref): + tm = tm_ref() + + # Cannot fallback to compile data as during execution we don't have compile_data context + if tm is None: + import warnings + warnings.warn("No ThunderModule found for exporting TE states", UserWarning) + return + + if not hasattr(tm, "te_reporter"): + tm.te_reporter = TEStateReporter() + + tm.te_reporter.update_from_runtime( + holder=holder, + recipe=recipe, + states=states, + mode=mode, + quantizers=quantizers, + ) def _te_fp8_recipe_meta() -> AnyProxy: return AnyProxy(None, prefix="r") @@ -61,6 +89,9 @@ def __call__(self) -> Recipe: if not self.fp8_recipe or self.fp8_recipe is not te_fp8_recipe: self.fp8_recipe = te_fp8_recipe + # Duplicate recipies are handled by the TEStateReporter as we don't have any early return logic here + _export_te_states(recipe=self.fp8_recipe, holder=self) + return self.fp8_recipe @@ -83,10 +114,14 @@ def __init__(self): def __call__(self, recipe_state: RecipeState, num_quantizers: int) -> list[Quantizer]: if self.quantizers and self.parent_recipe_state is recipe_state: return self.quantizers + quantizers = recipe_state.make_quantizers() self.quantizers = quantizers self.parent_recipe_state = recipe_state + + # Export only new quantizers + _export_te_states(recipe=recipe_state.recipe, quantizers=quantizers, holder=self) return quantizers @@ -120,6 +155,9 @@ def __call__(self, recipe: Recipe, mode: str, num_quantizers: int) -> RecipeStat self.state = recipe_state self.parent_recipe = recipe + # Export only new states + _export_te_states(recipe=recipe, states=(recipe_state,), mode=mode, holder=self) + return recipe_state @@ -343,6 +381,11 @@ def __init__(self): self.rhs_to_bsym_map: dict[BoundSymbolRHS, BoundSymbol] = {} self.redundant_map: dict[Variable, Proxy] = {} self.new_saved_for_backward = None + self._tm_ref = None + + def transform_module(self, model) -> None: + # Cache a weakref to the ThunderModule for later runtime export + self._tm_ref = weakref.ref(model) def reset(self): self.fp8_recipe = None @@ -351,6 +394,20 @@ def reset(self): self.redundant_map = {} self.new_saved_for_backward = None + def _stamp_te_refs_to_bsym(self, tr): + for bsym in tr.bound_symbols: + call_ctx = getattr(bsym, "_call_ctx", None) + if not call_ctx: + continue + te_prefixes = ["get_te_fp8_recipe", "get_te_fp8_state", "get_te_fp8_quantizers"] + if not any(bsym.sym.name.startswith(prefix) for prefix in te_prefixes): + continue + state_obj = next(iter(call_ctx.values())) + + assert getattr(state_obj, "_tm_ref", None) is None + + setattr(state_obj, "_tm_ref", self._tm_ref) # stamp the ThunderModule weakref here + def transform_trace_post_optimization(self, computation_trace, **kwargs): """ Finds and replaces TE executor recipe calls and replaces them with one. @@ -365,6 +422,14 @@ def transform_trace_post_optimization(self, computation_trace, **kwargs): if "transformer_engine" not in map(lambda x: x.name, kwargs["executors_list"]): return computation_trace + # Ensure we have a ThunderModule weakref available + if self._tm_ref is None: + cd = get_compile_data() + if cd is not None and getattr(cd, "is_module", False): + tm = get_thunder_module(cd.fn) + if tm is not None: + self._tm_ref = weakref.ref(tm) + start_time_ns = time.perf_counter_ns() new_trace = from_trace(computation_trace) @@ -422,6 +487,8 @@ def transform_trace_post_optimization(self, computation_trace, **kwargs): sync_trace = del_last_used(new_trace) + self._stamp_te_refs_to_bsym(sync_trace) + end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns elapsed_time_millis = elapsed_time_ns // 1000000 diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index c1eb9e8d93..9ee950212d 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -774,6 +774,99 @@ def _test_ddp_transformer_engine_llama_sanity(input_data): return None +def _test_ddp_transformer_engine_reporter(input_data): + # Test Description: + # Verify that the TEStateReporter correctly captures and reports TransformerEngine + # FP8 state information during DDP training execution, including global context, + # recipe summaries, and forward/backward state summaries across distributed processes. + + init_method, world_size, rank, _executor, device, _dtype, _unused_kwargs = input_data + devicetype = devices.device_from_string(device).devicetype + pg = init_per_process_distributed(init_method, devicetype, world_size, rank) + + fp8_recipe = get_default_fp8_recipe() + + torch.cuda.set_device(rank) + torch_device = torch.device("cuda", rank) + + dtype_t = torch.bfloat16 + if rank == 0: + x = torch.randn(3, 768, 4096, device=torch_device, dtype=dtype_t, requires_grad=True) + else: + x = torch.randn(2, 768, 4096, device=torch_device, dtype=dtype_t, requires_grad=True) + + class Module(nn.Module): + def __init__(self): + super(Module, self).__init__() + self.w1 = nn.Parameter(torch.randn(4096, 4096, device=torch_device, dtype=dtype_t)) + self.w2 = nn.Parameter(torch.randn(2048, 4096, device=torch_device, dtype=dtype_t)) + + def forward(self, x): + o = torch.nn.functional.linear(x, self.w1) + added = o + x + return torch.nn.functional.linear(added, self.w2) + + model = Module() + jmodel = thunder.distributed.ddp( + thunder.jit(model, executors=[transformer_engine_ex], transforms=[TransformerEngineTransform()]) + ) + + def train(model): + iters = 10 + for _ in range(iters): + with fp8_autocast(fp8_recipe=fp8_recipe): + y = model(x) + y.backward(torch.ones_like(y)) + + train(jmodel) + + rep = getattr(jmodel, "te_reporter", None) + if rep is None: + if rank == 0: + return [AssertionError("TransformerEngine reporter not found")] + return None + + payload = { + "rank": rank, + "error": None if rep is not None else "no reporter", + "global": {} if rep is None else rep.global_ctx, + "recipes": 0 if rep is None else len(rep.recipe_summaries), + "forward": 0 if rep is None else len(rep.state_summaries_forward), + "backward": 0 if rep is None else len(rep.state_summaries_backward), + "quantizers": 0 if rep is None else len(rep.quantizer_summaries), + "has_sections": ( + { + "global": "Global Context:" in rep.render_report(), + "forward": "Forward States (" in rep.render_report(), + "backward": "Backward States (" in rep.render_report(), + } + if rep is not None + else {"global": False, "forward": False, "backward": False} + ), + } + + gathered = [None] * world_size if rank == 0 else None + tdist.gather_object(payload, object_gather_list=gathered, dst=0, group=pg) + + tdist.barrier(pg) + tdist.destroy_process_group(pg) + + if rank == 0: + exceptions = [] + for p in gathered: + print(p) + if p["error"]: + exceptions.append(AssertionError(p["error"])) + continue + assert p["has_sections"]["global"] + assert p["recipes"] == 1 + assert p["forward"] == 2 # 2 forward linear ops + assert p["backward"] == 2 # 2 backward linear ops + assert p["quantizers"] == 6 # 2 quantizers per forward linear op, 1 quantizers per backward linear op + return exceptions + return None + + # NOTE This is just a stub, see the NOTE for ddp_wrapper @instantiate( dtypes=(thunder.float32,), @@ -877,5 +970,20 @@ def test_ddp_transformer_engine_llama_sanity(executor, devices, dtype): pass +@instantiate( + dtypes=(thunder.float32,), + num_devices=2, + devicetypes=(devices.DeviceType.CUDA,), + executors=(TorchExecutor,), + decorators=( + pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 devices"), + unittest.mock.patch.dict(os.environ, {"NVTE_TORCH_COMPILE": "0"}), + ), +) +@distributed_wrapper("test_ddp_transformer_engine_reporter", _test_ddp_transformer_engine_reporter) +def test_ddp_transformer_engine_reporter(executor, devices, dtype): + pass + + if __name__ == "__main__": common_utils.run_tests() diff --git a/thunder/tests/distributed/test_fsdp.py b/thunder/tests/distributed/test_fsdp.py index 5acc79cdaa..4e762ad770 100644 --- a/thunder/tests/distributed/test_fsdp.py +++ b/thunder/tests/distributed/test_fsdp.py @@ -1237,6 +1237,112 @@ def _test_fsdp_transformer_engine_bucketing(input_data): return None +def _test_fsdp_transformer_engine_reporter(input_data): + # Test Description: + # Verify that the TEStateReporter correctly captures and reports TransformerEngine + # FP8 state information during DDP training execution, including global context, + # recipe summaries, and forward/backward state summaries across distributed processes. + + init_method, world_size, rank, executor, device, _dtype, kwargs = input_data + thunder_fsdp_strategy, intermediate_activation_sharding = kwargs["thunder_fsdp_strategy_and_intermediate_sharding"] + devicetype = devices.device_from_string(device).devicetype + + # Setting LOCAL_RANK is necessary for thunder.distributed.fsdp + with unittest.mock.patch.dict(os.environ, {"LOCAL_RANK": str(rank)}): + pg = init_per_process_distributed(init_method, devicetype, world_size, rank) + torch.cuda.set_device(rank) + torch_device = torch.device("cuda", rank) + dtype_t = torch.bfloat16 + + fp8_recipe = get_default_fp8_recipe() + + dim = 256 + + class ThunderModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(dim, dim, bias=False, device=torch_device, dtype=dtype_t) + self.fc2 = torch.nn.Linear(dim, dim, bias=False, device=torch_device, dtype=dtype_t) + self.fc3 = torch.nn.Linear(dim, dim, bias=False, device=torch_device, dtype=dtype_t) + + def forward(self, x): + return self.fc3(self.fc2(torch.nn.functional.relu(self.fc1(x)))) + + # Inputs (different input on different rank). + if rank == 0: + x = torch.arange(dim * dim, dtype=dtype_t, device=torch_device).view(dim, dim) + if rank == 1: + x = torch.randn(dim, dim, device=torch_device, dtype=dtype_t) * 100 + + model = ThunderModel() + jmodel = thunder.distributed.fsdp( + thunder.jit( + model, + executors=[ + transformer_engine_ex, + ] + + executor.executors_list(), + fp8_shard_intermediate_activation=intermediate_activation_sharding, + transforms=[TransformerEngineTransform()], + ), + sharding_strategy=thunder_fsdp_strategy, + ) + + def train(model): + iters = 10 + for _ in range(iters): + with fp8_autocast(fp8_recipe=fp8_recipe): + y = model(x) + y.backward(torch.ones_like(y)) + + train(jmodel) + + rep = getattr(jmodel, "te_reporter", None) + if rep is None: + if rank == 0: + return [AssertionError("TransformerEngine reporter not found")] + return None + + payload = { + "rank": rank, + "error": None if rep is not None else "no reporter", + "global": {} if rep is None else rep.global_ctx, + "recipes": 0 if rep is None else len(rep.recipe_summaries), + "forward": 0 if rep is None else len(rep.state_summaries_forward), + "backward": 0 if rep is None else len(rep.state_summaries_backward), + "quantizers": 0 if rep is None else len(rep.quantizer_summaries), + "has_sections": ( + { + "global": "Global Context:" in rep.render_report(), + "forward": "Forward States (" in rep.render_report(), + "backward": "Backward States (" in rep.render_report(), + } + if rep is not None + else {"global": False, "forward": False, "backward": False} + ), + } + + gathered = [None] * world_size if rank == 0 else None + tdist.gather_object(payload, object_gather_list=gathered, dst=0, group=pg) + + tdist.barrier(pg) + tdist.destroy_process_group(pg) + + if rank == 0: + exceptions = [] + for p in gathered: + if p["error"]: + exceptions.append(AssertionError(p["error"])) + continue + assert p["has_sections"]["global"] + assert p["recipes"] == 1 + assert p["forward"] == 3 # 2 forward linear ops + assert p["backward"] == 3 # 2 backward linear ops + assert p["quantizers"] == 9 # 2 quantizers per forward linear op, 1 quantizers per backward linear op + return exceptions + return None + + @instantiate( dtypes=(thunder.float32,), num_devices=2, @@ -1269,6 +1375,35 @@ def test_fsdp_transformer_engine(executor, devices, dtype, thunder_fsdp_strategy pass +@instantiate( + dtypes=(thunder.float32,), + num_devices=2, + devicetypes=(devices.DeviceType.CUDA,), + executors=(TorchExecutor,), + decorators=( + pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 devices"), + unittest.mock.patch.dict(os.environ, {"NVTE_TORCH_COMPILE": "0"}), + pytest.mark.parametrize( + "thunder_fsdp_strategy_and_intermediate_sharding", + ( + (FSDPType.ZERO2, False), + (FSDPType.ZERO3, False), + # Intermediate sharding is only availabe TE v1.8 onwards + pytest.param( + (FSDPType.ZERO3, True), + marks=pytest.mark.skip("Intermediate sharding is errors in TE 2.0 (also with eager)."), + ), + ), + ), + pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."), + pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason), + ), +) +@distributed_wrapper("test_fsdp_transformer_engine_reporter", _test_fsdp_transformer_engine_reporter) +def test_fsdp_transformer_engine_reporter(executor, devices, dtype, thunder_fsdp_strategy_and_intermediate_sharding): + pass + + @instantiate( dtypes=(thunder.float32,), num_devices=2, diff --git a/thunder/tests/test_transformer_engine_executor_reporter.py b/thunder/tests/test_transformer_engine_executor_reporter.py new file mode 100644 index 0000000000..5607f99c71 --- /dev/null +++ b/thunder/tests/test_transformer_engine_executor_reporter.py @@ -0,0 +1,297 @@ +import pytest +import torch +import torch.nn as nn + +import thunder +from thunder.tests.framework import requiresCUDA + + +# NOTE: On SM120/121, TE defaults to using Float8BlockScaling +# which is currently unsupported in thunder, we skip the tests for these SM architectures. +from thunder.tests.utils import skip_on_sm120_and_sm121, is_sm120_orsm121 + +transformer_engine_module = pytest.importorskip( + "transformer_engine", reason="transformer_engine was not found, skipping the tests." +) + +from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform +from transformer_engine.common import recipe +import transformer_engine.pytorch as te + +# FP8 is supported on compute arch 8.9 onwards. +# MXFP8 is supported on compute arch 10.0 onwards. +# Skip the tests if current hardware is not supported. +is_fp8_supported, msg_fp8 = te.fp8.check_fp8_support() +is_mxfp8_supported, msg_mxfp8 = te.fp8.check_mxfp8_support() +if not is_fp8_supported: + pytest.skip(msg_fp8, allow_module_level=True) + +hybrid_fp8_delayed_scaling_recipe = recipe.DelayedScaling() +mxfp8_e4m3_recipe = recipe.MXFP8BlockScaling() + +# `None` is used to test the default recipe. +recipes = (None, hybrid_fp8_delayed_scaling_recipe, mxfp8_e4m3_recipe) +recipe_ids = ("default", "delayed_scaling", "mxfp8_e4m3") + + +@requiresCUDA +@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids) +@skip_on_sm120_and_sm121 +def test_te_linear_forward_backward(fp8_recipe: recipe.Recipe): + if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported): + pytest.skip(msg_mxfp8) + + if is_sm120_orsm121 and fp8_recipe is None: + pytest.skip("On SM120/121, default recipe is Float8BlockScaling which is not supported") + + # Test Description: + # Verify that the TEStateReporter correctly captures and reports TransformerEngine + # FP8 state information during forward pass execution, including global context, + # recipe summaries, and forward state summaries. + + dtype = torch.bfloat16 + device = "cuda" + + # Inputs (3D input) + x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) + + class Module(nn.Module): + def __init__(self): + super(Module, self).__init__() + self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype)) + self.w2 = nn.Parameter(torch.randn(2048, 4096, device=device, dtype=dtype)) + + def forward(self, x): + o = torch.nn.functional.linear(x, self.w1) + added = o + x + return torch.nn.functional.linear(added, self.w2) + + model = Module() + + jmodel = thunder.jit(model, executors=[transformer_engine_ex], transforms=[TransformerEngineTransform()]) + + # Enable autocasting for the forward pass + with te.fp8_autocast(fp8_recipe=fp8_recipe): + y = jmodel(x) + + # Validate TE reporter populated as expected + assert hasattr(jmodel, "te_reporter"), "ThunderModule should expose te_reporter" + rep = jmodel.te_reporter + + # Global context is captured + assert rep.global_ctx is not None, "Global context should be populated" + assert "fp8_available" in rep.global_ctx + assert "mxfp8_available" in rep.global_ctx + assert "fp8_block_scaling_available" in rep.global_ctx + + # Recipes captured; type should be one of known TE recipe classes + assert len(rep.recipe_summaries) >= 1 + recipe_types = {rs.get("type") for rs in rep.recipe_summaries} + known_types = {"DelayedScaling", "Float8BlockScaling", "MXFP8BlockScaling", "Float8CurrentScaling"} + assert recipe_types & known_types, f"Unexpected recipe types collected: {recipe_types}" + + # If a specific recipe is requested, ensure it's reflected + if isinstance(fp8_recipe, recipe.DelayedScaling): + assert "DelayedScaling" in recipe_types + if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + assert "MXFP8BlockScaling" in recipe_types + + # Forward states and quantizers should be recorded; no backward states without backward pass + assert len(rep.state_summaries_forward) == 2 + assert all(ss.get("mode") in (None, "forward") for ss in rep.state_summaries_forward) + assert any(ss.get("num_quantizers") in (1, 2) for ss in rep.state_summaries_forward) + assert len(rep.state_summaries_backward) == 0 + assert len(rep.quantizer_summaries) == 4 + assert all("cls" in qs and "dtype" in qs for qs in rep.quantizer_summaries) + + # Rendered report contains key sections + report_txt = rep.render_report() + assert "Global Context:" in report_txt + assert "Recipes (" in report_txt + assert "Forward States (" in report_txt + assert "Quantizers (" in report_txt + + grad_output = torch.randn_like(y) + y.backward(grad_output) + + report_txt = rep.render_report() + # After backward pass, backward states should be recorded and reported + assert len(rep.state_summaries_forward) == 2 # Forward states not changed + assert len(rep.state_summaries_backward) == 2 + assert all(ss.get("mode") in (None, "backward") for ss in rep.state_summaries_backward) + assert "Backward States (" in report_txt + + +@requiresCUDA +@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids) +@skip_on_sm120_and_sm121 +def test_te_linear_forward_backward_multiple_iteration(fp8_recipe: recipe.Recipe): + if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported): + pytest.skip(msg_mxfp8) + + if is_sm120_orsm121 and fp8_recipe is None: + pytest.skip("On SM120/121, default recipe is Float8BlockScaling which is not supported") + + # Test Description: + # Run multiple forward/backward iterations under a single recipe configuration and + # verify that the TE reporter does not grow with the iteration count. The recipe + # list should contain one unique entry, and state/quantizer summaries should reflect + # the two linear call sites exactly once per direction, independent of iterations. + + dtype = torch.bfloat16 + device = "cuda" + + # Inputs and model + x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) + + class Module(nn.Module): + def __init__(self): + super(Module, self).__init__() + self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype)) + self.w2 = nn.Parameter(torch.randn(2048, 4096, device=device, dtype=dtype)) + + def forward(self, x): + o = torch.nn.functional.linear(x, self.w1) + added = o + x + return torch.nn.functional.linear(added, self.w2) + + model = Module() + + jmodel = thunder.jit(model, executors=[transformer_engine_ex], transforms=[TransformerEngineTransform()]) + + num_iters = 10 + for _ in range(num_iters): + # Forward under FP8 autocast + with te.fp8_autocast(fp8_recipe=fp8_recipe): + y = jmodel(x) + # Backward with unit upstream gradient + y.backward(torch.ones_like(y)) + + # Validate reporter after multiple iterations + assert hasattr(jmodel, "te_reporter") + rep = jmodel.te_reporter + + # Global context present + assert rep.global_ctx is not None + + # Recipes captured + assert len(rep.recipe_summaries) == 1 + + # Forward/backward states recorded (may be cached, so at least one each) + assert len(rep.state_summaries_forward) == 2 + assert len(rep.state_summaries_backward) == 2 + + # Quantizers observed at least once + assert len(rep.quantizer_summaries) == 6 + + # Report reflects sections + rpt = rep.render_report() + assert "Forward States (" in rpt + assert "Backward States (" in rpt + + +@requiresCUDA +def test_te_linear_forward_backward_multiple_recipies_iteration(): + # Test Description: + # Alternate between two different recipes across iterations and ensure the reporter + # records both recipe configurations exactly once each. Verify forward/backward states + # and quantizers reflect both linear call sites per recipe, independent of iteration count. + + recipes = [recipe.DelayedScaling()] + supports_mxfp8, _ = te.fp8.check_mxfp8_support() + + if supports_mxfp8: + recipes += [recipe.MXFP8BlockScaling()] + + if len(recipes) < 2: + pytest.skip("platform does not support two different recipes") + + dtype = torch.bfloat16 + device = "cuda" + + # Inputs and model + x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) + + class Module(nn.Module): + def __init__(self): + super(Module, self).__init__() + self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype)) + self.w2 = nn.Parameter(torch.randn(2048, 4096, device=device, dtype=dtype)) + + def forward(self, x): + o = torch.nn.functional.linear(x, self.w1) + added = o + x + return torch.nn.functional.linear(added, self.w2) + + model = Module() + iters = 10 + def train_model(model): + for iter_n in range(iters): + te_recipe = recipes[iter_n % 2] + y = model(x, te_recipe) + y.backward(torch.ones_like(y)) + + jmodel = thunder.jit(model, executors=[transformer_engine_ex], transforms=[TransformerEngineTransform()]) + + def thunder_model(x, fp8_recipe): + with te.fp8_autocast(fp8_recipe=fp8_recipe): + return jmodel(x) + + train_model(thunder_model) + + rep_str = jmodel.te_reporter + assert len(rep_str.recipe_summaries) == len(recipes) + assert len(rep_str.state_summaries_forward) == 4 + assert len(rep_str.state_summaries_backward) == 4 + assert len(rep_str.quantizer_summaries) == 12 + +@requiresCUDA +def test_te_linear_forward_backward_same_recipe_not_reported_twice(): + # Test Description: + # Alternate between two separate DelayedScaling instances that are equivalent in configuration. + # Ensure the reporter treats them as the same effective recipe and does not duplicate entries + # across iterations. Forward/backward states should reflect the two linear call sites once each, + # and quantizers should be counted once per site, independent of iteration count. + + delayed_scaling_recipe_a = recipe.DelayedScaling() + delayed_scaling_recipe_b = recipe.DelayedScaling() + + dtype = torch.bfloat16 + device = "cuda" + + # Inputs and model + x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) + + class Module(nn.Module): + def __init__(self): + super(Module, self).__init__() + self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype)) + self.w2 = nn.Parameter(torch.randn(2048, 4096, device=device, dtype=dtype)) + + def forward(self, x): + o = torch.nn.functional.linear(x, self.w1) + added = o + x + return torch.nn.functional.linear(added, self.w2) + + model = Module() + + def train_model(model): + # Run for `iterations`. + for iter_n in range(3): + y = model(x, delayed_scaling_recipe_a if iter_n%2 == 0 else delayed_scaling_recipe_b) + + y.backward(torch.ones_like(y)) + + jmodel = thunder.jit(model, executors=[transformer_engine_ex], transforms=[TransformerEngineTransform()]) + + def thunder_model(x, fp8_recipe = None): + with te.fp8_autocast(fp8_recipe=fp8_recipe): + return jmodel(x) + + train_model(thunder_model) + + rep_str = jmodel.te_reporter + assert len(rep_str.recipe_summaries) == 1 + assert len(rep_str.state_summaries_forward) == 4 + assert len(rep_str.state_summaries_backward) == 4 + assert len(rep_str.quantizer_summaries) == 12 From 394e9deec4ff6fe8c281036b3fef648091ab4358 Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Thu, 25 Sep 2025 14:38:04 +0200 Subject: [PATCH 02/13] TE reporter with torch compile and thunder backend test --- ...st_transformer_engine_executor_reporter.py | 99 +++++++++++++++++-- 1 file changed, 93 insertions(+), 6 deletions(-) diff --git a/thunder/tests/test_transformer_engine_executor_reporter.py b/thunder/tests/test_transformer_engine_executor_reporter.py index 5607f99c71..7c55240ddd 100644 --- a/thunder/tests/test_transformer_engine_executor_reporter.py +++ b/thunder/tests/test_transformer_engine_executor_reporter.py @@ -15,6 +15,7 @@ ) from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform +from thunder.dynamo import ThunderCompiler from transformer_engine.common import recipe import transformer_engine.pytorch as te @@ -37,7 +38,7 @@ @requiresCUDA @pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids) @skip_on_sm120_and_sm121 -def test_te_linear_forward_backward(fp8_recipe: recipe.Recipe): +def test_te_reporter_linear_forward_backward(fp8_recipe: recipe.Recipe): if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported): pytest.skip(msg_mxfp8) @@ -125,7 +126,7 @@ def forward(self, x): @requiresCUDA @pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids) @skip_on_sm120_and_sm121 -def test_te_linear_forward_backward_multiple_iteration(fp8_recipe: recipe.Recipe): +def test_te_reporter_linear_forward_backward_multiple_iteration(fp8_recipe: recipe.Recipe): if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported): pytest.skip(msg_mxfp8) @@ -191,7 +192,7 @@ def forward(self, x): @requiresCUDA -def test_te_linear_forward_backward_multiple_recipies_iteration(): +def test_te_reporter_linear_forward_backward_multiple_recipies_iteration(): # Test Description: # Alternate between two different recipes across iterations and ensure the reporter # records both recipe configurations exactly once each. Verify forward/backward states @@ -225,6 +226,7 @@ def forward(self, x): model = Module() iters = 10 + def train_model(model): for iter_n in range(iters): te_recipe = recipes[iter_n % 2] @@ -245,8 +247,9 @@ def thunder_model(x, fp8_recipe): assert len(rep_str.state_summaries_backward) == 4 assert len(rep_str.quantizer_summaries) == 12 + @requiresCUDA -def test_te_linear_forward_backward_same_recipe_not_reported_twice(): +def test_te_reporter_linear_forward_backward_same_recipe_not_reported_twice(): # Test Description: # Alternate between two separate DelayedScaling instances that are equivalent in configuration. # Ensure the reporter treats them as the same effective recipe and does not duplicate entries @@ -278,13 +281,13 @@ def forward(self, x): def train_model(model): # Run for `iterations`. for iter_n in range(3): - y = model(x, delayed_scaling_recipe_a if iter_n%2 == 0 else delayed_scaling_recipe_b) + y = model(x, delayed_scaling_recipe_a if iter_n % 2 == 0 else delayed_scaling_recipe_b) y.backward(torch.ones_like(y)) jmodel = thunder.jit(model, executors=[transformer_engine_ex], transforms=[TransformerEngineTransform()]) - def thunder_model(x, fp8_recipe = None): + def thunder_model(x, fp8_recipe=None): with te.fp8_autocast(fp8_recipe=fp8_recipe): return jmodel(x) @@ -295,3 +298,87 @@ def thunder_model(x, fp8_recipe = None): assert len(rep_str.state_summaries_forward) == 4 assert len(rep_str.state_summaries_backward) == 4 assert len(rep_str.quantizer_summaries) == 12 + + +@requiresCUDA +@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids) +@skip_on_sm120_and_sm121 +def test_te_reporter_with_torch_compile_and_thunder_backend(fp8_recipe: recipe.Recipe): + # Test Description: + # Use torch.compile with Thunder as backend (ThunderCompiler) to run the model + # under FP8 autocast. Verify that TE runtime states are exported and available + # from the Thunder-compiled subgraphs via `te_reporter`, and that forward/backward + # summaries match expectations (iteration-invariant). + + if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported): + pytest.skip(msg_mxfp8) + + if is_sm120_orsm121 and fp8_recipe is None: + pytest.skip("On SM120/121, default recipe is Float8BlockScaling which is not supported") + + dtype = torch.bfloat16 + device = "cuda" + + x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) + + class Module(nn.Module): + def __init__(self): + super(Module, self).__init__() + self.attention = nn.MultiheadAttention(4096, 64, device=device, dtype=dtype, batch_first=True) + self.norm1 = nn.LayerNorm(4096, device=device, dtype=dtype) + self.norm2 = nn.LayerNorm(4096, device=device, dtype=dtype) + self.mlp = nn.Sequential( + nn.Linear(4096, 16384, device=device, dtype=dtype), + nn.GELU(), + nn.Linear(16384, 4096, device=device, dtype=dtype), + ) + + def forward(self, x): + attn_out, _ = self.attention(x, x, x) + x = self.norm1(x + attn_out) + mlp_out = self.mlp(x) + x = self.norm2(x + mlp_out) + return x + + model = Module() + + # Compile with torch.compile using Thunder as backend + backend = ThunderCompiler(executors=[transformer_engine_ex], transforms=[TransformerEngineTransform()]) + compiled_model = torch.compile(model, backend=backend) + + # Run one forward/backward under FP8 autocast + def train_model(model): + iters = 10 + for _ in range(iters): + with te.fp8_autocast(fp8_recipe=fp8_recipe): + y = model(x) + y.backward(torch.ones_like(y)) + + train_model(compiled_model) + + print(compiled_model.__class__) + + # Collect TE reporters from Thunder-compiled subgraphs + reporters = [] + for sinfo in backend.subgraph_infos: + if sinfo.thunder_compiled_fns: + for fn in sinfo.thunder_compiled_fns: + if hasattr(fn, "te_reporter"): + reporters.append(fn.te_reporter) + + # We expect at least one Thunder subgraph using TE + assert len(reporters) >= 1 + + # Aggregate counts across subgraphs + total_recipes = sum(len(r.recipe_summaries) for r in reporters) + total_fw_states = sum(len(r.state_summaries_forward) for r in reporters) + total_bw_states = sum(len(r.state_summaries_backward) for r in reporters) + total_quantizers = sum(len(r.quantizer_summaries) for r in reporters) + + # Recipe presence + assert total_recipes >= 1 + # Two linear call sites leading to two forward and two backward states in total + assert total_fw_states == 2 + assert total_bw_states == 2 + # Quantizers (2 per forward, 1 per backward site leading to 6 total) + assert total_quantizers == 6 From 09db0bdd5066317c9046aefda04d67ce5add6037 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Sep 2025 12:40:42 +0000 Subject: [PATCH 03/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/dev_utils/te_states_reporter.py | 41 ++++++++++--------- .../executors/transformer_engineex_impl.py | 7 +++- thunder/tests/distributed/test_ddp.py | 2 +- ...st_transformer_engine_executor_reporter.py | 10 ++--- 4 files changed, 32 insertions(+), 28 deletions(-) diff --git a/thunder/dev_utils/te_states_reporter.py b/thunder/dev_utils/te_states_reporter.py index b8ba4754ec..f0d999bdc3 100644 --- a/thunder/dev_utils/te_states_reporter.py +++ b/thunder/dev_utils/te_states_reporter.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple +from typing import Any +from collections.abc import Sequence import thunder import torch @@ -16,7 +17,7 @@ ) -def summarize_recipe(recipe: Recipe) -> Dict[str, Any]: +def summarize_recipe(recipe: Recipe) -> dict[str, Any]: """Create a compact, serializable summary of a TE FP8 recipe. The summary captures the recipe class name and a small set of key @@ -31,7 +32,7 @@ def summarize_recipe(recipe: Recipe) -> Dict[str, Any]: Returns: A dictionary with fields describing the recipe. """ - summary: Dict[str, Any] = { + summary: dict[str, Any] = { "type": recipe.__class__.__name__, "fp8_format": getattr(recipe, "fp8_format", None), } @@ -94,7 +95,7 @@ def summarize_recipe(recipe: Recipe) -> Dict[str, Any]: return summary -def summarize_state(state: RecipeState) -> Dict[str, Any]: +def summarize_state(state: RecipeState) -> dict[str, Any]: """Summarize a runtime FP8 `RecipeState` object. Captures the state class, mode (forward/backward/None), dtype, number of @@ -107,7 +108,7 @@ def summarize_state(state: RecipeState) -> Dict[str, Any]: Returns: A dictionary with essential metadata about the state for reporting. """ - out: Dict[str, Any] = { + out: dict[str, Any] = { "cls": state.__class__.__name__, "mode": getattr(state, "mode", None), "dtype": str(getattr(state, "dtype", None)), @@ -124,7 +125,7 @@ def summarize_state(state: RecipeState) -> Dict[str, Any]: return out -def summarize_quantizer(quantizer: Any) -> Dict[str, Any]: +def summarize_quantizer(quantizer: Any) -> dict[str, Any]: """Summarize an FP8 quantizer instance. Extracts commonly useful fields across different quantizer implementations @@ -138,7 +139,7 @@ def summarize_quantizer(quantizer: Any) -> Dict[str, Any]: Returns: A dictionary describing the quantizer in a compact, readable form. """ - base: Dict[str, Any] = { + base: dict[str, Any] = { "cls": quantizer.__class__.__name__, "rowwise_usage": getattr(quantizer, "rowwise_usage", None), "columnwise_usage": getattr(quantizer, "columnwise_usage", None), @@ -162,7 +163,7 @@ def summarize_quantizer(quantizer: Any) -> Dict[str, Any]: return base -def build_global_context() -> Dict[str, Any]: +def build_global_context() -> dict[str, Any]: """Collect global FP8 runtime context and environment details. Queries `FP8GlobalStateManager` and related sources to produce a stable @@ -232,23 +233,23 @@ def build_global_context() -> Dict[str, Any]: class TEStateReporter: """Accumulates TE runtime summaries and renders a report.""" - global_ctx: Optional[Dict[str, Any]] = None - recipe_summaries: List[Dict[str, Any]] = field(default_factory=list) - state_summaries_forward: List[Dict[str, Any]] = field(default_factory=list) - state_summaries_backward: List[Dict[str, Any]] = field(default_factory=list) - quantizer_summaries: List[Dict[str, Any]] = field(default_factory=list) - shape_policy: Dict[str, Any] = field(default_factory=lambda: {"mxfp8_block": MXFP8_BLOCK_SCALING_SIZE}) - seen_fw_states: Set[Tuple[int, int]] = field(default_factory=set) - seen_bw_states: Set[Tuple[int, int]] = field(default_factory=set) - seen_quantizers: Set[Tuple[int, int]] = field(default_factory=set) + global_ctx: dict[str, Any] | None = None + recipe_summaries: list[dict[str, Any]] = field(default_factory=list) + state_summaries_forward: list[dict[str, Any]] = field(default_factory=list) + state_summaries_backward: list[dict[str, Any]] = field(default_factory=list) + quantizer_summaries: list[dict[str, Any]] = field(default_factory=list) + shape_policy: dict[str, Any] = field(default_factory=lambda: {"mxfp8_block": MXFP8_BLOCK_SCALING_SIZE}) + seen_fw_states: set[tuple[int, int]] = field(default_factory=set) + seen_bw_states: set[tuple[int, int]] = field(default_factory=set) + seen_quantizers: set[tuple[int, int]] = field(default_factory=set) def update_from_runtime( self, *, holder, - recipe: Optional[Recipe] = None, + recipe: Recipe | None = None, states: Sequence[RecipeState] | None = None, - mode: Optional[str] = None, + mode: str | None = None, quantizers: Sequence[Any] | None = None, ) -> None: """Update the reporter with data observed during runtime. @@ -309,7 +310,7 @@ def render_report(self) -> str: Returns: A formatted string suitable for console logging or test output. """ - lines: List[str] = [] + lines: list[str] = [] def add(line: str = "") -> None: lines.append(line) diff --git a/thunder/executors/transformer_engineex_impl.py b/thunder/executors/transformer_engineex_impl.py index b8e1802458..9c27ce0253 100644 --- a/thunder/executors/transformer_engineex_impl.py +++ b/thunder/executors/transformer_engineex_impl.py @@ -48,7 +48,8 @@ transformer_engine_ex = StatefulExecutor("transformer_engine") register_executor(transformer_engine_ex) -def _export_te_states(*, recipe=None, states=None, mode = None, quantizers=None, holder=None): + +def _export_te_states(*, recipe=None, states=None, mode=None, quantizers=None, holder=None): tm = None if holder is not None: tm_ref = getattr(holder, "_tm_ref", None) @@ -58,6 +59,7 @@ def _export_te_states(*, recipe=None, states=None, mode = None, quantizers=None, # Cannot fallback to compile data as during execution we don't have compile_data context if tm is None: import warnings + warnings.warn("No ThunderModule found for exporting TE states", UserWarning) return @@ -72,6 +74,7 @@ def _export_te_states(*, recipe=None, states=None, mode = None, quantizers=None, quantizers=quantizers, ) + def _te_fp8_recipe_meta() -> AnyProxy: return AnyProxy(None, prefix="r") @@ -119,7 +122,7 @@ def __call__(self, recipe_state: RecipeState, num_quantizers: int) -> list[Quant self.quantizers = quantizers self.parent_recipe_state = recipe_state - + # Export only new quantizers _export_te_states(recipe=recipe_state.recipe, quantizers=quantizers, holder=self) diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index 9ee950212d..6d43a926bd 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -797,7 +797,7 @@ def _test_ddp_transformer_engine_reporter(input_data): class Module(nn.Module): def __init__(self): - super(Module, self).__init__() + super().__init__() self.w1 = nn.Parameter(torch.randn(4096, 4096, device=torch_device, dtype=dtype_t)) self.w2 = nn.Parameter(torch.randn(2048, 4096, device=torch_device, dtype=dtype_t)) diff --git a/thunder/tests/test_transformer_engine_executor_reporter.py b/thunder/tests/test_transformer_engine_executor_reporter.py index 7c55240ddd..12f9a07017 100644 --- a/thunder/tests/test_transformer_engine_executor_reporter.py +++ b/thunder/tests/test_transformer_engine_executor_reporter.py @@ -58,7 +58,7 @@ def test_te_reporter_linear_forward_backward(fp8_recipe: recipe.Recipe): class Module(nn.Module): def __init__(self): - super(Module, self).__init__() + super().__init__() self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype)) self.w2 = nn.Parameter(torch.randn(2048, 4096, device=device, dtype=dtype)) @@ -147,7 +147,7 @@ def test_te_reporter_linear_forward_backward_multiple_iteration(fp8_recipe: reci class Module(nn.Module): def __init__(self): - super(Module, self).__init__() + super().__init__() self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype)) self.w2 = nn.Parameter(torch.randn(2048, 4096, device=device, dtype=dtype)) @@ -215,7 +215,7 @@ def test_te_reporter_linear_forward_backward_multiple_recipies_iteration(): class Module(nn.Module): def __init__(self): - super(Module, self).__init__() + super().__init__() self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype)) self.w2 = nn.Parameter(torch.randn(2048, 4096, device=device, dtype=dtype)) @@ -267,7 +267,7 @@ def test_te_reporter_linear_forward_backward_same_recipe_not_reported_twice(): class Module(nn.Module): def __init__(self): - super(Module, self).__init__() + super().__init__() self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype)) self.w2 = nn.Parameter(torch.randn(2048, 4096, device=device, dtype=dtype)) @@ -323,7 +323,7 @@ def test_te_reporter_with_torch_compile_and_thunder_backend(fp8_recipe: recipe.R class Module(nn.Module): def __init__(self): - super(Module, self).__init__() + super().__init__() self.attention = nn.MultiheadAttention(4096, 64, device=device, dtype=dtype, batch_first=True) self.norm1 = nn.LayerNorm(4096, device=device, dtype=dtype) self.norm2 = nn.LayerNorm(4096, device=device, dtype=dtype) From 536c282521165e617ddfe05344326308b6a126b0 Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Fri, 26 Sep 2025 11:55:07 +0200 Subject: [PATCH 04/13] Setting fp8 block scaling default availability value --- thunder/dev_utils/te_states_reporter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/thunder/dev_utils/te_states_reporter.py b/thunder/dev_utils/te_states_reporter.py index f0d999bdc3..1900579b5c 100644 --- a/thunder/dev_utils/te_states_reporter.py +++ b/thunder/dev_utils/te_states_reporter.py @@ -186,7 +186,8 @@ def build_global_context() -> dict[str, Any]: # Availability and reasons fp8_avail, reason_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_avail, reason_no_mx = FP8GlobalStateManager.is_mxfp8_available() - fp8blk_avail, reason_no_blk = FP8GlobalStateManager.is_fp8_block_scaling_available() + # Thunder does not support fp8 block scaling: https://github.com/Lightning-AI/lightning-thunder/issues/2476 + fp8blk_avail, reason_no_blk = (False, "Thunder does not support fp8 block scaling") # Versions / device cuda_version = getattr(torch.version, "cuda", None) From a4157d5b423d96cf7a4245103f7fd3338101e161 Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Mon, 29 Sep 2025 07:59:31 +0000 Subject: [PATCH 05/13] Added missing TE availability guard --- thunder/tests/distributed/test_ddp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index 6d43a926bd..99e1ce4b1f 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -976,6 +976,8 @@ def test_ddp_transformer_engine_llama_sanity(executor, devices, dtype): devicetypes=(devices.DeviceType.CUDA,), executors=(TorchExecutor,), decorators=( + pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."), + pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason), pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 devices"), unittest.mock.patch.dict(os.environ, {"NVTE_TORCH_COMPILE": "0"}), ), From b5090f037753aa3cb5e1aae3e72ceb53fb54dd74 Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Thu, 2 Oct 2025 15:27:33 +0000 Subject: [PATCH 06/13] Generalized stateful executor state export through a centralzied transform --- thunder/__init__.py | 40 +- .../dev_utils/export_stateful_ex_transform.py | 64 +++ thunder/dev_utils/te_states_reporter.py | 414 ------------------ .../executors/transformer_engineex_impl.py | 218 ++++++--- thunder/tests/distributed/test_ddp.py | 62 ++- thunder/tests/distributed/test_fsdp.py | 46 +- .../test_export_stateful_ex_transform.py | 323 ++++++++++++++ ...st_transformer_engine_executor_reporter.py | 384 ---------------- 8 files changed, 622 insertions(+), 929 deletions(-) create mode 100644 thunder/dev_utils/export_stateful_ex_transform.py delete mode 100644 thunder/dev_utils/te_states_reporter.py create mode 100644 thunder/tests/test_export_stateful_ex_transform.py delete mode 100644 thunder/tests/test_transformer_engine_executor_reporter.py diff --git a/thunder/__init__.py b/thunder/__init__.py index 662336b129..fe95b7ec48 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -55,6 +55,7 @@ unwrap_return_value, wrap_return_value_together_with_arguments, ) +from thunder.dev_utils.export_stateful_ex_transform import ExportStatefulExecutorsTransform from thunder.core.update_aliases import insert_alias_updates from thunder.executors.torch_autograd import connect_to_autograd import thunder.extend as extend @@ -845,6 +846,28 @@ def wrapped(*args, **kwargs): # For more context see `NOTE: Split autograd.Function` disable_split_autograd: bool = compile_options.get("thunderfx_disable_split_autograd", False) + def wrapped_compiled_fn(fn: Callable, mode: str, *args): + """ + Wraps the compiled function to run the export stateful executors transform after the function is executed. + Stateful executors materialize information about the state at runtime. + """ + out = fn(*args) + cs = weakref_cs() + cd = weakref_cd() + assert cd is not None, "cd has been cleared." + assert cs is not None, "cs has been cleared." + if cs is None or cd is None: + return out + + trace = cs.last_traces[-1] if mode == "forward" else cs.last_backward_traces[-1] + if not trace: + return out + for transform in cd.transforms or []: + if isinstance(transform, ExportStatefulExecutorsTransform): + transform.transform_trace_post_execution(trace) + break + return out + def maybe_connect_to_autograd(cache_entry, result): if cache_entry.backward_fn: # If the backward function is available, we need to connect the @@ -854,8 +877,11 @@ def maybe_connect_to_autograd(cache_entry, result): is_differentiable_outputs = compile_options.get("is_differentiable_outputs", None) + def wrapped_backward_fn(saved_and_other, args): + return wrapped_compiled_fn(cache_entry.backward_fn, "backward", saved_and_other, args) + connect_to_autograd( - backward_fn=cache_entry.backward_fn, + backward_fn=cache_entry.backward_fn if not has_to_export() else wrapped_backward_fn, flat_args=data_for_autograd["flat_args"], flat_output=data_for_autograd["flat_output"], saved_tensors=saved_tensors, @@ -872,6 +898,11 @@ def call_epilogue(cache_entry, comp_result, pro_to_epi): result = cache_entry.epilogue_fn(*pro_to_epi, *comp_result) return result + def has_to_export(): + cd = weakref_cd() + assert cd is not None, "cd has been cleared." + return any(isinstance(t, ExportStatefulExecutorsTransform) for t in cd.transforms or []) + @wraps(fn) @update_call_statistics def fn_(*args, **kwargs) -> Any: @@ -886,7 +917,12 @@ def fn_(*args, **kwargs) -> Any: cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs) - result = cache_entry.computation_fn(*inps) + def wrapped_computation_fn(*args): + return wrapped_compiled_fn(cache_entry.computation_fn, "forward", *args) + + computation_fn = cache_entry.computation_fn if not has_to_export() else wrapped_computation_fn + + result = computation_fn(*inps) result = maybe_connect_to_autograd(cache_entry, result) result = call_epilogue(cache_entry, result, pro_to_epi) diff --git a/thunder/dev_utils/export_stateful_ex_transform.py b/thunder/dev_utils/export_stateful_ex_transform.py new file mode 100644 index 0000000000..cec3628d79 --- /dev/null +++ b/thunder/dev_utils/export_stateful_ex_transform.py @@ -0,0 +1,64 @@ +import weakref +from typing import Callable, Dict + +from thunder.core.transform_common import ( + Transform, +) +from thunder.core.trace import TraceCtx as Trace + + +class ExportStatefulExecutorsTransform(Transform): + """Export runtime state from stateful executors after a trace executes. + + - Singleton transform with a registry of export callbacks + - Register via `register_export_callback(name, callback)` + - Callbacks receive `(computation_trace, thunder_module)` and may attach + serialized state to the module (e.g., `module.te_fp8_stats`) + - Safe: callback errors are swallowed; export never blocks execution + + Example (TransformerEngine): a callback collects FP8 amax/scale and + quantizer metadata from `python_ctx` and records them under + `module.te_fp8_stats = {"forward": [...], "backward": [...]}`. + + Usage: + 1) Register once at import/init time: + ExportStatefulExecutorsTransform.register_export_callback("my_exec", my_cb) + 2) Enable at compile time: + thunder.jit(model, executors=[...], transforms=[..., ExportStatefulExecutorsTransform()]) + 3) Read exported fields from the compiled module in tests/tools. + """ + + _instance = None + _callbacks: Dict[str, Callable] = {} + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + self.tm_ref = None + + @classmethod + def register_export_callback(cls, name: str, callback: Callable) -> None: + cls._callbacks[name] = callback + + def transform_module(self, model) -> None: + # Cache a weakref to the ThunderModule for later runtime export + self.tm_ref = weakref.ref(model) + + def transform_trace_post_execution(self, computation_trace: Trace, **kwargs): + # Resolve ThunderModule from weakref; if unavailable, skip + tm = self.tm_ref() if self.tm_ref is not None else None + if tm is None: + return computation_trace + + # Invoke all registered export callbacks. + for _, cb in self._callbacks.items(): + try: + cb(computation_trace, tm) + except Exception: + # Swallow errors from individual exporters to avoid breaking execution. + pass + + return computation_trace diff --git a/thunder/dev_utils/te_states_reporter.py b/thunder/dev_utils/te_states_reporter.py deleted file mode 100644 index 1900579b5c..0000000000 --- a/thunder/dev_utils/te_states_reporter.py +++ /dev/null @@ -1,414 +0,0 @@ -from dataclasses import dataclass, field -from typing import Any -from collections.abc import Sequence - -import thunder -import torch - -import transformer_engine as te - -import transformer_engine_torch -from transformer_engine.common.recipe import Recipe -from transformer_engine.pytorch.constants import MXFP8_BLOCK_SCALING_SIZE -from transformer_engine.pytorch.fp8 import ( - FP8GlobalStateManager, - RecipeState, - get_fp8_torch_dtype, -) - - -def summarize_recipe(recipe: Recipe) -> dict[str, Any]: - """Create a compact, serializable summary of a TE FP8 recipe. - - The summary captures the recipe class name and a small set of key - configuration fields depending on the recipe family (DelayedScaling, - Float8CurrentScaling, MXFP8BlockScaling, Float8BlockScaling). For delayed - and current-scaling variants, the effective FP8 torch dtypes for forward - and backward are also included. - - Args: - recipe: A TransformerEngine `Recipe` instance from `transformer_engine.common.recipe`. - - Returns: - A dictionary with fields describing the recipe. - """ - summary: dict[str, Any] = { - "type": recipe.__class__.__name__, - "fp8_format": getattr(recipe, "fp8_format", None), - } - - if recipe.delayed(): - summary.update( - { - "margin": getattr(recipe, "margin", None), - "amax_history_len": getattr(recipe, "amax_history_len", None), - "amax_compute_algo": getattr(recipe, "amax_compute_algo", None), - "scaling_factor_compute_algo": getattr(recipe, "scaling_factor_compute_algo", None), - "reduce_amax": getattr(recipe, "reduce_amax", None), - "fp8_dpa": getattr(recipe, "fp8_dpa", None), - "fp8_mha": getattr(recipe, "fp8_mha", None), - # Effective FP8 dtypes per pass - "fwd_fp8_torch_dtype": str(get_fp8_torch_dtype(recipe, True)), - "bwd_fp8_torch_dtype": str(get_fp8_torch_dtype(recipe, False)), - } - ) - elif recipe.float8_current_scaling(): - summary.update( - { - "fp8_quant_fwd_inp": getattr(recipe, "fp8_quant_fwd_inp", None), - "fp8_quant_fwd_weight": getattr(recipe, "fp8_quant_fwd_weight", None), - "fp8_quant_bwd_grad": getattr(recipe, "fp8_quant_bwd_grad", None), - "fp8_gemm_fprop": getattr(recipe, "fp8_gemm_fprop", None), - "fp8_gemm_dgrad": getattr(recipe, "fp8_gemm_dgrad", None), - "fp8_gemm_wgrad": getattr(recipe, "fp8_gemm_wgrad", None), - "fp8_dpa": getattr(recipe, "fp8_dpa", None), - "fp8_mha": getattr(recipe, "fp8_mha", None), - "fwd_fp8_torch_dtype": str(get_fp8_torch_dtype(recipe, True)), - "bwd_fp8_torch_dtype": str(get_fp8_torch_dtype(recipe, False)), - } - ) - elif recipe.mxfp8(): - summary.update( - { - "margin": getattr(recipe, "margin", None), - "fwd_fp8_torch_dtype": str(get_fp8_torch_dtype(recipe, True)), - "bwd_fp8_torch_dtype": str(get_fp8_torch_dtype(recipe, False)), - } - ) - elif recipe.float8_block_scaling(): - summary.update( - { - "x_block_scaling_dim": getattr(recipe, "x_block_scaling_dim", None), - "w_block_scaling_dim": getattr(recipe, "w_block_scaling_dim", None), - "grad_block_scaling_dim": getattr(recipe, "grad_block_scaling_dim", None), - "fp8_quant_fwd_inp": getattr(recipe, "fp8_quant_fwd_inp", None), - "fp8_quant_fwd_weight": getattr(recipe, "fp8_quant_fwd_weight", None), - "fp8_quant_bwd_grad": getattr(recipe, "fp8_quant_bwd_grad", None), - "fp8_gemm_fprop": getattr(recipe, "fp8_gemm_fprop", None), - "fp8_gemm_dgrad": getattr(recipe, "fp8_gemm_dgrad", None), - "fp8_gemm_wgrad": getattr(recipe, "fp8_gemm_wgrad", None), - "fwd_fp8_torch_dtype": str(get_fp8_torch_dtype(recipe, True)), - "bwd_fp8_torch_dtype": str(get_fp8_torch_dtype(recipe, False)), - } - ) - - return summary - - -def summarize_state(state: RecipeState) -> dict[str, Any]: - """Summarize a runtime FP8 `RecipeState` object. - - Captures the state class, mode (forward/backward/None), dtype, number of - quantizers, and optionally basic tensor shape/device information for scale - and amax history tensors when present. - - Args: - state: A `RecipeState` produced by TransformerEngine during execution. - - Returns: - A dictionary with essential metadata about the state for reporting. - """ - out: dict[str, Any] = { - "cls": state.__class__.__name__, - "mode": getattr(state, "mode", None), - "dtype": str(getattr(state, "dtype", None)), - "num_quantizers": getattr(state, "num_quantizers", None), - } - scale = getattr(state, "scale", None) - if isinstance(scale, torch.Tensor): - out["scale_shape"] = tuple(scale.shape) - out["scale_device"] = str(scale.device) - amax_hist = getattr(state, "amax_history", None) - if isinstance(amax_hist, torch.Tensor): - out["amax_history_shape"] = tuple(amax_hist.shape) - out["amax_history_device"] = str(amax_hist.device) - return out - - -def summarize_quantizer(quantizer: Any) -> dict[str, Any]: - """Summarize an FP8 quantizer instance. - - Extracts commonly useful fields across different quantizer implementations - (rowwise/columnwise usage, internal flag, dtype) and, when available, - additional configuration such as amax reduction info. Tensor shape/device - metadata for `scale` and `amax` is included if present. - - Args: - quantizer: A quantizer-like object from TransformerEngine runtime. - - Returns: - A dictionary describing the quantizer in a compact, readable form. - """ - base: dict[str, Any] = { - "cls": quantizer.__class__.__name__, - "rowwise_usage": getattr(quantizer, "rowwise_usage", None), - "columnwise_usage": getattr(quantizer, "columnwise_usage", None), - "internal": getattr(quantizer, "internal", None), - "dtype": str(getattr(quantizer, "dtype", None)), - } - # Optional attributes by quantizer class - if hasattr(quantizer, "with_amax_reduction"): - base["with_amax_reduction"] = getattr(quantizer, "with_amax_reduction") - base["amax_reduction_group"] = str(getattr(quantizer, "amax_reduction_group", None)) - if hasattr(quantizer, "force_pow_2_scales"): - base["force_pow_2_scales"] = getattr(quantizer, "force_pow_2_scales") - if hasattr(quantizer, "amax_epsilon"): - base["amax_epsilon"] = getattr(quantizer, "amax_epsilon") - # Shapes (when available) - for attr in ("scale", "amax"): - tensor = getattr(quantizer, attr, None) - if isinstance(tensor, torch.Tensor): - base[f"{attr}_shape"] = tuple(tensor.shape) - base[f"{attr}_device"] = str(tensor.device) - return base - - -def build_global_context() -> dict[str, Any]: - """Collect global FP8 runtime context and environment details. - - Queries `FP8GlobalStateManager` and related sources to produce a stable - snapshot of the current FP8 configuration and availability, along with - environment metadata such as CUDA/cuBLASLt versions, device compute - capability, world size, and package versions. The result is intended to be - recorded once per session for reporting correlation. - - Returns: - A dictionary of global context fields suitable for rendering in reports. - """ - fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() - fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() - with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() - high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val() - fp8_group = FP8GlobalStateManager.get_fp8_group() - autocast_depth = FP8GlobalStateManager.FP8_AUTOCAST_DEPTH - graph_capturing = FP8GlobalStateManager.fp8_graph_capturing() - - # Availability and reasons - fp8_avail, reason_no_fp8 = FP8GlobalStateManager.is_fp8_available() - mxfp8_avail, reason_no_mx = FP8GlobalStateManager.is_mxfp8_available() - # Thunder does not support fp8 block scaling: https://github.com/Lightning-AI/lightning-thunder/issues/2476 - fp8blk_avail, reason_no_blk = (False, "Thunder does not support fp8 block scaling") - - # Versions / device - cuda_version = getattr(torch.version, "cuda", None) - cublaslt_version = transformer_engine_torch.get_cublasLt_version() - device_cc = torch.cuda.get_device_capability() if torch.cuda.is_available() else None - - # Dist info - world_size = None - if torch.distributed.is_available() and torch.distributed.is_initialized(): - try: - world_size = torch.distributed.get_world_size(group=fp8_group) - except (RuntimeError, ValueError, TypeError): - world_size = torch.distributed.get_world_size() - - # Package versions - te_version = getattr(te, "__version__", None) if te is not None else None - thunder_version = None - thunder_version = getattr(thunder, "__version__", None) - - return { - "fp8_enabled": fp8_enabled, - "fp8_calibration": fp8_calibration, - "with_fp8_parameters": with_fp8_params, - "high_precision_init_val": high_precision_init_val, - "fp8_group": str(fp8_group), - "world_size": world_size, - "autocast_depth": autocast_depth, - "graph_capturing": graph_capturing, - "fp8_available": fp8_avail, - "reason_no_fp8": reason_no_fp8, - "mxfp8_available": mxfp8_avail, - "reason_no_mxfp8": reason_no_mx, - "fp8_block_scaling_available": fp8blk_avail, - "reason_no_fp8_block_scaling": reason_no_blk, - "cuda_version": cuda_version, - "cublaslt_version": cublaslt_version, - "device_compute_capability": device_cc, - "te_version": te_version, - "thunder_version": thunder_version, - } - - -@dataclass -class TEStateReporter: - """Accumulates TE runtime summaries and renders a report.""" - - global_ctx: dict[str, Any] | None = None - recipe_summaries: list[dict[str, Any]] = field(default_factory=list) - state_summaries_forward: list[dict[str, Any]] = field(default_factory=list) - state_summaries_backward: list[dict[str, Any]] = field(default_factory=list) - quantizer_summaries: list[dict[str, Any]] = field(default_factory=list) - shape_policy: dict[str, Any] = field(default_factory=lambda: {"mxfp8_block": MXFP8_BLOCK_SCALING_SIZE}) - seen_fw_states: set[tuple[int, int]] = field(default_factory=set) - seen_bw_states: set[tuple[int, int]] = field(default_factory=set) - seen_quantizers: set[tuple[int, int]] = field(default_factory=set) - - def update_from_runtime( - self, - *, - holder, - recipe: Recipe | None = None, - states: Sequence[RecipeState] | None = None, - mode: str | None = None, - quantizers: Sequence[Any] | None = None, - ) -> None: - """Update the reporter with data observed during runtime. - - This method is called one or more times during forward/backward passes - to incrementally collect summaries. The first invocation also captures - the global context snapshot. - - Args: - holder: The holder object (TERecipe, TERecipeState, or TEQuantizerState) - that owns the runtime data being reported. - recipe: Optional recipe active for the current autocast session. - states: Optional sequence of `RecipeState` objects observed. - mode: Optional mode string ("forward" or "backward") indicating the - execution phase when states are captured. - quantizers: Optional sequence of quantizer objects observed. - """ - if self.global_ctx is None: - self.global_ctx = build_global_context() - - # Collect recipe summaries only when called from the main recipe holder (TERecipe class). - # This avoids duplicate recipe entries when the same recipe is referenced by quantizer - # or state holders, ensuring we track each unique recipe configuration exactly once. - if recipe is not None and not quantizers: - summary = summarize_recipe(recipe) - if summary not in self.recipe_summaries: - self.recipe_summaries.append(summary) - - # Each trace execution can contain multiple forward and backward states for different recipes. - # We track unique combinations of (holder_id, recipe_id) to avoid duplicate state summaries - # while ensuring we capture all distinct recipe configurations used during runtime. - if states: - if mode == "forward": - if (id(holder), id(recipe)) not in self.seen_fw_states: - self.seen_fw_states.add((id(holder), id(recipe))) - self.state_summaries_forward.extend(summarize_state(s) for s in states) - elif mode == "backward": - if (id(holder), id(recipe)) not in self.seen_bw_states: - self.seen_bw_states.add((id(holder), id(recipe))) - self.state_summaries_backward.extend(summarize_state(s) for s in states) - - # Quantizers are reused across multiple trace executions but their behavior depends on the active recipe. - # While the quantizer object instances remain the same, different recipes can affect their configuration - # and internal state. We track unique combinations of (holder_id, recipe_id) to ensure we capture - # quantizer summaries for each distinct recipe configuration, avoiding both duplicates and missed - # configurations when recipes change during runtime. - if quantizers: - if (id(holder), id(recipe)) not in self.seen_quantizers: - self.seen_quantizers.add((id(holder), id(recipe))) - self.quantizer_summaries.extend(summarize_quantizer(q) for q in quantizers) - - def render_report(self) -> str: - """Render a human-readable multi-section report of collected data. - - The report includes global context, recipes, forward/backward state - summaries, quantizer summaries, and shape policy information. - - Returns: - A formatted string suitable for console logging or test output. - """ - lines: list[str] = [] - - def add(line: str = "") -> None: - lines.append(line) - - # Global Context - ctx = self.global_ctx or {} - add("Global Context:") - add(f" • FP8 Enabled: {ctx.get('fp8_enabled')}") - add(f" • FP8 Calibration: {ctx.get('fp8_calibration')}") - add(f" • FP8 Parameters: {ctx.get('with_fp8_parameters')}") - add(f" • High Precision Init: {ctx.get('high_precision_init_val')}") - add(f" • FP8 Group: {ctx.get('fp8_group')}") - add(f" • World Size: {ctx.get('world_size')}") - add(f" • Autocast Depth: {ctx.get('autocast_depth')}") - add(f" • Graph Capturing: {ctx.get('graph_capturing')}") - add("") - add(" Availability:") - add(f" - FP8: {ctx.get('fp8_available')}") - add(f" - MXFP8: {ctx.get('mxfp8_available')}") - add(f" - FP8 Block Scaling: {ctx.get('fp8_block_scaling_available')}") - if not ctx.get("fp8_block_scaling_available", True): - add(f" Reason: {ctx.get('reason_no_fp8_block_scaling')}") - add("") - add(" Versions:") - add(f" - CUDA: {ctx.get('cuda_version')} cuBLASLt: {ctx.get('cublaslt_version')}") - add(f" - Compute Capability: {ctx.get('device_compute_capability')}") - add(f" - TransformerEngine: {ctx.get('te_version')} Thunder: {ctx.get('thunder_version')}") - add("") - - # Recipes - add(f"Recipes ({len(self.recipe_summaries)}):") - for idx, rs in enumerate(self.recipe_summaries, 1): - add(f" [{idx}] {rs.get('type')} - {rs.get('fp8_format')}") - # Print a compact subset - for key in ( - "margin", - "amax_history_len", - "amax_compute_algo", - "reduce_amax", - "fp8_dpa", - "fp8_mha", - "fwd_fp8_torch_dtype", - "bwd_fp8_torch_dtype", - "x_block_scaling_dim", - "w_block_scaling_dim", - "grad_block_scaling_dim", - ): - if rs.get(key) is not None: - add(f" {key}: {rs.get(key)}") - add("") - - # States - add(f"Forward States ({len(self.state_summaries_forward)}):") - for idx, ss in enumerate(self.state_summaries_forward, 1): - add(f" [{idx}] Mode: {ss.get('mode')} DType: {ss.get('dtype')} Quantizers: {ss.get('num_quantizers')}") - if ss.get("scale_shape") is not None: - add(f" Scale: {ss.get('scale_shape')} on {ss.get('scale_device')}") - else: - add(" Note: no per-tensor scale (likely MXFP8/blockwise)") - if ss.get("amax_history_shape") is not None: - add(f" Amax History: {ss.get('amax_history_shape')} on {ss.get('amax_history_device')}") - add("") - - # Backward States (if any) - if self.state_summaries_backward: - add(f"Backward States ({len(self.state_summaries_backward)}):") - for idx, ss in enumerate(self.state_summaries_backward, 1): - add( - f" [{idx}] Mode: {ss.get('mode')} DType: {ss.get('dtype')} Quantizers: {ss.get('num_quantizers')}" - ) - if ss.get("scale_shape") is not None: - add(f" Scale: {ss.get('scale_shape')} on {ss.get('scale_device')}") - if ss.get("amax_history_shape") is not None: - add(f" Amax History: {ss.get('amax_history_shape')} on {ss.get('amax_history_device')}") - add("") - - # Quantizers - add(f"Quantizers ({len(self.quantizer_summaries)}):") - for idx, qs in enumerate(self.quantizer_summaries, 1): - add( - f" [{idx}] {qs.get('cls')} - {qs.get('dtype')}\n" - f" Rowwise: {qs.get('rowwise_usage')}\n" - f" Columnwise: {qs.get('columnwise_usage')}\n" - f" Internal: {qs.get('internal')}" - ) - if qs.get("with_amax_reduction") is not None: - add(f" Amax Reduction: {qs.get('with_amax_reduction')} group={qs.get('amax_reduction_group')}") - for attr in ("scale", "amax"): - if qs.get(f"{attr}_shape") is not None: - add(f" {attr.capitalize()}: {qs.get(f'{attr}_shape')} on {qs.get(f'{attr}_device')}") - add("") - - # Shape policy - add("Shape Policy:") - add(f" • mxfp8_block: {self.shape_policy.get('mxfp8_block')}") - - return "\n".join(lines) - - -__all__ = ["TEStateReporter"] diff --git a/thunder/executors/transformer_engineex_impl.py b/thunder/executors/transformer_engineex_impl.py index 9c27ce0253..a0d575dc2f 100644 --- a/thunder/executors/transformer_engineex_impl.py +++ b/thunder/executors/transformer_engineex_impl.py @@ -1,10 +1,11 @@ import time from typing import TYPE_CHECKING -import weakref +import warnings +from collections import defaultdict +import torch import torch.distributed as torch_dist -from thunder.core.compile_data import get_compile_data from thunder.core.prims import linear as linear_prim from thunder.core.prims import get_grad, put_grad from thunder.core.proxies import AnyProxy, TensorProxy @@ -22,7 +23,6 @@ from thunder.core.transform_common import cse_single_bsym from thunder.executors.passes import del_last_used import thunder.core.utils as utils -from thunder.dev_utils.te_states_reporter import TEStateReporter if TYPE_CHECKING: from thunder.core.trace import VariableInterface @@ -30,6 +30,7 @@ from thunder.core.proxies import TensorProxy import transformer_engine.pytorch as te +import transformer_engine.common.recipe as te_recipe from transformer_engine.pytorch.constants import MXFP8_BLOCK_SCALING_SIZE from transformer_engine.pytorch.fp8 import ( _amax_and_scale_update, @@ -38,43 +39,18 @@ RecipeState, FP8GlobalStateManager, ) - from transformer_engine.pytorch.ops import BasicLinear from transformer_engine.pytorch.tensor import Quantizer from transformer_engine.pytorch.utils import check_dim_for_fp8_exec -from thunder.core.module import get_thunder_module +from thunder.dev_utils.export_stateful_ex_transform import ( + ExportStatefulExecutorsTransform as _ExportSETransform, +) transformer_engine_ex = StatefulExecutor("transformer_engine") register_executor(transformer_engine_ex) -def _export_te_states(*, recipe=None, states=None, mode=None, quantizers=None, holder=None): - tm = None - if holder is not None: - tm_ref = getattr(holder, "_tm_ref", None) - if tm_ref is not None and isinstance(tm_ref, weakref.ref): - tm = tm_ref() - - # Cannot fallback to compile data as during execution we don't have compile_data context - if tm is None: - import warnings - - warnings.warn("No ThunderModule found for exporting TE states", UserWarning) - return - - if not hasattr(tm, "te_reporter"): - tm.te_reporter = TEStateReporter() - - tm.te_reporter.update_from_runtime( - holder=holder, - recipe=recipe, - states=states, - mode=mode, - quantizers=quantizers, - ) - - def _te_fp8_recipe_meta() -> AnyProxy: return AnyProxy(None, prefix="r") @@ -92,9 +68,6 @@ def __call__(self) -> Recipe: if not self.fp8_recipe or self.fp8_recipe is not te_fp8_recipe: self.fp8_recipe = te_fp8_recipe - # Duplicate recipies are handled by the TEStateReporter as we don't have any early return logic here - _export_te_states(recipe=self.fp8_recipe, holder=self) - return self.fp8_recipe @@ -117,15 +90,11 @@ def __init__(self): def __call__(self, recipe_state: RecipeState, num_quantizers: int) -> list[Quantizer]: if self.quantizers and self.parent_recipe_state is recipe_state: return self.quantizers - quantizers = recipe_state.make_quantizers() self.quantizers = quantizers self.parent_recipe_state = recipe_state - # Export only new quantizers - _export_te_states(recipe=recipe_state.recipe, quantizers=quantizers, holder=self) - return quantizers @@ -158,9 +127,6 @@ def __call__(self, recipe: Recipe, mode: str, num_quantizers: int) -> RecipeStat self.state = recipe_state self.parent_recipe = recipe - # Export only new states - _export_te_states(recipe=recipe, states=(recipe_state,), mode=mode, holder=self) - return recipe_state @@ -312,14 +278,35 @@ def _view_input_as_2d(x): fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + supported_recipes = (te_recipe.DelayedScaling, te_recipe.MXFP8BlockScaling) + if hasattr(te_recipe, "NVFP4BlockScaling"): + supported_recipes = (*supported_recipes, te_recipe.NVFP4BlockScaling) + + if not isinstance(fp8_recipe, supported_recipes): + warnings.warn(f"{type(fp8_recipe)} is not supported by TE executor, TE wont be used.") + return False + def check_valid_fp8_shapes(a): - # DelayedScaling and MXFP8BlockScaling have different shape requirements. + # Each recipe type has different shape requirements. if fp8_recipe.delayed(): return check_dim_for_fp8_exec(a) - assert fp8_recipe.mxfp8() shape = a.shape - return shape[0] % MXFP8_BLOCK_SCALING_SIZE == 0 and shape[1] % MXFP8_BLOCK_SCALING_SIZE == 0 + + if fp8_recipe.mxfp8(): + return shape[0] % MXFP8_BLOCK_SCALING_SIZE == 0 and shape[1] % MXFP8_BLOCK_SCALING_SIZE == 0 + + if hasattr(fp8_recipe, "nvfp4") and fp8_recipe.nvfp4(): + from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE + + # Check inherited from TE https://github.com/ksivaman/TransformerEngine-1/blob/1af7dd88aae5afb45e82148089038e1d1de9675d/transformer_engine/pytorch/tensor/nvfp4_tensor.py#L176-L184 + return ( + len(shape) >= 2 + and shape[0] % NVFP4_BLOCK_SCALING_SIZE == 0 + and shape[1] % NVFP4_BLOCK_SCALING_SIZE == 0 + ) + + return False # Inputs must be on CUDA and # input sizes must satisfy size constraints based on the recipe. @@ -384,11 +371,6 @@ def __init__(self): self.rhs_to_bsym_map: dict[BoundSymbolRHS, BoundSymbol] = {} self.redundant_map: dict[Variable, Proxy] = {} self.new_saved_for_backward = None - self._tm_ref = None - - def transform_module(self, model) -> None: - # Cache a weakref to the ThunderModule for later runtime export - self._tm_ref = weakref.ref(model) def reset(self): self.fp8_recipe = None @@ -397,19 +379,120 @@ def reset(self): self.redundant_map = {} self.new_saved_for_backward = None - def _stamp_te_refs_to_bsym(self, tr): - for bsym in tr.bound_symbols: - call_ctx = getattr(bsym, "_call_ctx", None) - if not call_ctx: - continue - te_prefixes = ["get_te_fp8_recipe", "get_te_fp8_state", "get_te_fp8_quantizers"] - if not any(bsym.sym.name.startswith(prefix) for prefix in te_prefixes): + @staticmethod + def export_state(computation_trace, tm) -> None: + """ + Extracts and exports the FP8 amax/scale state information from TransformerEngine (TE) holders + present in the Python context of a computation trace. + + This method is intended to be called after a TE-enabled computation has executed, in order to + serialize and record the relevant FP8 state (such as amax and scale tensors) and quantizer + information for later inspection, debugging, or export. + + Args: + computation_trace: The Thunder computation trace object containing the Python context + with TE state and quantizer holders. + tm: The ThunderModule object. + + Returns: + None. + """ + # Extract FP8 amax/scale information from TE holders available in python context + python_ctx = computation_trace.python_ctx() + + # Helper: serialize small tensors; skip oversized payloads + def _to_list_limited(t, max_numel: int = 8192): + if not isinstance(t, torch.Tensor): + return None + try: + n = min(t.numel(), max_numel) + if t.numel() > max_numel: + warnings.warn( + f"TE Stateful Executor: Exporting only first {max_numel} elements of tensor with {t.numel()} elements", + UserWarning, + ) + flat = t.detach().float().cpu().view(-1)[:n].tolist() + return flat + except Exception: + return None + + # Infer context mode from available TE functional symbols + te_mode = None + if "te_functional_linear_fwd" in python_ctx: + te_mode = "forward" + elif "te_functional_linear_bwd" in python_ctx: + te_mode = "backward" + + delayed_entries: list[dict] = [] + block_entries: list[dict] = [] + + # Gather state and quantizer holders from context + state_holders = [v for k, v in python_ctx.items() if isinstance(k, str) and k.startswith("get_te_fp8_state")] + quantizer_holders = [ + v for k, v in python_ctx.items() if isinstance(k, str) and k.startswith("get_te_fp8_quantizers") + ] + + # Map RecipeState -> quantizers (if materialized) + state_to_quantizers: dict[int, list] = {} + for qh in quantizer_holders: + prs = getattr(qh, "parent_recipe_state", None) + qs = getattr(qh, "quantizers", None) + if prs is not None and qs: + state_to_quantizers.setdefault(id(prs), []).extend(qs) + + for sh in state_holders: + recipe = getattr(sh, "parent_recipe", None) + state = getattr(sh, "state", None) + if recipe is None: continue - state_obj = next(iter(call_ctx.values())) - assert getattr(state_obj, "_tm_ref", None) is None + # Determine recipe family + is_delayed = bool(recipe.delayed()) + is_mxfp8_or_block = bool(recipe.mxfp8()) + + # DelayedScaling: values live on state.scale and state.amax_history + if is_delayed and state is not None: + scale_vals = _to_list_limited(getattr(state, "scale", None)) + amax_hist = getattr(state, "amax_history", None) + amax_vals = None + if isinstance(amax_hist, torch.Tensor) and amax_hist.numel() > 0: + amax_slice = amax_hist[-1] if amax_hist.dim() >= 1 else amax_hist + amax_vals = _to_list_limited(amax_slice) + delayed_entries.append( + { + "scale_shape": getattr(getattr(state, "scale", None), "shape", None), + "scale": scale_vals, + "amax_shape": getattr(getattr(state, "amax_history", None), "shape", None), + "amax": amax_vals, + } + ) - setattr(state_obj, "_tm_ref", self._tm_ref) # stamp the ThunderModule weakref here + # MXFP8/Float8 block scaling: values live on quantizers + elif is_mxfp8_or_block and state is not None: + qs = state_to_quantizers.get(id(state), []) + for q in qs: + rowwise_usage = getattr(q, "rowwise_usage", None) + columnwise_usage = getattr(q, "columnwise_usage", None) + block_entries.append( + { + "cls": q.__class__.__name__, + "rowwise_usage": rowwise_usage, + "columnwise_usage": columnwise_usage, + "dtype": str(getattr(q, "dtype", None)), + } + ) + + entry = defaultdict(list) + if delayed_entries: + entry["delayed"] = delayed_entries + if block_entries: + entry["mxfp8_or_block"] = block_entries + + collected = getattr(tm, "te_fp8_stats", None) + if collected is None: + tm.te_fp8_stats = {"forward": [], "backward": []} + if entry["delayed"] or entry["mxfp8_or_block"]: + tm.te_fp8_stats[te_mode].append(entry) def transform_trace_post_optimization(self, computation_trace, **kwargs): """ @@ -425,14 +508,6 @@ def transform_trace_post_optimization(self, computation_trace, **kwargs): if "transformer_engine" not in map(lambda x: x.name, kwargs["executors_list"]): return computation_trace - # Ensure we have a ThunderModule weakref available - if self._tm_ref is None: - cd = get_compile_data() - if cd is not None and getattr(cd, "is_module", False): - tm = get_thunder_module(cd.fn) - if tm is not None: - self._tm_ref = weakref.ref(tm) - start_time_ns = time.perf_counter_ns() new_trace = from_trace(computation_trace) @@ -490,8 +565,6 @@ def transform_trace_post_optimization(self, computation_trace, **kwargs): sync_trace = del_last_used(new_trace) - self._stamp_te_refs_to_bsym(sync_trace) - end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns elapsed_time_millis = elapsed_time_ns // 1000000 @@ -560,3 +633,10 @@ def _te_activation_checkpointing_transform(joint_trace: TraceCtx) -> TraceCtx: new_trace.bound_symbols = [bsym.from_bsym_swap_proxies(swapmap) for bsym in reversed(reversed_bsyms)] return new_trace + + +# Register TE export callback with the singleton export transform +try: + _ExportSETransform.register_export_callback("transformer_engine", TransformerEngineTransform.export_state) +except Exception: + pass diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index 99e1ce4b1f..011fed94d3 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -26,6 +26,7 @@ ) from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform +from thunder.dev_utils.export_stateful_ex_transform import ExportStatefulExecutorsTransform is_fp8_supported: bool = False @@ -774,11 +775,12 @@ def _test_ddp_transformer_engine_llama_sanity(input_data): return None -def _test_ddp_transformer_engine_reporter(input_data): +def _test_ddp_transformer_engine_state_export(input_data): # Test Description: - # Verify that the TEStateReporter correctly captures and reports TransformerEngine - # FP8 state information during DDP training execution, including global context, - # recipe summaries, and forward/backward state summaries across distributed processes. + # This test ensures that the ExportStatefulExecutorsTransform correctly collects and exports + # TransformerEngine FP8 state information (such as amax/scale and quantizer stats) during + # distributed DDP training. It verifies that the state reporter works as expected across + # multiple processes, capturing both forward and backward state summaries for each rank. init_method, world_size, rank, _executor, device, _dtype, _unused_kwargs = input_data devicetype = devices.device_from_string(device).devicetype @@ -808,11 +810,18 @@ def forward(self, x): model = Module() jmodel = thunder.distributed.ddp( - thunder.jit(model, executors=[transformer_engine_ex], transforms=[TransformerEngineTransform()]) + thunder.jit( + model, + executors=[transformer_engine_ex], + transforms=[TransformerEngineTransform(), ExportStatefulExecutorsTransform()], + ) ) + # Use default TE recipe for this test + fp8_recipe = get_default_fp8_recipe() + + iters = 10 def train(model): - iters = 10 for _ in range(iters): with fp8_autocast(fp8_recipe=fp8_recipe): y = model(x) @@ -820,29 +829,19 @@ def train(model): train(jmodel) - rep = getattr(jmodel, "te_reporter", None) - if rep is None: + stats = getattr(jmodel, "te_fp8_stats", None) + if stats is None: if rank == 0: - return [AssertionError("TransformerEngine reporter not found")] + return [AssertionError("TransformerEngine FP8 stats not found")] return None payload = { "rank": rank, - "error": None if rep is not None else "no reporter", - "global": {} if rep is None else rep.global_ctx, - "recipes": 0 if rep is None else len(rep.recipe_summaries), - "forward": 0 if rep is None else len(rep.state_summaries_forward), - "backward": 0 if rep is None else len(rep.state_summaries_backward), - "quantizers": 0 if rep is None else len(rep.quantizer_summaries), - "has_sections": ( - { - "global": "Global Context:" in rep.render_report(), - "forward": "Forward States (" in rep.render_report(), - "backward": "Backward States (" in rep.render_report(), - } - if rep is not None - else {"global": False, "forward": False, "backward": False} - ), + "error": None, + "forward": len(stats.get("forward", [])), + "backward": len(stats.get("backward", [])), + # Include minimal info for delayed-scaling check on rank 0 + "has_delayed": any("delayed" in e and e["delayed"] for e in stats.get("forward", [])), } gathered = [None] * world_size if rank == 0 else None @@ -854,15 +853,14 @@ def train(model): if rank == 0: exceptions = [] for p in gathered: - print(p) if p["error"]: exceptions.append(AssertionError(p["error"])) continue - assert p["has_sections"]["global"] - assert p["recipes"] == 1 - assert p["forward"] == 2 # 2 forward linear ops - assert p["backward"] == 2 # 2 backward linear ops - assert p["quantizers"] == 6 # 2 quantizers per forward linear op, 1 quantizers per backward linear op + # We export one entry per iteration for both forward and backward + assert p["forward"] == iters + assert p["backward"] == iters + # At least one entry should include delayed info on platforms using delayed scaling + # We do not enforce across ranks since recipe may vary by platform return exceptions return None @@ -982,8 +980,8 @@ def test_ddp_transformer_engine_llama_sanity(executor, devices, dtype): unittest.mock.patch.dict(os.environ, {"NVTE_TORCH_COMPILE": "0"}), ), ) -@distributed_wrapper("test_ddp_transformer_engine_reporter", _test_ddp_transformer_engine_reporter) -def test_ddp_transformer_engine_reporter(executor, devices, dtype): +@distributed_wrapper("test_ddp_transformer_engine_state_export", _test_ddp_transformer_engine_state_export) +def test_ddp_transformer_engine_state_export(executor, devices, dtype): pass diff --git a/thunder/tests/distributed/test_fsdp.py b/thunder/tests/distributed/test_fsdp.py index 4e762ad770..4260d59d10 100644 --- a/thunder/tests/distributed/test_fsdp.py +++ b/thunder/tests/distributed/test_fsdp.py @@ -29,6 +29,7 @@ ) from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform +from thunder.dev_utils.export_stateful_ex_transform import ExportStatefulExecutorsTransform is_fp8_supported: bool = False @@ -1237,7 +1238,7 @@ def _test_fsdp_transformer_engine_bucketing(input_data): return None -def _test_fsdp_transformer_engine_reporter(input_data): +def _test_fsdp_transformer_engine_state_export(input_data): # Test Description: # Verify that the TEStateReporter correctly captures and reports TransformerEngine # FP8 state information during DDP training execution, including global context, @@ -1283,13 +1284,13 @@ def forward(self, x): ] + executor.executors_list(), fp8_shard_intermediate_activation=intermediate_activation_sharding, - transforms=[TransformerEngineTransform()], + transforms=[TransformerEngineTransform(), ExportStatefulExecutorsTransform()], ), sharding_strategy=thunder_fsdp_strategy, ) + iters = 10 def train(model): - iters = 10 for _ in range(iters): with fp8_autocast(fp8_recipe=fp8_recipe): y = model(x) @@ -1297,29 +1298,19 @@ def train(model): train(jmodel) - rep = getattr(jmodel, "te_reporter", None) - if rep is None: + stats = getattr(jmodel, "te_fp8_stats", None) + if stats is None: if rank == 0: - return [AssertionError("TransformerEngine reporter not found")] + return [AssertionError("TransformerEngine FP8 stats not found")] return None payload = { "rank": rank, - "error": None if rep is not None else "no reporter", - "global": {} if rep is None else rep.global_ctx, - "recipes": 0 if rep is None else len(rep.recipe_summaries), - "forward": 0 if rep is None else len(rep.state_summaries_forward), - "backward": 0 if rep is None else len(rep.state_summaries_backward), - "quantizers": 0 if rep is None else len(rep.quantizer_summaries), - "has_sections": ( - { - "global": "Global Context:" in rep.render_report(), - "forward": "Forward States (" in rep.render_report(), - "backward": "Backward States (" in rep.render_report(), - } - if rep is not None - else {"global": False, "forward": False, "backward": False} - ), + "error": None, + "forward": len(stats.get("forward", [])), + "backward": len(stats.get("backward", [])), + # Include minimal info for delayed-scaling check on rank 0 + "has_delayed": any("delayed" in e and e["delayed"] for e in stats.get("forward", [])), } gathered = [None] * world_size if rank == 0 else None @@ -1334,11 +1325,10 @@ def train(model): if p["error"]: exceptions.append(AssertionError(p["error"])) continue - assert p["has_sections"]["global"] - assert p["recipes"] == 1 - assert p["forward"] == 3 # 2 forward linear ops - assert p["backward"] == 3 # 2 backward linear ops - assert p["quantizers"] == 9 # 2 quantizers per forward linear op, 1 quantizers per backward linear op + # We export one entry per iteration for both forward and backward + assert p["forward"] == iters + assert p["backward"] == iters + # At least one entry should include delayed info on platforms using delayed scaling return exceptions return None @@ -1399,8 +1389,8 @@ def test_fsdp_transformer_engine(executor, devices, dtype, thunder_fsdp_strategy pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason), ), ) -@distributed_wrapper("test_fsdp_transformer_engine_reporter", _test_fsdp_transformer_engine_reporter) -def test_fsdp_transformer_engine_reporter(executor, devices, dtype, thunder_fsdp_strategy_and_intermediate_sharding): +@distributed_wrapper("test_fsdp_transformer_engine_state_export", _test_fsdp_transformer_engine_state_export) +def test_fsdp_transformer_engine_state_export(executor, devices, dtype, thunder_fsdp_strategy_and_intermediate_sharding): pass diff --git a/thunder/tests/test_export_stateful_ex_transform.py b/thunder/tests/test_export_stateful_ex_transform.py new file mode 100644 index 0000000000..adc2fbeb4f --- /dev/null +++ b/thunder/tests/test_export_stateful_ex_transform.py @@ -0,0 +1,323 @@ +import pytest +import torch +import torch.nn as nn + +import thunder +from thunder.tests.framework import requiresCUDA + + +# NOTE: On SM120/121, TE defaults to using Float8BlockScaling +# which is currently unsupported in thunder, we skip the tests for these SM architectures. +from thunder.tests.utils import skip_on_sm120_and_sm121, is_sm120_orsm121 +from thunder.dev_utils.export_stateful_ex_transform import ExportStatefulExecutorsTransform +from thunder.dynamo import ThunderCompiler + +# Make TE optional so this file can host tests for other executors too +TE_AVAILABLE = False +try: + import transformer_engine as _te_mod # noqa: F401 + import transformer_engine.pytorch as te # type: ignore + from transformer_engine.common import recipe # type: ignore + from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform + TE_AVAILABLE = True +except Exception: + te = None # type: ignore + recipe = None # type: ignore + TE_AVAILABLE = False + +if TE_AVAILABLE: + # FP8 is supported on compute arch 8.9 onwards. + # MXFP8 is supported on compute arch 10.0 onwards. + # Skip the TE-specific parametrizations if current hardware is not supported. + is_fp8_supported, msg_fp8 = te.fp8.check_fp8_support() + is_mxfp8_supported, msg_mxfp8 = te.fp8.check_mxfp8_support() + if not is_fp8_supported: + pytest.skip(msg_fp8, allow_module_level=True) + + hybrid_fp8_delayed_scaling_recipe = recipe.DelayedScaling() + mxfp8_e4m3_recipe = recipe.MXFP8BlockScaling() + + # `None` is used to test the default recipe. + recipes = (None, hybrid_fp8_delayed_scaling_recipe, mxfp8_e4m3_recipe) + recipe_ids = ("default", "delayed_scaling", "mxfp8_e4m3") +else: + is_mxfp8_supported, msg_mxfp8 = (False, "TransformerEngine not available") + recipes = (None,) + recipe_ids = ("default",) + + +@requiresCUDA +@pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed.") +@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids) +@skip_on_sm120_and_sm121 +def test_export_te_states_linear_forward(fp8_recipe): + if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported): + pytest.skip(msg_mxfp8) + + if is_sm120_orsm121 and fp8_recipe is None: + pytest.skip("On SM120/121, default recipe is Float8BlockScaling which is not supported") + + # Test Description: + # Verify that the TEStateReporter correctly captures and reports TransformerEngine + # FP8 state information during forward pass execution, including global context, + # recipe summaries, and forward state summaries. + + dtype = torch.bfloat16 + device = "cuda" + + # Inputs (3D input) + x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) + + class Module(nn.Module): + def __init__(self): + super().__init__() + self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype)) + self.w2 = nn.Parameter(torch.randn(2048, 4096, device=device, dtype=dtype)) + + def forward(self, x): + o = torch.nn.functional.linear(x, self.w1) + added = o + x + return torch.nn.functional.linear(added, self.w2) + + model = Module() + + jmodel = thunder.jit( + model, + executors=[transformer_engine_ex], + transforms=[TransformerEngineTransform(), ExportStatefulExecutorsTransform()], + ) + + # Enable autocasting for the forward pass + with te.fp8_autocast(fp8_recipe=fp8_recipe): + y = jmodel(x) + + # Validate TE exporter populated as expected + assert hasattr(jmodel, "te_fp8_stats"), "ThunderModule should expose te_fp8_stats" + stats = jmodel.te_fp8_stats + assert isinstance(stats, dict) and set(stats.keys()) == {"forward", "backward"} + # After forward, we should have exactly one forward entry and no backward entries yet + assert len(stats["forward"]) == 1 + assert len(stats["backward"]) == 0 + f_entry = stats["forward"][0] + assert isinstance(f_entry, dict) + # Ensure we collected either delayed scaling or block-scaling style info + assert ("delayed" in f_entry and isinstance(f_entry["delayed"], list)) or ( + "mxfp8_or_block" in f_entry and isinstance(f_entry["mxfp8_or_block"], list) + ) + # If delayed scaling is used, ensure amax and scale are present + if isinstance(fp8_recipe, recipe.DelayedScaling) or (fp8_recipe is None and te.fp8.check_fp8_support()[0]): + assert 'delayed' in f_entry + d = f_entry["delayed"][0] + assert d.get("scale") is not None + assert d.get("amax") is not None + + grad_output = torch.randn_like(y) + y.backward(grad_output) + + # After backward pass, one backward entry should be present + stats = jmodel.te_fp8_stats + assert len(stats["forward"]) == 1 + assert len(stats["backward"]) == 1 + + +@requiresCUDA +@pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed.") +@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids) +@skip_on_sm120_and_sm121 +def test_export_te_states_linear_forward_backward_multiple_iteration(fp8_recipe): + if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported): + pytest.skip(msg_mxfp8) + + if is_sm120_orsm121 and fp8_recipe is None: + pytest.skip("On SM120/121, default recipe is Float8BlockScaling which is not supported") + + # Test Description: + # Run multiple forward/backward iterations under a single recipe configuration and + # verify that the TE reporter does not grow with the iteration count. The recipe + # list should contain one unique entry, and state/quantizer summaries should reflect + # the two linear call sites exactly once per direction, independent of iterations. + + dtype = torch.bfloat16 + device = "cuda" + + # Inputs and model + x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) + + class Module(nn.Module): + def __init__(self): + super().__init__() + self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype)) + self.w2 = nn.Parameter(torch.randn(2048, 4096, device=device, dtype=dtype)) + + def forward(self, x): + o = torch.nn.functional.linear(x, self.w1) + added = o + x + return torch.nn.functional.linear(added, self.w2) + + model = Module() + + jmodel = thunder.jit( + model, + executors=[transformer_engine_ex], + transforms=[TransformerEngineTransform(), ExportStatefulExecutorsTransform()], + ) + + num_iters = 10 + for _ in range(num_iters): + # Forward under FP8 autocast + with te.fp8_autocast(fp8_recipe=fp8_recipe): + y = jmodel(x) + # Backward with unit upstream gradient + y.backward(torch.ones_like(y)) + + # Validate exporter after multiple iterations + assert hasattr(jmodel, "te_fp8_stats") + stats = jmodel.te_fp8_stats + # One forward and one backward export entry per iteration + assert len(stats["forward"]) == num_iters + assert len(stats["backward"]) == num_iters + + +@requiresCUDA +@pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed.") +def test_export_te_states_linear_forward_backward_multiple_recipies_iteration(): + # Test Description: + # Alternate between two different recipes across iterations and ensure the reporter + # records both recipe configurations exactly once each. Verify forward/backward states + # and quantizers reflect both linear call sites per recipe, independent of iteration count. + + test_recipes = [recipe.DelayedScaling()] + supports_mxfp8, _ = te.fp8.check_mxfp8_support() + + if supports_mxfp8: + test_recipes += [recipe.MXFP8BlockScaling()] + + if len(test_recipes) < 2: + pytest.skip("platform does not support two different recipes") + + dtype = torch.bfloat16 + device = "cuda" + + # Inputs and model + x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) + + class Module(nn.Module): + def __init__(self): + super().__init__() + self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype)) + self.w2 = nn.Parameter(torch.randn(2048, 4096, device=device, dtype=dtype)) + + def forward(self, x): + o = torch.nn.functional.linear(x, self.w1) + added = o + x + return torch.nn.functional.linear(added, self.w2) + + model = Module() + iters = 4 + + def train_model(model): + for iter_n in range(iters): + te_recipe = test_recipes[iter_n % 2] + y = model(x, te_recipe) + y.backward(torch.ones_like(y)) + + jmodel = thunder.jit( + model, + executors=[transformer_engine_ex], + transforms=[TransformerEngineTransform(), ExportStatefulExecutorsTransform()], + ) + + def thunder_model(x, fp8_recipe): + with te.fp8_autocast(fp8_recipe=fp8_recipe): + return jmodel(x) + + train_model(thunder_model) + + stats = jmodel.te_fp8_stats + from pprint import pprint + pprint(stats) + # We expect as many forward/backward entries as iterations + assert len(stats["forward"]) == iters + assert len(stats["backward"]) == iters + # Across all entries, we should see delayed info and, if supported, possibly block info + has_delayed = any(e.get("delayed") for e in stats["forward"]) or any(e.get("delayed") for e in stats["backward"]) + assert has_delayed + +@requiresCUDA +@pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed.") +@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids) +@skip_on_sm120_and_sm121 +def test_export_te_states_with_torch_compile_and_thunder_backend(fp8_recipe): + # Test Description: + # Use torch.compile with Thunder as backend (ThunderCompiler) to run the model + # under FP8 autocast. Verify that TE runtime states are exported and available + # from the Thunder-compiled subgraphs via `te_reporter`, and that forward/backward + # summaries match expectations (iteration-invariant). + + if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported): + pytest.skip(msg_mxfp8) + + if is_sm120_orsm121 and fp8_recipe is None: + pytest.skip("On SM120/121, default recipe is Float8BlockScaling which is not supported") + + dtype = torch.bfloat16 + device = "cuda" + + x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) + + class Module(nn.Module): + def __init__(self): + super().__init__() + self.attention = nn.MultiheadAttention(4096, 64, device=device, dtype=dtype, batch_first=True) + self.norm1 = nn.LayerNorm(4096, device=device, dtype=dtype) + self.norm2 = nn.LayerNorm(4096, device=device, dtype=dtype) + self.mlp = nn.Sequential( + nn.Linear(4096, 16384, device=device, dtype=dtype), + nn.GELU(), + nn.Linear(16384, 4096, device=device, dtype=dtype), + ) + + def forward(self, x): + attn_out, _ = self.attention(x, x, x) + x = self.norm1(x + attn_out) + mlp_out = self.mlp(x) + x = self.norm2(x + mlp_out) + return x + + model = Module() + + # Compile with torch.compile using Thunder as backend + backend = ThunderCompiler( + executors=[transformer_engine_ex], + transforms=[TransformerEngineTransform(), ExportStatefulExecutorsTransform()], + ) + compiled_model = torch.compile(model, backend=backend) + + # Run one forward/backward under FP8 autocast + iters = 10 + def train_model(model): + for _ in range(iters): + with te.fp8_autocast(fp8_recipe=fp8_recipe): + y = model(x) + y.backward(torch.ones_like(y)) + + train_model(compiled_model) + + # Collect TE fp8 stats from Thunder-compiled subgraphs + reporters = [] + for sinfo in backend.subgraph_infos: + if sinfo.thunder_compiled_fns: + for fn in sinfo.thunder_compiled_fns: + if hasattr(fn, "te_fp8_stats"): + reporters.append(fn.te_fp8_stats) + + # We expect at least one Thunder subgraph using TE + assert len(reporters) >= 1 + + # Aggregate counts across subgraphs + total_fw_entries = sum(len(r["forward"]) for r in reporters) + total_bw_entries = sum(len(r["backward"]) for r in reporters) + + # We expect at least one Thunder subgraph using TE and to have exported entries + assert total_fw_entries == iters + assert total_bw_entries == iters diff --git a/thunder/tests/test_transformer_engine_executor_reporter.py b/thunder/tests/test_transformer_engine_executor_reporter.py deleted file mode 100644 index 12f9a07017..0000000000 --- a/thunder/tests/test_transformer_engine_executor_reporter.py +++ /dev/null @@ -1,384 +0,0 @@ -import pytest -import torch -import torch.nn as nn - -import thunder -from thunder.tests.framework import requiresCUDA - - -# NOTE: On SM120/121, TE defaults to using Float8BlockScaling -# which is currently unsupported in thunder, we skip the tests for these SM architectures. -from thunder.tests.utils import skip_on_sm120_and_sm121, is_sm120_orsm121 - -transformer_engine_module = pytest.importorskip( - "transformer_engine", reason="transformer_engine was not found, skipping the tests." -) - -from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform -from thunder.dynamo import ThunderCompiler -from transformer_engine.common import recipe -import transformer_engine.pytorch as te - -# FP8 is supported on compute arch 8.9 onwards. -# MXFP8 is supported on compute arch 10.0 onwards. -# Skip the tests if current hardware is not supported. -is_fp8_supported, msg_fp8 = te.fp8.check_fp8_support() -is_mxfp8_supported, msg_mxfp8 = te.fp8.check_mxfp8_support() -if not is_fp8_supported: - pytest.skip(msg_fp8, allow_module_level=True) - -hybrid_fp8_delayed_scaling_recipe = recipe.DelayedScaling() -mxfp8_e4m3_recipe = recipe.MXFP8BlockScaling() - -# `None` is used to test the default recipe. -recipes = (None, hybrid_fp8_delayed_scaling_recipe, mxfp8_e4m3_recipe) -recipe_ids = ("default", "delayed_scaling", "mxfp8_e4m3") - - -@requiresCUDA -@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids) -@skip_on_sm120_and_sm121 -def test_te_reporter_linear_forward_backward(fp8_recipe: recipe.Recipe): - if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported): - pytest.skip(msg_mxfp8) - - if is_sm120_orsm121 and fp8_recipe is None: - pytest.skip("On SM120/121, default recipe is Float8BlockScaling which is not supported") - - # Test Description: - # Verify that the TEStateReporter correctly captures and reports TransformerEngine - # FP8 state information during forward pass execution, including global context, - # recipe summaries, and forward state summaries. - - dtype = torch.bfloat16 - device = "cuda" - - # Inputs (3D input) - x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) - - class Module(nn.Module): - def __init__(self): - super().__init__() - self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype)) - self.w2 = nn.Parameter(torch.randn(2048, 4096, device=device, dtype=dtype)) - - def forward(self, x): - o = torch.nn.functional.linear(x, self.w1) - added = o + x - return torch.nn.functional.linear(added, self.w2) - - model = Module() - - jmodel = thunder.jit(model, executors=[transformer_engine_ex], transforms=[TransformerEngineTransform()]) - - # Enable autocasting for the forward pass - with te.fp8_autocast(fp8_recipe=fp8_recipe): - y = jmodel(x) - - # Validate TE reporter populated as expected - assert hasattr(jmodel, "te_reporter"), "ThunderModule should expose te_reporter" - rep = jmodel.te_reporter - - # Global context is captured - assert rep.global_ctx is not None, "Global context should be populated" - assert "fp8_available" in rep.global_ctx - assert "mxfp8_available" in rep.global_ctx - assert "fp8_block_scaling_available" in rep.global_ctx - - # Recipes captured; type should be one of known TE recipe classes - assert len(rep.recipe_summaries) >= 1 - recipe_types = {rs.get("type") for rs in rep.recipe_summaries} - known_types = {"DelayedScaling", "Float8BlockScaling", "MXFP8BlockScaling", "Float8CurrentScaling"} - assert recipe_types & known_types, f"Unexpected recipe types collected: {recipe_types}" - - # If a specific recipe is requested, ensure it's reflected - if isinstance(fp8_recipe, recipe.DelayedScaling): - assert "DelayedScaling" in recipe_types - if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - assert "MXFP8BlockScaling" in recipe_types - - # Forward states and quantizers should be recorded; no backward states without backward pass - assert len(rep.state_summaries_forward) == 2 - assert all(ss.get("mode") in (None, "forward") for ss in rep.state_summaries_forward) - assert any(ss.get("num_quantizers") in (1, 2) for ss in rep.state_summaries_forward) - assert len(rep.state_summaries_backward) == 0 - assert len(rep.quantizer_summaries) == 4 - assert all("cls" in qs and "dtype" in qs for qs in rep.quantizer_summaries) - - # Rendered report contains key sections - report_txt = rep.render_report() - assert "Global Context:" in report_txt - assert "Recipes (" in report_txt - assert "Forward States (" in report_txt - assert "Quantizers (" in report_txt - - grad_output = torch.randn_like(y) - y.backward(grad_output) - - report_txt = rep.render_report() - # After backward pass, backward states should be recorded and reported - assert len(rep.state_summaries_forward) == 2 # Forward states not changed - assert len(rep.state_summaries_backward) == 2 - assert all(ss.get("mode") in (None, "backward") for ss in rep.state_summaries_backward) - assert "Backward States (" in report_txt - - -@requiresCUDA -@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids) -@skip_on_sm120_and_sm121 -def test_te_reporter_linear_forward_backward_multiple_iteration(fp8_recipe: recipe.Recipe): - if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported): - pytest.skip(msg_mxfp8) - - if is_sm120_orsm121 and fp8_recipe is None: - pytest.skip("On SM120/121, default recipe is Float8BlockScaling which is not supported") - - # Test Description: - # Run multiple forward/backward iterations under a single recipe configuration and - # verify that the TE reporter does not grow with the iteration count. The recipe - # list should contain one unique entry, and state/quantizer summaries should reflect - # the two linear call sites exactly once per direction, independent of iterations. - - dtype = torch.bfloat16 - device = "cuda" - - # Inputs and model - x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) - - class Module(nn.Module): - def __init__(self): - super().__init__() - self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype)) - self.w2 = nn.Parameter(torch.randn(2048, 4096, device=device, dtype=dtype)) - - def forward(self, x): - o = torch.nn.functional.linear(x, self.w1) - added = o + x - return torch.nn.functional.linear(added, self.w2) - - model = Module() - - jmodel = thunder.jit(model, executors=[transformer_engine_ex], transforms=[TransformerEngineTransform()]) - - num_iters = 10 - for _ in range(num_iters): - # Forward under FP8 autocast - with te.fp8_autocast(fp8_recipe=fp8_recipe): - y = jmodel(x) - # Backward with unit upstream gradient - y.backward(torch.ones_like(y)) - - # Validate reporter after multiple iterations - assert hasattr(jmodel, "te_reporter") - rep = jmodel.te_reporter - - # Global context present - assert rep.global_ctx is not None - - # Recipes captured - assert len(rep.recipe_summaries) == 1 - - # Forward/backward states recorded (may be cached, so at least one each) - assert len(rep.state_summaries_forward) == 2 - assert len(rep.state_summaries_backward) == 2 - - # Quantizers observed at least once - assert len(rep.quantizer_summaries) == 6 - - # Report reflects sections - rpt = rep.render_report() - assert "Forward States (" in rpt - assert "Backward States (" in rpt - - -@requiresCUDA -def test_te_reporter_linear_forward_backward_multiple_recipies_iteration(): - # Test Description: - # Alternate between two different recipes across iterations and ensure the reporter - # records both recipe configurations exactly once each. Verify forward/backward states - # and quantizers reflect both linear call sites per recipe, independent of iteration count. - - recipes = [recipe.DelayedScaling()] - supports_mxfp8, _ = te.fp8.check_mxfp8_support() - - if supports_mxfp8: - recipes += [recipe.MXFP8BlockScaling()] - - if len(recipes) < 2: - pytest.skip("platform does not support two different recipes") - - dtype = torch.bfloat16 - device = "cuda" - - # Inputs and model - x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) - - class Module(nn.Module): - def __init__(self): - super().__init__() - self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype)) - self.w2 = nn.Parameter(torch.randn(2048, 4096, device=device, dtype=dtype)) - - def forward(self, x): - o = torch.nn.functional.linear(x, self.w1) - added = o + x - return torch.nn.functional.linear(added, self.w2) - - model = Module() - iters = 10 - - def train_model(model): - for iter_n in range(iters): - te_recipe = recipes[iter_n % 2] - y = model(x, te_recipe) - y.backward(torch.ones_like(y)) - - jmodel = thunder.jit(model, executors=[transformer_engine_ex], transforms=[TransformerEngineTransform()]) - - def thunder_model(x, fp8_recipe): - with te.fp8_autocast(fp8_recipe=fp8_recipe): - return jmodel(x) - - train_model(thunder_model) - - rep_str = jmodel.te_reporter - assert len(rep_str.recipe_summaries) == len(recipes) - assert len(rep_str.state_summaries_forward) == 4 - assert len(rep_str.state_summaries_backward) == 4 - assert len(rep_str.quantizer_summaries) == 12 - - -@requiresCUDA -def test_te_reporter_linear_forward_backward_same_recipe_not_reported_twice(): - # Test Description: - # Alternate between two separate DelayedScaling instances that are equivalent in configuration. - # Ensure the reporter treats them as the same effective recipe and does not duplicate entries - # across iterations. Forward/backward states should reflect the two linear call sites once each, - # and quantizers should be counted once per site, independent of iteration count. - - delayed_scaling_recipe_a = recipe.DelayedScaling() - delayed_scaling_recipe_b = recipe.DelayedScaling() - - dtype = torch.bfloat16 - device = "cuda" - - # Inputs and model - x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) - - class Module(nn.Module): - def __init__(self): - super().__init__() - self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype)) - self.w2 = nn.Parameter(torch.randn(2048, 4096, device=device, dtype=dtype)) - - def forward(self, x): - o = torch.nn.functional.linear(x, self.w1) - added = o + x - return torch.nn.functional.linear(added, self.w2) - - model = Module() - - def train_model(model): - # Run for `iterations`. - for iter_n in range(3): - y = model(x, delayed_scaling_recipe_a if iter_n % 2 == 0 else delayed_scaling_recipe_b) - - y.backward(torch.ones_like(y)) - - jmodel = thunder.jit(model, executors=[transformer_engine_ex], transforms=[TransformerEngineTransform()]) - - def thunder_model(x, fp8_recipe=None): - with te.fp8_autocast(fp8_recipe=fp8_recipe): - return jmodel(x) - - train_model(thunder_model) - - rep_str = jmodel.te_reporter - assert len(rep_str.recipe_summaries) == 1 - assert len(rep_str.state_summaries_forward) == 4 - assert len(rep_str.state_summaries_backward) == 4 - assert len(rep_str.quantizer_summaries) == 12 - - -@requiresCUDA -@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids) -@skip_on_sm120_and_sm121 -def test_te_reporter_with_torch_compile_and_thunder_backend(fp8_recipe: recipe.Recipe): - # Test Description: - # Use torch.compile with Thunder as backend (ThunderCompiler) to run the model - # under FP8 autocast. Verify that TE runtime states are exported and available - # from the Thunder-compiled subgraphs via `te_reporter`, and that forward/backward - # summaries match expectations (iteration-invariant). - - if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported): - pytest.skip(msg_mxfp8) - - if is_sm120_orsm121 and fp8_recipe is None: - pytest.skip("On SM120/121, default recipe is Float8BlockScaling which is not supported") - - dtype = torch.bfloat16 - device = "cuda" - - x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) - - class Module(nn.Module): - def __init__(self): - super().__init__() - self.attention = nn.MultiheadAttention(4096, 64, device=device, dtype=dtype, batch_first=True) - self.norm1 = nn.LayerNorm(4096, device=device, dtype=dtype) - self.norm2 = nn.LayerNorm(4096, device=device, dtype=dtype) - self.mlp = nn.Sequential( - nn.Linear(4096, 16384, device=device, dtype=dtype), - nn.GELU(), - nn.Linear(16384, 4096, device=device, dtype=dtype), - ) - - def forward(self, x): - attn_out, _ = self.attention(x, x, x) - x = self.norm1(x + attn_out) - mlp_out = self.mlp(x) - x = self.norm2(x + mlp_out) - return x - - model = Module() - - # Compile with torch.compile using Thunder as backend - backend = ThunderCompiler(executors=[transformer_engine_ex], transforms=[TransformerEngineTransform()]) - compiled_model = torch.compile(model, backend=backend) - - # Run one forward/backward under FP8 autocast - def train_model(model): - iters = 10 - for _ in range(iters): - with te.fp8_autocast(fp8_recipe=fp8_recipe): - y = model(x) - y.backward(torch.ones_like(y)) - - train_model(compiled_model) - - print(compiled_model.__class__) - - # Collect TE reporters from Thunder-compiled subgraphs - reporters = [] - for sinfo in backend.subgraph_infos: - if sinfo.thunder_compiled_fns: - for fn in sinfo.thunder_compiled_fns: - if hasattr(fn, "te_reporter"): - reporters.append(fn.te_reporter) - - # We expect at least one Thunder subgraph using TE - assert len(reporters) >= 1 - - # Aggregate counts across subgraphs - total_recipes = sum(len(r.recipe_summaries) for r in reporters) - total_fw_states = sum(len(r.state_summaries_forward) for r in reporters) - total_bw_states = sum(len(r.state_summaries_backward) for r in reporters) - total_quantizers = sum(len(r.quantizer_summaries) for r in reporters) - - # Recipe presence - assert total_recipes >= 1 - # Two linear call sites leading to two forward and two backward states in total - assert total_fw_states == 2 - assert total_bw_states == 2 - # Quantizers (2 per forward, 1 per backward site leading to 6 total) - assert total_quantizers == 6 From a471a92053c2a3198c5175b7061814cff3c2cec9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Oct 2025 15:29:19 +0000 Subject: [PATCH 07/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/dev_utils/export_stateful_ex_transform.py | 4 ++-- thunder/tests/distributed/test_ddp.py | 1 + thunder/tests/distributed/test_fsdp.py | 5 ++++- thunder/tests/test_export_stateful_ex_transform.py | 8 ++++++-- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/thunder/dev_utils/export_stateful_ex_transform.py b/thunder/dev_utils/export_stateful_ex_transform.py index cec3628d79..57b8a7f440 100644 --- a/thunder/dev_utils/export_stateful_ex_transform.py +++ b/thunder/dev_utils/export_stateful_ex_transform.py @@ -1,5 +1,5 @@ import weakref -from typing import Callable, Dict +from collections.abc import Callable from thunder.core.transform_common import ( Transform, @@ -29,7 +29,7 @@ class ExportStatefulExecutorsTransform(Transform): """ _instance = None - _callbacks: Dict[str, Callable] = {} + _callbacks: dict[str, Callable] = {} def __new__(cls, *args, **kwargs): if cls._instance is None: diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index 011fed94d3..1bf5fa4640 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -821,6 +821,7 @@ def forward(self, x): fp8_recipe = get_default_fp8_recipe() iters = 10 + def train(model): for _ in range(iters): with fp8_autocast(fp8_recipe=fp8_recipe): diff --git a/thunder/tests/distributed/test_fsdp.py b/thunder/tests/distributed/test_fsdp.py index 4260d59d10..cda50425a7 100644 --- a/thunder/tests/distributed/test_fsdp.py +++ b/thunder/tests/distributed/test_fsdp.py @@ -1290,6 +1290,7 @@ def forward(self, x): ) iters = 10 + def train(model): for _ in range(iters): with fp8_autocast(fp8_recipe=fp8_recipe): @@ -1390,7 +1391,9 @@ def test_fsdp_transformer_engine(executor, devices, dtype, thunder_fsdp_strategy ), ) @distributed_wrapper("test_fsdp_transformer_engine_state_export", _test_fsdp_transformer_engine_state_export) -def test_fsdp_transformer_engine_state_export(executor, devices, dtype, thunder_fsdp_strategy_and_intermediate_sharding): +def test_fsdp_transformer_engine_state_export( + executor, devices, dtype, thunder_fsdp_strategy_and_intermediate_sharding +): pass diff --git a/thunder/tests/test_export_stateful_ex_transform.py b/thunder/tests/test_export_stateful_ex_transform.py index adc2fbeb4f..8d9183c47b 100644 --- a/thunder/tests/test_export_stateful_ex_transform.py +++ b/thunder/tests/test_export_stateful_ex_transform.py @@ -19,6 +19,7 @@ import transformer_engine.pytorch as te # type: ignore from transformer_engine.common import recipe # type: ignore from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform + TE_AVAILABLE = True except Exception: te = None # type: ignore @@ -106,7 +107,7 @@ def forward(self, x): ) # If delayed scaling is used, ensure amax and scale are present if isinstance(fp8_recipe, recipe.DelayedScaling) or (fp8_recipe is None and te.fp8.check_fp8_support()[0]): - assert 'delayed' in f_entry + assert "delayed" in f_entry d = f_entry["delayed"][0] assert d.get("scale") is not None assert d.get("amax") is not None @@ -235,14 +236,16 @@ def thunder_model(x, fp8_recipe): stats = jmodel.te_fp8_stats from pprint import pprint + pprint(stats) # We expect as many forward/backward entries as iterations assert len(stats["forward"]) == iters assert len(stats["backward"]) == iters # Across all entries, we should see delayed info and, if supported, possibly block info - has_delayed = any(e.get("delayed") for e in stats["forward"]) or any(e.get("delayed") for e in stats["backward"]) + has_delayed = any(e.get("delayed") for e in stats["forward"]) or any(e.get("delayed") for e in stats["backward"]) assert has_delayed + @requiresCUDA @pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed.") @pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids) @@ -295,6 +298,7 @@ def forward(self, x): # Run one forward/backward under FP8 autocast iters = 10 + def train_model(model): for _ in range(iters): with te.fp8_autocast(fp8_recipe=fp8_recipe): From 8154054aa0df6f199a8dee3b20d710047a20f574 Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Fri, 3 Oct 2025 11:25:39 +0000 Subject: [PATCH 08/13] Removed leftover --- thunder/tests/test_export_stateful_ex_transform.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/thunder/tests/test_export_stateful_ex_transform.py b/thunder/tests/test_export_stateful_ex_transform.py index 8d9183c47b..dec2b99d10 100644 --- a/thunder/tests/test_export_stateful_ex_transform.py +++ b/thunder/tests/test_export_stateful_ex_transform.py @@ -235,9 +235,6 @@ def thunder_model(x, fp8_recipe): train_model(thunder_model) stats = jmodel.te_fp8_stats - from pprint import pprint - - pprint(stats) # We expect as many forward/backward entries as iterations assert len(stats["forward"]) == iters assert len(stats["backward"]) == iters From 5e6847240d1979d7fa75267c5190d90123e5ed22 Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Fri, 3 Oct 2025 11:38:11 +0000 Subject: [PATCH 09/13] Removed comments --- thunder/tests/test_export_stateful_ex_transform.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/thunder/tests/test_export_stateful_ex_transform.py b/thunder/tests/test_export_stateful_ex_transform.py index dec2b99d10..7b18a4d467 100644 --- a/thunder/tests/test_export_stateful_ex_transform.py +++ b/thunder/tests/test_export_stateful_ex_transform.py @@ -15,15 +15,15 @@ # Make TE optional so this file can host tests for other executors too TE_AVAILABLE = False try: - import transformer_engine as _te_mod # noqa: F401 - import transformer_engine.pytorch as te # type: ignore - from transformer_engine.common import recipe # type: ignore + import transformer_engine as _te_mod + import transformer_engine.pytorch as te + from transformer_engine.common import recipe from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform TE_AVAILABLE = True except Exception: - te = None # type: ignore - recipe = None # type: ignore + te = None + recipe = None TE_AVAILABLE = False if TE_AVAILABLE: From 89633c5f6224534dc7484da82036309b50093d93 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Oct 2025 11:38:32 +0000 Subject: [PATCH 10/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/tests/test_export_stateful_ex_transform.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thunder/tests/test_export_stateful_ex_transform.py b/thunder/tests/test_export_stateful_ex_transform.py index 7b18a4d467..4f97cdbcf7 100644 --- a/thunder/tests/test_export_stateful_ex_transform.py +++ b/thunder/tests/test_export_stateful_ex_transform.py @@ -15,7 +15,6 @@ # Make TE optional so this file can host tests for other executors too TE_AVAILABLE = False try: - import transformer_engine as _te_mod import transformer_engine.pytorch as te from transformer_engine.common import recipe from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform From 456846485b63433817dfdce5cf1f72cf627e6d1a Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Fri, 3 Oct 2025 15:57:49 +0000 Subject: [PATCH 11/13] Updated the recording and materialization flow for transformer engine executor removing the post execution transform --- thunder/__init__.py | 39 +-- .../dev_utils/export_stateful_ex_transform.py | 122 ++++++--- .../executors/transformer_engineex_impl.py | 233 +++++++++++------- thunder/tests/distributed/test_ddp.py | 34 +-- thunder/tests/distributed/test_fsdp.py | 31 ++- .../test_export_stateful_ex_transform.py | 83 +++---- 6 files changed, 306 insertions(+), 236 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index fe95b7ec48..0e70f7d7bf 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -846,28 +846,6 @@ def wrapped(*args, **kwargs): # For more context see `NOTE: Split autograd.Function` disable_split_autograd: bool = compile_options.get("thunderfx_disable_split_autograd", False) - def wrapped_compiled_fn(fn: Callable, mode: str, *args): - """ - Wraps the compiled function to run the export stateful executors transform after the function is executed. - Stateful executors materialize information about the state at runtime. - """ - out = fn(*args) - cs = weakref_cs() - cd = weakref_cd() - assert cd is not None, "cd has been cleared." - assert cs is not None, "cs has been cleared." - if cs is None or cd is None: - return out - - trace = cs.last_traces[-1] if mode == "forward" else cs.last_backward_traces[-1] - if not trace: - return out - for transform in cd.transforms or []: - if isinstance(transform, ExportStatefulExecutorsTransform): - transform.transform_trace_post_execution(trace) - break - return out - def maybe_connect_to_autograd(cache_entry, result): if cache_entry.backward_fn: # If the backward function is available, we need to connect the @@ -877,11 +855,8 @@ def maybe_connect_to_autograd(cache_entry, result): is_differentiable_outputs = compile_options.get("is_differentiable_outputs", None) - def wrapped_backward_fn(saved_and_other, args): - return wrapped_compiled_fn(cache_entry.backward_fn, "backward", saved_and_other, args) - connect_to_autograd( - backward_fn=cache_entry.backward_fn if not has_to_export() else wrapped_backward_fn, + backward_fn=cache_entry.backward_fn, flat_args=data_for_autograd["flat_args"], flat_output=data_for_autograd["flat_output"], saved_tensors=saved_tensors, @@ -898,11 +873,6 @@ def call_epilogue(cache_entry, comp_result, pro_to_epi): result = cache_entry.epilogue_fn(*pro_to_epi, *comp_result) return result - def has_to_export(): - cd = weakref_cd() - assert cd is not None, "cd has been cleared." - return any(isinstance(t, ExportStatefulExecutorsTransform) for t in cd.transforms or []) - @wraps(fn) @update_call_statistics def fn_(*args, **kwargs) -> Any: @@ -917,12 +887,7 @@ def fn_(*args, **kwargs) -> Any: cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs) - def wrapped_computation_fn(*args): - return wrapped_compiled_fn(cache_entry.computation_fn, "forward", *args) - - computation_fn = cache_entry.computation_fn if not has_to_export() else wrapped_computation_fn - - result = computation_fn(*inps) + result = cache_entry.computation_fn(*inps) result = maybe_connect_to_autograd(cache_entry, result) result = call_epilogue(cache_entry, result, pro_to_epi) diff --git a/thunder/dev_utils/export_stateful_ex_transform.py b/thunder/dev_utils/export_stateful_ex_transform.py index 57b8a7f440..ab39bcb48f 100644 --- a/thunder/dev_utils/export_stateful_ex_transform.py +++ b/thunder/dev_utils/export_stateful_ex_transform.py @@ -1,64 +1,124 @@ import weakref from collections.abc import Callable +from typing import List, Tuple from thunder.core.transform_common import ( Transform, ) from thunder.core.trace import TraceCtx as Trace +from thunder.core.module import ThunderModule + + +class ExportStatefulExecutorsStats: + def __init__(self, tm: ThunderModule, resolver_fn: Callable): + """Lightweight accessor attached to a `ThunderModule`. + + Args: + tm: The `ThunderModule` instance this accessor belongs to. + resolver_fn: A callable that knows how to resolve the recorded + references on `tm` and return real values. + """ + self.tm = tm + self.resolver_fn = resolver_fn class ExportStatefulExecutorsTransform(Transform): - """Export runtime state from stateful executors after a trace executes. + """Register references and resolve runtime state lazily. - - Singleton transform with a registry of export callbacks - - Register via `register_export_callback(name, callback)` - - Callbacks receive `(computation_trace, thunder_module)` and may attach - serialized state to the module (e.g., `module.te_fp8_stats`) - - Safe: callback errors are swallowed; export never blocks execution + What this transform does: + - Singleton registry to plug per-executor exporters + - At module transform time, installs a lightweight accessor on the module + (e.g., `module.te_fp8_states`) that can resolve values on demand + - At post-optimization time, calls registered reference callbacks to record + only where values will materialize (holders + attribute paths). No data + are copied or materialized in this step + - When code calls the accessor (e.g., `module.te_fp8_states()`), the resolve + callback reads the recorded references and returns the latest values - Example (TransformerEngine): a callback collects FP8 amax/scale and - quantizer metadata from `python_ctx` and records them under - `module.te_fp8_stats = {"forward": [...], "backward": [...]}`. + + API overview: + - register_ref_callback(name, register_cb, resolve_cb, instance_cls): + name: attribute name to attach on the module + register_cb(trace, module): store references from the trace/python_ctx + resolve_cb(module): materialize and return values using the stored refs + instance_cls: a small class constructed as instance_cls(module, resolve_cb) + and attached as `setattr(module, name, instance)`; it typically stores + containers for references and implements __call__(...) to resolve Usage: - 1) Register once at import/init time: - ExportStatefulExecutorsTransform.register_export_callback("my_exec", my_cb) + 1) Register once at import/init time. For example, for TransformerEngine: + ExportStatefulExecutorsTransform.register_ref_callback( + "te_fp8_states", register_cb, resolve_cb, StatsClass + ) 2) Enable at compile time: thunder.jit(model, executors=[...], transforms=[..., ExportStatefulExecutorsTransform()]) - 3) Read exported fields from the compiled module in tests/tools. + 3) After each run, call `module.te_fp8_states()` to resolve and return the latest values. + + Notes: + - Supports multiple ThunderModule instances (e.g., subgraphs) + - Callback errors are swallowed to avoid interfering with execution """ + _register_callbacks: dict[str, Callable] = {} + _callback_attributes: List[Tuple[str, type[ExportStatefulExecutorsStats], Callable]] = [] + _instance = None - _callbacks: dict[str, Callable] = {} def __new__(cls, *args, **kwargs): + """Ensure singleton instance across repeated transform construction.""" if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): - self.tm_ref = None + """Initialize internal weakrefs registry. + + ThunderCompiler and other compilation flows may create multiple + `ThunderModule` instances for subgraphs; we keep weak references + to update all of them during post-optimization registration. + """ + self.tm_refs = [] @classmethod - def register_export_callback(cls, name: str, callback: Callable) -> None: - cls._callbacks[name] = callback + def register_ref_callback( + cls, name: str, callback: Callable, resolve_cb: Callable, instance: type[ExportStatefulExecutorsStats] + ) -> None: + """Register per-executor reference and resolver callbacks. + + Installs a module attribute named `name` by constructing `instance` with + the resolver function. The `callback` will be invoked during + post-optimization to record reference locations on the module. + + Args: + name: Module attribute to attach (e.g., "te_fp8_states"). + callback: Function `(trace, module) -> None` that records refs. + resolve_cb: Function `(module) -> Any` that resolves values on demand. + instance: A class (must be a subclass of ExportStatefulExecutorsStats) constructed as `instance(module, resolve_cb)`. + """ + if not issubclass(instance, ExportStatefulExecutorsStats): + raise TypeError(f"Provided instance {instance} must be a subclass of ExportStatefulExecutorsStats") + cls._register_callbacks[name] = callback + cls._callback_attributes.append((name, instance, resolve_cb)) def transform_module(self, model) -> None: + assert model is not None # Cache a weakref to the ThunderModule for later runtime export - self.tm_ref = weakref.ref(model) - - def transform_trace_post_execution(self, computation_trace: Trace, **kwargs): - # Resolve ThunderModule from weakref; if unavailable, skip - tm = self.tm_ref() if self.tm_ref is not None else None - if tm is None: - return computation_trace - - # Invoke all registered export callbacks. - for _, cb in self._callbacks.items(): - try: - cb(computation_trace, tm) - except Exception: - # Swallow errors from individual exporters to avoid breaking execution. - pass + self.tm_refs.append(weakref.ref(model)) + # Initialize attributes on model + for name, instance, resolve_cb in self._callback_attributes: + setattr(model, name, instance(model, resolve_cb)) + + def transform_trace_post_optimization(self, computation_trace: Trace, **kwargs): + for tm_ref in self.tm_refs: + # Resolve ThunderModule from weakref; if unavailable, skip + tm = tm_ref() if tm_ref is not None else None + if tm is None: + continue + # Invoke all registered callbacks to register reference locations + for _, cb in self._register_callbacks.items(): + try: + cb(computation_trace, tm) + except Exception: + pass return computation_trace diff --git a/thunder/executors/transformer_engineex_impl.py b/thunder/executors/transformer_engineex_impl.py index a0d575dc2f..26f9b800e1 100644 --- a/thunder/executors/transformer_engineex_impl.py +++ b/thunder/executors/transformer_engineex_impl.py @@ -2,10 +2,13 @@ from typing import TYPE_CHECKING import warnings from collections import defaultdict +from collections.abc import Callable import torch import torch.distributed as torch_dist +import thunder +from thunder.core.module import ThunderModule from thunder.core.prims import linear as linear_prim from thunder.core.prims import get_grad, put_grad from thunder.core.proxies import AnyProxy, TensorProxy @@ -44,6 +47,7 @@ from transformer_engine.pytorch.utils import check_dim_for_fp8_exec from thunder.dev_utils.export_stateful_ex_transform import ( ExportStatefulExecutorsTransform as _ExportSETransform, + ExportStatefulExecutorsStats, ) @@ -379,120 +383,158 @@ def reset(self): self.redundant_map = {} self.new_saved_for_backward = None + class TEFP8Stats(ExportStatefulExecutorsStats): + def __init__(self, tm: ThunderModule, resolver_fn: Callable): + """Accessor attached on the module to resolve TE FP8 states on demand. + + Args: + tm: ThunderModule to which this accessor is bound. + resolver_fn: Callable invoked as `resolver_fn(mode, tm)` to produce + the latest snapshot of FP8 state based on registered refs. + """ + super().__init__(tm, resolver_fn) + self.refs = {"forward": None, "backward": None} + + def __call__(self, mode: str = "forward") -> dict: + """Resolve and return the latest FP8 state for the given mode. + + Args: + mode: "forward" or "backward". Defaults to "forward". + Returns: + A dictionary snapshot of resolved values (e.g., delayed or mxfp8 entries), + or an empty dict on invalid mode or if nothing is recorded. + """ + if mode not in ["forward", "backward"]: + warnings.warn(f"Received an invalid inspection mode: {mode}. Please use 'forward' or 'backward'.") + return {} + return self.resolver_fn(mode, self.tm) + @staticmethod - def export_state(computation_trace, tm) -> None: + def register_refs(computation_trace, tm) -> None: + """Record where FP8 values will materialize for later lazy resolution. + + This inspects the trace's python context, finds TE state and quantizer + holders, and stores only references (holder objects and attribute paths) + into the module accessor. No tensors or runtime data are copied here. + The actual values are read by `resolve_values` after execution. """ - Extracts and exports the FP8 amax/scale state information from TransformerEngine (TE) holders - present in the Python context of a computation trace. + python_ctx = computation_trace.python_ctx() - This method is intended to be called after a TE-enabled computation has executed, in order to - serialize and record the relevant FP8 state (such as amax and scale tensors) and quantizer - information for later inspection, debugging, or export. + # Collect holders and where to read values from later + refs = defaultdict(list) - Args: - computation_trace: The Thunder computation trace object containing the Python context - with TE state and quantizer holders. - tm: The ThunderModule object. + # Collect mode from trace tags + mode = "forward" if TraceTag.AUGMENTED_FORWARD in computation_trace.tags else "backward" + # States: register all state holders; decide recipe type at resolve time + state_holders = [v for k, v in python_ctx.items() if isinstance(k, str) and k.startswith("get_te_fp8_state")] + for sh in state_holders: + # Always store attrs we may need; recipe classification happens later + refs["state_holder"].append( + { + "holder": sh, + "scale_attr": "state.scale", + "amax_attr": "state.amax_history", + } + ) + + # Quantizers (MXFP8/block): resolve via TEQuantizerState linked to RecipeState + quantizer_holders = [ + v for k, v in python_ctx.items() if isinstance(k, str) and k.startswith("get_te_fp8_quantizers") + ] + for qh in quantizer_holders: + refs["quantizer_holder"].append( + {"holder": qh, "quant_attr": "quantizers", "parent_state_attr": "parent_recipe_state"} + ) + + if len(refs) > 0: + tm.te_fp8_states.refs[mode] = refs + + @staticmethod + def resolve_values(mode: str, tm: ThunderModule) -> dict: + """Load and serialize FP8 values using previously-registered references. + + Args: + mode: "forward" or "backward" indicating which refs to resolve. + tm: ThunderModule whose accessor holds the recorded references. Returns: - None. + A dictionary with resolved entries (e.g., {"delayed": [...]} or + {"mxfp8": [...]}); returns empty dict if nothing is recorded. """ - # Extract FP8 amax/scale information from TE holders available in python context - python_ctx = computation_trace.python_ctx() - # Helper: serialize small tensors; skip oversized payloads - def _to_list_limited(t, max_numel: int = 8192): + def _get_attr(obj, attr_path: str): + cur = obj + for part in attr_path.split("."): + cur = getattr(cur, part) + return cur + + def _tensor_head(t, max_numel: int = 8192): if not isinstance(t, torch.Tensor): return None - try: - n = min(t.numel(), max_numel) - if t.numel() > max_numel: - warnings.warn( - f"TE Stateful Executor: Exporting only first {max_numel} elements of tensor with {t.numel()} elements", - UserWarning, - ) - flat = t.detach().float().cpu().view(-1)[:n].tolist() - return flat - except Exception: - return None + n = min(t.numel(), max_numel) + return t.detach().float().cpu().view(-1)[:n].tolist() - # Infer context mode from available TE functional symbols - te_mode = None - if "te_functional_linear_fwd" in python_ctx: - te_mode = "forward" - elif "te_functional_linear_bwd" in python_ctx: - te_mode = "backward" + # Pull last registered refs for this mode + refs = tm.te_fp8_states.refs[mode] + if refs is None: + return {} - delayed_entries: list[dict] = [] - block_entries: list[dict] = [] + out = defaultdict(list) - # Gather state and quantizer holders from context - state_holders = [v for k, v in python_ctx.items() if isinstance(k, str) and k.startswith("get_te_fp8_state")] - quantizer_holders = [ - v for k, v in python_ctx.items() if isinstance(k, str) and k.startswith("get_te_fp8_quantizers") - ] + # Classify states now that recipes and states have materialized + # MXFP8/block scaling: will be collected via quantizers section below + for ref in refs.get("state_holder", []): + sh = ref["holder"] + recipe = getattr(sh, "parent_recipe", None) + state = getattr(sh, "state", None) + if recipe is None or state is None: + continue + # Delayed scaling: extract from state tensors + if getattr(recipe, "delayed", lambda: False)(): + scale = _get_attr(sh, ref["scale_attr"]) # state.scale + amax_hist = _get_attr(sh, ref["amax_attr"]) # state.amax_history + scale_vals = _tensor_head(scale) + amax_vals = _tensor_head( + amax_hist[-1] if isinstance(amax_hist, torch.Tensor) and amax_hist.numel() > 0 else amax_hist + ) + out["delayed"].append( + { + "scale_shape": getattr(scale, "shape", None), + "scale": scale_vals, + "amax_shape": getattr(amax_hist, "shape", None), + "amax": amax_vals, + } + ) - # Map RecipeState -> quantizers (if materialized) - state_to_quantizers: dict[int, list] = {} - for qh in quantizer_holders: + # MXFP8 via quantizers + # First, build mapping from recipe state id to quantizers + state_to_qs = {} + for ref in refs.get("quantizer_holder", []): + qh = ref["holder"] prs = getattr(qh, "parent_recipe_state", None) qs = getattr(qh, "quantizers", None) if prs is not None and qs: - state_to_quantizers.setdefault(id(prs), []).extend(qs) + state_to_qs.setdefault(id(prs), []).extend(qs) - for sh in state_holders: + # For MXFP8/block scaling, gather quantizers linked to each materialized state + for ref in refs.get("state_holder", []): + sh = ref["holder"] recipe = getattr(sh, "parent_recipe", None) state = getattr(sh, "state", None) - if recipe is None: + if recipe is None or state is None: continue - - # Determine recipe family - is_delayed = bool(recipe.delayed()) - is_mxfp8_or_block = bool(recipe.mxfp8()) - - # DelayedScaling: values live on state.scale and state.amax_history - if is_delayed and state is not None: - scale_vals = _to_list_limited(getattr(state, "scale", None)) - amax_hist = getattr(state, "amax_history", None) - amax_vals = None - if isinstance(amax_hist, torch.Tensor) and amax_hist.numel() > 0: - amax_slice = amax_hist[-1] if amax_hist.dim() >= 1 else amax_hist - amax_vals = _to_list_limited(amax_slice) - delayed_entries.append( - { - "scale_shape": getattr(getattr(state, "scale", None), "shape", None), - "scale": scale_vals, - "amax_shape": getattr(getattr(state, "amax_history", None), "shape", None), - "amax": amax_vals, + if getattr(recipe, "mxfp8", lambda: False)(): + for q in state_to_qs.get(id(state), []): + entry = { + "cls": q.__class__.__name__, + "rowwise_usage": getattr(q, "rowwise_usage", None), + "columnwise_usage": getattr(q, "columnwise_usage", None), + "dtype": str(getattr(q, "dtype", None)), } - ) + if entry not in out["mxfp8"]: + out["mxfp8"].append(entry) - # MXFP8/Float8 block scaling: values live on quantizers - elif is_mxfp8_or_block and state is not None: - qs = state_to_quantizers.get(id(state), []) - for q in qs: - rowwise_usage = getattr(q, "rowwise_usage", None) - columnwise_usage = getattr(q, "columnwise_usage", None) - block_entries.append( - { - "cls": q.__class__.__name__, - "rowwise_usage": rowwise_usage, - "columnwise_usage": columnwise_usage, - "dtype": str(getattr(q, "dtype", None)), - } - ) - - entry = defaultdict(list) - if delayed_entries: - entry["delayed"] = delayed_entries - if block_entries: - entry["mxfp8_or_block"] = block_entries - - collected = getattr(tm, "te_fp8_stats", None) - if collected is None: - tm.te_fp8_stats = {"forward": [], "backward": []} - if entry["delayed"] or entry["mxfp8_or_block"]: - tm.te_fp8_stats[te_mode].append(entry) + return out def transform_trace_post_optimization(self, computation_trace, **kwargs): """ @@ -635,8 +677,13 @@ def _te_activation_checkpointing_transform(joint_trace: TraceCtx) -> TraceCtx: return new_trace -# Register TE export callback with the singleton export transform +# Register TE reference and resolve callbacks with the singleton export transform try: - _ExportSETransform.register_export_callback("transformer_engine", TransformerEngineTransform.export_state) + _ExportSETransform.register_ref_callback( + "te_fp8_states", + TransformerEngineTransform.register_refs, + TransformerEngineTransform.resolve_values, + TransformerEngineTransform.TEFP8Stats, + ) except Exception: pass diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index 1bf5fa4640..6202e6ee8b 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -786,8 +786,6 @@ def _test_ddp_transformer_engine_state_export(input_data): devicetype = devices.device_from_string(device).devicetype pg = init_per_process_distributed(init_method, devicetype, world_size, rank) - fp8_recipe = get_default_fp8_recipe() - torch.cuda.set_device(rank) torch_device = torch.device("cuda", rank) @@ -830,19 +828,25 @@ def train(model): train(jmodel) - stats = getattr(jmodel, "te_fp8_stats", None) - if stats is None: - if rank == 0: - return [AssertionError("TransformerEngine FP8 stats not found")] - return None + # Resolve latest TE FP8 states on demand (forward and backward) + resolved_fwd = False + resolved_bw = False + try: + if hasattr(jmodel, "te_fp8_states"): + f_entry = jmodel.te_fp8_states() + if isinstance(f_entry, dict) and ("delayed" in f_entry or "mxfp8" in f_entry): + resolved_fwd = True + b_entry = jmodel.te_fp8_states(mode="backward") + if isinstance(b_entry, dict) and ("delayed" in b_entry or "mxfp8" in b_entry): + resolved_bw = True + except Exception: + pass payload = { "rank": rank, "error": None, - "forward": len(stats.get("forward", [])), - "backward": len(stats.get("backward", [])), - # Include minimal info for delayed-scaling check on rank 0 - "has_delayed": any("delayed" in e and e["delayed"] for e in stats.get("forward", [])), + "resolved_fwd": resolved_fwd, + "resolved_bw": resolved_bw, } gathered = [None] * world_size if rank == 0 else None @@ -857,11 +861,9 @@ def train(model): if p["error"]: exceptions.append(AssertionError(p["error"])) continue - # We export one entry per iteration for both forward and backward - assert p["forward"] == iters - assert p["backward"] == iters - # At least one entry should include delayed info on platforms using delayed scaling - # We do not enforce across ranks since recipe may vary by platform + # We should resolve both forward and backward entries + assert p["resolved_fwd"] + assert p["resolved_bw"] return exceptions return None diff --git a/thunder/tests/distributed/test_fsdp.py b/thunder/tests/distributed/test_fsdp.py index cda50425a7..a3116bfae1 100644 --- a/thunder/tests/distributed/test_fsdp.py +++ b/thunder/tests/distributed/test_fsdp.py @@ -1299,19 +1299,25 @@ def train(model): train(jmodel) - stats = getattr(jmodel, "te_fp8_stats", None) - if stats is None: - if rank == 0: - return [AssertionError("TransformerEngine FP8 stats not found")] - return None + # Resolve latest TE FP8 states on demand (forward and backward) + resolved_fwd = False + resolved_bw = False + try: + if hasattr(jmodel, "te_fp8_states"): + f_entry = jmodel.te_fp8_states() + if isinstance(f_entry, dict) and ("delayed" in f_entry or "mxfp8" in f_entry): + resolved_fwd = True + b_entry = jmodel.te_fp8_states(mode="backward") + if isinstance(b_entry, dict) and ("delayed" in b_entry or "mxfp8" in b_entry): + resolved_bw = True + except Exception: + pass payload = { "rank": rank, "error": None, - "forward": len(stats.get("forward", [])), - "backward": len(stats.get("backward", [])), - # Include minimal info for delayed-scaling check on rank 0 - "has_delayed": any("delayed" in e and e["delayed"] for e in stats.get("forward", [])), + "resolved_fwd": resolved_fwd, + "resolved_bw": resolved_bw, } gathered = [None] * world_size if rank == 0 else None @@ -1326,10 +1332,9 @@ def train(model): if p["error"]: exceptions.append(AssertionError(p["error"])) continue - # We export one entry per iteration for both forward and backward - assert p["forward"] == iters - assert p["backward"] == iters - # At least one entry should include delayed info on platforms using delayed scaling + # We should resolve both forward and backward entries + assert p["resolved_fwd"] + assert p["resolved_bw"] return exceptions return None diff --git a/thunder/tests/test_export_stateful_ex_transform.py b/thunder/tests/test_export_stateful_ex_transform.py index 4f97cdbcf7..a61121cc7f 100644 --- a/thunder/tests/test_export_stateful_ex_transform.py +++ b/thunder/tests/test_export_stateful_ex_transform.py @@ -91,33 +91,33 @@ def forward(self, x): with te.fp8_autocast(fp8_recipe=fp8_recipe): y = jmodel(x) - # Validate TE exporter populated as expected - assert hasattr(jmodel, "te_fp8_stats"), "ThunderModule should expose te_fp8_stats" - stats = jmodel.te_fp8_stats - assert isinstance(stats, dict) and set(stats.keys()) == {"forward", "backward"} - # After forward, we should have exactly one forward entry and no backward entries yet - assert len(stats["forward"]) == 1 - assert len(stats["backward"]) == 0 - f_entry = stats["forward"][0] + # Validate TE exporter resolves values after forward + assert hasattr(jmodel, "te_fp8_states"), "ThunderModule should expose te_fp8_states()" + f_entry = jmodel.te_fp8_states() assert isinstance(f_entry, dict) # Ensure we collected either delayed scaling or block-scaling style info assert ("delayed" in f_entry and isinstance(f_entry["delayed"], list)) or ( - "mxfp8_or_block" in f_entry and isinstance(f_entry["mxfp8_or_block"], list) + "mxfp8" in f_entry and isinstance(f_entry["mxfp8"], list) ) # If delayed scaling is used, ensure amax and scale are present if isinstance(fp8_recipe, recipe.DelayedScaling) or (fp8_recipe is None and te.fp8.check_fp8_support()[0]): - assert "delayed" in f_entry - d = f_entry["delayed"][0] - assert d.get("scale") is not None - assert d.get("amax") is not None + if "delayed" in f_entry and f_entry["delayed"]: + d = f_entry["delayed"][0] + assert d.get("scale") is not None + assert d.get("amax") is not None + # Expect two elements (forward/backward states across two linear sites) + assert isinstance(d.get("scale"), list) and len(d["scale"]) == 2 + assert isinstance(d.get("amax"), list) and len(d["amax"]) == 2 grad_output = torch.randn_like(y) y.backward(grad_output) - # After backward pass, one backward entry should be present - stats = jmodel.te_fp8_stats - assert len(stats["forward"]) == 1 - assert len(stats["backward"]) == 1 + # Validate TE exporter resolves values after backward + b_entry = jmodel.te_fp8_states(mode="backward") + assert isinstance(b_entry, dict) + assert ("delayed" in b_entry and isinstance(b_entry["delayed"], list)) or ( + "mxfp8" in b_entry and isinstance(b_entry["mxfp8"], list) + ) @requiresCUDA @@ -167,15 +167,14 @@ def forward(self, x): # Forward under FP8 autocast with te.fp8_autocast(fp8_recipe=fp8_recipe): y = jmodel(x) + f_entry = jmodel.te_fp8_states() + assert isinstance(f_entry, dict) + assert ("delayed" in f_entry) or ("mxfp8" in f_entry) # Backward with unit upstream gradient y.backward(torch.ones_like(y)) - - # Validate exporter after multiple iterations - assert hasattr(jmodel, "te_fp8_stats") - stats = jmodel.te_fp8_stats - # One forward and one backward export entry per iteration - assert len(stats["forward"]) == num_iters - assert len(stats["backward"]) == num_iters + b_entry = jmodel.te_fp8_states(mode="backward") + assert isinstance(b_entry, dict) + assert ("delayed" in b_entry) or ("mxfp8" in b_entry) @requiresCUDA @@ -233,13 +232,13 @@ def thunder_model(x, fp8_recipe): train_model(thunder_model) - stats = jmodel.te_fp8_stats - # We expect as many forward/backward entries as iterations - assert len(stats["forward"]) == iters - assert len(stats["backward"]) == iters - # Across all entries, we should see delayed info and, if supported, possibly block info - has_delayed = any(e.get("delayed") for e in stats["forward"]) or any(e.get("delayed") for e in stats["backward"]) - assert has_delayed + # Resolve most recent forward and backward entries on demand + f_entry = jmodel.te_fp8_states() + assert isinstance(f_entry, dict) + assert ("delayed" in f_entry) or ("mxfp8" in f_entry) + b_entry = jmodel.te_fp8_states(mode="backward") + assert isinstance(b_entry, dict) + assert ("delayed" in b_entry) or ("mxfp8" in b_entry) @requiresCUDA @@ -303,21 +302,13 @@ def train_model(model): train_model(compiled_model) - # Collect TE fp8 stats from Thunder-compiled subgraphs - reporters = [] + # Collect TE fp8 states from Thunder-compiled subgraphs (resolve on-demand) + resolved = 0 for sinfo in backend.subgraph_infos: if sinfo.thunder_compiled_fns: for fn in sinfo.thunder_compiled_fns: - if hasattr(fn, "te_fp8_stats"): - reporters.append(fn.te_fp8_stats) - - # We expect at least one Thunder subgraph using TE - assert len(reporters) >= 1 - - # Aggregate counts across subgraphs - total_fw_entries = sum(len(r["forward"]) for r in reporters) - total_bw_entries = sum(len(r["backward"]) for r in reporters) - - # We expect at least one Thunder subgraph using TE and to have exported entries - assert total_fw_entries == iters - assert total_bw_entries == iters + if hasattr(fn, "te_fp8_states") and fn.te_fp8_states.refs['forward'] is not None: + entry = fn.te_fp8_states() + if isinstance(entry, dict) and ("delayed" in entry or "mxfp8" in entry): + resolved += 1 + assert resolved >= 1 From 8db5a9d15a9089248bfef41943a264191f07267d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Oct 2025 15:58:11 +0000 Subject: [PATCH 12/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/dev_utils/export_stateful_ex_transform.py | 3 +-- thunder/executors/transformer_engineex_impl.py | 1 - thunder/tests/test_export_stateful_ex_transform.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/thunder/dev_utils/export_stateful_ex_transform.py b/thunder/dev_utils/export_stateful_ex_transform.py index ab39bcb48f..9062ef4c6d 100644 --- a/thunder/dev_utils/export_stateful_ex_transform.py +++ b/thunder/dev_utils/export_stateful_ex_transform.py @@ -1,6 +1,5 @@ import weakref from collections.abc import Callable -from typing import List, Tuple from thunder.core.transform_common import ( Transform, @@ -60,7 +59,7 @@ class ExportStatefulExecutorsTransform(Transform): """ _register_callbacks: dict[str, Callable] = {} - _callback_attributes: List[Tuple[str, type[ExportStatefulExecutorsStats], Callable]] = [] + _callback_attributes: list[tuple[str, type[ExportStatefulExecutorsStats], Callable]] = [] _instance = None diff --git a/thunder/executors/transformer_engineex_impl.py b/thunder/executors/transformer_engineex_impl.py index 26f9b800e1..b2cd6e56d6 100644 --- a/thunder/executors/transformer_engineex_impl.py +++ b/thunder/executors/transformer_engineex_impl.py @@ -7,7 +7,6 @@ import torch import torch.distributed as torch_dist -import thunder from thunder.core.module import ThunderModule from thunder.core.prims import linear as linear_prim from thunder.core.prims import get_grad, put_grad diff --git a/thunder/tests/test_export_stateful_ex_transform.py b/thunder/tests/test_export_stateful_ex_transform.py index a61121cc7f..54bdf14473 100644 --- a/thunder/tests/test_export_stateful_ex_transform.py +++ b/thunder/tests/test_export_stateful_ex_transform.py @@ -307,7 +307,7 @@ def train_model(model): for sinfo in backend.subgraph_infos: if sinfo.thunder_compiled_fns: for fn in sinfo.thunder_compiled_fns: - if hasattr(fn, "te_fp8_states") and fn.te_fp8_states.refs['forward'] is not None: + if hasattr(fn, "te_fp8_states") and fn.te_fp8_states.refs["forward"] is not None: entry = fn.te_fp8_states() if isinstance(entry, dict) and ("delayed" in entry or "mxfp8" in entry): resolved += 1 From 5d0e41704d1670f28c4ad7dd80c6c20d3366752b Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Fri, 3 Oct 2025 15:59:50 +0000 Subject: [PATCH 13/13] Removed import --- thunder/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 0e70f7d7bf..662336b129 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -55,7 +55,6 @@ unwrap_return_value, wrap_return_value_together_with_arguments, ) -from thunder.dev_utils.export_stateful_ex_transform import ExportStatefulExecutorsTransform from thunder.core.update_aliases import insert_alias_updates from thunder.executors.torch_autograd import connect_to_autograd import thunder.extend as extend