Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
unwrap_return_value,
wrap_return_value_together_with_arguments,
)
from thunder.dev_utils.export_stateful_ex_transform import ExportStatefulExecutorsTransform
from thunder.core.update_aliases import insert_alias_updates
from thunder.executors.torch_autograd import connect_to_autograd
import thunder.extend as extend
Expand Down Expand Up @@ -845,6 +846,28 @@ def wrapped(*args, **kwargs):
# For more context see `NOTE: Split autograd.Function`
disable_split_autograd: bool = compile_options.get("thunderfx_disable_split_autograd", False)

def wrapped_compiled_fn(fn: Callable, mode: str, *args):
"""
Wraps the compiled function to run the export stateful executors transform after the function is executed.
Stateful executors materialize information about the state at runtime.
"""
out = fn(*args)
cs = weakref_cs()
cd = weakref_cd()
assert cd is not None, "cd has been cleared."
assert cs is not None, "cs has been cleared."
if cs is None or cd is None:
return out

trace = cs.last_traces[-1] if mode == "forward" else cs.last_backward_traces[-1]
if not trace:
return out
for transform in cd.transforms or []:
if isinstance(transform, ExportStatefulExecutorsTransform):
transform.transform_trace_post_execution(trace)
break
return out

def maybe_connect_to_autograd(cache_entry, result):
if cache_entry.backward_fn:
# If the backward function is available, we need to connect the
Expand All @@ -854,8 +877,11 @@ def maybe_connect_to_autograd(cache_entry, result):

is_differentiable_outputs = compile_options.get("is_differentiable_outputs", None)

def wrapped_backward_fn(saved_and_other, args):
return wrapped_compiled_fn(cache_entry.backward_fn, "backward", saved_and_other, args)

connect_to_autograd(
backward_fn=cache_entry.backward_fn,
backward_fn=cache_entry.backward_fn if not has_to_export() else wrapped_backward_fn,
flat_args=data_for_autograd["flat_args"],
flat_output=data_for_autograd["flat_output"],
saved_tensors=saved_tensors,
Expand All @@ -872,6 +898,11 @@ def call_epilogue(cache_entry, comp_result, pro_to_epi):
result = cache_entry.epilogue_fn(*pro_to_epi, *comp_result)
return result

def has_to_export():
cd = weakref_cd()
assert cd is not None, "cd has been cleared."
return any(isinstance(t, ExportStatefulExecutorsTransform) for t in cd.transforms or [])

@wraps(fn)
@update_call_statistics
def fn_(*args, **kwargs) -> Any:
Expand All @@ -886,7 +917,12 @@ def fn_(*args, **kwargs) -> Any:

cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)

result = cache_entry.computation_fn(*inps)
def wrapped_computation_fn(*args):
return wrapped_compiled_fn(cache_entry.computation_fn, "forward", *args)

computation_fn = cache_entry.computation_fn if not has_to_export() else wrapped_computation_fn

result = computation_fn(*inps)
result = maybe_connect_to_autograd(cache_entry, result)
result = call_epilogue(cache_entry, result, pro_to_epi)

Expand Down
64 changes: 64 additions & 0 deletions thunder/dev_utils/export_stateful_ex_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import weakref
from collections.abc import Callable

from thunder.core.transform_common import (
Transform,
)
from thunder.core.trace import TraceCtx as Trace


class ExportStatefulExecutorsTransform(Transform):
"""Export runtime state from stateful executors after a trace executes.

- Singleton transform with a registry of export callbacks
- Register via `register_export_callback(name, callback)`
- Callbacks receive `(computation_trace, thunder_module)` and may attach
serialized state to the module (e.g., `module.te_fp8_stats`)
- Safe: callback errors are swallowed; export never blocks execution

Example (TransformerEngine): a callback collects FP8 amax/scale and
quantizer metadata from `python_ctx` and records them under
`module.te_fp8_stats = {"forward": [...], "backward": [...]}`.

Usage:
1) Register once at import/init time:
ExportStatefulExecutorsTransform.register_export_callback("my_exec", my_cb)
2) Enable at compile time:
thunder.jit(model, executors=[...], transforms=[..., ExportStatefulExecutorsTransform()])
3) Read exported fields from the compiled module in tests/tools.
"""

_instance = None
_callbacks: dict[str, Callable] = {}

def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __init__(self):
self.tm_ref = None

@classmethod
def register_export_callback(cls, name: str, callback: Callable) -> None:
cls._callbacks[name] = callback

def transform_module(self, model) -> None:
# Cache a weakref to the ThunderModule for later runtime export
self.tm_ref = weakref.ref(model)

def transform_trace_post_execution(self, computation_trace: Trace, **kwargs):
# Resolve ThunderModule from weakref; if unavailable, skip
tm = self.tm_ref() if self.tm_ref is not None else None
if tm is None:
return computation_trace

# Invoke all registered export callbacks.
for _, cb in self._callbacks.items():
try:
cb(computation_trace, tm)
except Exception:
# Swallow errors from individual exporters to avoid breaking execution.
pass

return computation_trace
127 changes: 127 additions & 0 deletions thunder/executors/transformer_engineex_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import time
from typing import TYPE_CHECKING
import warnings
from collections import defaultdict

import torch
import torch.distributed as torch_dist

