Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions thunder/dev_utils/export_stateful_ex_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import weakref
from collections.abc import Callable

from thunder.core.transform_common import (
Transform,
)
from thunder.core.trace import TraceCtx as Trace
from thunder.core.module import ThunderModule


class ExportStatefulExecutorsStats:
def __init__(self, tm: ThunderModule, resolver_fn: Callable):
"""Lightweight accessor attached to a `ThunderModule`.

Args:
tm: The `ThunderModule` instance this accessor belongs to.
resolver_fn: A callable that knows how to resolve the recorded
references on `tm` and return real values.
"""
self.tm = tm
self.resolver_fn = resolver_fn


class ExportStatefulExecutorsTransform(Transform):
"""Register references and resolve runtime state lazily.

What this transform does:
- Singleton registry to plug per-executor exporters
- At module transform time, installs a lightweight accessor on the module
(e.g., `module.te_fp8_states`) that can resolve values on demand
- At post-optimization time, calls registered reference callbacks to record
only where values will materialize (holders + attribute paths). No data
are copied or materialized in this step
- When code calls the accessor (e.g., `module.te_fp8_states()`), the resolve
callback reads the recorded references and returns the latest values


API overview:
- register_ref_callback(name, register_cb, resolve_cb, instance_cls):
name: attribute name to attach on the module
register_cb(trace, module): store references from the trace/python_ctx
resolve_cb(module): materialize and return values using the stored refs
instance_cls: a small class constructed as instance_cls(module, resolve_cb)
and attached as `setattr(module, name, instance)`; it typically stores
containers for references and implements __call__(...) to resolve

Usage:
1) Register once at import/init time. For example, for TransformerEngine:
ExportStatefulExecutorsTransform.register_ref_callback(
"te_fp8_states", register_cb, resolve_cb, StatsClass
)
2) Enable at compile time:
thunder.jit(model, executors=[...], transforms=[..., ExportStatefulExecutorsTransform()])
3) After each run, call `module.te_fp8_states()` to resolve and return the latest values.

Notes:
- Supports multiple ThunderModule instances (e.g., subgraphs)
- Callback errors are swallowed to avoid interfering with execution
"""

_register_callbacks: dict[str, Callable] = {}
_callback_attributes: list[tuple[str, type[ExportStatefulExecutorsStats], Callable]] = []

_instance = None

def __new__(cls, *args, **kwargs):
"""Ensure singleton instance across repeated transform construction."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __init__(self):
"""Initialize internal weakrefs registry.

ThunderCompiler and other compilation flows may create multiple
`ThunderModule` instances for subgraphs; we keep weak references
to update all of them during post-optimization registration.
"""
self.tm_refs = []

@classmethod
def register_ref_callback(
cls, name: str, callback: Callable, resolve_cb: Callable, instance: type[ExportStatefulExecutorsStats]
) -> None:
"""Register per-executor reference and resolver callbacks.

Installs a module attribute named `name` by constructing `instance` with
the resolver function. The `callback` will be invoked during
post-optimization to record reference locations on the module.

