|
1 | 1 | import time
|
2 | 2 | from typing import TYPE_CHECKING
|
| 3 | +import warnings |
3 | 4 |
|
4 | 5 | import torch.distributed as torch_dist
|
5 | 6 |
|
|
27 | 28 | from thunder.core.proxies import TensorProxy
|
28 | 29 |
|
29 | 30 | import transformer_engine.pytorch as te
|
| 31 | +import transformer_engine.common.recipe as te_recipe |
30 | 32 | from transformer_engine.pytorch.constants import MXFP8_BLOCK_SCALING_SIZE
|
31 | 33 | from transformer_engine.pytorch.fp8 import (
|
32 | 34 | _amax_and_scale_update,
|
@@ -271,14 +273,35 @@ def _view_input_as_2d(x):
|
271 | 273 |
|
272 | 274 | fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
|
273 | 275 |
|
| 276 | + supported_recipes = (te_recipe.DelayedScaling, te_recipe.MXFP8BlockScaling) |
| 277 | + if hasattr(te_recipe, "NVFP4BlockScaling"): |
| 278 | + supported_recipes = (*supported_recipes, te_recipe.NVFP4BlockScaling) |
| 279 | + |
| 280 | + if not isinstance(fp8_recipe, supported_recipes): |
| 281 | + warnings.warn(f"{type(fp8_recipe)} is not supported by TE executor, TE wont be used.") |
| 282 | + return False |
| 283 | + |
274 | 284 | def check_valid_fp8_shapes(a):
|
275 |
| - # DelayedScaling and MXFP8BlockScaling have different shape requirements. |
| 285 | + # Each recipe type has different shape requirements. |
276 | 286 | if fp8_recipe.delayed():
|
277 | 287 | return check_dim_for_fp8_exec(a)
|
278 | 288 |
|
279 |
| - assert fp8_recipe.mxfp8() |
280 | 289 | shape = a.shape
|
281 |
| - return shape[0] % MXFP8_BLOCK_SCALING_SIZE == 0 and shape[1] % MXFP8_BLOCK_SCALING_SIZE == 0 |
| 290 | + |
| 291 | + if fp8_recipe.mxfp8(): |
| 292 | + return shape[0] % MXFP8_BLOCK_SCALING_SIZE == 0 and shape[1] % MXFP8_BLOCK_SCALING_SIZE == 0 |
| 293 | + |
| 294 | + if hasattr(fp8_recipe, "nvfp4") and fp8_recipe.nvfp4(): |
| 295 | + from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE |
| 296 | + |
| 297 | + # Check inherited from TE https://github.com/ksivaman/TransformerEngine-1/blob/1af7dd88aae5afb45e82148089038e1d1de9675d/transformer_engine/pytorch/tensor/nvfp4_tensor.py#L176-L184 |
| 298 | + return ( |
| 299 | + len(shape) >= 2 |
| 300 | + and shape[0] % NVFP4_BLOCK_SCALING_SIZE == 0 |
| 301 | + and shape[1] % NVFP4_BLOCK_SCALING_SIZE == 0 |
| 302 | + ) |
| 303 | + |
| 304 | + return False |
282 | 305 |
|
283 | 306 | # Inputs must be on CUDA and
|
284 | 307 | # input sizes must satisfy size constraints based on the recipe.
|
|
0 commit comments