Skip to content

Commit 157ac80

Browse files
Initial TE NVFP4 recipe support (#2523)
1 parent 5ece7ef commit 157ac80

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

thunder/executors/transformer_engineex_impl.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import time
22
from typing import TYPE_CHECKING
3+
import warnings
34

45
import torch.distributed as torch_dist
56

@@ -27,6 +28,7 @@
2728
from thunder.core.proxies import TensorProxy
2829

2930
import transformer_engine.pytorch as te
31+
import transformer_engine.common.recipe as te_recipe
3032
from transformer_engine.pytorch.constants import MXFP8_BLOCK_SCALING_SIZE
3133
from transformer_engine.pytorch.fp8 import (
3234
_amax_and_scale_update,
@@ -271,14 +273,35 @@ def _view_input_as_2d(x):
271273

272274
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
273275

276+
supported_recipes = (te_recipe.DelayedScaling, te_recipe.MXFP8BlockScaling)
277+
if hasattr(te_recipe, "NVFP4BlockScaling"):
278+
supported_recipes = (*supported_recipes, te_recipe.NVFP4BlockScaling)
279+
280+
if not isinstance(fp8_recipe, supported_recipes):
281+
warnings.warn(f"{type(fp8_recipe)} is not supported by TE executor, TE wont be used.")
282+
return False
283+
274284
def check_valid_fp8_shapes(a):
275-
# DelayedScaling and MXFP8BlockScaling have different shape requirements.
285+
# Each recipe type has different shape requirements.
276286
if fp8_recipe.delayed():
277287
return check_dim_for_fp8_exec(a)
278288

279-
assert fp8_recipe.mxfp8()
280289
shape = a.shape
281-
return shape[0] % MXFP8_BLOCK_SCALING_SIZE == 0 and shape[1] % MXFP8_BLOCK_SCALING_SIZE == 0
290+
291+
if fp8_recipe.mxfp8():
292+
return shape[0] % MXFP8_BLOCK_SCALING_SIZE == 0 and shape[1] % MXFP8_BLOCK_SCALING_SIZE == 0
293+
294+
if hasattr(fp8_recipe, "nvfp4") and fp8_recipe.nvfp4():
295+
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
296+
297+
# Check inherited from TE https://github.com/ksivaman/TransformerEngine-1/blob/1af7dd88aae5afb45e82148089038e1d1de9675d/transformer_engine/pytorch/tensor/nvfp4_tensor.py#L176-L184
298+
return (
299+
len(shape) >= 2
300+
and shape[0] % NVFP4_BLOCK_SCALING_SIZE == 0
301+
and shape[1] % NVFP4_BLOCK_SCALING_SIZE == 0
302+
)
303+
304+
return False
282305

283306
# Inputs must be on CUDA and
284307
# input sizes must satisfy size constraints based on the recipe.

0 commit comments

Comments
 (0)