Args:
name: Module attribute to attach (e.g., "te_fp8_states").
callback: Function `(trace, module) -> None` that records refs.
resolve_cb: Function `(module) -> Any` that resolves values on demand.
instance: A class (must be a subclass of ExportStatefulExecutorsStats) constructed as `instance(module, resolve_cb)`.
"""
if not issubclass(instance, ExportStatefulExecutorsStats):
raise TypeError(f"Provided instance {instance} must be a subclass of ExportStatefulExecutorsStats")
cls._register_callbacks[name] = callback
cls._callback_attributes.append((name, instance, resolve_cb))

def transform_module(self, model) -> None:
assert model is not None
# Cache a weakref to the ThunderModule for later runtime export
self.tm_refs.append(weakref.ref(model))
# Initialize attributes on model
for name, instance, resolve_cb in self._callback_attributes:
setattr(model, name, instance(model, resolve_cb))

def transform_trace_post_optimization(self, computation_trace: Trace, **kwargs):
for tm_ref in self.tm_refs:
# Resolve ThunderModule from weakref; if unavailable, skip
tm = tm_ref() if tm_ref is not None else None
if tm is None:
continue

# Invoke all registered callbacks to register reference locations
for _, cb in self._register_callbacks.items():
try:
cb(computation_trace, tm)
except Exception:
pass
return computation_trace
173 changes: 173 additions & 0 deletions thunder/executors/transformer_engineex_impl.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import time
from typing import TYPE_CHECKING
import warnings
from collections import defaultdict
from collections.abc import Callable

import torch
import torch.distributed as torch_dist

from thunder.core.module import ThunderModule
from thunder.core.prims import linear as linear_prim
from thunder.core.prims import get_grad, put_grad
from thunder.core.proxies import AnyProxy, TensorProxy
Expand Down Expand Up @@ -40,6 +44,10 @@
from transformer_engine.pytorch.ops import BasicLinear
from transformer_engine.pytorch.tensor import Quantizer
from transformer_engine.pytorch.utils import check_dim_for_fp8_exec
from thunder.dev_utils.export_stateful_ex_transform import (
ExportStatefulExecutorsTransform as _ExportSETransform,
ExportStatefulExecutorsStats,
)


transformer_engine_ex = StatefulExecutor("transformer_engine")
Expand Down Expand Up @@ -374,6 +382,159 @@ def reset(self):
self.redundant_map = {}
self.new_saved_for_backward = None

class TEFP8Stats(ExportStatefulExecutorsStats):
def __init__(self, tm: ThunderModule, resolver_fn: Callable):
"""Accessor attached on the module to resolve TE FP8 states on demand.

Args:
tm: ThunderModule to which this accessor is bound.
resolver_fn: Callable invoked as `resolver_fn(mode, tm)` to produce
the latest snapshot of FP8 state based on registered refs.
"""
super().__init__(tm, resolver_fn)
self.refs = {"forward": None, "backward": None}

def __call__(self, mode: str = "forward") -> dict:
"""Resolve and return the latest FP8 state for the given mode.

Args:
mode: "forward" or "backward". Defaults to "forward".
Returns:
A dictionary snapshot of resolved values (e.g., delayed or mxfp8 entries),
or an empty dict on invalid mode or if nothing is recorded.
"""
if mode not in ["forward", "backward"]:
warnings.warn(f"Received an invalid inspection mode: {mode}. Please use 'forward' or 'backward'.")
return {}
return self.resolver_fn(mode, self.tm)

@staticmethod
def register_refs(computation_trace, tm) -> None:
"""Record where FP8 values will materialize for later lazy resolution.

This inspects the trace's python context, finds TE state and quantizer
holders, and stores only references (holder objects and attribute paths)
into the module accessor. No tensors or runtime data are copied here.
The actual values are read by `resolve_values` after execution.
"""
python_ctx = computation_trace.python_ctx()

# Collect holders and where to read values from later
refs = defaultdict(list)

# Collect mode from trace tags
mode = "forward" if TraceTag.AUGMENTED_FORWARD in computation_trace.tags else "backward"

# States: register all state holders; decide recipe type at resolve time
state_holders = [v for k, v in python_ctx.items() if isinstance(k, str) and k.startswith("get_te_fp8_state")]
for sh in state_holders:
# Always store attrs we may need; recipe classification happens later
refs["state_holder"].append(
{
"holder": sh,
"scale_attr": "state.scale",
"amax_attr": "state.amax_history",
}
)

# Quantizers (MXFP8/block): resolve via TEQuantizerState linked to RecipeState
quantizer_holders = [
v for k, v in python_ctx.items() if isinstance(k, str) and k.startswith("get_te_fp8_quantizers")
]
for qh in quantizer_holders:
refs["quantizer_holder"].append(
{"holder": qh, "quant_attr": "quantizers", "parent_state_attr": "parent_recipe_state"}
)

if len(refs) > 0:
tm.te_fp8_states.refs[mode] = refs

@staticmethod
def resolve_values(mode: str, tm: ThunderModule) -> dict:
"""Load and serialize FP8 values using previously-registered references.