from thunder.core.prims import linear as linear_prim
Expand Down Expand Up @@ -40,6 +42,9 @@
from transformer_engine.pytorch.ops import BasicLinear
from transformer_engine.pytorch.tensor import Quantizer
from transformer_engine.pytorch.utils import check_dim_for_fp8_exec
from thunder.dev_utils.export_stateful_ex_transform import (
ExportStatefulExecutorsTransform as _ExportSETransform,
)


transformer_engine_ex = StatefulExecutor("transformer_engine")
Expand Down Expand Up @@ -374,6 +379,121 @@ def reset(self):
self.redundant_map = {}
self.new_saved_for_backward = None

@staticmethod
def export_state(computation_trace, tm) -> None:
"""
Extracts and exports the FP8 amax/scale state information from TransformerEngine (TE) holders
present in the Python context of a computation trace.

This method is intended to be called after a TE-enabled computation has executed, in order to
serialize and record the relevant FP8 state (such as amax and scale tensors) and quantizer
information for later inspection, debugging, or export.

Args:
computation_trace: The Thunder computation trace object containing the Python context
with TE state and quantizer holders.
tm: The ThunderModule object.

Returns:
None.
"""
# Extract FP8 amax/scale information from TE holders available in python context
python_ctx = computation_trace.python_ctx()

# Helper: serialize small tensors; skip oversized payloads
def _to_list_limited(t, max_numel: int = 8192):
if not isinstance(t, torch.Tensor):
return None
try:
n = min(t.numel(), max_numel)
if t.numel() > max_numel:
warnings.warn(
f"TE Stateful Executor: Exporting only first {max_numel} elements of tensor with {t.numel()} elements",
UserWarning,
)
flat = t.detach().float().cpu().view(-1)[:n].tolist()
return flat
except Exception:
return None

# Infer context mode from available TE functional symbols
te_mode = None
if "te_functional_linear_fwd" in python_ctx:
te_mode = "forward"
elif "te_functional_linear_bwd" in python_ctx:
te_mode = "backward"

delayed_entries: list[dict] = []
block_entries: list[dict] = []

# Gather state and quantizer holders from context
state_holders = [v for k, v in python_ctx.items() if isinstance(k, str) and k.startswith("get_te_fp8_state")]
quantizer_holders = [
v for k, v in python_ctx.items() if isinstance(k, str) and k.startswith("get_te_fp8_quantizers")
]

# Map RecipeState -> quantizers (if materialized)
state_to_quantizers: dict[int, list] = {}
for qh in quantizer_holders:
prs = getattr(qh, "parent_recipe_state", None)
qs = getattr(qh, "quantizers", None)
if prs is not None and qs:
state_to_quantizers.setdefault(id(prs), []).extend(qs)

for sh in state_holders:
recipe = getattr(sh, "parent_recipe", None)
state = getattr(sh, "state", None)
if recipe is None:
continue

# Determine recipe family
is_delayed = bool(recipe.delayed())
is_mxfp8_or_block = bool(recipe.mxfp8())

# DelayedScaling: values live on state.scale and state.amax_history
if is_delayed and state is not None:
scale_vals = _to_list_limited(getattr(state, "scale", None))
amax_hist = getattr(state, "amax_history", None)
amax_vals = None
if isinstance(amax_hist, torch.Tensor) and amax_hist.numel() > 0:
amax_slice = amax_hist[-1] if amax_hist.dim() >= 1 else amax_hist
amax_vals = _to_list_limited(amax_slice)
delayed_entries.append(
{
"scale_shape": getattr(getattr(state, "scale", None), "shape", None),
"scale": scale_vals,
"amax_shape": getattr(getattr(state, "amax_history", None), "shape", None),
"amax": amax_vals,
}
)

# MXFP8/Float8 block scaling: values live on quantizers
elif is_mxfp8_or_block and state is not None:
qs = state_to_quantizers.get(id(state), [])
for q in qs:
rowwise_usage = getattr(q, "rowwise_usage", None)
columnwise_usage = getattr(q, "columnwise_usage", None)
block_entries.append(
{
"cls": q.__class__.__name__,
"rowwise_usage": rowwise_usage,
"columnwise_usage": columnwise_usage,
"dtype": str(getattr(q, "dtype", None)),
}
)

entry = defaultdict(list)
if delayed_entries:
entry["delayed"] = delayed_entries
if block_entries:
entry["mxfp8_or_block"] = block_entries

collected = getattr(tm, "te_fp8_stats", None)
if collected is None:
tm.te_fp8_stats = {"forward": [], "backward": []}
if entry["delayed"] or entry["mxfp8_or_block"]:
tm.te_fp8_stats[te_mode].append(entry)

def transform_trace_post_optimization(self, computation_trace, **kwargs):
"""
Finds and replaces TE executor recipe calls and replaces them with one.
Expand Down Expand Up @@ -513,3 +633,10 @@ def _te_activation_checkpointing_transform(joint_trace: TraceCtx) -> TraceCtx:
new_trace.bound_symbols = [bsym.from_bsym_swap_proxies(swapmap) for bsym in reversed(reversed_bsyms)]

return new_trace


# Register TE export callback with the singleton export transform
try:
_ExportSETransform.register_export_callback("transformer_engine", TransformerEngineTransform.export_state)
except Exception:
pass
Loading
Loading