Skip to content

Commit be82014

Browse files
TEv2 as default TE executor (#2510)
1 parent b23aa27 commit be82014

19 files changed

+1355
-1349
lines changed

thunder/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -574,8 +574,8 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com
574574
if requires_grad:
575575
from thunder.transforms.autodiff import split_into_forward_and_backward
576576

577-
if "transformer_engine_v2" in {ex.name for ex in cd.executors_list}:
578-
from thunder.executors.transformer_engine_v2ex import _te_activation_checkpointing_transform
577+
if "transformer_engine" in {ex.name for ex in cd.executors_list}:
578+
from thunder.executors.transformer_engineex import _te_activation_checkpointing_transform
579579

580580
computation_trc = _te_activation_checkpointing_transform(computation_trc)
581581

thunder/benchmarks/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from thunder.executors.apexex import apex_ex, apex_entropy_available
2828
from thunder.executors.cudnn_layernormex import cudnn_layernorm_ex
2929
from thunder.executors.cudnnex import cudnn_ex, cudnn_available
30-
from thunder.executors.transformer_engineex import transformer_engine_ex, TE_AVAILABLE
30+
from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform
3131
from thunder.executors.sdpaex import sdpa_ex
3232
from thunder.executors.torch_compile import torch_compile_cat_ex, torch_compile_ex
3333
from thunder.transforms.cudagraph import CUDAGraphTransform
@@ -764,10 +764,16 @@ def thunder_cudnn_layer_norm_nvfuser_executor(fn: Callable) -> Callable:
764764

765765
thunder_transformerengine_executor: None | Callable = None
766766

767-
if TE_AVAILABLE:
767+
if transformer_engine_ex is not None:
768768

769769
def thunder_transformerengine_executor(fn: Callable):
770-
return thunder.jit(fn, executors=(transformer_engine_ex,) + thunder.get_default_executors())
770+
return thunder.jit(
771+
fn,
772+
executors=(transformer_engine_ex,) + thunder.get_default_executors(),
773+
transforms=[
774+
TransformerEngineTransform(),
775+
],
776+
)
771777

772778

773779
def thunder_sdpa_executor(fn: Callable) -> Callable:

thunder/benchmarks/benchmark_litgpt.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -625,19 +625,19 @@ def setup_compile(self, model):
625625

626626
executors.insert(0, torch_compile_ex)
627627

628-
if "transformerengine_v2" in self.compile:
629-
from thunder.executors.transformer_engine_v2ex import (
630-
transformer_engine_v2_ex,
631-
TransformerEngineTransformV2,
632-
)
628+
if "transformerengine_v1" in self.compile:
629+
from thunder.executors.transformer_engine_v1ex import transformer_engine_v1_ex
633630

634-
executors.insert(0, transformer_engine_v2_ex)
635-
transforms.insert(0, TransformerEngineTransformV2())
631+
executors.insert(0, transformer_engine_v1_ex)
636632

637633
elif "transformerengine" in self.compile:
638-
from thunder.executors.transformer_engineex import transformer_engine_ex
634+
from thunder.executors.transformer_engineex import (
635+
transformer_engine_ex,
636+
TransformerEngineTransform,
637+
)
639638

640639
executors.insert(0, transformer_engine_ex)
640+
transforms.insert(0, TransformerEngineTransform())
641641

642642
if "dynamo" in self.compile:
643643
if self.distributed_mode == "fsdp2":

thunder/core/trace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ def keyfn(class_or_module: type | ModuleType) -> str:
413413
# NOTE: For TE v1.6 onwards, `fp8_autocast` checks if `torch.is_grad_enabled` for updating
414414
# the FP8 scales/inverses. So this decorator should be applied before `torch.no_grad` (so that
415415
# it is in grad enabled part).
416-
from thunder.executors.transformer_engineex import _is_te_linear_enabled, _get_te_wrapper_string
416+
from thunder.executors.transformer_engine_v1ex import _is_te_linear_enabled, _get_te_wrapper_string
417417

418418
if TraceTag.AUGMENTED_FORWARD and _is_te_linear_enabled(import_ctx, object_ctx):
419419
program.append(_get_te_wrapper_string())

0 commit comments

Comments
 (0)