Skip to content

Conversation

mattteochen
Copy link
Collaborator

@mattteochen mattteochen commented Oct 2, 2025

What does this PR do?

Closes #2438.

Summary

  • Added ExportStatefulExecutorsTransform (singleton) with a registry of export callbacks.
  • Executors can register a callback to export runtime state post-execution.
  • Currently integrates with transformer_engine_ex.

Usage

The export is controlled by passing the transform to the transforms list:

thunder.jit(model, executors=[...], transforms=[TransformerEngineTransform(), ExportStatefulExecutorsTransform()])

We might or might not go towards a compiler flag to make this more intuitive for the user (hence handling this transform automatically).

Whole example:

from thunder.dev_utils.export_stateful_ex_transform import ExportStatefulExecutorsTransform
import torch
import torch.nn as nn
import thunder
from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform
from transformer_engine.common import recipe
import transformer_engine.pytorch as te
from pprint import pprint

torch.manual_seed(42)

# Device and data type
dtype = torch.bfloat16
device = "cuda"

# Inputs (3D input)
x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True)

# TE recipe
fp8_recipe = recipe.DelayedScaling()
# fp8_recipe = recipe.MXFP8BlockScaling()

# Dummy model
class Module(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype))
        self.w2 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype))
        self.w3 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype))

    def forward(self, x):
        return torch.nn.functional.linear(torch.nn.functional.linear(torch.nn.functional.linear(x, self.w2), self.w3), self.w1)

model = Module()
jmodel = thunder.jit(model, executors=[transformer_engine_ex], transforms=[TransformerEngineTransform(), ExportStatefulExecutorsTransform()])

# Enable autocasting for the forward pass
for _ in range(2):
    with te.fp8_autocast(fp8_recipe=fp8_recipe):
        y = jmodel(torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True))
        rep = jmodel.te_fp8_states()
        pprint(rep) # Visualize states information after operation
    grad_output = torch.randn_like(y)
    y.backward(grad_output)
    rep = jmodel.te_fp8_states(mode="backward")
    pprint(rep) # Visualize states information after operation

Why it matters

  • Not TE-specific: provides a reusable, uniform path for any stateful executor to export runtime state for debugging, validation, and reporting.

@mattteochen mattteochen marked this pull request as draft October 2, 2025 15:48
@mattteochen mattteochen changed the title 2438 Export stateful executor states toThunderModule Oct 2, 2025
@mattteochen mattteochen marked this pull request as ready for review October 2, 2025 18:34
Comment on lines 16 to 27
TE_AVAILABLE = False
try:
import transformer_engine as _te_mod # noqa: F401
import transformer_engine.pytorch as te # type: ignore
from transformer_engine.common import recipe # type: ignore
from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform

TE_AVAILABLE = True
except Exception:
te = None # type: ignore
recipe = None # type: ignore
TE_AVAILABLE = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At glance I'm speculating that pytest.importerskip("transformer_engine") would be sufficing (ref)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but I didn't want to skip the entire test file, as future tests might not depend on the te executor

@mattteochen mattteochen marked this pull request as draft October 3, 2025 12:13
@mattteochen mattteochen marked this pull request as ready for review October 4, 2025 08:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Connect TEv2 states to ThunderModule
3 participants