diff --git a/thunder/dev_utils/export_stateful_ex_transform.py b/thunder/dev_utils/export_stateful_ex_transform.py new file mode 100644 index 0000000000..9062ef4c6d --- /dev/null +++ b/thunder/dev_utils/export_stateful_ex_transform.py @@ -0,0 +1,123 @@ +import weakref +from collections.abc import Callable + +from thunder.core.transform_common import ( + Transform, +) +from thunder.core.trace import TraceCtx as Trace +from thunder.core.module import ThunderModule + + +class ExportStatefulExecutorsStats: + def __init__(self, tm: ThunderModule, resolver_fn: Callable): + """Lightweight accessor attached to a `ThunderModule`. + + Args: + tm: The `ThunderModule` instance this accessor belongs to. + resolver_fn: A callable that knows how to resolve the recorded + references on `tm` and return real values. + """ + self.tm = tm + self.resolver_fn = resolver_fn + + +class ExportStatefulExecutorsTransform(Transform): + """Register references and resolve runtime state lazily. + + What this transform does: + - Singleton registry to plug per-executor exporters + - At module transform time, installs a lightweight accessor on the module + (e.g., `module.te_fp8_states`) that can resolve values on demand + - At post-optimization time, calls registered reference callbacks to record + only where values will materialize (holders + attribute paths). No data + are copied or materialized in this step + - When code calls the accessor (e.g., `module.te_fp8_states()`), the resolve + callback reads the recorded references and returns the latest values + + + API overview: + - register_ref_callback(name, register_cb, resolve_cb, instance_cls): + name: attribute name to attach on the module + register_cb(trace, module): store references from the trace/python_ctx + resolve_cb(module): materialize and return values using the stored refs + instance_cls: a small class constructed as instance_cls(module, resolve_cb) + and attached as `setattr(module, name, instance)`; it typically stores + containers for references and implements __call__(...) to resolve + + Usage: + 1) Register once at import/init time. For example, for TransformerEngine: + ExportStatefulExecutorsTransform.register_ref_callback( + "te_fp8_states", register_cb, resolve_cb, StatsClass + ) + 2) Enable at compile time: + thunder.jit(model, executors=[...], transforms=[..., ExportStatefulExecutorsTransform()]) + 3) After each run, call `module.te_fp8_states()` to resolve and return the latest values. + + Notes: + - Supports multiple ThunderModule instances (e.g., subgraphs) + - Callback errors are swallowed to avoid interfering with execution + """ + + _register_callbacks: dict[str, Callable] = {} + _callback_attributes: list[tuple[str, type[ExportStatefulExecutorsStats], Callable]] = [] + + _instance = None + + def __new__(cls, *args, **kwargs): + """Ensure singleton instance across repeated transform construction.""" + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + """Initialize internal weakrefs registry. + + ThunderCompiler and other compilation flows may create multiple + `ThunderModule` instances for subgraphs; we keep weak references + to update all of them during post-optimization registration. + """ + self.tm_refs = [] + + @classmethod + def register_ref_callback( + cls, name: str, callback: Callable, resolve_cb: Callable, instance: type[ExportStatefulExecutorsStats] + ) -> None: + """Register per-executor reference and resolver callbacks. + + Installs a module attribute named `name` by constructing `instance` with + the resolver function. The `callback` will be invoked during + post-optimization to record reference locations on the module. + + Args: + name: Module attribute to attach (e.g., "te_fp8_states"). + callback: Function `(trace, module) -> None` that records refs. + resolve_cb: Function `(module) -> Any` that resolves values on demand. + instance: A class (must be a subclass of ExportStatefulExecutorsStats) constructed as `instance(module, resolve_cb)`. + """ + if not issubclass(instance, ExportStatefulExecutorsStats): + raise TypeError(f"Provided instance {instance} must be a subclass of ExportStatefulExecutorsStats") + cls._register_callbacks[name] = callback + cls._callback_attributes.append((name, instance, resolve_cb)) + + def transform_module(self, model) -> None: + assert model is not None + # Cache a weakref to the ThunderModule for later runtime export + self.tm_refs.append(weakref.ref(model)) + # Initialize attributes on model + for name, instance, resolve_cb in self._callback_attributes: + setattr(model, name, instance(model, resolve_cb)) + + def transform_trace_post_optimization(self, computation_trace: Trace, **kwargs): + for tm_ref in self.tm_refs: + # Resolve ThunderModule from weakref; if unavailable, skip + tm = tm_ref() if tm_ref is not None else None + if tm is None: + continue + + # Invoke all registered callbacks to register reference locations + for _, cb in self._register_callbacks.items(): + try: + cb(computation_trace, tm) + except Exception: + pass + return computation_trace diff --git a/thunder/executors/transformer_engineex_impl.py b/thunder/executors/transformer_engineex_impl.py index ac4fc43372..b2cd6e56d6 100644 --- a/thunder/executors/transformer_engineex_impl.py +++ b/thunder/executors/transformer_engineex_impl.py @@ -1,9 +1,13 @@ import time from typing import TYPE_CHECKING import warnings +from collections import defaultdict +from collections.abc import Callable +import torch import torch.distributed as torch_dist +from thunder.core.module import ThunderModule from thunder.core.prims import linear as linear_prim from thunder.core.prims import get_grad, put_grad from thunder.core.proxies import AnyProxy, TensorProxy @@ -40,6 +44,10 @@ from transformer_engine.pytorch.ops import BasicLinear from transformer_engine.pytorch.tensor import Quantizer from transformer_engine.pytorch.utils import check_dim_for_fp8_exec +from thunder.dev_utils.export_stateful_ex_transform import ( + ExportStatefulExecutorsTransform as _ExportSETransform, + ExportStatefulExecutorsStats, +) transformer_engine_ex = StatefulExecutor("transformer_engine") @@ -374,6 +382,159 @@ def reset(self): self.redundant_map = {} self.new_saved_for_backward = None + class TEFP8Stats(ExportStatefulExecutorsStats): + def __init__(self, tm: ThunderModule, resolver_fn: Callable): + """Accessor attached on the module to resolve TE FP8 states on demand. + + Args: + tm: ThunderModule to which this accessor is bound. + resolver_fn: Callable invoked as `resolver_fn(mode, tm)` to produce + the latest snapshot of FP8 state based on registered refs. + """ + super().__init__(tm, resolver_fn) + self.refs = {"forward": None, "backward": None} + + def __call__(self, mode: str = "forward") -> dict: + """Resolve and return the latest FP8 state for the given mode. + + Args: + mode: "forward" or "backward". Defaults to "forward". + Returns: + A dictionary snapshot of resolved values (e.g., delayed or mxfp8 entries), + or an empty dict on invalid mode or if nothing is recorded. + """ + if mode not in ["forward", "backward"]: + warnings.warn(f"Received an invalid inspection mode: {mode}. Please use 'forward' or 'backward'.") + return {} + return self.resolver_fn(mode, self.tm) + + @staticmethod + def register_refs(computation_trace, tm) -> None: + """Record where FP8 values will materialize for later lazy resolution. + + This inspects the trace's python context, finds TE state and quantizer + holders, and stores only references (holder objects and attribute paths) + into the module accessor. No tensors or runtime data are copied here. + The actual values are read by `resolve_values` after execution. + """ + python_ctx = computation_trace.python_ctx() + + # Collect holders and where to read values from later + refs = defaultdict(list) + + # Collect mode from trace tags + mode = "forward" if TraceTag.AUGMENTED_FORWARD in computation_trace.tags else "backward" + + # States: register all state holders; decide recipe type at resolve time + state_holders = [v for k, v in python_ctx.items() if isinstance(k, str) and k.startswith("get_te_fp8_state")] + for sh in state_holders: + # Always store attrs we may need; recipe classification happens later + refs["state_holder"].append( + { + "holder": sh, + "scale_attr": "state.scale", + "amax_attr": "state.amax_history", + } + ) + + # Quantizers (MXFP8/block): resolve via TEQuantizerState linked to RecipeState + quantizer_holders = [ + v for k, v in python_ctx.items() if isinstance(k, str) and k.startswith("get_te_fp8_quantizers") + ] + for qh in quantizer_holders: + refs["quantizer_holder"].append( + {"holder": qh, "quant_attr": "quantizers", "parent_state_attr": "parent_recipe_state"} + ) + + if len(refs) > 0: + tm.te_fp8_states.refs[mode] = refs + + @staticmethod + def resolve_values(mode: str, tm: ThunderModule) -> dict: + """Load and serialize FP8 values using previously-registered references. + + Args: + mode: "forward" or "backward" indicating which refs to resolve. + tm: ThunderModule whose accessor holds the recorded references. + Returns: + A dictionary with resolved entries (e.g., {"delayed": [...]} or + {"mxfp8": [...]}); returns empty dict if nothing is recorded. + """ + + def _get_attr(obj, attr_path: str): + cur = obj + for part in attr_path.split("."): + cur = getattr(cur, part) + return cur + + def _tensor_head(t, max_numel: int = 8192): + if not isinstance(t, torch.Tensor): + return None + n = min(t.numel(), max_numel) + return t.detach().float().cpu().view(-1)[:n].tolist() + + # Pull last registered refs for this mode + refs = tm.te_fp8_states.refs[mode] + if refs is None: + return {} + + out = defaultdict(list) + + # Classify states now that recipes and states have materialized + # MXFP8/block scaling: will be collected via quantizers section below + for ref in refs.get("state_holder", []): + sh = ref["holder"] + recipe = getattr(sh, "parent_recipe", None) + state = getattr(sh, "state", None) + if recipe is None or state is None: + continue + # Delayed scaling: extract from state tensors + if getattr(recipe, "delayed", lambda: False)(): + scale = _get_attr(sh, ref["scale_attr"]) # state.scale + amax_hist = _get_attr(sh, ref["amax_attr"]) # state.amax_history + scale_vals = _tensor_head(scale) + amax_vals = _tensor_head( + amax_hist[-1] if isinstance(amax_hist, torch.Tensor) and amax_hist.numel() > 0 else amax_hist + ) + out["delayed"].append( + { + "scale_shape": getattr(scale, "shape", None), + "scale": scale_vals, + "amax_shape": getattr(amax_hist, "shape", None), + "amax": amax_vals, + } + ) + + # MXFP8 via quantizers + # First, build mapping from recipe state id to quantizers + state_to_qs = {} + for ref in refs.get("quantizer_holder", []): + qh = ref["holder"] + prs = getattr(qh, "parent_recipe_state", None) + qs = getattr(qh, "quantizers", None) + if prs is not None and qs: + state_to_qs.setdefault(id(prs), []).extend(qs) + + # For MXFP8/block scaling, gather quantizers linked to each materialized state + for ref in refs.get("state_holder", []): + sh = ref["holder"] + recipe = getattr(sh, "parent_recipe", None) + state = getattr(sh, "state", None) + if recipe is None or state is None: + continue + if getattr(recipe, "mxfp8", lambda: False)(): + for q in state_to_qs.get(id(state), []): + entry = { + "cls": q.__class__.__name__, + "rowwise_usage": getattr(q, "rowwise_usage", None), + "columnwise_usage": getattr(q, "columnwise_usage", None), + "dtype": str(getattr(q, "dtype", None)), + } + if entry not in out["mxfp8"]: + out["mxfp8"].append(entry) + + return out + def transform_trace_post_optimization(self, computation_trace, **kwargs): """ Finds and replaces TE executor recipe calls and replaces them with one. @@ -513,3 +674,15 @@ def _te_activation_checkpointing_transform(joint_trace: TraceCtx) -> TraceCtx: new_trace.bound_symbols = [bsym.from_bsym_swap_proxies(swapmap) for bsym in reversed(reversed_bsyms)] return new_trace + + +# Register TE reference and resolve callbacks with the singleton export transform +try: + _ExportSETransform.register_ref_callback( + "te_fp8_states", + TransformerEngineTransform.register_refs, + TransformerEngineTransform.resolve_values, + TransformerEngineTransform.TEFP8Stats, + ) +except Exception: + pass diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index c1eb9e8d93..6202e6ee8b 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -26,6 +26,7 @@ ) from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform +from thunder.dev_utils.export_stateful_ex_transform import ExportStatefulExecutorsTransform is_fp8_supported: bool = False @@ -774,6 +775,99 @@ def _test_ddp_transformer_engine_llama_sanity(input_data): return None +def _test_ddp_transformer_engine_state_export(input_data): + # Test Description: + # This test ensures that the ExportStatefulExecutorsTransform correctly collects and exports + # TransformerEngine FP8 state information (such as amax/scale and quantizer stats) during + # distributed DDP training. It verifies that the state reporter works as expected across + # multiple processes, capturing both forward and backward state summaries for each rank. + + init_method, world_size, rank, _executor, device, _dtype, _unused_kwargs = input_data + devicetype = devices.device_from_string(device).devicetype + pg = init_per_process_distributed(init_method, devicetype, world_size, rank) + + torch.cuda.set_device(rank) + torch_device = torch.device("cuda", rank) + + dtype_t = torch.bfloat16 + if rank == 0: + x = torch.randn(3, 768, 4096, device=torch_device, dtype=dtype_t, requires_grad=True) + else: + x = torch.randn(2, 768, 4096, device=torch_device, dtype=dtype_t, requires_grad=True) + + class Module(nn.Module): + def __init__(self): + super().__init__() + self.w1 = nn.Parameter(torch.randn(4096, 4096, device=torch_device, dtype=dtype_t)) + self.w2 = nn.Parameter(torch.randn(2048, 4096, device=torch_device, dtype=dtype_t)) + + def forward(self, x): + o = torch.nn.functional.linear(x, self.w1) + added = o + x + return torch.nn.functional.linear(added, self.w2) + + model = Module() + jmodel = thunder.distributed.ddp( + thunder.jit( + model, + executors=[transformer_engine_ex], + transforms=[TransformerEngineTransform(), ExportStatefulExecutorsTransform()], + ) + ) + + # Use default TE recipe for this test + fp8_recipe = get_default_fp8_recipe() + + iters = 10 + + def train(model): + for _ in range(iters): + with fp8_autocast(fp8_recipe=fp8_recipe): + y = model(x) + y.backward(torch.ones_like(y)) + + train(jmodel) + + # Resolve latest TE FP8 states on demand (forward and backward) + resolved_fwd = False + resolved_bw = False + try: + if hasattr(jmodel, "te_fp8_states"): + f_entry = jmodel.te_fp8_states() + if isinstance(f_entry, dict) and ("delayed" in f_entry or "mxfp8" in f_entry): + resolved_fwd = True + b_entry = jmodel.te_fp8_states(mode="backward") + if isinstance(b_entry, dict) and ("delayed" in b_entry or "mxfp8" in b_entry): + resolved_bw = True + except Exception: + pass + + payload = { + "rank": rank, + "error": None, + "resolved_fwd": resolved_fwd, + "resolved_bw": resolved_bw, + } + + gathered = [None] * world_size if rank == 0 else None + tdist.gather_object(payload, object_gather_list=gathered, dst=0, group=pg) + + tdist.barrier(pg) + tdist.destroy_process_group(pg) + + if rank == 0: + exceptions = [] + for p in gathered: + if p["error"]: + exceptions.append(AssertionError(p["error"])) + continue + # We should resolve both forward and backward entries + assert p["resolved_fwd"] + assert p["resolved_bw"] + return exceptions + return None + + # NOTE This is just a stub, see the NOTE for ddp_wrapper @instantiate( dtypes=(thunder.float32,), @@ -877,5 +971,22 @@ def test_ddp_transformer_engine_llama_sanity(executor, devices, dtype): pass +@instantiate( + dtypes=(thunder.float32,), + num_devices=2, + devicetypes=(devices.DeviceType.CUDA,), + executors=(TorchExecutor,), + decorators=( + pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."), + pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason), + pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 devices"), + unittest.mock.patch.dict(os.environ, {"NVTE_TORCH_COMPILE": "0"}), + ), +) +@distributed_wrapper("test_ddp_transformer_engine_state_export", _test_ddp_transformer_engine_state_export) +def test_ddp_transformer_engine_state_export(executor, devices, dtype): + pass + + if __name__ == "__main__": common_utils.run_tests() diff --git a/thunder/tests/distributed/test_fsdp.py b/thunder/tests/distributed/test_fsdp.py index 5acc79cdaa..a3116bfae1 100644 --- a/thunder/tests/distributed/test_fsdp.py +++ b/thunder/tests/distributed/test_fsdp.py @@ -29,6 +29,7 @@ ) from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform +from thunder.dev_utils.export_stateful_ex_transform import ExportStatefulExecutorsTransform is_fp8_supported: bool = False @@ -1237,6 +1238,107 @@ def _test_fsdp_transformer_engine_bucketing(input_data): return None +def _test_fsdp_transformer_engine_state_export(input_data): + # Test Description: + # Verify that the TEStateReporter correctly captures and reports TransformerEngine + # FP8 state information during DDP training execution, including global context, + # recipe summaries, and forward/backward state summaries across distributed processes. + + init_method, world_size, rank, executor, device, _dtype, kwargs = input_data + thunder_fsdp_strategy, intermediate_activation_sharding = kwargs["thunder_fsdp_strategy_and_intermediate_sharding"] + devicetype = devices.device_from_string(device).devicetype + + # Setting LOCAL_RANK is necessary for thunder.distributed.fsdp + with unittest.mock.patch.dict(os.environ, {"LOCAL_RANK": str(rank)}): + pg = init_per_process_distributed(init_method, devicetype, world_size, rank) + torch.cuda.set_device(rank) + torch_device = torch.device("cuda", rank) + dtype_t = torch.bfloat16 + + fp8_recipe = get_default_fp8_recipe() + + dim = 256 + + class ThunderModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(dim, dim, bias=False, device=torch_device, dtype=dtype_t) + self.fc2 = torch.nn.Linear(dim, dim, bias=False, device=torch_device, dtype=dtype_t) + self.fc3 = torch.nn.Linear(dim, dim, bias=False, device=torch_device, dtype=dtype_t) + + def forward(self, x): + return self.fc3(self.fc2(torch.nn.functional.relu(self.fc1(x)))) + + # Inputs (different input on different rank). + if rank == 0: + x = torch.arange(dim * dim, dtype=dtype_t, device=torch_device).view(dim, dim) + if rank == 1: + x = torch.randn(dim, dim, device=torch_device, dtype=dtype_t) * 100 + + model = ThunderModel() + jmodel = thunder.distributed.fsdp( + thunder.jit( + model, + executors=[ + transformer_engine_ex, + ] + + executor.executors_list(), + fp8_shard_intermediate_activation=intermediate_activation_sharding, + transforms=[TransformerEngineTransform(), ExportStatefulExecutorsTransform()], + ), + sharding_strategy=thunder_fsdp_strategy, + ) + + iters = 10 + + def train(model): + for _ in range(iters): + with fp8_autocast(fp8_recipe=fp8_recipe): + y = model(x) + y.backward(torch.ones_like(y)) + + train(jmodel) + + # Resolve latest TE FP8 states on demand (forward and backward) + resolved_fwd = False + resolved_bw = False + try: + if hasattr(jmodel, "te_fp8_states"): + f_entry = jmodel.te_fp8_states() + if isinstance(f_entry, dict) and ("delayed" in f_entry or "mxfp8" in f_entry): + resolved_fwd = True + b_entry = jmodel.te_fp8_states(mode="backward") + if isinstance(b_entry, dict) and ("delayed" in b_entry or "mxfp8" in b_entry): + resolved_bw = True + except Exception: + pass + + payload = { + "rank": rank, + "error": None, + "resolved_fwd": resolved_fwd, + "resolved_bw": resolved_bw, + } + + gathered = [None] * world_size if rank == 0 else None + tdist.gather_object(payload, object_gather_list=gathered, dst=0, group=pg) + + tdist.barrier(pg) + tdist.destroy_process_group(pg) + + if rank == 0: + exceptions = [] + for p in gathered: + if p["error"]: + exceptions.append(AssertionError(p["error"])) + continue + # We should resolve both forward and backward entries + assert p["resolved_fwd"] + assert p["resolved_bw"] + return exceptions + return None + + @instantiate( dtypes=(thunder.float32,), num_devices=2, @@ -1269,6 +1371,37 @@ def test_fsdp_transformer_engine(executor, devices, dtype, thunder_fsdp_strategy pass +@instantiate( + dtypes=(thunder.float32,), + num_devices=2, + devicetypes=(devices.DeviceType.CUDA,), + executors=(TorchExecutor,), + decorators=( + pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 devices"), + unittest.mock.patch.dict(os.environ, {"NVTE_TORCH_COMPILE": "0"}), + pytest.mark.parametrize( + "thunder_fsdp_strategy_and_intermediate_sharding", + ( + (FSDPType.ZERO2, False), + (FSDPType.ZERO3, False), + # Intermediate sharding is only availabe TE v1.8 onwards + pytest.param( + (FSDPType.ZERO3, True), + marks=pytest.mark.skip("Intermediate sharding is errors in TE 2.0 (also with eager)."), + ), + ), + ), + pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."), + pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason), + ), +) +@distributed_wrapper("test_fsdp_transformer_engine_state_export", _test_fsdp_transformer_engine_state_export) +def test_fsdp_transformer_engine_state_export( + executor, devices, dtype, thunder_fsdp_strategy_and_intermediate_sharding +): + pass + + @instantiate( dtypes=(thunder.float32,), num_devices=2, diff --git a/thunder/tests/test_export_stateful_ex_transform.py b/thunder/tests/test_export_stateful_ex_transform.py new file mode 100644 index 0000000000..54bdf14473 --- /dev/null +++ b/thunder/tests/test_export_stateful_ex_transform.py @@ -0,0 +1,314 @@ +import pytest +import torch +import torch.nn as nn + +import thunder +from thunder.tests.framework import requiresCUDA + + +# NOTE: On SM120/121, TE defaults to using Float8BlockScaling +# which is currently unsupported in thunder, we skip the tests for these SM architectures. +from thunder.tests.utils import skip_on_sm120_and_sm121, is_sm120_orsm121 +from thunder.dev_utils.export_stateful_ex_transform import ExportStatefulExecutorsTransform +from thunder.dynamo import ThunderCompiler + +# Make TE optional so this file can host tests for other executors too +TE_AVAILABLE = False +try: + import transformer_engine.pytorch as te + from transformer_engine.common import recipe + from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform + + TE_AVAILABLE = True +except Exception: + te = None + recipe = None + TE_AVAILABLE = False + +if TE_AVAILABLE: + # FP8 is supported on compute arch 8.9 onwards. + # MXFP8 is supported on compute arch 10.0 onwards. + # Skip the TE-specific parametrizations if current hardware is not supported. + is_fp8_supported, msg_fp8 = te.fp8.check_fp8_support() + is_mxfp8_supported, msg_mxfp8 = te.fp8.check_mxfp8_support() + if not is_fp8_supported: + pytest.skip(msg_fp8, allow_module_level=True) + + hybrid_fp8_delayed_scaling_recipe = recipe.DelayedScaling() + mxfp8_e4m3_recipe = recipe.MXFP8BlockScaling() + + # `None` is used to test the default recipe. + recipes = (None, hybrid_fp8_delayed_scaling_recipe, mxfp8_e4m3_recipe) + recipe_ids = ("default", "delayed_scaling", "mxfp8_e4m3") +else: + is_mxfp8_supported, msg_mxfp8 = (False, "TransformerEngine not available") + recipes = (None,) + recipe_ids = ("default",) + + +@requiresCUDA +@pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed.") +@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids) +@skip_on_sm120_and_sm121 +def test_export_te_states_linear_forward(fp8_recipe): + if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported): + pytest.skip(msg_mxfp8) + + if is_sm120_orsm121 and fp8_recipe is None: + pytest.skip("On SM120/121, default recipe is Float8BlockScaling which is not supported") + + # Test Description: + # Verify that the TEStateReporter correctly captures and reports TransformerEngine + # FP8 state information during forward pass execution, including global context, + # recipe summaries, and forward state summaries. + + dtype = torch.bfloat16 + device = "cuda" + + # Inputs (3D input) + x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) + + class Module(nn.Module): + def __init__(self): + super().__init__() + self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype)) + self.w2 = nn.Parameter(torch.randn(2048, 4096, device=device, dtype=dtype)) + + def forward(self, x): + o = torch.nn.functional.linear(x, self.w1) + added = o + x + return torch.nn.functional.linear(added, self.w2) + + model = Module() + + jmodel = thunder.jit( + model, + executors=[transformer_engine_ex], + transforms=[TransformerEngineTransform(), ExportStatefulExecutorsTransform()], + ) + + # Enable autocasting for the forward pass + with te.fp8_autocast(fp8_recipe=fp8_recipe): + y = jmodel(x) + + # Validate TE exporter resolves values after forward + assert hasattr(jmodel, "te_fp8_states"), "ThunderModule should expose te_fp8_states()" + f_entry = jmodel.te_fp8_states() + assert isinstance(f_entry, dict) + # Ensure we collected either delayed scaling or block-scaling style info + assert ("delayed" in f_entry and isinstance(f_entry["delayed"], list)) or ( + "mxfp8" in f_entry and isinstance(f_entry["mxfp8"], list) + ) + # If delayed scaling is used, ensure amax and scale are present + if isinstance(fp8_recipe, recipe.DelayedScaling) or (fp8_recipe is None and te.fp8.check_fp8_support()[0]): + if "delayed" in f_entry and f_entry["delayed"]: + d = f_entry["delayed"][0] + assert d.get("scale") is not None + assert d.get("amax") is not None + # Expect two elements (forward/backward states across two linear sites) + assert isinstance(d.get("scale"), list) and len(d["scale"]) == 2 + assert isinstance(d.get("amax"), list) and len(d["amax"]) == 2 + + grad_output = torch.randn_like(y) + y.backward(grad_output) + + # Validate TE exporter resolves values after backward + b_entry = jmodel.te_fp8_states(mode="backward") + assert isinstance(b_entry, dict) + assert ("delayed" in b_entry and isinstance(b_entry["delayed"], list)) or ( + "mxfp8" in b_entry and isinstance(b_entry["mxfp8"], list) + ) + + +@requiresCUDA +@pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed.") +@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids) +@skip_on_sm120_and_sm121 +def test_export_te_states_linear_forward_backward_multiple_iteration(fp8_recipe): + if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported): + pytest.skip(msg_mxfp8) + + if is_sm120_orsm121 and fp8_recipe is None: + pytest.skip("On SM120/121, default recipe is Float8BlockScaling which is not supported") + + # Test Description: + # Run multiple forward/backward iterations under a single recipe configuration and + # verify that the TE reporter does not grow with the iteration count. The recipe + # list should contain one unique entry, and state/quantizer summaries should reflect + # the two linear call sites exactly once per direction, independent of iterations. + + dtype = torch.bfloat16 + device = "cuda" + + # Inputs and model + x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) + + class Module(nn.Module): + def __init__(self): + super().__init__() + self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype)) + self.w2 = nn.Parameter(torch.randn(2048, 4096, device=device, dtype=dtype)) + + def forward(self, x): + o = torch.nn.functional.linear(x, self.w1) + added = o + x + return torch.nn.functional.linear(added, self.w2) + + model = Module() + + jmodel = thunder.jit( + model, + executors=[transformer_engine_ex], + transforms=[TransformerEngineTransform(), ExportStatefulExecutorsTransform()], + ) + + num_iters = 10 + for _ in range(num_iters): + # Forward under FP8 autocast + with te.fp8_autocast(fp8_recipe=fp8_recipe): + y = jmodel(x) + f_entry = jmodel.te_fp8_states() + assert isinstance(f_entry, dict) + assert ("delayed" in f_entry) or ("mxfp8" in f_entry) + # Backward with unit upstream gradient + y.backward(torch.ones_like(y)) + b_entry = jmodel.te_fp8_states(mode="backward") + assert isinstance(b_entry, dict) + assert ("delayed" in b_entry) or ("mxfp8" in b_entry) + + +@requiresCUDA +@pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed.") +def test_export_te_states_linear_forward_backward_multiple_recipies_iteration(): + # Test Description: + # Alternate between two different recipes across iterations and ensure the reporter + # records both recipe configurations exactly once each. Verify forward/backward states + # and quantizers reflect both linear call sites per recipe, independent of iteration count. + + test_recipes = [recipe.DelayedScaling()] + supports_mxfp8, _ = te.fp8.check_mxfp8_support() + + if supports_mxfp8: + test_recipes += [recipe.MXFP8BlockScaling()] + + if len(test_recipes) < 2: + pytest.skip("platform does not support two different recipes") + + dtype = torch.bfloat16 + device = "cuda" + + # Inputs and model + x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) + + class Module(nn.Module): + def __init__(self): + super().__init__() + self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype)) + self.w2 = nn.Parameter(torch.randn(2048, 4096, device=device, dtype=dtype)) + + def forward(self, x): + o = torch.nn.functional.linear(x, self.w1) + added = o + x + return torch.nn.functional.linear(added, self.w2) + + model = Module() + iters = 4 + + def train_model(model): + for iter_n in range(iters): + te_recipe = test_recipes[iter_n % 2] + y = model(x, te_recipe) + y.backward(torch.ones_like(y)) + + jmodel = thunder.jit( + model, + executors=[transformer_engine_ex], + transforms=[TransformerEngineTransform(), ExportStatefulExecutorsTransform()], + ) + + def thunder_model(x, fp8_recipe): + with te.fp8_autocast(fp8_recipe=fp8_recipe): + return jmodel(x) + + train_model(thunder_model) + + # Resolve most recent forward and backward entries on demand + f_entry = jmodel.te_fp8_states() + assert isinstance(f_entry, dict) + assert ("delayed" in f_entry) or ("mxfp8" in f_entry) + b_entry = jmodel.te_fp8_states(mode="backward") + assert isinstance(b_entry, dict) + assert ("delayed" in b_entry) or ("mxfp8" in b_entry) + + +@requiresCUDA +@pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed.") +@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids) +@skip_on_sm120_and_sm121 +def test_export_te_states_with_torch_compile_and_thunder_backend(fp8_recipe): + # Test Description: + # Use torch.compile with Thunder as backend (ThunderCompiler) to run the model + # under FP8 autocast. Verify that TE runtime states are exported and available + # from the Thunder-compiled subgraphs via `te_reporter`, and that forward/backward + # summaries match expectations (iteration-invariant). + + if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported): + pytest.skip(msg_mxfp8) + + if is_sm120_orsm121 and fp8_recipe is None: + pytest.skip("On SM120/121, default recipe is Float8BlockScaling which is not supported") + + dtype = torch.bfloat16 + device = "cuda" + + x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) + + class Module(nn.Module): + def __init__(self): + super().__init__() + self.attention = nn.MultiheadAttention(4096, 64, device=device, dtype=dtype, batch_first=True) + self.norm1 = nn.LayerNorm(4096, device=device, dtype=dtype) + self.norm2 = nn.LayerNorm(4096, device=device, dtype=dtype) + self.mlp = nn.Sequential( + nn.Linear(4096, 16384, device=device, dtype=dtype), + nn.GELU(), + nn.Linear(16384, 4096, device=device, dtype=dtype), + ) + + def forward(self, x): + attn_out, _ = self.attention(x, x, x) + x = self.norm1(x + attn_out) + mlp_out = self.mlp(x) + x = self.norm2(x + mlp_out) + return x + + model = Module() + + # Compile with torch.compile using Thunder as backend + backend = ThunderCompiler( + executors=[transformer_engine_ex], + transforms=[TransformerEngineTransform(), ExportStatefulExecutorsTransform()], + ) + compiled_model = torch.compile(model, backend=backend) + + # Run one forward/backward under FP8 autocast + iters = 10 + + def train_model(model): + for _ in range(iters): + with te.fp8_autocast(fp8_recipe=fp8_recipe): + y = model(x) + y.backward(torch.ones_like(y)) + + train_model(compiled_model) + + # Collect TE fp8 states from Thunder-compiled subgraphs (resolve on-demand) + resolved = 0 + for sinfo in backend.subgraph_infos: + if sinfo.thunder_compiled_fns: + for fn in sinfo.thunder_compiled_fns: + if hasattr(fn, "te_fp8_states") and fn.te_fp8_states.refs["forward"] is not None: + entry = fn.te_fp8_states() + if isinstance(entry, dict) and ("delayed" in entry or "mxfp8" in entry): + resolved += 1 + assert resolved >= 1