Args:
mode: "forward" or "backward" indicating which refs to resolve.
tm: ThunderModule whose accessor holds the recorded references.
Returns:
A dictionary with resolved entries (e.g., {"delayed": [...]} or
{"mxfp8": [...]}); returns empty dict if nothing is recorded.
"""

def _get_attr(obj, attr_path: str):
cur = obj
for part in attr_path.split("."):
cur = getattr(cur, part)
return cur

def _tensor_head(t, max_numel: int = 8192):
if not isinstance(t, torch.Tensor):
return None
n = min(t.numel(), max_numel)
return t.detach().float().cpu().view(-1)[:n].tolist()

# Pull last registered refs for this mode
refs = tm.te_fp8_states.refs[mode]
if refs is None:
return {}

out = defaultdict(list)

# Classify states now that recipes and states have materialized
# MXFP8/block scaling: will be collected via quantizers section below
for ref in refs.get("state_holder", []):
sh = ref["holder"]
recipe = getattr(sh, "parent_recipe", None)
state = getattr(sh, "state", None)
if recipe is None or state is None:
continue
# Delayed scaling: extract from state tensors
if getattr(recipe, "delayed", lambda: False)():
scale = _get_attr(sh, ref["scale_attr"]) # state.scale
amax_hist = _get_attr(sh, ref["amax_attr"]) # state.amax_history
scale_vals = _tensor_head(scale)
amax_vals = _tensor_head(
amax_hist[-1] if isinstance(amax_hist, torch.Tensor) and amax_hist.numel() > 0 else amax_hist
)
out["delayed"].append(
{
"scale_shape": getattr(scale, "shape", None),
"scale": scale_vals,
"amax_shape": getattr(amax_hist, "shape", None),
"amax": amax_vals,
}
)

# MXFP8 via quantizers
# First, build mapping from recipe state id to quantizers
state_to_qs = {}
for ref in refs.get("quantizer_holder", []):
qh = ref["holder"]
prs = getattr(qh, "parent_recipe_state", None)
qs = getattr(qh, "quantizers", None)
if prs is not None and qs:
state_to_qs.setdefault(id(prs), []).extend(qs)

# For MXFP8/block scaling, gather quantizers linked to each materialized state
for ref in refs.get("state_holder", []):
sh = ref["holder"]
recipe = getattr(sh, "parent_recipe", None)
state = getattr(sh, "state", None)
if recipe is None or state is None:
continue
if getattr(recipe, "mxfp8", lambda: False)():
for q in state_to_qs.get(id(state), []):
entry = {
"cls": q.__class__.__name__,
"rowwise_usage": getattr(q, "rowwise_usage", None),
"columnwise_usage": getattr(q, "columnwise_usage", None),
"dtype": str(getattr(q, "dtype", None)),
}
if entry not in out["mxfp8"]:
out["mxfp8"].append(entry)

return out

def transform_trace_post_optimization(self, computation_trace, **kwargs):
"""
Finds and replaces TE executor recipe calls and replaces them with one.
Expand Down Expand Up @@ -513,3 +674,15 @@ def _te_activation_checkpointing_transform(joint_trace: TraceCtx) -> TraceCtx:
new_trace.bound_symbols = [bsym.from_bsym_swap_proxies(swapmap) for bsym in reversed(reversed_bsyms)]

return new_trace


# Register TE reference and resolve callbacks with the singleton export transform
try:
_ExportSETransform.register_ref_callback(
"te_fp8_states",
TransformerEngineTransform.register_refs,
TransformerEngineTransform.resolve_values,
TransformerEngineTransform.TEFP8Stats,
)
except Exception:
pass
Loading
Loading