diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 310f2c0e7..f40b28189 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -18,8 +18,8 @@ jobs: - name: 'Dependencies' run: | apt-get update - apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12 - pip install cmake==3.21.0 + apt-get install -y git python3.9 pip cudnn9-cuda-12 + pip install cmake==3.21.0 pybind11[global] ninja - name: 'Checkout' uses: actions/checkout@v3 with: @@ -42,8 +42,8 @@ jobs: - name: 'Dependencies' run: | apt-get update - apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12 - pip install cmake torch pydantic importlib-metadata>=1.0 packaging pybind11 + apt-get install -y git python3.9 pip cudnn9-cuda-12 + pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript - name: 'Checkout' uses: actions/checkout@v3 with: @@ -54,7 +54,6 @@ jobs: NVTE_FRAMEWORK: pytorch MAX_JOBS: 1 - name: 'Sanity check' - if: false # Sanity import test requires Flash Attention run: python3 tests/pytorch/test_sanity_import.py jax: name: 'JAX' @@ -63,6 +62,8 @@ jobs: image: ghcr.io/nvidia/jax:jax options: --user root steps: + - name: 'Dependencies' + run: pip install pybind11[global] - name: 'Checkout' uses: actions/checkout@v3 with: @@ -73,4 +74,24 @@ jobs: NVTE_FRAMEWORK: jax MAX_JOBS: 1 - name: 'Sanity check' - run: python tests/jax/test_sanity_import.py + run: python3 tests/jax/test_sanity_import.py + all: + name: 'All' + runs-on: ubuntu-latest + container: + image: ghcr.io/nvidia/jax:jax + options: --user root + steps: + - name: 'Dependencies' + run: pip install torch pybind11[global] einops onnxscript + - name: 'Checkout' + uses: actions/checkout@v3 + with: + submodules: recursive + - name: 'Build' + run: pip install --no-build-isolation . -v --no-deps + env: + NVTE_FRAMEWORK: all + MAX_JOBS: 1 + - name: 'Sanity check' + run: python3 tests/pytorch/test_sanity_import.py && python3 tests/jax/test_sanity_import.py diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 06077bbd6..66400ffd7 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -53,6 +53,7 @@ jobs: || github.actor == 'lhb8125' || github.actor == 'kunlunl' || github.actor == 'pstjohn' + || github.actor == 'mk-61' ) steps: - name: Check if comment is issued by authorized person diff --git a/.gitignore b/.gitignore index 874eed018..d3b18b358 100644 --- a/.gitignore +++ b/.gitignore @@ -49,8 +49,9 @@ downloads/ .pytest_cache/ compile_commands.json .nfs +tensor_dumps/ +artifacts/ **/profiler_outputs/ **/times.csv -tensor_dumps/ transformer_engine/build_info.txt transformer_engine/common/util/hip_nvml.* diff --git a/README.rst b/README.rst index b7afcbd2a..3b1a589db 100644 --- a/README.rst +++ b/README.rst @@ -449,7 +449,7 @@ Installation ============ System Requirements -^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^ * **Hardware:** Blackwell, Hopper, Grace Hopper/Blackwell, Ada, Ampere @@ -467,10 +467,10 @@ System Requirements * **Notes:** FP8 features require Compute Capability 8.9+ (Ada/Hopper/Blackwell) Installation Methods -^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^ Docker (Recommended) -^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^ The quickest way to get started with Transformer Engine is by using Docker images on `NVIDIA GPU Cloud (NGC) Catalog `_. @@ -495,7 +495,7 @@ Where 25.04 (corresponding to April 2025 release) is the container version. * NGC PyTorch 23.08+ containers include FlashAttention-2 pip Installation -^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^ **Prerequisites for pip installation:** @@ -519,13 +519,25 @@ Alternatively, install directly from the GitHub repository: .. code-block:: bash - pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable + pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable When installing from GitHub, you can explicitly specify frameworks using the environment variable: .. code-block:: bash - NVTE_FRAMEWORK=pytorch,jax pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable + NVTE_FRAMEWORK=pytorch,jax pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable + +conda Installation +^^^^^^^^^^^^^^^^^^ + +To install the latest stable version with conda from conda-forge: + +.. code-block:: bash + + # For PyTorch integration + conda install -c conda-forge transformer-engine-torch + + # JAX integration (coming soon) Source Installation ^^^^^^^^^^^^^^^^^^^ @@ -533,7 +545,7 @@ Source Installation `See the installation guide `_ Environment Variables -^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^ These environment variables can be set before installation to customize the build process: * **CUDA_PATH**: Path to CUDA installation @@ -544,7 +556,7 @@ These environment variables can be set before installation to customize the buil * **NVTE_BUILD_THREADS_PER_JOB**: Control threads per build job Compiling with FlashAttention -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Transformer Engine supports both FlashAttention-2 and FlashAttention-3 in PyTorch for improved performance. FlashAttention-3 was added in release v1.11 and is prioritized over FlashAttention-2 when both are present in the environment. You can verify which FlashAttention version is being used by setting these environment variables: @@ -556,8 +568,9 @@ You can verify which FlashAttention version is being used by setting these envir It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug `_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue. .. troubleshooting-begin-marker-do-not-remove + Troubleshooting -^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^ **Common Issues and Solutions:** @@ -691,7 +704,7 @@ Papers Videos ====== -* `Stable and Scalable FP8 Deep Learning Training on Blackwell | GTC 2025 `_ +* `Stable and Scalable FP8 Deep Learning Training on Blackwell | GTC 2025 `__ * `Blackwell Numerics for AI | GTC 2025 `_ * `Building LLMs: Accelerating Pretraining of Foundational Models With FP8 Precision | GTC 2025 `_ * `From FP8 LLM Training to Inference: Language AI at Scale | GTC 2025 `_ diff --git a/benchmarks/linear/benchmark_grouped_linear.py b/benchmarks/linear/benchmark_grouped_linear.py new file mode 100644 index 000000000..0dbee212d --- /dev/null +++ b/benchmarks/linear/benchmark_grouped_linear.py @@ -0,0 +1,290 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import argparse +import torch +import torch.utils.benchmark as benchmark +import pandas as pd +import pathlib + +from transformer_engine.pytorch.module import GroupedLinear +from transformer_engine.common.recipe import Float8BlockScaling, MXFP8BlockScaling +from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager +from contextlib import nullcontext + +""" +# Profile BF16 recipe with Nsight Systems +nsys profile \ + --output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_bf16 \ + --force-overwrite true \ + --trace=cuda,nvtx,cudnn,cublas \ + python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe bf16 + +# Profile FP8 sub-channel recipe with Nsight Systems +nsys profile \ + --output=./benchmarks/linear/h100hbm_mkn_4096_4096_4096_numgemm_8_fp8_sub_channel \ + --force-overwrite true \ + --trace=cuda,nvtx,cudnn,cublas \ + python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe fp8_sub_channel + +# Profile MXFP8 recipe with Nsight Systems +nsys profile \ + --output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_mxfp8 \ + --force-overwrite true \ + --trace=cuda,nvtx,cudnn,cublas \ + python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe mxfp8 + +""" + +RECIPES = { + "bf16": None, + "fp8_sub_channel": Float8BlockScaling(), + "mxfp8": MXFP8BlockScaling(), +} + +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) + + +def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None): + assert mode in ["fwd_only", "fwd_bwd"] + fp8_context = ( + fp8_autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext() + ) + # print(f"fp8_context: {fp8_context} and is it nullcontext? {isinstance(fp8_context, nullcontext)}") + + if mode == "fwd_only": + with torch.no_grad(), fp8_context: + for i in range(run_num_steps): + y_q = layer.forward( + x, + m_splits, + is_first_microbatch=(i == 0), + ) + return y_q + else: + # reset gradients + layer.zero_grad() + x.grad = None + + with fp8_context: + for i in range(run_num_steps): + label = f"step_{i}" + torch.cuda.nvtx.range_push(label) + y_q = layer.forward( + x, + m_splits, + is_first_microbatch=(i == 0), + ) + y_q.backward(gradient) + torch.cuda.nvtx.range_pop() + + grads_q = [] + grads_q.append(x.grad) + # remaining derivatives are in respect to model parameters + for p in layer.parameters(): + if p.requires_grad: + grads_q.append(p.grad) + + return y_q, grads_q + + +def benchmark_linear( + x, + ws, + m_splits, + bias, + recipe_name, + mode, + num_gemms=4, +): + params_dtype = torch.bfloat16 + recipe = RECIPES[recipe_name] + + in_features = x.shape[1] + out_features = ws[0].shape[0] + gradient = torch.ones((x.shape[0], out_features), dtype=torch.bfloat16, device=x.device) + + layer = GroupedLinear( + num_gemms, + in_features, + out_features, + bias=bias is not None, + params_dtype=params_dtype, + ) + + layer = layer.to("cuda") + with torch.no_grad(): + for i in range(num_gemms): + weight_i = getattr(layer, f"weight{i}") + weight_i.copy_(ws[i]) + if bias is not None: + bias_i = getattr(layer, f"bias{i}") + bias_i.copy_(bias) + + num_microbatches = 32 + + label = f"{recipe_name}_{'grouped'}" + torch.cuda.nvtx.range_push(label) + timing = benchmark.Timer( + stmt=( + "run_linear_multiple_steps(layer, x, m_splits, mode, gradient, num_microbatches," + " recipe)" + ), + globals={ + "run_linear_multiple_steps": run_linear_multiple_steps, + "layer": layer, + "x": x, + "m_splits": m_splits, + "mode": mode, + "gradient": gradient, + "num_microbatches": num_microbatches, + "recipe": recipe, + }, + num_threads=1, + ).blocked_autorange(min_run_time=5) + print(f"{recipe_name}: {timing} \n") + timing_ms = timing.median * 1000 / num_microbatches + + return timing_ms + + +def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4): + data = [] + assert not use_bias, "Bias is not supported for GroupedLinear benchmark" + + print(f"========== Benchmarking {recipe_name} ==========") + for m, k, n in mkns: + device = "cuda" + x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True) + ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)] + assert m % num_gemms == 0 + m_splits = [m // num_gemms] * num_gemms + # Bias is not supported for GroupedLinear benchmark + bias = None + + # Run the benchmark + print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}") + + grouped_fwd_bwd_timing_ms = benchmark_linear( + x, + ws, + m_splits, + bias, + recipe_name, + mode="fwd_bwd", + num_gemms=num_gemms, + ) + + # Append the results + data.append( + [ + m, + k, + n, + recipe_name, + num_gemms, + grouped_fwd_bwd_timing_ms, + ] + ) + + df = pd.DataFrame( + data=data, + columns=[ + "m", + "k", + "n", + "recipe", + "num_gemms", + "grouped_fwd_bwd_time_ms", + ], + ) + + print(df, "\n") + return df + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--profile", action="store_true", help="Enable profiling mode") + parser.add_argument( + "--output_dir", + type=str, + default="benchmark_output/", + help="output path for report", + ) + # arguments for recipe, options are fp8_sub_channel, mxfp8, bf16, all + parser.add_argument( + "--recipe", + type=str, + default="bf16", + help="Recipe to use, options are fp8_sub_channel, mxfp8, bf16, or all", + ) + args = parser.parse_args() + + use_bias = False + # Set the MKN values to benchmark + mkns = [] + for m in [8192]: + # for m in [4096, 8192, 16384]: + # for n in [1024, 2048, 4096, 8192, 16384]: + for n in [8192]: + for k in [4096]: + mkns.append((m, k, n)) + + # default recipes to run if not specified + recipe_list = ["bf16"] + + if args.recipe == "all": + recipe_list = ["bf16", "fp8_sub_channel", "mxfp8"] + else: + recipe_list = [args.recipe] + + num_gemms_list = [8] + + if args.profile: + mkns = [(4096, 4096, 4096)] + # in profile mode, only run one recipe specified in args.recipe + assert args.recipe != "all", ( + "In profile mode, only one recipe can be specified, please specify the recipe as" + " fp8_sub_channel, mxfp8, or bf16" + ) + recipe_list = [args.recipe] + num_gemms_list = [8] + torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() + + # Initialize a dataframe to store the results + df_linears = pd.DataFrame() + + # Run the fp8 benchmarks + for num_gemms in num_gemms_list: + print(f"========== Benchmarking with num_gemms={num_gemms} ==========") + for recipe_name in recipe_list: + assert recipe_name in [ + "bf16", + "fp8_sub_channel", + "mxfp8", + ], "Recipe must be one of bf16, fp8_sub_channel, or mxfp8" + if recipe_name == "mxfp8" and not mxfp8_available: + print(f"MXFP8 is not available, skipping {recipe_name}") + continue + if recipe_name == "fp8_sub_channel" and not fp8_block_scaling_available: + print(f"FP8 block scaling is not available, skipping {recipe_name}") + continue + + df = run_benchmark_linear( + mkns, + recipe_name, + use_bias, + num_gemms=num_gemms, + ) + df_linears = pd.concat([df_linears, df]) + + print(df_linears) + + if args.profile: + torch.autograd.profiler.emit_nvtx().__exit__(None, None, None) diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index ed3c1af81..2a45a8a5c 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.4.0.dev0 +2.6.0.dev0 diff --git a/build_tools/jax.py b/build_tools/jax.py index 4e587b965..182940c11 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -6,7 +6,6 @@ """JAX related extensions.""" import os -import shutil from pathlib import Path import setuptools @@ -16,6 +15,19 @@ from typing import List +def install_requirements() -> List[str]: + """Install dependencies for TE/JAX extensions.""" + if rocm_build(): + return jax_install_requires(["flax>=0.7.1"]) + else: + return ["jax", "flax>=0.7.1"] + + +def test_requirements() -> List[str]: + """Test dependencies for TE/JAX extensions.""" + return ["numpy"] + + def xla_path() -> str: """XLA root path lookup. Throws FileNotFoundError if XLA source is not found.""" @@ -88,22 +100,11 @@ def setup_jax_extension( # Define TE/JAX as a Pybind11Extension from pybind11.setup_helpers import Pybind11Extension - class Pybind11CPPExtension(Pybind11Extension): - """Modified Pybind11Extension to allow custom CXX flags.""" - - def _add_cflags(self, flags: List[str]) -> None: - if isinstance(self.extra_compile_args, dict): - cxx_flags = self.extra_compile_args.pop("cxx", []) - cxx_flags += flags - self.extra_compile_args["cxx"] = cxx_flags - else: - self.extra_compile_args[:0] = flags - - return Pybind11CPPExtension( + return Pybind11Extension( "transformer_engine_jax", sources=[str(path) for path in sources], include_dirs=[str(path) for path in include_dirs], - extra_compile_args={"cxx": cxx_flags}, + extra_compile_args=cxx_flags, ) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 0609d1bc9..d8fcc8a34 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -22,6 +22,29 @@ get_cuda_include_dirs, debug_build_enabled, ) +from typing import List + + +def install_requirements() -> List[str]: + """Install dependencies for TE/PyTorch extensions.""" + reqs = ["einops"] + reqs.append( + "nvdlfw-inspect @" + " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect" + ) + reqs.extend( + [ + "torch>=2.1", + "onnx", + "onnxscript@git+https://github.com/microsoft/onnxscript.git@51ecf47523ef079c53b0e620c62d56d70cfd3871", + ] + ) + return reqs + + +def test_requirements() -> List[str]: + """Test dependencies for TE/JAX extensions.""" + return ["numpy", "torchvision", "transformers"] def setup_pytorch_extension( diff --git a/build_tools/utils.py b/build_tools/utils.py index f848bad74..739e353e4 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -24,13 +24,7 @@ @functools.lru_cache(maxsize=None) def debug_build_enabled() -> bool: """Whether to build with a debug configuration""" - for arg in sys.argv: - if arg == "--debug": - sys.argv.remove(arg) - return True - if int(os.getenv("NVTE_BUILD_DEBUG", "0")): - return True - return False + return bool(int(os.getenv("NVTE_BUILD_DEBUG", "0"))) @functools.lru_cache(maxsize=None) @@ -313,9 +307,12 @@ def get_cuda_include_dirs() -> Tuple[str, str]: def cuda_archs() -> str: version = cuda_version() if os.getenv("NVTE_CUDA_ARCHS") is None: - os.environ["NVTE_CUDA_ARCHS"] = ( - "70;80;89;90;100;120" if version >= (12, 8) else "70;80;89;90" - ) + if version >= (13, 0): + os.environ["NVTE_CUDA_ARCHS"] = "75;80;89;90;100;120" + elif version >= (12, 8): + os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90;100;120" + else: + os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90" return os.getenv("NVTE_CUDA_ARCHS") @@ -454,7 +451,6 @@ def copy_common_headers( new_path.parent.mkdir(exist_ok=True, parents=True) shutil.copy(path, new_path) - def copy_hipify_tools( src_dir: Union[Path, str], dst_dir: Union[Path, str], diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 5d37ae1d9..a876c23e9 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -42,6 +42,9 @@ if [ "$ROCM_BUILD" = "1" ]; then ${PYBINDIR}pip install setuptools wheel fi +# Install deps +${PYBINDIR}pip install cmake pybind11[global] ninja + if $BUILD_METAPACKAGE ; then cd /TransformerEngine if [ "$ROCM_BUILD" != "1" ]; then @@ -85,25 +88,25 @@ if $BUILD_COMMON ; then fi if $BUILD_PYTORCH ; then - cd /TransformerEngine/transformer_engine/pytorch - if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install torch --index-url https://download.pytorch.org/whl/rocm6.3 - else - PYBINDIR=/opt/python/cp38-cp38/bin/ - ${PYBINDIR}pip install torch - fi - ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt - cp dist/* /wheelhouse/ + cd /TransformerEngine/transformer_engine/pytorch + if [ "$ROCM_BUILD" = "1" ]; then + ${PYBINDIR}pip install torch --index-url https://download.pytorch.org/whl/rocm6.3 + else + PYBINDIR=/opt/python/cp38-cp38/bin/ + ${PYBINDIR}pip install torch + fi + ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt + cp dist/* /wheelhouse/ fi if $BUILD_JAX ; then - cd /TransformerEngine/transformer_engine/jax - if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install jax - else - PYBINDIR=/opt/python/cp310-cp310/bin/ - ${PYBINDIR}pip install "jax[cuda12_local]" jaxlib - fi - ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt - cp dist/* /wheelhouse/ + cd /TransformerEngine/transformer_engine/jax + if [ "$ROCM_BUILD" = "1" ]; then + ${PYBINDIR}pip install jax + else + PYBINDIR=/opt/python/cp310-cp310/bin/ + ${PYBINDIR}pip install "jax[cuda12_local]" jaxlib + fi + ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt + cp dist/* /wheelhouse/ fi diff --git a/docs/api/c/cast_transpose_noop.rst b/docs/api/c/cast_transpose_noop.rst new file mode 100644 index 000000000..ae80c5d2d --- /dev/null +++ b/docs/api/c/cast_transpose_noop.rst @@ -0,0 +1,9 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +cast_transpose_noop.h +===================== + +.. doxygenfile:: cast_transpose_noop.h diff --git a/docs/api/c/cudnn.rst b/docs/api/c/cudnn.rst new file mode 100644 index 000000000..5d93c4d6e --- /dev/null +++ b/docs/api/c/cudnn.rst @@ -0,0 +1,9 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +cudnn.h +======= + +.. doxygenfile:: cudnn.h diff --git a/docs/api/c/index.rst b/docs/api/c/index.rst index 7bc864dcc..0499f52f0 100644 --- a/docs/api/c/index.rst +++ b/docs/api/c/index.rst @@ -14,10 +14,13 @@ directly from C/C++, without Python. transformer_engine.h activation.h + cast_transpose_noop.h cast.h + cudnn.h fused_attn.h fused_rope.h gemm.h + multi_tensor.h normalization.h padding.h permutation.h diff --git a/docs/api/c/multi_tensor.rst b/docs/api/c/multi_tensor.rst new file mode 100644 index 000000000..8ba2d274c --- /dev/null +++ b/docs/api/c/multi_tensor.rst @@ -0,0 +1,9 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +multi_tensor.h +============== + +.. doxygenfile:: multi_tensor.h diff --git a/docs/api/common.rst b/docs/api/common.rst index 95d4b50f3..541118985 100644 --- a/docs/api/common.rst +++ b/docs/api/common.rst @@ -11,3 +11,7 @@ Common API .. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None) .. autoapiclass:: transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=Format.E4M3) + +.. autoapiclass:: transformer_engine.common.recipe.Float8CurrentScaling(fp8_format=Format.HYBRID) + +.. autoapiclass:: transformer_engine.common.recipe.Float8BlockScaling(fp8_format=Format.E4M3) diff --git a/docs/debug.rst b/docs/debug.rst new file mode 100644 index 000000000..d33568ea3 --- /dev/null +++ b/docs/debug.rst @@ -0,0 +1,14 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. +Precision debug tools +============================================== + +.. toctree:: + :caption: Precision debug tools + + debug/1_getting_started.rst + debug/2_config_file_structure.rst + debug/api + debug/4_distributed.rst \ No newline at end of file diff --git a/docs/debug/1_getting_started.rst b/docs/debug/1_getting_started.rst new file mode 100644 index 000000000..bc2b95057 --- /dev/null +++ b/docs/debug/1_getting_started.rst @@ -0,0 +1,241 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Getting started +============== + +.. note:: + + Precision debug tools with `Nvidia-DL-Framework-Inspect `_ for Transformer Engine are currently supported only for PyTorch. + +Transformer Engine provides a set of precision debug tools which allow you to easily: + +- log the statistics for each of the tensors in every matrix multiply (GEMM) operation, +- run selected GEMMs in higher precision, +- run current scaling - with one scaling factor per tensor - for particular GEMMs, +- test new precisions and integrate them with FP8 training, +- ... and many more. + +There are 4 things one needs to do to use Transformer Engine debug features: + +1. Create a configuration YAML file to configure the desired features. +2. Import, and initialize the `Nvidia-DL-Framework-Inspect `_ tool, which is installed as the dependency of the Transformer Engine. +3. One can pass ``name="..."`` when creating TE layers to easier identify layer names. If this is not provided, names will be inferred automatically. +4. Invoke ``debug_api.step()`` at the end of one forward-backward pass. + +To start debugging, one needs to create a configuration YAML file. This file lists the features to be used in particular layers. There are 2 kinds of features: + +- provided by the Transformer Engine - for example, DisableFP8GEMM or LogTensorStats - they are listed in the :doc:`debug features API <3_api_features>` section +- defined by the user. For details on how to create a custom feature - please read the :doc:`calls to Nvidia-DL-Framework-Inspect <3_api_te_calls>` section. + +.. figure:: ./img/introduction.svg + :align: center + + Fig 1: Example of Nvidia-DL-Framework-Inspect affecting training script with 3 TE Linear Layers. + ``config.yaml`` contains the specification of the features used for each Linear layer. Some feature classes are provided by TE, + one - ``UserProvidedPrecision`` - is a custom feature implemented by the user. Nvidia-DL-Framework-Inspect inserts features into the layers according to the config. + +Example training script +---------------------- + +Let's look at a simple example of training a Transformer layer using Transformer Engine with FP8 precision. This example demonstrates how to set up the layer, define an optimizer, and perform a few training iterations using synthetic data. + +.. code-block:: python + + # train.py + + from transformer_engine.pytorch import TransformerLayer + import torch + import torch.nn as nn + import torch.optim as optim + import transformer_engine.pytorch as te + + hidden_size = 512 + num_attention_heads = 8 + + transformer_layer = TransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=hidden_size, + num_attention_heads=num_attention_heads + ).cuda() + + dummy_input = torch.randn(10, 32, hidden_size).cuda() + criterion = nn.MSELoss() + optimizer = optim.Adam(transformer_layer.parameters(), lr=1e-4) + dummy_target = torch.randn(10, 32, hidden_size).cuda() + + for epoch in range(5): + transformer_layer.train() + optimizer.zero_grad() + with te.fp8_autocast(enabled=True): + output = transformer_layer(dummy_input) + loss = criterion(output, dummy_target) + loss.backward() + optimizer.step() + +We will demonstrate two debug features on the code above: + +1. Disabling FP8 precision for specific GEMM operations, such as the FC1 and FC2 forward propagation GEMM. +2. Logging statistics for other GEMM operations, such as gradient statistics for data gradient GEMM within the LayerNormLinear sub-layer of the TransformerLayer. + +Config file +---------- + +We need to prepare the configuration YAML file, as below + +.. code-block:: yaml + + # config.yaml + + fc1_fprop_to_fp8: + enabled: True + layers: + layer_types: [fc1, fc2] # contains fc1 or fc2 in name + transformer_engine: + DisableFP8GEMM: + enabled: True + gemms: [fprop] + + log_tensor_stats: + enabled: True + layers: + layer_types: [layernorm_linear] # contains layernorm_linear in name + transformer_engine: + LogTensorStats: + enabled: True + stats: [max, min, mean, std, l1_norm] + tensors: [activation] + freq: 1 + start_step: 2 + end_step: 5 + +Further explanation on how to create config files is in the :doc:`next part of the documentation <2_config_file_structure>`. + +Adjusting Python file +-------------------- + +.. code-block:: python + + # (...) + + import nvdlfw_inspect.api as debug_api + debug_api.initialize( + config_file="./config.yaml", + feature_dirs=["/path/to/transformer_engine/debug/features"], + log_dir="./log", + default_logging_enabled=True) + + # initialization of the TransformerLayer with the name + transformer_layer = TransformerLayer( + name="transformer_layer", + # ...) + + # (...) + for epoch in range(5): + # forward and backward pass + # ... + debug_api.step() + +In the modified code above, the following changes were made: + +1. Added an import for ``nvdlfw_inspect.api``. +2. Initialized the Nvidia-DL-Framework-Inspect by calling ``debug_api.initialize()`` with appropriate configuration, specifying the path to the config file, feature directories, and log directory. +3. Added ``debug_api.step()`` after each of the forward-backward pass. + +Inspecting the logs +------------------ + +Let's look at the files with the logs. Two files will be created: + +1. debug logs. +2. statistics logs. + +Let's look inside them! + +In the main log file, you can find detailed information about the transformer layer's GEMMs behavior. You can see that ``fc1`` and ``fc2`` fprop GEMMs are run in high precision, as intended. + +.. code-block:: text + + # log/nvdlfw_inspect_logs/nvdlfw_inspect_globalrank-0.log + + INFO - Default logging to file enabled at ./log + INFO - Reading config from ./config.yaml. + INFO - Loaded configs for dict_keys(['fc1_fprop_to_fp8', 'log_tensor_stats']). + INFO - transformer_layer.self_attention.layernorm_qkv: Tensor: activation, gemm fprop - FP8 quantization + INFO - transformer_layer.self_attention.layernorm_qkv: Tensor: activation, gemm wgrad - FP8 quantization + INFO - transformer_layer.self_attention.layernorm_qkv: Tensor: weight, gemm fprop - FP8 quantization + INFO - transformer_layer.self_attention.layernorm_qkv: Tensor: weight, gemm dgrad - FP8 quantization + INFO - transformer_layer.self_attention.layernorm_qkv: Tensor: gradient, gemm dgrad - FP8 quantization + INFO - transformer_layer.self_attention.layernorm_qkv: Tensor: gradient, gemm wgrad - FP8 quantization + INFO - transformer_layer.self_attention.proj: Tensor: activation, gemm fprop - FP8 quantization + INFO - transformer_layer.self_attention.proj: Tensor: activation, gemm wgrad - FP8 quantization + INFO - transformer_layer.self_attention.proj: Tensor: weight, gemm fprop - FP8 quantization + INFO - transformer_layer.self_attention.proj: Tensor: weight, gemm dgrad - FP8 quantization + INFO - transformer_layer.self_attention.proj: Tensor: gradient, gemm dgrad - FP8 quantization + INFO - transformer_layer.self_attention.proj: Tensor: gradient, gemm wgrad - FP8 quantization + INFO - transformer_layer.layernorm_mlp.fc1: Tensor: activation, gemm fprop - High precision + INFO - transformer_layer.layernorm_mlp.fc1: Tensor: activation, gemm wgrad - FP8 quantization + INFO - transformer_layer.layernorm_mlp.fc1: Tensor: weight, gemm fprop - High precision + INFO - transformer_layer.layernorm_mlp.fc1: Tensor: weight, gemm dgrad - FP8 quantization + INFO - transformer_layer.layernorm_mlp.fc1: Tensor: gradient, gemm dgrad - FP8 quantization + INFO - transformer_layer.layernorm_mlp.fc1: Tensor: gradient, gemm wgrad - FP8 quantization + INFO - transformer_layer.layernorm_mlp.fc2: Tensor: activation, gemm fprop - High precision + INFO - transformer_layer.layernorm_mlp.fc2: Tensor: activation, gemm wgrad - FP8 quantization + INFO - transformer_layer.layernorm_mlp.fc2: Tensor: weight, gemm fprop - High precision + INFO - transformer_layer.layernorm_mlp.fc2: Tensor: weight, gemm dgrad - FP8 quantization + INFO - transformer_layer.layernorm_mlp.fc2: Tensor: gradient, gemm dgrad - FP8 quantization + INFO - transformer_layer.layernorm_mlp.fc2: Tensor: gradient, gemm wgrad - FP8 quantization + INFO - transformer_layer.self_attention.layernorm_qkv: Feature=LogTensorStats, API=look_at_tensor_before_process: activation + .... + +The second log file (``nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log``) contains statistics for tensors we requested in ``config.yaml``. + +.. code-block:: text + + # log/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log + + INFO - transformer_layer.self_attention.layernorm_qkv_activation_max iteration=000002 value=4.3188 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_min iteration=000002 value=-4.3386 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_mean iteration=000002 value=0.0000 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_std iteration=000002 value=0.9998 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_l1_norm iteration=000002 value=130799.6953 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_max iteration=000003 value=4.3184 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_min iteration=000003 value=-4.3381 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_mean iteration=000003 value=0.0000 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_std iteration=000003 value=0.9997 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_l1_norm iteration=000003 value=130788.1016 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_max iteration=000004 value=4.3181 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_min iteration=000004 value=-4.3377 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_mean iteration=000004 value=0.0000 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_std iteration=000004 value=0.9996 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_l1_norm iteration=000004 value=130776.7969 + +Logging using TensorBoard +------------------------ + +Precision debug tools support logging using `TensorBoard `_. To enable it, one needs to pass the argument ``tb_writer`` to the ``debug_api.initialize()``. Let's modify ``train.py`` file. + +.. code-block:: python + + # (...) + + from torch.utils.tensorboard import SummaryWriter + tb_writer = SummaryWriter('./tensorboard_dir/run1') + + # add tb_writer to the Debug API initialization + debug_api.initialize( + config_file="./config.yaml", + feature_dirs=["/path/to/transformer_engine/debug/features"], + log_dir="./log", + tb_writer=tb_writer) + + # (...) + +Let's run training and open TensorBoard by ``tensorboard --logdir=./tensorboard_dir/run1``: + +.. figure:: ./img/tensorboard.png + :align: center + + Fig 2: TensorBoard with plotted stats. \ No newline at end of file diff --git a/docs/debug/2_config_file_structure.rst b/docs/debug/2_config_file_structure.rst new file mode 100644 index 000000000..f1069b0c8 --- /dev/null +++ b/docs/debug/2_config_file_structure.rst @@ -0,0 +1,241 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Config File Structure +==================== + +To enable debug features, create a configuration YAML file to specify the desired behavior, such as determining which GEMMs (General Matrix Multiply operations) should run in higher precision rather than FP8 and defining which statistics to log. +Below, we outline how to structure the configuration YAML file. + +General Format +------------- + +A config file can have one or more sections, each containing settings for specific layers and features: + +.. code-block:: yaml + + section_name_1: + enabled: ... + layers: + # Specify layers here... + transformer_engine: + Feature1Name: + enabled: ... + # Feature details... + Feature2Name: + enabled: ... + # Feature details... + + section_name_2: + enabled: ... + layers: + # Specify layers here... + Feature1Name: # If feature has no namespace, then it is in the default namespace. + enabled: ... + # Feature details... + + section_name_3: + enabled: ... + layers: + # Specify layers here... + transformer_engine: + Feature1Name: + enabled: ... + # Feature details... + Feature2Name: + enabled: ... + # Feature details... + +Sections may have any name and must contain: + +1. An ``enabled`` field that specifies whether the features in that section will be active. +2. A ``layers`` field specifying which layers the section applies to. Each layer can belong to only one section. +3. Additional fields describing features for those layers. + +Layer Specification +------------------ + +Debug layers can be identified by a ``name`` parameter: + +.. code-block:: python + + linear = transformer_engine.debug.pytorch.Linear(in_features, out_features, name="linear1") + +This name is used in the config file to identify the layer. To specify the ``layers`` field, you can use one of the following methods: + +1. ``layer_name_regex_pattern``: Use a regular expression to match layer names. This expression must adhere to the Python ``re`` module syntax. +2. ``layer_types``: Provide a list of strings, where a layer will be selected if any string matches part of its name. + +Examples: + +.. code-block:: yaml + + # Example 1: Using regular expression to select layers + my_section: + enabled: ... + layers: + layer_name_regex_pattern: 'self_attn.*' + transformer_engine: + (...) + + # Example 2: Using layer type to select layers + another_section: + enabled: ... + layers: + layer_types: ['fc1', 'layernorm_linear'] + transformer_engine: + (...) + +Names in Transformer Layers +-------------------------- + +There are three ways to assign a name to a layer in the Transformer Engine: + +- Initialize the layer with the ``name=...`` argument. +- Use ``debug_api.infer_and_assign_layer_names(model)``, which assigns names based on class names. +- Rely on the default names assigned during module initialization, such as ``Layer_n``, where ``n`` represents the layer number. + +The ``TransformerLayer`` in Transformer Engine is a composition of multiple sub-layers. We can modify some of these layers using precision debug tools, particularly those that contain exactly one linear layer. To see the names of all such layers, we can inspect log files. For instance, a ``TransformerLayer`` named ``transformer_layer`` might consist of: + +- ``transformer_layer.self_attn.layernorm_linear_qkv`` / ``transformer_layer.self_attn.linear_qkv`` / ``transformer_layer.self_attn.layernorm_linear_q`` / ``transformer_layer.self_attn.linear_q`` / ``transformer_layer.self_attn.linear_kv``, +- ``transformer_layer.self_attn.proj``, +- ``transformer_layer.inter_attn.*`` for ``layer_type="decoder"``, +- ``transformer_layer.layernorm_mlp.fc1``, +- ``transformer_layer.layernorm_mlp.fc2``, + +depending on the configuration. Some layers, like ``LayerNormLinear``, are fusions of two layers: ``LayerNorm`` and ``Linear``. When referring to such layers in precision debug tools, only the ``Linear`` part is affected. + +Below is an example ``TransformerLayer`` with four linear layers that can be influenced by the precision debug tools. + +.. figure:: ./img/names.svg + :align: center + :width: 80% + + Fig 1: Names of layers in an example configuration of TransformerLayer. The most nested blocks represent the most basic layers, each containing one linear layer. Layers that do not contain linear layers, such as ``DotProductAttention``, are omitted. + +**Configuration File Example** + +.. code-block:: yaml + + # Disables wgrad in all 4 GEMMs + section1: + enabled: True + layers: + layer_types: [transformer_layer] + transformer_engine: + DisableFP8GEMM: + enabled: True + gemms: [wgrad] + + # Disables all GEMMs in layernorm_mlp layer + section2: + enabled: True + layers: + layer_types: [layernorm_mlp] + transformer_engine: + DisableFP8Layer: + enabled: True + + # Logs wgrad stats in fc1 + section3: + enabled: True + layers: + layer_types: [fc1] + transformer_engine: + LogTensorStats: + enabled: True + stats: [min] + tensors: [wgrad] + freq: 1 + start_step: 0 + end_step: 50 + + +Structured Configuration for GEMMs and Tensors +--------------------------------------------- + +Sometimes a feature is parameterized by a list of tensors or by a list of GEMMs. +There are multiple ways of describing this parameterization. + +We can pass lists, as below. + +.. code-block:: yaml + + Feature: + enabled: ... + gemms: [gemm1, gemm2] + tensors: [tensor1, tensor2] + ... + +We can use struct for tensors. + +.. code-block:: yaml + + Feature: + gemms: [gemm1, gemm2] + tensors_struct: + - tensor: tensor1 + feature_param1: value + - tensor: tensor2 + feature_param1: value + gemm_feature_param1: value + +Similarly, we can use struct for GEMMs. + +.. code-block:: yaml + + Feature: + enabled: ... + tensors: [tensor1, tensor2] + gemms_struct: + - gemm: gemm1 + feature_param1: value + - gemm: gemm2 + feature_param1: value + gemm_feature_param1: value + +We can use both structs for tensors and GEMMs. The tensors_struct should be nested inside gemms_struct. + +.. code-block:: yaml + + Feature: + enabled: ... + gemms_struct: + - gemm: gemm1 + tensors: [tensor1, tensor2] + tensor_feature_param1: value + gemm_feature_param1: value + - gemm: gemm2 + tensors_struct: + - tensor: tensor1 + tensor_feature_param1: value + - tensor: tensor2 + tensor_feature_param2: value + gemm_feature_param1: value + +Enabling or Disabling Sections and Features +------------------------------------------ + +Debug features can be enabled or disabled with the ``enabled`` keyword: + +.. code-block:: yaml + + section1: + enabled: True + layers: + layer_types: [self_attention] + transformer_engine: + LogTensorStats: + enabled: False # Disables the LogTensorStats feature + stats: [max, min, mean, std, l1_norm] + + section2: + enabled: False # Disables entire section2 + transformer_engine: + LogFp8TensorStats: + enabled: True # Does not enable the LogFp8TensorStats feature, because section2 is disabled + stats: [underflows, overflows] + +By organizing your ``config.yaml`` properly, you can easily manage debugging features, ensuring a more streamlined and customizable debugging experience. diff --git a/docs/debug/3_api_debug_setup.rst b/docs/debug/3_api_debug_setup.rst new file mode 100644 index 000000000..bda8f096d --- /dev/null +++ b/docs/debug/3_api_debug_setup.rst @@ -0,0 +1,87 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Setup +===== + +Precision debug tools for the Transformer Engine use `Nvidia-DL-Framework-Inspect `_ package from NVIDIA. +Please refer to the Nvidia-DL-Framework-Inspect `documentation `_ for more details. +Below, we outline the steps for debug initialization. + +initialize() +----------- + +Must be called once on every rank in the global context to initialize Nvidia-DL-Framework-Inspect. + +**Parameters** + +- **config_file** (*str*, default=""): Path to the configuration YAML file containing features to enable and layer names. If one wants to run without the configuration file, pass ``""``. +- **feature_dirs** (*List[str] | str*): List of directories containing features to load and register. One needs to pass ``[/path/to/transformerengine/transformer_engine/debug/features]`` to use TE features. +- **logger** (*Union[BaseLogger, None]*, default=None): Logger for logging tensor statistics. Should adhere to ``BaseLogger`` from the `Nvidia-DL-Framework-Inspect `_ package. +- **log_dir** (*str*, default= "."): Directory path to hold ``debug_logs`` and ``debug_statistics_logs``. +- **tb_writer** (*TensorBoardWriter*, default=None): TensorBoard writer for logging. +- **default_logging_enabled** (*bool*, default=False): Enable default logging to the file. + +.. code-block:: python + + import nvdlfw_inspect.api as debug_api + + debug_api.initialize( + config_file="./config.yaml", + feature_dirs=["/path/to/transformer_engine/debug/features"], + log_dir="./log_dir") + +set_tensor_reduction_group() +-------------------------- + +Needed only for logging tensor stats. In multi-GPU training, activation and gradient tensors are distributed across multiple nodes. This method lets you specify the group for the reduction of stats; see the `reduction group section <./4_distributed.rst#reduction-groups>`_ for more details. + +If the tensor reduction group is not specified, then statistics are reduced across all nodes in the run. + +**Parameters** + +- **group** (torch.distributed.ProcessGroup): The process group across which tensors will be reduced to get stats. + + +.. code-block:: python + + import nvdlfw_inspect.api as debug_api + + # initialization + # (...) + + pipeline_parallel_group = initialize_pipeline_parallel_group() + + debug_api.set_tensor_reduction_group(pipeline_parallel_group) + + # training + # (...) + # activation/gradient tensor statistics are reduced along pipeline_parallel_group + +set_weight_tensor_tp_group_reduce() +--------------------------------- + +By default, weight tensor statistics are reduced within the tensor parallel group. This function allows you to disable that behavior; for more details, see `reduction group section <./4_distributed.rst#reduction-groups>`_. + +This method is not provided by the ``debug_api``, but by the ``transformer_engine.debug``. + +**Parameters** + +- **enabled** (*bool*, default=True): A boolean flag to enable or disable the reduction of weight tensor statistics within the tensor parallel group. + + +.. code-block:: python + + import nvdlfw_inspect.api as debug_api + from transformer_engine.debug import set_weight_tensor_tp_group_reduce + + # initialization + # (...) + + set_weight_tensor_tp_group_reduce(False) + + # training + # (...) + # weight tensor statistics are not reduced diff --git a/docs/debug/3_api_features.rst b/docs/debug/3_api_features.rst new file mode 100644 index 000000000..b31c437b2 --- /dev/null +++ b/docs/debug/3_api_features.rst @@ -0,0 +1,14 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Debug features +========== + +.. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats +.. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats +.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM +.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer +.. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling +.. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant diff --git a/docs/debug/3_api_te_calls.rst b/docs/debug/3_api_te_calls.rst new file mode 100644 index 000000000..eb66c8ff2 --- /dev/null +++ b/docs/debug/3_api_te_calls.rst @@ -0,0 +1,45 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Calls to Nvidia-DL-Framework-Inspect +==================================== +Let's look deeper into how Nvidia-DL-Framework-Inspect with Transformer Engine work together. TransformerEngine layers have some hook calls inside each of the GEMMs. Users can define feature classes or use feature classes provided with TE. File ``config.yaml`` describes which hooks need to be used for which layers. Nvidia-DL-Framework-Inspect combines 3 things: TE training, feature classes and ``config.yaml`` and takes care of inserting hooks in the correct places. This process is illustrated in the image below. + +.. figure:: ./img/api_calls1.svg + :align: center + + Fig 1: Example of Nvidia-DL-Framework-Inspect affecting training script with 1 Linear Layer. For tensors mentioned in ``config.yaml``, behavior of ``modify_tensor_enabled()`` and ``modify_tensor()`` calls are substituted with definitions from the feature class. Other calls return default values - in fact they do nothing. + +In this page, all calls from TransformerEngine to the Nvidia-DL-Framework-Inspect for each GEMM are listed. The order of these calls is illustrated in the image below. + +.. figure:: ./img/api_calls2.svg + :align: center + + Fig 2: The calls to Nvidia-DL-Framework-Inspect done for Transformer Engine. There are 2 types of calls: GEMM calls and routing calls. + + +There are 2 categories of API calls, each is used for different purposes: + +- GEMM calls - invoked during every GEMM, used to process or quantize tensors and collect information about them, +- Routing calls - invoked at the beginning of every forward pass - they indicate whether a feature is going to use `modify_tensor()`, etc. + +If all routing calls for the layer return `False`, then the layer is invoked in an optimized version with Transformer Engine fusions. +If any of the routing calls return `True`, layers are run without the fusions. This is necessary because otherwise some tensors cannot be accessed +if fusions happen. An important remark is that if no feature is used for the layer, then it should perform as fast as the layer without initializing `debug_api`. + + +.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor + +.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor + +.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize + +.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor_enabled + +.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.fp8_gemm_enabled + +.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_enabled + +.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize_enabled diff --git a/docs/debug/4_distributed.rst b/docs/debug/4_distributed.rst new file mode 100644 index 000000000..6f69f2712 --- /dev/null +++ b/docs/debug/4_distributed.rst @@ -0,0 +1,91 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Distributed training +=================== + +Nvidia-Pytorch-Inspect with Transformer Engine supports multi-GPU training. This guide describes how to run it and how the supported features work in the distributed setting. + +To use precision debug tools in multi-GPU training, one needs to: + +1. Run ``debug_api.initialize(...)`` and provide the same configuration YAML file on every node. +2. If one wants to log stats, one may want to invoke ``debug_api.set_tensor_reduction_group`` with a proper reduction group. + +Behavior of the features +----------------------- + +In a distributed setting, **DisableFP8GEMM** and **DisableFP8Layer** function similarly to the single-GPU case, with no notable differences. + +**PerTensorScaling** and **FakeQuant** calculate FP8 scaling factors independently on each node, meaning the number of GPUs may affect results. This differs from the delayed scaling FP8 recipe behavior, in which scaling factors are synchronized. + +.. figure:: ./img/scaling_factors.svg + :align: center + + Fig 1: For **PerTensorScaling** and **FakeQuant** tensor scaling factors are computed separately for each of the tensor shards. This is not the case for delayed scaling FP8 scaling factors, which are synchronized. + +Logging-related features are more complex and will be discussed further in the next sections. + +Reduction groups +-------------- + +In setups with tensor, data, or pipeline parallelism, some tensors are distributed across multiple GPUs, requiring a reduction operation to compute statistics for these tensors. + +The weight tensor is always split among the tensor parallel group, and debug tools automatically reduce statistics within this group by default. To disable this automatic reduction, use: + +.. code-block:: python + + transformer_engine.debug.set_weight_tensor_tp_group_reduce(False) + +In cases of data parallelism, Transformer Engine modules lack the process group needed for reduction. To manually specify the group, use: + +.. code-block:: python + + debug_api.set_tensor_reduction_group(group) + +This command ensures statistics are reduced across the defined group. Activation statistics are logged after the forward pass (immediately after exiting autocast), while gradient (dgrad and wgrad) statistics are logged following the backward pass. + +Below, we illustrate configurations for a 4-node setup with tensor parallelism size 2 and data parallelism size 2, showcasing different reduction configurations. + +.. figure:: ./img/reduction1.svg + :align: center + + Fig 2: There is a single tensor reduction group composed of all nodes. As a result, each node logs the same statistics for the tensors, as they are fully reduced across all nodes. + +.. figure:: ./img/reduction2.svg + :align: center + + Fig 3: Every node is set with a tensor reduction group consisting of itself. Every node prints the same statistics for weights (which are still synchronized within TP groups), but the statistics of activations and gradients are not synchronized. + +.. figure:: ./img/reduction3.svg + :align: center + + Fig 4: Weight synchronization is disabled by ``set_weight_tensor_tp_group_reduce(False)``, so every node logs stats for its shard of the weight. + + +Microbatching +----------- + +Let's dive into how statistics collection works with microbatching. By microbatching, we mean invoking multiple ``forward()`` calls for each ``debug_api.step()``. The behavior is as follows: + +- For weight tensors, the stats remain the same for each microbatch because the weight does not change. +- For other tensors, the stats are accumulated. + +Logging to files and TensorBoard +------------------------------ + +In a single-node setup with ``default_logging_enabled=True``, all logs are saved by default to ``log_dir/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log``. In multi-GPU training, each node writes its reduced statistics to its unique file, named ``log_dir/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-i.log`` for rank i. Because these logs contain reduced statistics, the logged values are identical for all nodes within a reduction group. + +If certain nodes are given a TensorBoard writer, only those nodes will log to TensorBoard. This is useful in scenarios involving pipeline, data, and tensor parallelism, such as with two transformer layers and settings TP_SIZE = 2, DP_SIZE = 2, and PP_SIZE = 2. To log all stats to TensorBoard, you should pass a TensorBoard writer to one process in each pipeline parallel group. + +.. figure:: ./img/pipeline_logging.svg + :align: center + + Fig 5: Example with pipeline parallelism, where a ``tb_writer`` is assigned to one node within each pipeline parallel group, setting these as tensor reduction groups. + +Alternatively, setting the tensor reduction group to None will yield unreduced statistics for wgrad and dgrad tensors on each node, allowing for post-processing. For weight statistics without reduction in the TP parallel group, use: + +.. code-block:: python + + transformer_engine.debug.set_weight_tensor_tp_group_reduce(False) \ No newline at end of file diff --git a/docs/debug/api.rst b/docs/debug/api.rst new file mode 100644 index 000000000..ac593d353 --- /dev/null +++ b/docs/debug/api.rst @@ -0,0 +1,13 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. +API +============ + +.. toctree:: + :caption: Precision debug tools API + + 3_api_debug_setup.rst + 3_api_features.rst + 3_api_te_calls.rst \ No newline at end of file diff --git a/docs/debug/img/api_calls1.svg b/docs/debug/img/api_calls1.svg new file mode 100644 index 000000000..098f384b2 --- /dev/null +++ b/docs/debug/img/api_calls1.svg @@ -0,0 +1 @@ +te.LinearLinear1Nvidia-DLFramework-Inspectconfig.yamlSection1:enabled: Truelayer_names: [Linear1]UserProvidedPrecision:enabled: Truegemms_struct:-gemm: frop-tensors: [activation, output]-gemm: dgrad-tensors: [weight]FeatureclassesUserProvidedPrecisionFPROPWGRADDGRADmodify_tensor_enabledDefaultmodify_tensormodify_tensor_enabledmodify_tensor \ No newline at end of file diff --git a/docs/debug/img/api_calls2.svg b/docs/debug/img/api_calls2.svg new file mode 100644 index 000000000..5df72fc2e --- /dev/null +++ b/docs/debug/img/api_calls2.svg @@ -0,0 +1 @@ +Tensor Ainspect_tensorfp8 castmodify_tensorinspect_tensor_postquantizeGEMMinspect_tensormodify_tensorinspect_tensor_enabledinspect_tensor_postquantize_enabledfp8_gemm_enabledmodify_tensor_enabledTensor Binspect_tensorfp8 castmodify_tensorinspect_tensor_postquantizeRouting callsGEMM calls \ No newline at end of file diff --git a/docs/debug/img/fake_quant.svg b/docs/debug/img/fake_quant.svg new file mode 100644 index 000000000..3ba6973d5 --- /dev/null +++ b/docs/debug/img/fake_quant.svg @@ -0,0 +1 @@ +FP8 GEMMBF16weightBF16inputFP8inputFP8weightBF16activationBF16 GEMMBF16weightBF16inputBF16activationBF16 Inputfake quantizedto FP8FP8inputBF16 Inputfake quantizedto FP8 \ No newline at end of file diff --git a/docs/debug/img/introduction.svg b/docs/debug/img/introduction.svg new file mode 100644 index 000000000..0eae8e820 --- /dev/null +++ b/docs/debug/img/introduction.svg @@ -0,0 +1 @@ +te.LinearLinear1Nvidia-DLFramework-InspectDisableFp8LayerLogTensorStatsconfig.yamlte.LinearLinear2DisableFp8LayerLogTensorStatsSection1:enabled: Truelayer_names: [Linear1, Linear2]DisableFp8Layer:enabled: TrueSection2:enabled: Truelayer_names: [Linear2]LogTensorStats:enabled: TrueotherparamsSection3:enabled: Truelayer_names: [Linear3]UserProvidedPrecision:enabled: Truete.LinearLinear3FeatureclassesDisableFp8LayerUserProvidedPrecisionUserProvidedPrecisionProvidedby the Transformer EngineUser candefinecustomfeatureclasses \ No newline at end of file diff --git a/docs/debug/img/names.svg b/docs/debug/img/names.svg new file mode 100644 index 000000000..3990939e7 --- /dev/null +++ b/docs/debug/img/names.svg @@ -0,0 +1 @@ +Transformer Layer with name transformer_layertransformer_layer.self_attntransformer_layer.self_attn.projtransformer_layer.self_attn.layernorm_linear_qkvtransformer_layer.layernorm_mlptransformer_layer.layernorm_mlp.fc1transformer_layer.layernorm_mlp.fc21 Linear1 Linear1 Linear1 Linear \ No newline at end of file diff --git a/docs/debug/img/pipeline_logging.svg b/docs/debug/img/pipeline_logging.svg new file mode 100644 index 000000000..b87254315 --- /dev/null +++ b/docs/debug/img/pipeline_logging.svg @@ -0,0 +1 @@ +Node 1Node 2Node 3Node 4Node 5Node 6Node 7Node 8TensorBoard logstb_writertb_writertensor reduction group 1=pipeline parallel group 1tensor reduction group 2=pipeline parallel group 2 \ No newline at end of file diff --git a/docs/debug/img/reduction1.svg b/docs/debug/img/reduction1.svg new file mode 100644 index 000000000..184799d53 --- /dev/null +++ b/docs/debug/img/reduction1.svg @@ -0,0 +1 @@ +Node 1Node 2Node 3Node 4TP group 1activation/gradient tensorsweight tensorsTensor reduction groupTP group 2StatsStatsStatsStats \ No newline at end of file diff --git a/docs/debug/img/reduction2.svg b/docs/debug/img/reduction2.svg new file mode 100644 index 000000000..36f94611e --- /dev/null +++ b/docs/debug/img/reduction2.svg @@ -0,0 +1 @@ +TP group 1activation/gradient tensorsweight tensorsTensor reduction groupTP group 2StatsStatsStatsStatsTensor reduction groupTensor reduction groupTensor reduction groupNode 1Node 2Node 3Node 4 \ No newline at end of file diff --git a/docs/debug/img/reduction3.svg b/docs/debug/img/reduction3.svg new file mode 100644 index 000000000..601fb8502 --- /dev/null +++ b/docs/debug/img/reduction3.svg @@ -0,0 +1 @@ +TP group 1activation/gradient tensorsweight tensorsTensor reduction groupTP group 2StatsStatsStatsStatsNode 1Node 2Node 3Node 4 \ No newline at end of file diff --git a/docs/debug/img/scaling_factors.svg b/docs/debug/img/scaling_factors.svg new file mode 100644 index 000000000..b70b51e66 --- /dev/null +++ b/docs/debug/img/scaling_factors.svg @@ -0,0 +1 @@ +One Scaling FactorScaling Factor No. 1Scaling Factor No. 2Node 1Node 2NodeOne Scaling FactorOne Scaling FactorNode 1Node 2PerTensorScalingandFakeQuantFP8 Delayed Scaling \ No newline at end of file diff --git a/docs/debug/img/tensorboard.png b/docs/debug/img/tensorboard.png new file mode 100644 index 000000000..481dbd2eb Binary files /dev/null and b/docs/debug/img/tensorboard.png differ diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index b6ec290b0..8297ac6d2 100644 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -19,7 +19,7 @@ LlamaRMSNorm, LlamaConfig, ) -from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model +from transformers.modeling_utils import _add_variant, load_state_dict from transformers.utils import WEIGHTS_INDEX_NAME from transformers.utils.hub import get_checkpoint_shard_files @@ -148,8 +148,8 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k state_dict = load_state_dict(shard_file) # replace_params copies parameters relevant only to TransformerEngine replace_params(state_dict, vanilla_model.state_dict(), config) - # _load_state_dict_into_model copies parameters other than those in TransformerEngine - _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") + # load_state_dict copies parameters other than those in TransformerEngine + vanilla_model.load_state_dict(state_dict, strict=False) # Force mem release. Taken from huggingface code del state_dict diff --git a/docs/index.rst b/docs/index.rst index cd9ce41cf..bbdb4fea6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -52,4 +52,5 @@ Transformer Engine documentation :caption: Advanced api/c/index + debug examples/attention/attention.ipynb diff --git a/docs/installation.rst b/docs/installation.rst index d0d6cf96d..ecb1e9a0d 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -112,7 +112,7 @@ To build the C++ extensions with debug symbols, e.g. with the `-g` flag: .. code-block:: bash - pip3 install --no-build-isolation . --global-option=--debug + NVTE_BUILD_DEBUG=1 pip3 install --no-build-isolation . .. include:: ../README.rst :start-after: troubleshooting-begin-marker-do-not-remove diff --git a/examples/jax/encoder/common.py b/examples/jax/encoder/common.py index a23df1be2..98c984839 100644 --- a/examples/jax/encoder/common.py +++ b/examples/jax/encoder/common.py @@ -6,12 +6,15 @@ """Shared functions for the encoder tests""" from functools import lru_cache +import jax +import jax.numpy import transformer_engine from transformer_engine_jax import get_device_compute_capability from transformer_engine.common import recipe from transformer_engine.jax.util import is_hip_extension if is_hip_extension(): from transformer_engine.jax.util import is_mi200 +import numpy as np @lru_cache @@ -44,6 +47,71 @@ def is_mxfp8_supported(): return gpu_arch >= 100 +def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01, print_info=False): + """Checks whether most params are sharded across sharding axis. + + (Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/315e551e5942b24656a4250dcfca986fb4135b72/MaxText/maxtext_utils.py#L348) + + This function determines whether the majority of parameters are distributed + across a specified sharding axes with an acceptable tolerance. It compares the + current distribution to a scenario where all parameters are fully sharded + across the axes on which the params are sharded e.g. 'tensor' axis. + + Args: + params: params of the model state + mesh: mesh constructed from config + tolerance: float between 0.0 and 1.0 representing the allowed percentage of + non-sharded parameters. + """ + + def get_product_num_devices_for_weight_sharding(weight_sharding_axes): + product_num_devices_for_weight_sharding = 1 + for axis in weight_sharding_axes: + product_num_devices_for_weight_sharding *= mesh.shape.get(axis, 1) + return product_num_devices_for_weight_sharding + + def assert_leaf_sharding(path, arr): + + # Is the weight sharded? Get the axes on which it is sharded. + partition_spec = arr.sharding.spec + weight_sharding_axes = set(partition_spec) - set([None]) # None is not a sharding axis + + # Total number of devices on the axes on which the weight is sharded. + product_num_devices_for_weight_sharding = get_product_num_devices_for_weight_sharding( + weight_sharding_axes + ) + + # Params present in one shard (on one device). + shard = arr.addressable_shards[0] + params_per_chip = np.prod(shard.data.shape) + + # Total number of params (across all devicess). + total_params = jax.numpy.size(arr) + + # Percentage of params that are unsharded. + unsharded_perc = ( + (params_per_chip / (total_params / product_num_devices_for_weight_sharding) - 1) * 100 + if params_per_chip < total_params + else 100 + ) + + if print_info: + print( + f"{path}: {unsharded_perc:.2f}% unsharded, unsharded param shape={arr.shape}," + f" partition spec={partition_spec}" + ) + + # If the weight is sharded on any axis, then the percentage of + # unsharded params should be less than the tolerance. + assert ( + product_num_devices_for_weight_sharding == 1 or unsharded_perc < tolerance + ), f"{path}: {unsharded_perc:.2f}% unsharded" + + jax.tree_util.tree_map_with_path( + lambda p, x: assert_leaf_sharding("/".join(str(x) for x in p), x), params + ) + + def get_fp8_recipe_from_name_string(name: str): """Query recipe from a given name string""" match name: diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh index 6c9e9063f..2a1ac0f8f 100644 --- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -30,7 +30,7 @@ for TEST_CASE in "${TEST_CASES[@]}"; do LOG_FILE="${TEST_CASE}_gpu_${i}.log" # Run pytest and redirect stdout and stderr to the log file - pytest -c "$TE_PATH/tests/jax/pytest.ini" \ + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ --num-process=$NUM_GPUS \ --process-id=$i > "$LOG_FILE" 2>&1 & @@ -38,22 +38,22 @@ for TEST_CASE in "${TEST_CASES[@]}"; do # Wait for the process to finish wait + tail -n +7 "${TEST_CASE}_gpu_0.log" # Check and print the log content accordingly - if grep -q "FAILED" "${TEST_CASE}_gpu_0.log"; then - HAS_FAILURE=1 - echo "... $TEST_CASE FAILED" - tail -n +7 "${TEST_CASE}_gpu_0.log" - elif grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then + if grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then echo "... $TEST_CASE SKIPPED" elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then echo "... $TEST_CASE PASSED" else - echo "Invalid ${TEST_CASE}_gpu_0.log" + HAS_FAILURE=1 + echo "... $TEST_CASE FAILED" fi # Remove the log file after processing it + wait rm ${TEST_CASE}_gpu_*.log done +wait exit $HAS_FAILURE diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index bf6fdf579..b761b1381 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -21,8 +21,13 @@ from jax.experimental import mesh_utils from jax.sharding import PartitionSpec, NamedSharding -from common import is_bf16_supported, get_fp8_recipe_from_name_string +from common import ( + is_bf16_supported, + get_fp8_recipe_from_name_string, + assert_params_sufficiently_sharded, +) import transformer_engine.jax as te +import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax from transformer_engine.jax.quantize import is_fp8_available, ScalingMode @@ -225,38 +230,6 @@ def check_fp8(state, var_collect, inputs, masks, labels): assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr -def get_params_sharding(sharding_rules, abs_var_collect, mesh): - """Refer params to create params sharding""" - rules_dict = dict(sharding_rules) - - def to_device_axis(logical_axis): - partitions = [rules_dict[key] for key in logical_axis] - return NamedSharding(mesh, PartitionSpec(*partitions)) - - params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {}) - params_axes_sharding = jax.tree_util.tree_map( - to_device_axis, nn_partitioning.get_axis_names(params_axes) - ) - params_axes_sharding = flax.core.unfreeze(params_axes_sharding) - params_sharding = jax.tree_util.tree_map( - lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY] - ) - params_sharding = {**params_sharding, **params_axes_sharding} - return params_sharding - - -def get_state_sharding(state, params_sharding): - """Refer params_sharding to create state sharding""" - - def replace_params(x): - return params_sharding if isinstance(x, dict) else None - - state_sharding = jax.tree_util.tree_map( - replace_params, state, is_leaf=lambda x: isinstance(x, dict) - ) - return state_sharding - - def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) @@ -293,8 +266,11 @@ def train_and_evaluate(args): device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) with jax.sharding.Mesh( devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) - ) as mesh: - + ) as mesh, te.fp8_autocast( + enabled=args.use_fp8, + fp8_recipe=fp8_recipe, + mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), + ): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) @@ -304,35 +280,65 @@ def train_and_evaluate(args): mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] label_shape = [args.batch_size] - with te.fp8_autocast( - enabled=args.use_fp8, - fp8_recipe=fp8_recipe, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), - ): + # Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast + axis_rules = flax.linen.get_logical_axis_rules() + axis_rules += ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) + te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) + + with flax.linen.logical_axis_rules(te_extended_axis_rules): + + print(f"Device mesh: {mesh}") + print(f"Axis rules: {te_extended_axis_rules}") + encoder = Net(num_embed, args.enable_sp) inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) - customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) - sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules - params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) + logical_partition_spec = nn.get_partition_spec(abs_var_collect) + + # Note that `nn.logical_to_mesh_sharding` returns a dict with an extra + # "params" key that contains the sharding for the parameters. + params_sharding = nn.logical_to_mesh_sharding( + logical_partition_spec, mesh, te_extended_axis_rules + ) + inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None)) masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None)) in_shardings = (None, inputs_sharding, masks_sharding) out_shardings = { - key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect + key: params_sharding[PARAMS_KEY] if key is PARAMS_KEY else None + for key in abs_var_collect } - jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) + jit_encoder_init = jax.jit( + encoder.init, in_shardings=in_shardings, out_shardings=out_shardings + ) var_collect = jit_encoder_init(init_rngs, inputs, masks) + # Check if params are sufficiently sharded after initialization + assert_params_sufficiently_sharded(var_collect, mesh, print_info=False) + optimizer = optax.adamw(args.lr) var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) state = train_state.TrainState.create( apply_fn=encoder.apply, params=params, tx=optimizer ) - state_sharding = get_state_sharding(state, params_sharding) + + abs_state = jax.eval_shape( + lambda: train_state.TrainState.create( + apply_fn=encoder.apply, params=params, tx=optimizer + ) + ) + logical_state_partition_spec = nn.get_partition_spec(abs_state) + state_sharding = nn.logical_to_mesh_sharding( + logical_state_partition_spec, mesh, te_extended_axis_rules + ) + + # Check if params are sufficiently sharded after jitting the state creation + assert_params_sufficiently_sharded(state.params, mesh, print_info=False) + + # state_sharding = get_state_sharding(state, params_sharding) labels_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS)) in_shardings = ( @@ -344,11 +350,15 @@ def train_and_evaluate(args): None, ) out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit(train_step, in_shardings, out_shardings) + jit_train_step = jax.jit( + train_step, in_shardings=in_shardings, out_shardings=out_shardings + ) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) + jit_eval_step = jax.jit( + eval_step, in_shardings=in_shardings, out_shardings=out_shardings + ) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -459,14 +469,14 @@ class TestEncoder(unittest.TestCase): is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) def setUp(self): - """Run 3 epochs for testing""" - self.args = encoder_parser(["--epochs", "3"]) + """Run 5 epochs for testing""" + self.args = encoder_parser(["--epochs", "5"]) @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.39 and actual[1] > 0.83 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -474,7 +484,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.39 and actual[1] > 0.83 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -482,14 +492,14 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.39 and actual[1] > 0.83 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_with_sp(self): """Test Transformer Engine with BF16 + SP""" self.args.enable_sp = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.39 and actual[1] > 0.83 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_with_sp(self): @@ -498,7 +508,7 @@ def test_te_delayed_scaling_fp8_with_sp(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.39 and actual[1] > 0.83 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8_with_sp(self): @@ -507,14 +517,14 @@ def test_te_mxfp8_with_sp(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.39 and actual[1] > 0.83 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.39 and actual[1] > 0.83 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_shardy(self): @@ -523,7 +533,7 @@ def test_te_delayed_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.39 and actual[1] > 0.83 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_with_sp_shardy(self): @@ -533,9 +543,32 @@ def test_te_delayed_scaling_fp8_with_sp_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.39 and actual[1] > 0.83 + + @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) + @unittest.skipIf( + tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." + ) + def test_te_mxfp8_shardy(self): + """Test Transformer Engine with MXFP8""" + self.args.enable_shardy = True + self.args.use_fp8 = True + self.args.fp8_recipe = "MXFP8BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.39 and actual[1] > 0.83 - # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. + @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) + @unittest.skipIf( + tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." + ) + def test_te_mxfp8_with_sp_shardy(self): + """Test Transformer Engine with MXFP8 + SP""" + self.args.enable_shardy = True + self.args.enable_sp = True + self.args.use_fp8 = True + self.args.fp8_recipe = "MXFP8BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.39 and actual[1] > 0.83 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 3310c57ff..26740c025 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -23,6 +23,7 @@ from common import is_bf16_supported, get_fp8_recipe_from_name_string import transformer_engine.jax as te +import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax from transformer_engine.jax.quantize import is_fp8_available, ScalingMode @@ -260,7 +261,13 @@ def train_and_evaluate(args): fp8_recipe = None device_mesh = mesh_utils.create_device_mesh((num_gpu,)) - with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)) as mesh: + with jax.sharding.Mesh( + devices=device_mesh, axis_names=(DEVICE_DP_AXIS,) + ) as mesh, te.fp8_autocast( + enabled=args.use_fp8, + fp8_recipe=fp8_recipe, + mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None), + ): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) @@ -271,17 +278,14 @@ def train_and_evaluate(args): mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] label_shape = [args.batch_size] - with te.fp8_autocast( - enabled=args.use_fp8, - fp8_recipe=fp8_recipe, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None), - ): + # Add TE logical axis rules to our Flax logical axis rule context. This must be done inside fp8_autocast + sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + with flax.linen.logical_axis_rules(sharding_rules): encoder = Net(num_embed) inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) - sharding_rules = te_flax.extend_logical_axis_rules(tuple()) params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None)) masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None)) @@ -290,7 +294,9 @@ def train_and_evaluate(args): out_shardings = { key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect } - jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) + jit_encoder_init = jax.jit( + encoder.init, in_shardings=in_shardings, out_shardings=out_shardings + ) var_collect = jit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) @@ -314,11 +320,15 @@ def train_and_evaluate(args): None, ) out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit(train_step, in_shardings, out_shardings) + jit_train_step = jax.jit( + train_step, in_shardings=in_shardings, out_shardings=out_shardings + ) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) + jit_eval_step = jax.jit( + eval_step, in_shardings=in_shardings, out_shardings=out_shardings + ) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -426,14 +436,14 @@ class TestEncoder(unittest.TestCase): is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) def setUp(self): - """Run 3 epochs for testing""" - self.args = encoder_parser(["--epochs", "3"]) + """Run 5 epochs for testing""" + self.args = encoder_parser(["--epochs", "5"]) @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.52 and actual[1] > 0.74 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -441,7 +451,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.52 and actual[1] > 0.74 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8(self): @@ -449,7 +459,7 @@ def test_te_current_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.52 and actual[1] > 0.74 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -457,14 +467,14 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.52 and actual[1] > 0.74 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.52 and actual[1] > 0.74 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_shardy(self): @@ -473,9 +483,7 @@ def test_te_delayed_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 - - # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. + assert actual[0] < 0.52 and actual[1] > 0.74 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8_shardy(self): @@ -484,7 +492,19 @@ def test_te_current_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.52 and actual[1] > 0.74 + + @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) + @unittest.skipIf( + tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." + ) + def test_te_mxfp8_shardy(self): + """Test Transformer Engine with MXFP8""" + self.args.enable_shardy = True + self.args.use_fp8 = True + self.args.fp8_recipe = "MXFP8BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.52 and actual[1] > 0.74 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 1bb32b7d6..e8a14a146 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -30,8 +30,8 @@ get_fp8_recipe_from_name_string, ) import transformer_engine.jax as te +import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -381,8 +381,11 @@ def train_and_evaluate(args): device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) with jax.sharding.Mesh( devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) - ) as mesh: - + ) as mesh, te.fp8_autocast( + enabled=args.use_fp8, + fp8_recipe=fp8_recipe, + mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), + ): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) @@ -392,18 +395,18 @@ def train_and_evaluate(args): mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] label_shape = [args.batch_size] - with te.fp8_autocast( - enabled=args.use_fp8, - fp8_recipe=fp8_recipe, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), - ): + # Create custom Flax logical axis rules for sharding. + customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) + # Extend the logical axis rules with TE's rules. This must be done inside fp8_autocast. + sharding_rules = te_flax.extend_logical_axis_rules(customized_rules) + + with flax.linen.logical_axis_rules(sharding_rules): + encoder = Net(num_embed) inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) - customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) - sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) @@ -414,7 +417,9 @@ def train_and_evaluate(args): out_shardings = { key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect } - jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) + jit_encoder_init = jax.jit( + encoder.init, in_shardings=in_shardings, out_shardings=out_shardings + ) var_collect = jit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) @@ -434,11 +439,15 @@ def train_and_evaluate(args): None, ) out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit(train_step, in_shardings, out_shardings) + jit_train_step = jax.jit( + train_step, in_shardings=in_shardings, out_shardings=out_shardings + ) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) + jit_eval_step = jax.jit( + eval_step, in_shardings=in_shardings, out_shardings=out_shardings + ) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -580,8 +589,8 @@ class TestEncoder(unittest.TestCase): """Encoder unittests""" def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): - """Run 3 epochs for testing""" - args = encoder_parser([]) + """Run 5 epochs for testing""" + args = encoder_parser(["--epochs", "5"]) num_gpu = self.num_process tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1 @@ -603,7 +612,7 @@ def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): def test_te_bf16(self): """Test Transformer Engine with BF16""" result = self.exec(False, None) - assert result[0] < 0.505 and result[1] > 0.755 + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" @@ -611,7 +620,7 @@ def test_te_bf16(self): def test_te_delayed_scaling_fp8(self): """Test Transformer Engine with DelayedScaling FP8""" result = self.exec(True, "DelayedScaling") - assert result[0] < 0.505 and result[1] > 0.753 + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" @@ -619,7 +628,7 @@ def test_te_delayed_scaling_fp8(self): def test_te_current_scaling_fp8(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling") - assert result[0] < 0.507 and result[1] > 0.753 + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf( not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" @@ -627,13 +636,13 @@ def test_te_current_scaling_fp8(self): def test_te_mxfp8(self): """Test Transformer Engine with MXFP8""" result = self.exec(True, "MXFP8BlockScaling") - assert result[0] < 0.505 and result[1] > 0.754 + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" result = self.exec(False, None, enable_shardy=True) - assert result[0] < 0.505 and result[1] > 0.755 + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" @@ -641,9 +650,7 @@ def test_te_bf16_shardy(self): def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" result = self.exec(True, "DelayedScaling", enable_shardy=True) - assert result[0] < 0.505 and result[1] > 0.753 - - # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" @@ -651,7 +658,18 @@ def test_te_delayed_scaling_fp8_shardy(self): def test_te_current_scaling_fp8_shardy(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) - assert result[0] < 0.507 and result[1] > 0.753 + assert result[0] < 0.43 and result[1] > 0.80 + + @unittest.skipIf( + not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" + ) + @unittest.skipIf( + tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." + ) + def test_te_mxfp8_shardy(self): + """Test Transformer Engine with MXFP8""" + result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True) + assert result[0] < 0.43 and result[1] > 0.80 if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml new file mode 100755 index 000000000..ef112d279 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +[build-system] +requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax[cuda12]", "flax>=0.7.1"] + +# Use legacy backend to import local packages in setup.py +build-backend = "setuptools.build_meta:__legacy__" + diff --git a/qa/L0_cppunittest/test.sh b/qa/L0_cppunittest/test.sh index df8a48b66..cd46b0b63 100755 --- a/qa/L0_cppunittest/test.sh +++ b/qa/L0_cppunittest/test.sh @@ -6,7 +6,7 @@ set -e # Find TE : ${TE_PATH:=/opt/transformerengine} -TE_LIB_PATH=`pip3 show transformer-engine | grep Location | cut -d ' ' -f 2` +TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}') export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH # Set parallelization parameters diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index 92434c28e..d9c46347f 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -28,7 +28,7 @@ python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/py wait python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" wait -. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" +TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 6ffc5945a..3d00e0346 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -27,9 +27,6 @@ mkdir -p "$XML_LOG_DIR" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*" -# Test without custom calls -NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py" - pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" @@ -37,6 +34,9 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" +# Test without custom calls +export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" +NVTE_JAX_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh new file mode 100644 index 000000000..c94edba2b --- /dev/null +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -0,0 +1,27 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + + + +: ${TE_PATH:=/opt/transformerengine} +: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} +: ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/} + +# Config with the dummy feature which prevents nvinspect from being disabled. +# Nvinspect will be disabled if no feature is active. +: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} + +FAIL=0 + +pip install pytest==8.2.1 +pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 +NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 + +# standard sanity and numerics tests with initialized debug +NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 +NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 + +exit $FAIL diff --git a/qa/L0_pytorch_lint/test.sh b/qa/L0_pytorch_lint/test.sh index 81d7822d7..e2c50c445 100644 --- a/qa/L0_pytorch_lint/test.sh +++ b/qa/L0_pytorch_lint/test.sh @@ -20,5 +20,5 @@ if [ -z "${CPP_ONLY}" ] then cd $TE_PATH echo "Checking Python files" - python3 -m pylint --recursive=y transformer_engine/common transformer_engine/pytorch + python3 -m pylint --recursive=y transformer_engine/common transformer_engine/pytorch transformer_engine/debug fi diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 79f3c8fb9..7fe439b37 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -23,6 +23,8 @@ set -x mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" +pip3 install onnxruntime==1.20.1 || error_exit "Failed to install onnxruntime" +pip3 install onnxruntime_extensions==0.13.0 || error_exit "Failed to install onnxruntime_extensions" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" @@ -38,12 +40,16 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gem python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py || test_fail "test_onnx_export.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" +NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 4319e96c7..09ef661c4 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -20,6 +20,7 @@ FAILED_CASES="" : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" + pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" @@ -30,6 +31,19 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_use python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" + +# debug tests + + +# Config with the dummy feature which prevents nvinspect from being disabled. +# Nvinspect will be disabled if no feature is active. +: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} +: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} + +pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" +# standard numerics tests with initialized debug +NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" + if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" exit 1 diff --git a/qa/L2_jax_unittest/test.sh b/qa/L2_jax_unittest/test.sh index 9691f0e7c..c5c193351 100644 --- a/qa/L2_jax_unittest/test.sh +++ b/qa/L2_jax_unittest/test.sh @@ -25,18 +25,18 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" - -# Test without custom calls -NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py" +NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" +NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements" # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" +NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" +# Test without custom calls +export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" +NVTE_JAX_CUSTOM_CALLS_RE="" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 8a510becd..547849e95 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -18,10 +18,10 @@ sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); pri export FLASH_ATTN_CUDA_ARCHS=$sm_arch if [ $sm_arch -gt 90 ] then - FA_versions=(2.7.3) + FA_versions=(2.8.1) elif [ $sm_arch -eq 90 ] then - FA_versions=(2.5.7 2.7.3 3.0.0b1) + FA_versions=(2.7.3 2.8.1 3.0.0b1) fi for fa_version in "${FA_versions[@]}" diff --git a/setup.py b/setup.py index 91817d56e..d078b594f 100644 --- a/setup.py +++ b/setup.py @@ -20,11 +20,7 @@ from build_tools.utils import ( rocm_build, cuda_archs, - found_cmake, - found_ninja, - found_pybind11, get_frameworks, - install_and_import, remove_dups, ) @@ -39,7 +35,6 @@ if "pytorch" in frameworks: from torch.utils.cpp_extension import BuildExtension elif "jax" in frameworks: - install_and_import("pybind11[global]") from pybind11.setup_helpers import build_ext as BuildExtension @@ -91,6 +86,11 @@ def setup_common_extension() -> CMakeExtension: if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))): cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON") + # Add custom CMake arguments from environment variable + nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") + if nvte_cmake_extra_args: + cmake_flags.extend(nvte_cmake_extra_args.split()) + # Project directory root root_path = Path(__file__).resolve().parent @@ -101,25 +101,13 @@ def setup_common_extension() -> CMakeExtension: ) -def setup_requirements() -> Tuple[List[str], List[str], List[str]]: +def setup_requirements() -> Tuple[List[str], List[str]]: """Setup Python dependencies - Returns dependencies for build, runtime, and testing. + Returns dependencies for runtime and testing. """ # Common requirements - if rocm_build(): - setup_reqs: List[str] = [] - else: - setup_reqs: List[str] = [ - "nvidia-cuda-runtime-cu12", - "nvidia-cublas-cu12", - "nvidia-cudnn-cu12", - "nvidia-cuda-cccl-cu12", - "nvidia-cuda-nvcc-cu12", - "nvidia-nvtx-cu12", - "nvidia-cuda-nvrtc-cu12", - ] install_reqs: List[str] = [ "pydantic", "importlib-metadata>=1.0", @@ -127,42 +115,20 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ] test_reqs: List[str] = ["pytest>=8.2.1"] - # Requirements that may be installed outside of Python - if not found_cmake(): - setup_reqs.append("cmake>=3.21") - if not found_ninja(): - import sys - - subprocess.check_call([sys.executable, "-m", "pip", "install", "ninja"]) - setup_reqs.append("ninja") - if not found_pybind11(): - setup_reqs.append("pybind11") - # Framework-specific requirements if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: - if rocm_build(): - install_reqs.extend(["einops"]) - else: - setup_reqs.extend(["torch>=2.1"]) - install_reqs.extend(["torch>=2.1"]) - install_reqs.append( - "nvdlfw-inspect @" - " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect" - ) - # Blackwell is not supported as of Triton 3.2.0, need custom internal build - # install_reqs.append("triton") - test_reqs.extend(["numpy", "torchvision"]) + from build_tools.pytorch import install_requirements, test_requirements + + install_reqs.extend(install_requirements()) + test_reqs.extend(test_requirements()) if "jax" in frameworks: - if rocm_build(): - from build_tools.jax import jax_install_requires - install_reqs.extend(jax_install_requires(["flax>=0.7.1"])) - else: - setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"]) - install_reqs.extend(["jax", "flax>=0.7.1"]) - test_reqs.extend(["numpy"]) + from build_tools.jax import install_requirements, test_requirements + + install_reqs.extend(install_requirements()) + test_reqs.extend(test_requirements()) - return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] + return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]] if __name__ == "__main__": @@ -181,14 +147,13 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: cmdclass = {} package_data = {} include_package_data = False - setup_requires = [] install_requires = ([f"transformer_engine_{te_cuda_vers}=={__version__}"],) extras_require = { "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], } else: - setup_requires, install_requires, test_requires = setup_requirements() + install_requires, test_requires = setup_requirements() ext_modules = [setup_common_extension()] cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} package_data = {"": ["VERSION.txt"]} @@ -234,15 +199,8 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: long_description_content_type="text/x-rst", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, - python_requires=">=3.8, <3.13", - classifiers=[ - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - ], - setup_requires=setup_requires, + python_requires=">=3.8", + classifiers=["Programming Language :: Python :: 3"], install_requires=install_requires, license_files=("LICENSE",), include_package_data=include_package_data, diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 4ab5fd237..75c52fdd7 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -64,7 +64,7 @@ else() project(transformer_engine_tests LANGUAGES HIP CXX) # Ask hcc to generate device code during compilation so we can use # host linker to link. - set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -fno-gpu-rdc -Wno-defaulted-function-deleted -Wno-unused-result") + set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -fno-gpu-rdc -Wno-defaulted-function-deleted -Wno-unused-result -ftemplate-backtrace-limit=0") endif() add_subdirectory(../../3rdparty/googletest ${PROJECT_BINARY_DIR}/googletest) @@ -74,11 +74,13 @@ enable_testing() include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) if(NOT DEFINED TE_LIB_PATH) - execute_process(COMMAND bash -c "pip3 show transformer-engine | grep Location | cut -d ' ' -f 2 | tr -d '\n'" - OUTPUT_VARIABLE TE_LIB_PATH) + execute_process(COMMAND bash -c "python3 -c 'import transformer_engine as te; print(te.__file__)'" + OUTPUT_VARIABLE TE_LIB_FILE + OUTPUT_STRIP_TRAILING_WHITESPACE) + get_filename_component(TE_LIB_PATH ${TE_LIB_FILE} DIRECTORY) endif() -find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/transformer_engine" ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) +find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) message(STATUS "Found transformer_engine library: ${TE_LIB}") include_directories(../../transformer_engine/common/include) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 8a0c0177c..e3af4a360 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -12,7 +12,6 @@ list(APPEND test_cuda_sources test_qdq.cu test_cast_mxfp8.cu test_dequantize_mxfp8.cu - test_cast_transpose.cu test_transpose.cu test_cast_transpose.cu test_cast_transpose_current_scaling.cu @@ -26,6 +25,7 @@ list(APPEND test_cuda_sources test_normalization_mxfp8.cu test_multi_cast_transpose.cu test_multi_padding.cu + test_multi_unpadding.cu test_causal_softmax.cu test_swizzle.cu ../test_common.cu) diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index 4acbac4fb..96663e752 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -422,7 +422,7 @@ std::vector> matrix_sizes = { {256, 256}, {993, 512}, {768, 1024}, - {65536, 128}, + {65504, 128}, {16384, 1632}, }; diff --git a/tests/cpp/operator/test_multi_unpadding.cu b/tests/cpp/operator/test_multi_unpadding.cu new file mode 100644 index 000000000..ca685b962 --- /dev/null +++ b/tests/cpp/operator/test_multi_unpadding.cu @@ -0,0 +1,186 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_unpadding_ref(const std::vector>& input_list, + std::vector>& output_list, + const std::vector& height_list, + const std::vector& width_list, + const std::vector& padded_height_list) { + using compute_t = float; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + const auto& input = input_list[tensor_id]; + auto& output = output_list[tensor_id]; + const size_t height = height_list[tensor_id]; + const size_t width = width_list[tensor_id]; + const size_t padded_height = padded_height_list[tensor_id]; + + // Only copy the valid (unpadded) portion + for (size_t i = 0; i < height; ++i) { + for (size_t j = 0; j < width; ++j) { + const compute_t x = static_cast(input[i * width + j]); + const OutputType y = static_cast(x); + output[i * width + j] = y; + } + } + } +} + +template +void performUnpaddingTest() { + using namespace test; + + const DType itype = TypeInfo::dtype; + const DType otype = TypeInfo::dtype; + const std::vector> tensor_dims = {{1,1}, + {1,768}, + {768,1}, + {768,768}, + {43,43}, + {43,256}, + {256,43}, + {256,256}}; + const size_t num_tensors = tensor_dims.size(); + constexpr int align = 16; + + // Buffers for Transformer Engine implementation + std::vector padded_input_list, unpadded_output_list; + + // Buffers for reference implementation + std::vector> ref_padded_input_list; + std::vector> ref_unpadded_output_list; + std::vector ref_height_list(num_tensors), ref_width_list(num_tensors); + std::vector ref_padded_height_list(num_tensors); + + // Initialize buffers + for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { + const size_t original_height = tensor_dims[tensor_id].first; + const size_t width = tensor_dims[tensor_id].second; + const size_t padded_height = (original_height + align - 1) / align * align; + + // Input is padded tensor (padded_height x width) + padded_input_list.emplace_back( + Tensor("padded_input_" + std::to_string(tensor_id), + std::vector{padded_height, width}, itype)); + + // Output is unpadded tensor (original_height x width) + unpadded_output_list.emplace_back( + Tensor("unpadded_output_" + std::to_string(tensor_id), + std::vector{original_height, width}, otype)); + + auto& padded_input = padded_input_list.back(); + auto& unpadded_output = unpadded_output_list.back(); + + // Fill padded input with random data (including padding area) + fillUniform(&padded_input); + setRandomScale(&unpadded_output); + + // Initialize reference buffers + ref_padded_input_list.emplace_back(padded_height * width); + ref_unpadded_output_list.emplace_back(original_height * width); + + // Copy data to reference buffers + std::copy(padded_input.rowwise_cpu_dptr(), + padded_input.rowwise_cpu_dptr() + padded_height * width, + ref_padded_input_list.back().begin()); + + ref_height_list[tensor_id] = original_height; + ref_width_list[tensor_id] = width; + ref_padded_height_list[tensor_id] = padded_height; + } + + // Transformer Engine implementation + auto make_nvte_vector = [](std::vector& tensor_list) + -> std::vector { + std::vector nvte_tensor_list; + for (auto& tensor : tensor_list) { + nvte_tensor_list.emplace_back(tensor.data()); + } + return nvte_tensor_list; + }; + + // Convert height_list to int for the API + std::vector original_height_list_int(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + original_height_list_int[i] = static_cast(ref_height_list[i]); + } + + // Call unpadding API + nvte_multi_unpadding(num_tensors, + make_nvte_vector(padded_input_list).data(), + make_nvte_vector(unpadded_output_list).data(), + original_height_list_int.data(), + 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + // Reference implementation + compute_unpadding_ref(ref_padded_input_list, + ref_unpadded_output_list, + ref_height_list, + ref_width_list, + ref_padded_height_list); + + // Check correctness + for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { + auto [atol, rtol] = getTolerances(otype); + compareResults("unpadded_output", + unpadded_output_list[tensor_id], + ref_unpadded_output_list[tensor_id].data(), + true, + atol, rtol); + } +} + +} // namespace + +class MultiUnpaddingTestSuite + : public ::testing::TestWithParam {}; + +TEST_P(MultiUnpaddingTestSuite, TestMultiUnpadding) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = GetParam(); + const DType output_type = input_type; + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performUnpaddingTest(); + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + MultiUnpaddingTestSuite, + ::testing::ValuesIn(test::all_fp_types), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(info.param); + return name; + }); diff --git a/tests/cpp/operator/test_normalization.h b/tests/cpp/operator/test_normalization.h index 5f5603a7f..39ff358cb 100644 --- a/tests/cpp/operator/test_normalization.h +++ b/tests/cpp/operator/test_normalization.h @@ -73,7 +73,12 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const // Remove the use_cudnn check here when it is supported by both backends. const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype; +#ifdef __HIP_PLATFORM_AMD__ if constexpr (std::is_same_v || std::is_same_v){ +#else + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v){ +#endif compute_t g = static_cast(gamma); if (zero_centered_gamma) { g += static_cast(1.f); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 71ac2ce66..a608f6ef2 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -47,7 +47,7 @@ bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) { return true; } -size_t typeToSize(DType type) { +size_t typeToNumBits(DType type) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, { return TypeInfo::size; @@ -64,7 +64,8 @@ const std::string &typeName(DType type) { {DType::kBFloat16, "bfloat16"}, {DType::kFloat8E4M3, "float8e4m3"}, {DType::kFloat8E5M2, "float8e5m2"}, - {DType::kFloat8E8M0, "float8e8m0"}}; + {DType::kFloat8E8M0, "float8e8m0"}, + {DType::kFloat4E2M1, "float4e2m1"}}; return name_map.at(type); } @@ -111,9 +112,16 @@ size_t DIVUP(const size_t &x, const size_t &y){ struct scale_inv_meta { std::vector shape; DType type; - size_t type_size; + size_t type_size_bits; + size_t bytes() const noexcept { + return (product(shape) * type_size_bits) / 8; + } }; +size_t bytes(const NVTEShape& shape, const DType type) { + return (product(shape) * typeToNumBits(type)) / 8; +} + NVTEShape convertShape(const std::vector& s) { return nvte_make_shape(s.data(), s.size()); } @@ -124,7 +132,7 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret; ret.shape = {1}; ret.type = DType::kFloat32; - ret.type_size = sizeof(float); + ret.type_size_bits = typeToNumBits(DType::kFloat32); return {ret, ret}; } if (scaling_mode == NVTE_MXFP8_1D_SCALING) { @@ -154,8 +162,8 @@ std::pair get_scales(const NVTEShape& shape, } ret_rowwise.type = DType::kFloat8E8M0; ret_colwise.type = DType::kFloat8E8M0; - ret_rowwise.type_size = sizeof(uint8_t); - ret_colwise.type_size = sizeof(uint8_t); + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); return {ret_rowwise, ret_colwise}; } @@ -181,8 +189,8 @@ std::pair get_scales(const NVTEShape& shape, } ret_rowwise.type = DType::kFloat32; ret_colwise.type = DType::kFloat32; - ret_rowwise.type_size = sizeof(float); - ret_colwise.type_size = sizeof(float); + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32); + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32); return {ret_rowwise, ret_colwise}; } @@ -207,8 +215,8 @@ std::pair get_scales(const NVTEShape& shape, } ret_rowwise.type = DType::kFloat32; ret_colwise.type = DType::kFloat32; - ret_rowwise.type_size = sizeof(float); - ret_colwise.type_size = sizeof(float); + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32); + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32); return {ret_rowwise, ret_colwise}; } @@ -224,8 +232,7 @@ Tensor::Tensor(const std::string& name, gen_.seed(seed); rowwise_ = rowwise; columnwise_ = columnwise; - size_t s = typeToSize(type); - size_t total_size = product(shape) * s; + size_t total_size = bytes(shape, type); void *dptr_rowwise = nullptr; void *dptr_columnwise = nullptr; cpu_data_rowwise_ = nullptr; @@ -307,8 +314,8 @@ Tensor::Tensor(const std::string& name, } else { auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode()); - auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; - auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; + auto rowwise_scale_size = rowwise_scale_meta.bytes(); + auto columnwise_scale_size = colwise_scale_meta.bytes(); auto scale_shape = rowwise_scale_meta.shape; auto columnwise_scale_shape = colwise_scale_meta.shape; if (rowwise) { @@ -333,7 +340,7 @@ Tensor::Tensor(const std::string& name, void Tensor::to_cpu() const { const NVTEShape s = tensor_.shape(); - const size_t size = product(s) * typeToSize(tensor_.dtype()); + const size_t size = bytes(s, tensor_.dtype()); if (rowwise_) { (void)cudaMemcpy(cpu_data_rowwise_.get(), tensor_.get_rowwise_data().data_ptr, @@ -362,14 +369,14 @@ void Tensor::to_cpu() const { auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); if (rowwise_) { - auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; + auto scale_size = rowwise_scale_meta.bytes(); (void)cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), tensor_.get_rowwise_scale_inv().data_ptr, scale_size, cudaMemcpyDeviceToHost); } if (columnwise_) { - auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; + auto scale_size = colwise_scale_meta.bytes(); (void)cudaMemcpy(columnwise_scale_inv_cpu_data_.get(), tensor_.get_columnwise_scale_inv().data_ptr, scale_size, @@ -380,34 +387,32 @@ void Tensor::to_cpu() const { void Tensor::from_cpu() const { const NVTEShape s = tensor_.shape(); - const size_t size = product(s) * typeToSize(tensor_.dtype()); + const size_t size = bytes(s, tensor_.dtype()); if (rowwise_) { - (void)cudaMemcpy(tensor_.get_rowwise_data().data_ptr, - cpu_data_rowwise_.get(), size, cudaMemcpyHostToDevice); + (void)cudaMemcpy(tensor_.get_rowwise_data().data_ptr, cpu_data_rowwise_.get(), size, + cudaMemcpyHostToDevice); } if (columnwise_) { - (void)cudaMemcpy(tensor_.get_columnwise_data().data_ptr, - cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice); + (void)cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size, + cudaMemcpyHostToDevice); } if (isFp8Type(dtype())) { if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { if (tensor_.amax() != nullptr){ - (void)cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), - cudaMemcpyHostToDevice); + (void)cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } - (void)cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), - cudaMemcpyHostToDevice); + (void)cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); if (rowwise_) { - auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; + auto scale_size = rowwise_scale_meta.bytes(); (void)cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, rowwise_scale_inv_cpu_data_.get(), scale_size, cudaMemcpyHostToDevice); } if (columnwise_) { - auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; + auto scale_size = colwise_scale_meta.bytes(); (void)cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr, columnwise_scale_inv_cpu_data_.get(), scale_size, cudaMemcpyHostToDevice); @@ -783,7 +788,21 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { for (int i = 0; i < size; i++) { data[i] = static_cast(dis(*gen)); } + gen->discard(size); #else + // Check how many RNG calls are required to generate one uniform random value + int rng_calls_per_val = 0; + { + std::mt19937 gen1 = *gen, gen2 = *gen; + std::uniform_real_distribution<> dis(-2.0, 1.0); + const float _ = dis(gen1); + while (gen2 != gen1) { + auto _ = gen2(); + ++rng_calls_per_val; + } + } + + // Generate uniform random values in parallel #pragma omp parallel proc_bind(spread) { std::mt19937 gen_local = *gen; @@ -792,15 +811,15 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { const int chunk_size = (size + threads_num - 1) / threads_num; const int idx_min = chunk_size * thread_ID; const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast(size)); - gen_local.discard(idx_min); + gen_local.discard(idx_min * rng_calls_per_val); std::uniform_real_distribution<> dis(-2.0, 1.0); for (int i = idx_min; i < idx_max; ++i) { data[i] = static_cast(dis(gen_local)); } } + gen->discard(size * rng_calls_per_val); #endif - gen->discard(size); } void fillUniform(Tensor *t) { diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 6b9514d38..a7290a535 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -14,12 +14,16 @@ #include #ifndef USE_ROCM +#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080) #include #include +#if FP4_TYPE_SUPPORTED +#include +#endif #else +#define FP4_TYPE_SUPPORTED (false) #include #include "amd_detail/hip_float8.h" -#include #endif #include @@ -68,19 +72,32 @@ using fp8e4m3 = te_hip_fp8_e4m3; using fp8e5m2 = te_hip_fp8_e5m2; #endif //USE_ROCM using fp8e8m0 = uint8_t; +#if FP4_TYPE_SUPPORTED +using fp4e2m1 = __nv_fp4_e2m1; +#endif template -struct TypeInfo{ - using types = std::tuple; +struct BitsNumber; + +#if FP4_TYPE_SUPPORTED +template <> +struct BitsNumber { + static constexpr size_t num_bits = 4; +}; +#endif + +template +struct BitsNumber { + static constexpr size_t num_bits = 8 * sizeof(T); +}; + +template +struct TypeInfo { +#if FP4_TYPE_SUPPORTED + using types = std::tuple; +#else + using types = std::tuple; +#endif template struct Helper { @@ -107,7 +124,7 @@ struct TypeInfo{ } constexpr static DType dtype = getType(); - constexpr static size_t size = sizeof(T); + constexpr static size_t size = BitsNumber::num_bits;; }; class Tensor { @@ -443,9 +460,10 @@ inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); } inline float srelu(const float x) { return x > 0 ? x * x : 0; } inline float dsrelu(const float x) { return fmaxf(0, 2 * x); } -size_t typeToSize(DType type); +size_t typeToNumBits(DType type); size_t product(const NVTEShape &shape); size_t product(const std::vector &shape); +size_t bytes(const NVTEShape& shape, const DType type); size_t first_dimension(const std::vector &shape); size_t last_dimension(const std::vector &shape); @@ -499,6 +517,16 @@ constexpr int32_t blackwellComputeCapability = 100; } // namespace test +#if FP4_TYPE_SUPPORTED +#define SWITCH_FP4_TYPE_HANDLE(type, ...) \ + case DType::kFloat4E2M1: { \ + using type = fp4e2m1; \ + { __VA_ARGS__ } \ + } break; +#else +#define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing +#endif + #define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ @@ -550,8 +578,16 @@ constexpr int32_t blackwellComputeCapability = 100; {__VA_ARGS__} \ } \ break; \ + case DType::kFloat8E8M0: \ + { \ + using type = fp8e8m0; \ + {__VA_ARGS__} \ + } \ + break; \ + SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ default: \ - NVTE_ERROR("Invalid type."); \ + printf("dtype: %d\n", static_cast(dtype)); \ + NVTE_ERROR("Invalid type MARKED TEST."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \ @@ -570,7 +606,15 @@ constexpr int32_t blackwellComputeCapability = 100; } \ break; \ default: \ - NVTE_ERROR("Invalid type."); \ + NVTE_ERROR("Invalid type MARKED TEST 2."); \ + } + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + SWITCH_FP4_HANDLE(type, __VA_ARGS__) \ + default: \ + NVTE_ERROR("Invalid type MARKED TEST 3."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \ @@ -595,5 +639,5 @@ constexpr int32_t blackwellComputeCapability = 100; } \ break; \ default: \ - NVTE_ERROR("Invalid type."); \ + NVTE_ERROR("Invalid type MARKED TEST 4."); \ } diff --git a/tests/cpp/util/test_string.cpp b/tests/cpp/util/test_string.cpp index a2e8bc141..6a9fe0d9a 100644 --- a/tests/cpp/util/test_string.cpp +++ b/tests/cpp/util/test_string.cpp @@ -38,7 +38,7 @@ TEST(UtilTest, ToStringLike) { // to_string_like // Non-zero integer types EXPECT_EQ(to_string_like(static_cast(1)), "1"); - EXPECT_EQ(to_string_like(static_cast(-1)), "-1"); + EXPECT_EQ(to_string_like(static_cast(-1)), "-1"); EXPECT_EQ(to_string_like(static_cast(2)), "2"); EXPECT_EQ(to_string_like(static_cast(3)), "3"); EXPECT_EQ(to_string_like(static_cast(-5)), "-5"); diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index d0424de5a..20a8037eb 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -6,16 +6,16 @@ import jax import jax.numpy as jnp -import numpy as np import pytest from jax import jit, value_and_grad from functools import reduce +from typing import Union import operator from utils import ( assert_allclose, - assert_tree_like_allclose, pytest_parametrize_wrapper, + use_jax_gemm, ) from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm_mlp import layernorm_mlp @@ -34,17 +34,19 @@ from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax.cpp_extensions.misc import is_hip_extension from transformer_engine.jax.quantize import ( - DelayedScaleQuantizer, ScaledTensor, + ScaledTensor1x, + ScaledTensor2x, + GroupedScaledTensor1x, ScalingMode, QuantizerFactory, QuantizeLayout, + noop_quantizer_set, ) from transformer_engine.jax.quantize import helper from transformer_engine.jax.activation import activation -from transformer_engine.jax.dense import dense +from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense -from transformer_engine.jax.quantize import ScaledTensor1x, ScaledTensor2x from transformer_engine.jax.util import get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type GEMM_CASES = [ @@ -61,8 +63,8 @@ FP8_COMPUTE_TYPE = [jnp_float8_e4m3_type, jnp_float8_e5m2_type] LN_CASES = [(256, 128), (128, 256)] DTYPES = [jnp.bfloat16, jnp.float32] -is_fp8_supported, reason = helper.is_fp8_available() -is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) +is_fp8_supported, fp8_unsupported_reason = helper.is_fp8_available() +is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) supported_scaling_modes = [] """ Find supported scaling modes""" @@ -115,12 +117,44 @@ def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray): else: assert_allclose(a.dequantize(), b, dtype=a.data.dtype) elif isinstance(a, ScaledTensor2x): - assert_dequantized_scaled_tensor(a.get_rowwise_tensor(), b) - assert_dequantized_scaled_tensor(a.get_colwise_tensor(), b) + assert_dequantized_scaled_tensor(a.rowwise_tensor, b) + assert_dequantized_scaled_tensor(a.colwise_tensor, b) else: pytest.fail("a must be a ScaledTensor object") +def assert_dequantized_grouped_scaled_tensor( + a: Union[GroupedScaledTensor1x, ScaledTensor2x], b: jnp.ndarray +): + if isinstance(a, GroupedScaledTensor1x): + assert a.group_sizes.sum() == b.shape[0] + b = jnp.split(b, jnp.cumulative_sum(a.group_sizes)[:-1], axis=0) + dq_a = a.dequantize() + for dq_a_i, b_i in zip(dq_a, b): + if len(dq_a_i) == 0: + continue + if a.data_layout == "T": + data_ndim = len(a.original_shape) + flatten_axis = a.flatten_axis + if b_i.shape[0] == 1: + b_i = jnp.transpose( + b_i, (0, *range(flatten_axis, data_ndim), *range(1, flatten_axis)) + ) + else: + b_i = jnp.transpose( + b_i, (*range(flatten_axis, data_ndim), *range(flatten_axis)) + ) + dq_a_i = dq_a_i.reshape(b_i.shape) + assert_allclose(dq_a_i, b_i, dtype=a.data.dtype) + elif isinstance(a, ScaledTensor2x): + assert isinstance(a.rowwise_tensor, GroupedScaledTensor1x) + assert isinstance(a.colwise_tensor, GroupedScaledTensor1x) + assert_dequantized_grouped_scaled_tensor(a.rowwise_tensor, b) + assert_dequantized_grouped_scaled_tensor(a.colwise_tensor, b) + else: + pytest.fail("a must be a GroupedScaledTensor object") + + ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)] ALL_ACTIVATION_TYPES = [ ("gelu",), @@ -181,7 +215,7 @@ def test_act_grad(self, shape, activation_type): assert_allclose(prim_out, ref_out, dtype=x.dtype) assert_allclose(prim_grad, ref_grad, dtype=x.dtype) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("output_type", FP8_COMPUTE_TYPE) @@ -212,7 +246,7 @@ def test_act_grad_with_tensor_scaling_fp8( assert_allclose(prim_out, ref_out, dtype=output_type) assert_allclose(prim_grad, ref_grad, dtype=output_type) - @pytest.mark.skipif(not is_mxfp8_supported, reason=reason) + @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("output_type", FP8_COMPUTE_TYPE) @@ -242,7 +276,7 @@ def test_act_forward_with_tensor_scaling_fp8( assert_bitwise_scaled_tensors(te_output, jax_output) - @pytest.mark.skipif(not is_mxfp8_supported, reason=reason) + @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)]) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("output_type", FP8_COMPUTE_TYPE) @@ -363,7 +397,7 @@ def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) # No Norm FWD E5M2 in TE backend @pytest_parametrize_wrapper("out_dtype", [jnp_float8_e4m3_type]) @pytest_parametrize_wrapper( @@ -478,7 +512,7 @@ def _test_norm_forward( if norm_type == "layernorm": assert_allclose(mu, ref_mu, dtype=inp_dtype) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) # No Norm FWD E5M2 in TE backend @pytest_parametrize_wrapper("out_dtype", [jnp_float8_e4m3_type]) @pytest_parametrize_wrapper( @@ -514,7 +548,7 @@ def test_norm_forward_with_tensor_scaling_fp8( q_layout=q_layout, ) - @pytest.mark.skipif(not is_mxfp8_supported, reason=reason) + @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest.mark.parametrize("out_dtype", FP8_COMPUTE_TYPE) def test_norm_forward_with_block_scaling_fp8( self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype @@ -540,7 +574,7 @@ def test_norm_forward_with_block_scaling_fp8( ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [ ((32, 64), -1), ((2, 64, 32), -1), - ((2, 64, 32), -2), + ((64, 2, 32), -2), ((32, 256, 128), -1), ((32, 256, 128), -2), ((64, 32, 32, 256), -1), @@ -552,7 +586,7 @@ def test_norm_forward_with_block_scaling_fp8( "L0": [ ((32, 64), -1), ((2, 64, 32), -1), - ((2, 64, 32), -2), + ((64, 2, 32), -2), ], "L2": ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES, } @@ -563,7 +597,7 @@ def test_norm_forward_with_block_scaling_fp8( } -@pytest.mark.skipif(not is_fp8_supported, reason=reason) +@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("q_dtype", FP8_COMPUTE_TYPE) @pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) @@ -585,9 +619,6 @@ def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatt q_dtype=q_dtype, q_layout=q_layout, ) - # Adding dimension to test if padding is done correctly when flatten 3D to 2D - if flatten_axis == -2: - input_shape = input_shape[:-1] + (2,) + input_shape[-1:] n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): @@ -601,8 +632,6 @@ def test_quantize_bitwise( ): key = jax.random.PRNGKey(0) - if flatten_axis == -2: - input_shape = input_shape[:-1] + (2,) + input_shape[-1:] input = jax.random.uniform(key, input_shape, in_dtype) te_quantizer, jax_quantizer = QuantizerFactory.create( @@ -615,10 +644,65 @@ def test_quantize_bitwise( assert_bitwise_scaled_tensors(te_output, jax_output) +@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) +@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) +@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)]) +@pytest_parametrize_wrapper("q_dtype", [jnp_float8_e4m3_type]) +@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) +@pytest_parametrize_wrapper("flatten_axis", [-1]) +@pytest_parametrize_wrapper("with_group_sizes", [True, False]) +@pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE] +) +class TestGroupedQuantize: + def test_grouped_qdq( + self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes + ): + n_groups, m, n = input_shape + key = jax.random.PRNGKey(0) + subkeys = jax.random.split(key, 2) + + # *32 so that the input shapes works for MXFP8 + input_shape = (m * 32, n) + + if with_group_sizes: + group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) + group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) + group_sizes = jnp.diff(group_sizes) + assert group_sizes.sum() == m + assert jnp.any(group_sizes == 0) # make sure that at least one group has 0 row + group_sizes = group_sizes * 32 + else: + group_sizes = None + input_shape = (n_groups, input_shape[0] // n_groups, input_shape[1]) + + if flatten_axis == -2: + input_shape = input_shape[:-1] + (2,) + input_shape[-1:] + + x = jax.random.uniform(subkeys[1], input_shape, in_dtype) + + grouped_quantizer = QuantizerFactory.create( + scaling_mode=scaling_mode, + q_dtype=q_dtype, + q_layout=q_layout, + n_groups=n_groups, + ) + + # grouped_quantize does not work with cudaGraph yet, so the jitting will breaks + # To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to + # disable cudaGraph, then use the following jitted function + + scaled_tensor = tex.grouped_quantize( + x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer + ) + + assert_dequantized_grouped_scaled_tensor(scaled_tensor, x) + + @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) class TestFusedQuantize: - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @@ -633,12 +717,6 @@ def test_quantize_dbias( ): pytest.skip(f"Input shape {input_shape} is not supported by MXFP8") - if (flatten_axis < 0 and flatten_axis + len(input_shape) <= 0) or flatten_axis <= 0: - pytest.skip( - f"Flatten axis {flatten_axis} is not supported for input shape {input_shape}. There" - " must be at least one axis on either side of the flatten_axis split." - ) - key = jax.random.PRNGKey(0) input = jax.random.uniform(key, input_shape, in_dtype) @@ -725,7 +803,7 @@ def test_quantize_dact_dbias_no_quantization( q_layout=QuantizeLayout.ROWWISE, ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @@ -749,7 +827,7 @@ def test_quantize_dact_dbias_tensor_scaling( q_layout=q_layout, ) - @pytest.mark.skipif(not is_mxfp8_supported, reason=reason) + @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper( "input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)] @@ -781,6 +859,22 @@ def test_quantize_dact_dbias_mxfp8_scaling( ) +valid_fp8_gemm_operand_types = [ + (jnp_float8_e4m3_type, jnp_float8_e4m3_type), + (jnp_float8_e5m2_type, jnp_float8_e4m3_type), + (jnp_float8_e4m3_type, jnp_float8_e5m2_type), +] + + +def _use_jax_fp8_gemm(enabled=False): + import os + + if enabled: + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" + elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: + os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") + + class TestDense: def _ref_gemm_with_jnp_dot(self, a, b, data_layout): if data_layout[0] == "T": @@ -813,27 +907,47 @@ def _generate_gemm_input(self, m, n, k, data_layout): def test_gemm_bf16(self, m, n, k, data_layout): x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) - primitive_out = tex.gemm(x, w, contracting_dims) + primitive_out = tex.gemm(x, w, contracting_dims=contracting_dims) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - @pytest_parametrize_wrapper("q_dtype", FP8_COMPUTE_TYPE) + @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) - def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout): + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm): + if ( + not with_jax_gemm + and scaling_mode.is_1d_block_scaling() + and jnp_float8_e5m2_type in (x_qtype, w_qtype) + ): + pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.") + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False - ) - primitive_out = tex.gemm( - x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set + scaling_mode=scaling_mode, + fwd_dtype=jnp_float8_e4m3_type, + bwd_dtype=jnp_float8_e5m2_type, + is_2x2x=False, ) + with use_jax_gemm(enabled=with_jax_gemm): + primitive_out = tex.gemm( + x, + w, + contracting_dims=contracting_dims, + lhs_quantizer=( + quantizer_set.x if x_qtype == jnp_float8_e4m3_type else quantizer_set.dgrad + ), + rhs_quantizer=( + quantizer_set.kernel if w_qtype == jnp_float8_e4m3_type else quantizer_set.dgrad + ), + ) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) - assert_allclose(primitive_out, ref_out, dtype=q_dtype) + assert_allclose(primitive_out, ref_out, dtype=jnp_float8_e4m3_type) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) def test_dense_grad_bf16(self, m, n, k): @@ -860,11 +974,11 @@ def ref_func(x, w, data_layout): assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16) assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - @pytest_parametrize_wrapper("q_dtype", FP8_COMPUTE_TYPE) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode): + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm): data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) @@ -886,23 +1000,27 @@ def ref_func(x, w, bias, data_layout): value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True + scaling_mode=scaling_mode, + fwd_dtype=jnp_float8_e4m3_type, + bwd_dtype=jnp_float8_e5m2_type if scaling_mode.is_tensor_scaling() else jnp_float8_e4m3_type, + is_2x2x=True, ) n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 - for _ in range(n_iterations): - primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( - value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) - ) + with use_jax_gemm(enabled=with_jax_gemm): + for _ in range(n_iterations): + primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( + value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) + ) ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func( x, w, bias, data_layout ) - assert_allclose(primitive_out, ref_out, dtype=q_dtype) - assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype) - assert_allclose(primitive_w_grad, ref_w_grad, dtype=q_dtype) - assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=q_dtype) + assert_allclose(primitive_out, ref_out, dtype=jnp_float8_e4m3_type) + assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp_float8_e5m2_type) + assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp_float8_e5m2_type) + assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp_float8_e5m2_type) @pytest.fixture(name="random_inputs") @@ -927,22 +1045,15 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan class TestFusedDense: - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) - @pytest.mark.parametrize("q_dtype", FP8_COMPUTE_TYPE) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) - def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_gemm): """ Test layernorm_dense VJP Rule """ - # No Norm FWD E5M2 in TE backend - if q_dtype == jnp_float8_e5m2_type and scaling_mode in ( - ScalingMode.DELAYED_TENSOR_SCALING, - ScalingMode.CURRENT_TENSOR_SCALING, - ): - pytest.skip("E5M2 is not supported in normalization with TE Backend!") - # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False eps = 1e-6 @@ -958,8 +1069,8 @@ def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): quantizer_set = QuantizerFactory.create_set( scaling_mode=scaling_mode, - fwd_dtype=q_dtype, - bwd_dtype=q_dtype, + fwd_dtype=jnp_float8_e4m3_type, + bwd_dtype=jnp_float8_e5m2_type if scaling_mode.is_tensor_scaling() else jnp_float8_e4m3_type, is_2x2x=True, ) @@ -997,41 +1108,35 @@ def ref_func(x, w, gamma, beta): ) n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 - for _ in range(n_iterations): - prim_out, ( - prim_x_grad, - prim_w_grad, - prim_gamma_grad, - prim_beta_grad, - ) = value_n_grad_prim_func(x, w, gamma, beta) - - assert_allclose(prim_out, ref_out, dtype=q_dtype) - assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) - assert_allclose(prim_w_grad, ref_w_grad, dtype=q_dtype) - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype) + with use_jax_gemm(enabled=with_jax_gemm): + for _ in range(n_iterations): + prim_out, ( + prim_x_grad, + prim_w_grad, + prim_gamma_grad, + prim_beta_grad, + ) = value_n_grad_prim_func(x, w, gamma, beta) + + assert_allclose(prim_out, ref_out, dtype=jnp_float8_e4m3_type) + assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp_float8_e5m2_type) + assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp_float8_e5m2_type) + assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp_float8_e5m2_type) if beta is not None: - assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype) + assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp_float8_e5m2_type) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) - @pytest.mark.parametrize("q_dtype", FP8_COMPUTE_TYPE) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) - @pytest.mark.parametrize("use_bias", [True, False]) + @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( - self, m, n, k, activation_type, q_dtype, scaling_mode, norm_type, use_bias + self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm ): """ Test layernorm_mlp VJP Rule """ - # No Norm FWD E5M2 in TE backend - if q_dtype == jnp_float8_e5m2_type and scaling_mode in ( - ScalingMode.DELAYED_TENSOR_SCALING, - ScalingMode.CURRENT_TENSOR_SCALING, - ): - pytest.skip("E5M2 is not supported in normalization with TE Backend!") - # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False eps = 1e-6 @@ -1056,8 +1161,8 @@ def test_layernorm_mlp_grad( quantizer_sets = QuantizerFactory.create_set( n_quantizer_sets=2, scaling_mode=scaling_mode, - fwd_dtype=q_dtype, - bwd_dtype=q_dtype, + fwd_dtype=jnp_float8_e4m3_type, + bwd_dtype=jnp_float8_e5m2_type if scaling_mode.is_tensor_scaling() else jnp_float8_e4m3_type, is_2x2x=True, ) @@ -1086,14 +1191,13 @@ def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ln_out = _ref_jax_norm_impl( x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None ) - # TODO: replace gemm with jnp.dot - linear_1_out = tex.gemm(ln_out, kernel_1, ((1,), (0,))) + linear_1_out = jax.lax.dot_general(ln_out, kernel_1, (((1,), (0,)), ((), ()))) if use_bias: bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape linear_1_out += jnp.reshape(bias_1, bias_1_shape) x = _jax_act_lu(linear_1_out, activation_type) - linear_2_out = tex.gemm(x, kernel_2, ((1,), (0,))) + linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ()))) if use_bias: bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape linear_2_out += jnp.reshape(bias_2, bias_2_shape) @@ -1107,15 +1211,16 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): value_n_grad_ref_func = value_and_grad(ref_func, range(6)) n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 - for _ in range(n_iterations): - prim_out, ( - prim_x_grad, - prim_gamma_grad, - prim_kernel_1_grad, - prim_kernel_2_grad, - prim_bias_1_grad, - prim_bias_2_grad, - ) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) + with use_jax_gemm(enabled=with_jax_gemm): + for _ in range(n_iterations): + prim_out, ( + prim_x_grad, + prim_gamma_grad, + prim_kernel_1_grad, + prim_kernel_2_grad, + prim_bias_1_grad, + prim_bias_2_grad, + ) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) ref_out, ( ref_x_grad, @@ -1126,36 +1231,18 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ref_bias_2_grad, ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) - assert_allclose(prim_out, ref_out, dtype=q_dtype) + assert_allclose(prim_out, ref_out, dtype=jnp_float8_e4m3_type) - assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=q_dtype) + assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp_float8_e5m2_type) if use_bias: - assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=q_dtype) + assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp_float8_e5m2_type) - assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=q_dtype) + assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp_float8_e5m2_type) if use_bias: - assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=q_dtype) - - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype) - assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) - + assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp_float8_e5m2_type) -# This function is modified from transformer_engine/jax/cpp_extensions/gemm.py::_jax_gemm() -def _quantize_gemm_pair(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer): - ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims - lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1 - rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1 - lhs_q = lhs_quantizer.quantize( - lhs, - is_rowwise=lhs_is_rowwise, - is_colwise=not lhs_is_rowwise, - ) - rhs_q = rhs_quantizer.quantize( - rhs, - is_rowwise=rhs_is_rowwise, - is_colwise=not rhs_is_rowwise, - ) - return lhs_q, rhs_q + assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp_float8_e5m2_type) + assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp_float8_e5m2_type) # E5M2 * E5M2 is not supported @@ -1165,219 +1252,217 @@ def _quantize_gemm_pair(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer [jnp_float8_e5m2_type, jnp_float8_e4m3_type], ] -""" -@pytest_parametrize_wrapper( - "shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]] -) +GROUPED_DENSE_INPUT_SHAPES = [ + # (n_groups, m, n, k), the actual m will be multiplied by 32 + (5, 32, 128, 64), # Test the case where n_groups is not a multiple of 4 + (8, 64, 32, 128), + (8, 64, 128, 256), +] + + +@pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES) class TestGroupedDense: - def _ref_grouped_gemm_with_jnp_dot(self, lhs_list, rhs_list, contracting_dims_list): - ref_out_list = [] - for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list): - dim_nums = (contracting_dims, ((), ())) - ref_out_list.append(jax.lax.dot_general(lhs, rhs, dim_nums)) - return ref_out_list - - def _generate_grouped_gemm_input(self, dtype, shape_list, layout_list): + def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims): + lhs_contract_dim, _ = contracting_dims + assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3 + if bias is None: + bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype) + else: + assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2]) + remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop() + lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis) + rhs = jnp.split(rhs, rhs.shape[0], axis=0) + bias = jnp.split(bias, bias.shape[0], axis=0) + ref_out = [] + dim_num = (contracting_dims, ((), ())) + for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias): + out_i = jax.lax.dot_general( + lhs_i, rhs_i, dim_num, precision=jax.lax.Precision.HIGHEST + ) + jnp.expand_dims(bias_i, axis=0) + ref_out.append(jnp.squeeze(out_i)) + return ref_out + + def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False): key = jax.random.PRNGKey(0) - subkeys = jax.random.split(key, len(shape_list) * 2) - - lhs_list, rhs_list, contracting_dims_list = [], [], [] - for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)): - lhs = jax.random.uniform( - subkeys[2 * i], - (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m), - dtype=dtype, - ) - rhs = jax.random.uniform( - subkeys[2 * i + 1], - (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k), - dtype=dtype, - ) - lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) - rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,) - contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) + subkeys = jax.random.split(key, 4) + n_groups, m, n, k = input_shape + + group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) + group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) + group_sizes = jnp.diff(group_sizes) + # Make one empty input lhs to test empty GEMM handling + group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1]) + group_sizes = group_sizes.at[1].set(0) + assert group_sizes.sum() == m - lhs_list.append(lhs) - rhs_list.append(rhs) - contracting_dims_list.append(contracting_dims) + # *32 to make sure that input shape works for MXFP8 + group_sizes = group_sizes * 32 + m = m * 32 - return lhs_list, rhs_list, contracting_dims_list + lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m) + rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k) + bias_shape = (n_groups, n) + + lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype) + rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype) + bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None + + lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) + rhs_contracting_dim = (1,) if data_layout[1] == "N" else (2,) + contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) + + return lhs, rhs, group_sizes, contracting_dims, bias + + def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype): + assert out.dtype == ref_list[0].dtype + out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + for i in range(len(ref_list)): + assert_allclose(out_list[i], ref_list[i], dtype=dtype) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) - @pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]]) - def test_grouped_gemm_fp16(self, dtype, shape_list, layout_list): - lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input( - dtype, shape_list, layout_list + @pytest_parametrize_wrapper("layout", ["NN"]) + def test_grouped_gemm_fp16(self, dtype, input_shape, layout): + lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( + dtype, input_shape, layout ) - ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list) - primitive_out = tex.grouped_gemm(lhs_list, rhs_list, contracting_dims_list) - for i in range(len(shape_list)): - assert_allclose(primitive_out[i], ref_out[i], dtype=dtype) + ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) + + # grouped_gemm does not work with cudaGraph yet, so the jitting will breaks + # To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to + # disable cudaGraph, then use the following jitted function + + # jitting grouped_gemm + # prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( + # lhs, rhs, group_sizes, contracting_dims, + # ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims) + self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) + + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - @pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]]) - def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list, layout_list): + @pytest_parametrize_wrapper("layout", ["NN"]) + def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout): fwd_dtype, bwd_dtype = fwd_bwd_dtype quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=False + scaling_mode=scaling_mode, + fwd_dtype=fwd_dtype, + bwd_dtype=bwd_dtype, + is_2x2x=False, + n_groups=input_shape[0], ) + # quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype + # We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype + quantizer_set.kernel.q_dtype = bwd_dtype + for quantizer in quantizer_set.kernel.quantizers: + quantizer.q_dtype = bwd_dtype + out_dtype = jnp.bfloat16 - lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input( - out_dtype, shape_list, layout_list + lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( + out_dtype, input_shape, layout ) - q_lhs_list = [] - q_rhs_list = [] - for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list): - # quantizer_set.x and quantizer_set.kernel have the same q_dtype, we want to - # test the case where lhs and rhs have different q_dtypes - q_lhs, q_rhs = _quantize_gemm_pair( - lhs, rhs, contracting_dims, quantizer_set.x, quantizer_set.dgrad - ) - q_lhs_list.append(q_lhs) - q_rhs_list.append(q_rhs) + ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) - ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list) - primitive_out = tex.grouped_gemm(q_lhs_list, q_rhs_list, contracting_dims_list) + # jitting grouped_gemm + # prim_out = jax.jit(tex.grouped_gemm, static_argnames=('contracting_dims',))( + # lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set + # ) + + prim_out = tex.grouped_gemm( + lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set + ) allclose_dtype = jnp_float8_e4m3_type - if fwd_dtype == jnp_float8_e5m2_type or bwd_dtype == jnp_float8_e5m2_type: + if jnp_float8_e5m2_type in fwd_bwd_dtype: allclose_dtype = jnp_float8_e5m2_type - for i in range(len(shape_list)): - assert_allclose(primitive_out[i], ref_out[i], dtype=allclose_dtype) - @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) - def test_grouped_dense_grad_fp16(self, dtype, shape_list): - group_size = len(shape_list) - layout_list = ["NN" for _ in range(group_size)] + self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype) - x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input( - dtype, shape_list, layout_list + def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims): + out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims) + # Note: we use jnp.sum instead of jnp.mean to make the gradient larger + # and prevent them from being clamp to zero in FP8. / sqrt(x.size) is used to + # normalize the output and prevent the gradient from being too large for FP8. + out_sum_list = [jnp.sum(out) for out in out_list] + return jnp.sum(jnp.asarray(out_sum_list)) / jnp.sqrt(x.size) + + def _primitive_sum_grouped_dense( + self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set + ): + out = grouped_dense( + x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set ) - bias_list = [] - key = jax.random.PRNGKey(1) - for shape in shape_list: - n = shape[1] - bias = jax.random.uniform(key, n, dtype=dtype) - bias_list.append(bias) - - def ref_func(x_list, kernel_list, bias_list, contracting_dims_list): - out_list = [] - for i in range(len(x_list)): - out_list.append( - dense( - x_list[i], - kernel_list[i], - bias_list[i], - contracting_dims=contracting_dims_list[i], - ) - ) - # Note: we use jnp.sum instead of jnp.mean to make the gradient larger - # and prevent them from being clamp to zero - out_sum_list = [jnp.sum(out) for out in out_list] - return jnp.sum(jnp.asarray(out_sum_list)) + return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size) - def primitive_func(x_list, kernel_list, bias_list, contracting_dims_list): - out_list = grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list) - out_sum_list = [jnp.sum(out) for out in out_list] - return jnp.sum(jnp.asarray(out_sum_list)) + @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) + def test_grouped_dense_grad_fp16(self, dtype, input_shape): + x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( + dtype, + input_shape, + with_bias=True, + ) - value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) - value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) + value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) + # jitting the grouped_dense + # value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), + # static_argnums=(4,)) + value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) - ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func( - x_list, kernel_list, bias_list, contracting_dims_list + ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( + x, kernel, bias, group_sizes, contracting_dims ) - primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = ( - value_n_grad_primitive_func(x_list, kernel_list, bias_list, contracting_dims_list) + prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func( + x, kernel, bias, group_sizes, contracting_dims ) - assert_allclose(primitive_out_mean, ref_out_mean, dtype=dtype) - for i in range(group_size): - assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=dtype) - assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=dtype) - assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=dtype) + assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype) + assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype) + assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype) + assert_allclose(prim_dbias, ref_dbias, dtype=dtype) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) + @pytest.mark.parametrize( + "fwd_bwd_dtype", + [(jnp_float8_e4m3_type, jnp_float8_e4m3_type), (jnp_float8_e4m3_type, jnp_float8_e5m2_type)], + ) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list): - group_size = len(shape_list) - layout_list = ["NN" for _ in range(group_size)] + def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): fwd_dtype, bwd_dtype = fwd_bwd_dtype - if fwd_dtype == jnp_float8_e5m2_type: - pytest.skip("We never use E5M2 for fwd_dtype in training") - - # Question: should we use different quantizers for different groups? - ref_quantizer_set_list = [] - quantizer_set_list = [] - for _ in range(group_size): - ref_quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True - ) - ref_quantizer_set_list.append(ref_quantizer_set) - quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True - ) - quantizer_set_list.append(quantizer_set) - - out_dtype = jnp.bfloat16 - x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input( - out_dtype, shape_list, layout_list + dtype = jnp.bfloat16 + x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( + dtype, + input_shape, + with_bias=True, ) - bias_list = [] - key = jax.random.PRNGKey(1) - for shape in shape_list: - n = shape[1] - bias = jax.random.uniform(key, n, dtype=out_dtype) - bias_list.append(bias) - - def ref_func(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list): - out_list = [] - for i in range(len(x_list)): - out_list.append( - dense( - x_list[i], - kernel_list[i], - bias_list[i], - contracting_dims=contracting_dims_list[i], - quantizer_set=quantizer_set_list[i], - ) - ) - # Note: we use jnp.sum instead of jnp.mean to make the gradient larger - # and prevent them from being clamp to zero - out_sum_list = [jnp.sum(out) for out in out_list] - return jnp.sum(jnp.asarray(out_sum_list)) - - def primitive_func( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list - ): - out_list = grouped_dense( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list - ) - out_sum_list = [jnp.sum(out) for out in out_list] - return jnp.sum(jnp.asarray(out_sum_list)) - - value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) - value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) - ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func( - x_list, kernel_list, bias_list, contracting_dims_list, ref_quantizer_set_list + quantizer_set = QuantizerFactory.create_set( + scaling_mode=scaling_mode, + fwd_dtype=fwd_dtype, + bwd_dtype=bwd_dtype, + is_2x2x=True, + n_groups=group_sizes.size, ) - primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = ( - value_n_grad_primitive_func( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list - ) + value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) + + # jitting the grouped_dense + # value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), + # static_argnums=(4,)) + value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) + + ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( + x, + kernel, + bias, + group_sizes, + contracting_dims, + ) + prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func( + x, kernel, bias, group_sizes, contracting_dims, quantizer_set=quantizer_set ) - allclose_dtype = jnp_float8_e4m3_type - if fwd_dtype == jnp_float8_e5m2_type or bwd_dtype == jnp_float8_e5m2_type: - allclose_dtype = jnp_float8_e5m2_type - assert_allclose(primitive_out_mean, ref_out_mean, dtype=allclose_dtype) - for i in range(group_size): - assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype) - assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype) - assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype) -""" + assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype) + assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype) + assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) + assert_allclose(prim_dbias, ref_dbias, dtype=dtype) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 3f56e261c..5a824e8c6 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -72,6 +72,7 @@ def impl_test_self_attn( batch, seqlen, num_head, hidden = data_shape if not is_fused_attn_kernel_available( + is_training, dtype, dtype, QKVLayout.BS3HD, @@ -222,6 +223,7 @@ def test_cross_attn( batch, seqlen, num_head, hidden = data_shape if not is_fused_attn_kernel_available( + is_training, dtype, dtype, QKVLayout.BSHD_BS2HD, @@ -300,6 +302,7 @@ def impl_test_context_parallel_attn( cp_strategy, use_shardy, use_scan_ring=False, + window_size=None, ): if qkv_layout.is_thd(): if is_hip_extension() and cp_strategy == CPStrategy.RING: @@ -348,7 +351,7 @@ def impl_test_context_parallel_attn( is_training, qkv_layout, bias_shape, - None, + window_size, SeqDescFormat.SegmentIDs, number_of_devices=device_count, mesh_shape=mesh_shape, @@ -360,6 +363,7 @@ def impl_test_context_parallel_attn( def check_has_backend_for_mask(mask_type): return is_fused_attn_kernel_available( + is_training, dtype, dtype, qkv_layout, @@ -500,6 +504,13 @@ def test_context_parallel_allgather_attn( "use_scan", [pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")], ) + @pytest.mark.parametrize( + "window_size", + [ + pytest.param((-1, -1), id="window_size(-1, -1)"), + pytest.param((20, 0), id="window_size(20, 0)"), + ], + ) def test_context_parallel_ring_attn( self, device_count, @@ -513,7 +524,15 @@ def test_context_parallel_ring_attn( qkv_layout, load_balanced, use_scan, + window_size, ): + if window_size != (-1, -1) and not qkv_layout.is_thd(): + pytest.skip("Sliding window attention is only supported for THD layout") + if window_size != (-1, -1) and qkv_layout.is_thd() and use_scan: + pytest.skip( + "When context parallelism and sliding window attention are used, " + "scanloop is not supported" + ) self.impl_test_context_parallel_attn( device_count, mesh_shape, @@ -528,6 +547,7 @@ def test_context_parallel_ring_attn( CPStrategy.RING, use_shardy=False, use_scan_ring=use_scan, + window_size=window_size, ) @pytest.mark.skipif(version.parse(jax.__version__) < version.parse("0.5.0"), reason="shardy sharding requires JAX 0.5.0") diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index 098ee25fb..57098b0e2 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -77,8 +77,6 @@ def generate_collectives_count_ref( all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize ) other_bytes = 0 - if fp8_recipe == recipe.MXFP8BlockScaling() and "dp" in mesh_axes: - other_bytes = 384 # required for small scale shapes that require padding if fp8_recipe == recipe.Float8CurrentScaling(): allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction return generate_collectives_count( diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index ff6ebb3ed..694610978 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -1,9 +1,10 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. from typing import Callable, Sequence, Union, Optional import pytest -from packaging import version import jax import jax.numpy as jnp @@ -14,6 +15,7 @@ assert_tree_like_allclose, is_devices_enough, pytest_parametrize_wrapper, + use_jax_gemm, ) from transformer_engine.common import recipe @@ -34,7 +36,11 @@ ) from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.quantize import QuantizerFactory +from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability +from transformer_engine.jax.util import get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type +jnp_float8_e4m3_type = get_jnp_float8_e4m3_type() +jnp_float8_e5m2_type = get_jnp_float8_e5m2_type() is_fp8_supported, reason = is_fp8_available() is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) @@ -147,7 +153,15 @@ def layernorm_fp8_mlp_prim_func( ) def _test_layernorm_mlp_grad( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, use_shardy + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + fp8_recipe, + use_shardy, + with_jax_gemm, ): jax.config.update("jax_use_shardy_partitioner", use_shardy) device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config @@ -157,72 +171,87 @@ def _test_layernorm_mlp_grad( input_shape, activation_type, use_bias, dtype ) static_inputs = [layernorm_type, activation_type] - value_and_grad_func = jax.value_and_grad( - self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs)) - ) - # Single GPU - with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - single_jitter = jax.jit( - value_and_grad_func, - static_argnums=range(len(inputs), len(static_inputs) + len(inputs)), + with use_jax_gemm(enabled=with_jax_gemm): + value_and_grad_func = jax.value_and_grad( + self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs)) ) - single_fwd, single_grads = single_jitter(*inputs, *static_inputs) - - # Multi GPUs - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource): - k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp")) - k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp")) - k1_ = jax.device_put(k1, k1_sharding) - k2_ = jax.device_put(k2, k2_sharding) - if use_bias: - b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp")) - b1_ = jax.device_put(b1, b1_sharding) - else: - b1_sharding = b1_ = None - multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]] - - # Position ref for sharding pspec lists - # x, gamma, k1, k2, b1, - # b2 - in_shardings = ( - None, - None, - k1_sharding, - k2_sharding, - b1_sharding, - None, - ) - out_shardings = ( - None, - (None, None, k1_sharding, k2_sharding, b1_sharding, None), - ) - - multi_jitter = jax.jit( - value_and_grad_func, - in_shardings=in_shardings, - out_shardings=out_shardings, - static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1), - ) # +1 for multi_gpus - multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) - - assert_allclose(multi_fwd, single_fwd, dtype=dtype) + # Single GPU + with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + single_jitter = jax.jit( + value_and_grad_func, + static_argnums=range(len(inputs), len(static_inputs) + len(inputs)), + ) + single_fwd, single_grads = single_jitter(*inputs, *static_inputs) + + # Multi GPUs + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + with mesh, fp8_autocast( + enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource + ): + k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp")) + k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp")) + k1_ = jax.device_put(k1, k1_sharding) + k2_ = jax.device_put(k2, k2_sharding) + if use_bias: + b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp")) + b1_ = jax.device_put(b1, b1_sharding) + else: + b1_sharding = b1_ = None + multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]] + + # Position ref for sharding pspec lists + # x, gamma, k1, k2, b1, + # b2 + in_shardings = ( + None, + None, + k1_sharding, + k2_sharding, + b1_sharding, + None, + ) + out_shardings = ( + None, + (None, None, k1_sharding, k2_sharding, b1_sharding, None), + ) + + multi_jitter = jax.jit( + value_and_grad_func, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=range( + len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1 + ), + ) # +1 for multi_gpus + + multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) + + # TODO: skip cases with single fwd as nan/inf + if jnp.any(jnp.isnan(single_fwd)) or jnp.any(jnp.isinf(single_fwd)): + pytest.skip("skip tests with nan/inf single fwd.") + + fwd_test_type = dtype if fp8_recipe is None else jnp_float8_e4m3_type + bwd_test_type = dtype if fp8_recipe is None else jnp_float8_e5m2_type + assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) for i in range(len(inputs)): if multi_grads[i] is not None: if isinstance(multi_grads[i], list): assert isinstance(single_grads[i], list) for m_grad, s_grad in zip(multi_grads[i], single_grads[i]): assert_allclose( - m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close" + m_grad, + s_grad, + dtype=bwd_test_type, + err_msg=f"multi_grads[{i}] is not close", ) else: assert_allclose( multi_grads[i], single_grads[i], - dtype=dtype, + dtype=bwd_test_type, err_msg=f"multi_grads[{i}] is not close", ) @@ -233,8 +262,16 @@ def _test_layernorm_mlp_grad( @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + fp8_recipe, + with_jax_gemm, ): self._test_layernorm_mlp_grad( mesh_config, @@ -244,6 +281,7 @@ def test_layernorm_mlp_grad( dtype, fp8_recipe, use_shardy=False, + with_jax_gemm=with_jax_gemm, ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -252,20 +290,29 @@ def test_layernorm_mlp_grad( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest.mark.skipif(version.parse(jax.__version__) < version.parse("0.5.0"), reason="shardy sharding requires JAX 0.5.0") + @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad_shardy( - self, mesh_config, activation_type, use_bias, input_shape, dtype + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + fp8_recipe, + with_jax_gemm, ): - # We don't test block scaling with Shardy because at the time of writing, - # it is not supported in JAX's scaled_matmul_stablehlo. + if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") self._test_layernorm_mlp_grad( mesh_config, activation_type, use_bias, input_shape, dtype, - fp8_recipe=recipe.DelayedScaling(), + fp8_recipe=fp8_recipe, use_shardy=True, + with_jax_gemm=with_jax_gemm, ) def _test_layernorm_mlp( @@ -278,6 +325,7 @@ def _test_layernorm_mlp( use_fp8, fp8_recipe, use_shardy, + with_jax_gemm, ): jax.config.update("jax_use_shardy_partitioner", use_shardy) batch, seqlen, hidden_in = input_shape @@ -289,62 +337,95 @@ def _test_layernorm_mlp( x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) init_rngs = {"params": subkeys[1]} - # Single GPUs - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): - ln_mlp_single = LayerNormMLP( - layernorm_type=layernorm_type, - transpose_batch_sequence=False, # input: [batch, seqlen, hidden] - intermediate_dim=INTERMEDIATE, - activations=activation_type, - use_bias=use_bias, - ) - params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) - mlp_out_single, ln_out_single = ln_mlp_single.apply( - params_single, x, deterministic=True - ) - - # Multi GPUs - device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast( - enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource - ): - ln_mlp_sharded = LayerNormMLP( - layernorm_type=layernorm_type, - transpose_batch_sequence=False, - intermediate_dim=INTERMEDIATE, - activations=activation_type, - scale_axes=LN_SCALE_AXES, - ln_bias_axes=LN_BIAS_AXES, - kernel_axes_1=KERNEL_1_AXES, - kernel_axes_2=KERNEL_2_AXES, - use_bias=use_bias, - bias_axes_1=BIAS_1_AXES, - bias_axes_2=BIAS_2_AXES, - layernorm_input_axes=LAYERNORM_INPUT_AXES, - dot_1_input_axes=DOT_1_INPUT_AXES, - dot_2_input_axes=DOT_2_INPUT_AXES, - name="mlp", - ) - params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True) - mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply( - params_sharded, x, deterministic=True - ) + with use_jax_gemm(enabled=with_jax_gemm): + # Single GPUs + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + ln_mlp_single = LayerNormMLP( + layernorm_type=layernorm_type, + transpose_batch_sequence=False, # input: [batch, seqlen, hidden] + intermediate_dim=INTERMEDIATE, + activations=activation_type, + use_bias=use_bias, + ) + params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) + mlp_out_single, ln_out_single = ln_mlp_single.apply( + params_single, x, deterministic=True + ) + + # Multi GPUs + device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + with mesh, fp8_autocast( + enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource + ): + ln_mlp_sharded = LayerNormMLP( + layernorm_type=layernorm_type, + transpose_batch_sequence=False, + intermediate_dim=INTERMEDIATE, + activations=activation_type, + scale_axes=LN_SCALE_AXES, + ln_bias_axes=LN_BIAS_AXES, + kernel_axes_1=KERNEL_1_AXES, + kernel_axes_2=KERNEL_2_AXES, + use_bias=use_bias, + bias_axes_1=BIAS_1_AXES, + bias_axes_2=BIAS_2_AXES, + layernorm_input_axes=LAYERNORM_INPUT_AXES, + dot_1_input_axes=DOT_1_INPUT_AXES, + dot_2_input_axes=DOT_2_INPUT_AXES, + name="mlp", + ) + params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True) + mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply( + params_sharded, x, deterministic=True + ) # Make sure params values are the same assert_tree_like_allclose(params_sharded["params"], params_single["params"]) assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype) - assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype) + + atol = None + rtol = None + l40_tolerance_update = ( + get_min_device_compute_capability() == 89 + and fp8_recipe == recipe.DelayedScaling() + and use_fp8 + and dtype == jnp.float16 + and activation_type == ("gelu",) + ) + if l40_tolerance_update: + atol = 0.04 + rtol = 11 + + # JAX's FP8 GEMM, jax.lax.dot_general, now uses the + # Triton backend by default. The error of + # the Triton FP8 gemm has been verified to be less than or equal + # to the error of the cuDNN FP8 gemm w.r.t a float32 ground truth. + # However, Triton can auto-tune a different kernel for the single GPU + # and multi-GPU run in this test, meaning the diff between single GPU + # and multi-GPU can be larger in some cases, even though both are + # within tolerance to the float32 ground truth. + jax_triton_gemm_precision_tolerance_update = ( + with_jax_gemm + and isinstance(fp8_recipe, recipe.Float8CurrentScaling) + and dtype == jnp.bfloat16 + and activation_type == ("gelu", "linear") + ) + if jax_triton_gemm_precision_tolerance_update: + atol = 0.08 + rtol = 15 + + assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype, atol=atol, rtol=rtol) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("use_shardy", [False, True] if version.parse(jax.__version__) >= version.parse("0.5.0") else [False]) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer( - self, mesh_config, activation_type, use_bias, input_shape, dtype, use_shardy + self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm ): self._test_layernorm_mlp( mesh_config, @@ -354,7 +435,8 @@ def test_layernorm_mlp_layer( dtype, use_fp8=False, fp8_recipe=None, - use_shardy=use_shardy, + use_shardy=False, + with_jax_gemm=with_jax_gemm, ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -364,8 +446,9 @@ def test_layernorm_mlp_layer( @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer_fp8( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm ): self._test_layernorm_mlp( mesh_config, @@ -376,4 +459,51 @@ def test_layernorm_mlp_layer_fp8( use_fp8=True, fp8_recipe=fp8_recipe, use_shardy=False, + with_jax_gemm=with_jax_gemm, + ) + + @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_layernorm_mlp_layer_shardy( + self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm + ): + self._test_layernorm_mlp( + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + use_fp8=False, + fp8_recipe=None, + use_shardy=True, + with_jax_gemm=with_jax_gemm, + ) + + @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) + @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_layernorm_mlp_layer_fp8_shardy( + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm + ): + if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") + self._test_layernorm_mlp( + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + use_fp8=True, + fp8_recipe=fp8_recipe, + use_shardy=True, + with_jax_gemm=with_jax_gemm, ) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index ab32afe40..e08c3a1b9 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -373,6 +373,7 @@ def _check_configs(self): pytest.skip("Aiter currently supports MLA hd192_hd128 only with mask-based SeqDescFormat.") self.backend = FusedAttnHelper( + self.is_training, self.dtype, self.dtype, self.qkv_layout, diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py index 73721b318..e237318a4 100644 --- a/tests/jax/test_helper.py +++ b/tests/jax/test_helper.py @@ -92,7 +92,7 @@ def test_fp8_autocast_delayed_scaling(self): self._check_default_state() @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) - def test_fp8_autocast_mxfp8_scaling(self): + def test_fp8_autocast_current_scaling(self): QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. self._check_default_state() @@ -116,7 +116,7 @@ def test_fp8_autocast_mxfp8_scaling(self): self._check_default_state() @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) - def test_fp8_autocast_mxfp8_scaling(self): + def test_fp8_autocast_mxfp8_block_scaling(self): QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. self._check_default_state() diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 8d69fd817..f34fb5448 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -5,11 +5,12 @@ # See LICENSE for license information. """Utility for the TE layer tests""" +import os import functools import math import operator -from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional -import os +from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional, NewType +from contextlib import contextmanager import jax import jax.numpy as jnp @@ -22,7 +23,6 @@ import pytest from transformer_engine.jax.attention import ( - AttnMaskType, canonicalize_attn_mask_type, make_swa_mask, ) @@ -31,8 +31,8 @@ PRNGKey = Any Shape = Tuple[int, ...] -DType = jnp.dtype -Array = Any +DType = NewType("DType", jnp.dtype) +Array = NewType("Array", jnp.ndarray) PrecisionLike = Union[ None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] ] @@ -1522,7 +1522,7 @@ def dtype_tols( TEDType.kFloat8E5M2: get_jnp_float8_e5m2_type(), }[dtype] elif isinstance(dtype, np.dtype): - dtype = jnp.dtype(dtype) + dtype = DType(dtype) # Expect bit-wise accuracy for integer dtypes if not jnp.issubdtype(dtype, jnp.floating): @@ -1603,3 +1603,20 @@ def print_debug_tensor_stats(prefix, tensor, hist=False): fmt = fmt + "\n {}\n {}" jax.debug.print(fmt, *args) + + +@contextmanager +def use_jax_gemm(enabled=False): + orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None) + + try: + if enabled: + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" + yield + + finally: + if enabled: + if orig_custom_calls_filter is None: + os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") + else: + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls_filter diff --git a/tests/pytorch/debug/conftest.py b/tests/pytorch/debug/conftest.py new file mode 100644 index 000000000..20edc6aab --- /dev/null +++ b/tests/pytorch/debug/conftest.py @@ -0,0 +1,27 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--feature_dirs", nargs="+", action="store", default="", help="List of feature directories" + ) + parser.addoption( + "--configs_dir", + action="store", + default="", + type=str, + help="Path to the directory with configs.", + ) + + +@pytest.fixture +def feature_dirs(request): + return request.config.getoption("--feature_dirs") + + +@pytest.fixture +def configs_dir(request): + return request.config.getoption("--configs_dir") diff --git a/tests/pytorch/debug/run_distributed.py b/tests/pytorch/debug/run_distributed.py new file mode 100644 index 000000000..b12f8c3d3 --- /dev/null +++ b/tests/pytorch/debug/run_distributed.py @@ -0,0 +1,666 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import tempfile +import functools +import os +import itertools +import random +import argparse +import re + +import torch +import torch.distributed as dist +import transformer_engine +import transformer_engine_torch as tex +import nvdlfw_inspect.api as debug_api +from transformer_engine.debug import set_weight_tensor_tp_group_reduce +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager + +from test_numerics import ( + _emulate_linear, + _init_debug, + disable_fp8_gemms_create_config, + DISABLE_FP8_LAYER_CONFIG, + _cmp, + IN_SIZE, + OUT_SIZE, + _init_model, + SEED, + SEQ_LEN, + BATCH_SIZE, + FP8_RECIPE, + fake_quant_fp8_create_config, + _get_current_scale, + _prepare_per_tensor_scaling_config, + AMAX_HISTORY_LEN, + set_scaling_factors, + set_current_scaling_factors, +) + +WORLD_RANK, WORLD_SIZE = None, None +NCCL_WORLD = None +FEATURE_DIRS = None +all_boolean = [True, False] +TEST_NR = 0 + +fp8_available, _ = FP8GlobalStateManager.is_fp8_available() + + +def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None): + if tp_size is None: + tp_size = WORLD_SIZE + tp_rank = WORLD_RANK + torch.manual_seed(weight_seed) + weight = torch.randn((OUT_SIZE, IN_SIZE)).cuda() + torch.manual_seed(data_seed) + in_split_size = IN_SIZE // tp_size + out_split_size = OUT_SIZE // tp_size + x = torch.randn((SEQ_LEN * BATCH_SIZE, IN_SIZE), requires_grad=True).cuda() + if parallel_mode == "row": + x = x[:, tp_rank * in_split_size : (tp_rank + 1) * in_split_size] + x.retain_grad() + + with torch.no_grad(): + if parallel_mode == "column": + weight = weight[tp_rank * out_split_size : (tp_rank + 1) * out_split_size, :] + else: + weight = weight[:, tp_rank * in_split_size : (tp_rank + 1) * in_split_size] + + return x, weight.contiguous() + + +def _init_model(weight, parallel_mode=None, tp_group=None, name="linear"): + model = transformer_engine.pytorch.Linear( + IN_SIZE, + OUT_SIZE, + name=name, + parallel_mode=parallel_mode, + tp_group=(tp_group or NCCL_WORLD if parallel_mode else None), + ) + with torch.no_grad(): + model.weight.copy_(weight) + return model + + +class AllGather(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor, dim, group=None): + if group is None: + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + world_size = torch.distributed.get_world_size(group=group) + rank = torch.distributed.get_rank(group=group) + dist.barrier() + + # Create a list to gather tensors from all processes + y_list = [torch.zeros_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(y_list, tensor, group=group) + + # Save the world size and rank for backward computation + ctx.world_size = world_size + ctx.rank = rank + ctx.dim = dim + + # Concatenate the gathered tensors along the feature dimension + y_full = torch.cat(y_list, dim=dim) + + return y_full + + @staticmethod + def backward(ctx, grad_output): + # Split the gradient output and return the portion corresponding to this rank + grad_input = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)[ctx.rank] + return grad_input, None, None + + +def _run_forward_backward(x, model, parallel_mode=None, group=None): + with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + y = model(x) + + y.requires_grad_(True) + y.retain_grad() + if parallel_mode == "column": + y = AllGather.apply(y, -1, group) + y.requires_grad_(True) + y.retain_grad() + l = y.sum() + l.backward() + elif parallel_mode == "row": + l = y.sum() + l.backward() + debug_api.step() + return y + + +def _emulate_linear_distributed(*args, parallel_mode=None, **kwargs): + assert parallel_mode in ["column", "row"] + + def split(gradient): + split_size = OUT_SIZE // WORLD_SIZE + gradient = gradient[:, WORLD_RANK * split_size : (WORLD_RANK + 1) * split_size] + return gradient + + activation_sync = None + gradient_sync = None + if parallel_mode == "column": + activation_sync = lambda x: AllGather.apply(x, -1) + gradient_sync = split + else: + activation_sync = ( + lambda activation: dist.all_reduce(activation, op=dist.ReduceOp.SUM) or activation + ) + + output = _emulate_linear( + *args, activation_sync=activation_sync, gradient_sync=gradient_sync, **kwargs + ) + + if parallel_mode == "column": + dist.all_reduce(output["dgrad"], op=dist.ReduceOp.SUM) + + return output + + +def check_debug_log(msg): + with open(f"log/debug_logs/debug_log_globalrank-{WORLD_RANK}.log", "r") as f: + for line in f.readlines(): + if msg in line: + return True + return False + + +def run_debug_test(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank = dist.get_rank() + temp_file_name = None + temp_logdir_name = None + + if rank == 0: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file: + temp_file_name = temp_file.name + temp_dir_obj = tempfile.TemporaryDirectory() + temp_logdir_name = temp_dir_obj.name + + # Store the TemporaryDirectory object to prevent it from being deleted + wrapper.temp_dir_obj = temp_dir_obj + + temp_file_name_list = [temp_file_name] + temp_logdir_name_list = [temp_logdir_name] + + # Broadcast the temporary file and directory names to all processes + dist.broadcast_object_list(temp_file_name_list, src=0) + dist.broadcast_object_list(temp_logdir_name_list, src=0) + + temp_file_name = temp_file_name_list[0] + temp_logdir_name = temp_logdir_name_list[0] + + dist.barrier() + + config_file = open(temp_file_name, mode="r+", buffering=1) + + try: + kwargs["config_file"] = config_file + kwargs["log_dir"] = temp_logdir_name + + if rank == 0: + global TEST_NR + print(f"Running test {TEST_NR} {func.__name__} with args = {args}.") + TEST_NR += 1 + + func(*args, **kwargs) + finally: + if rank == 0 and temp_file_name is not None: + os.unlink(temp_file_name) + + debug_api.end_debug() + + if rank == 0 and hasattr(wrapper, "temp_dir_obj"): + wrapper.temp_dir_obj.cleanup() + + return wrapper + + +CONFIG_LOG_TEST_DISTRIBUTED_FP8 = """log_distributed: + layers: + layer_types: [linear] + enabled: + True + transformer_engine: + LogTensorStats: + enabled: True + tensors: [activation, gradient, weight, output, wgrad, dgrad] + stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range] + start_step : 0 + end_step: 1 + LogFp8TensorStats: + enabled: True + tensors: [activation, gradient, weight] + stats: [underflows%] + start_step : 0 + end_step: 1 +""" + +CONFIG_LOG_TEST_DISTRIBUTED_NO_FP8 = """log_distributed: + layers: + layer_types: [linear] + enabled: + True + transformer_engine: + LogTensorStats: + enabled: True + tensors: [activation, gradient, weight, output, wgrad, dgrad] + stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range] + start_step : 0 + end_step: 1 +""" + + +def _prepare_config_test_log_distributed(config_file): + if WORLD_RANK != 0: + return + config_file.write( + CONFIG_LOG_TEST_DISTRIBUTED_FP8 if fp8_available else CONFIG_LOG_TEST_DISTRIBUTED_NO_FP8 + ) + config_file.flush() + + +def _compute_dynamic_range(tensor): + tensor_abs = tensor.abs() + tensor_abs = tensor_abs[tensor_abs != 0] + if tensor_abs.any(): + amin = tensor_abs.min().float() + else: + amin = torch.tensor(1, device=tensor.device).to(torch.float) + amax = tensor_abs.max().float() + if not amax.all(): + amax = torch.tensor(1, device=tensor.device).to(torch.float) + dynamic_range = torch.log2(amax) - torch.log2(amin) + return dynamic_range + + +@run_debug_test +def test_log_distributed(parallel_mode, gather_weight, **kwargs): + _prepare_config_test_log_distributed(kwargs["config_file"]) + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS) + set_weight_tensor_tp_group_reduce(gather_weight) + if WORLD_SIZE % 2 != 0: + return # skip + TP_SIZE = WORLD_SIZE // 2 + DP_SIZE = 2 + TP_RANK = WORLD_RANK % TP_SIZE + DP_RANK = (WORLD_RANK - TP_RANK) // TP_SIZE + + debug_api.set_tensor_reduction_group(NCCL_WORLD) + + x, weight = _get_tensors( + parallel_mode, + weight_seed=TP_RANK * 1234, + data_seed=DP_RANK * 1234, + tp_size=TP_SIZE, + tp_rank=TP_RANK, + ) + + tp_group_ranks = [i for i in range(DP_RANK * TP_SIZE, (DP_RANK + 1) * TP_SIZE)] + tp_group = dist.new_group(ranks=tp_group_ranks) + + dp_group_ranks = [i for i in range(TP_RANK, WORLD_SIZE, TP_SIZE)] + dp_group = dist.new_group(ranks=dp_group_ranks) + + model = _init_model(weight, parallel_mode=parallel_mode, tp_group=tp_group) + output = _run_forward_backward(x, model, parallel_mode=parallel_mode, group=tp_group) + + gathered_activation = AllGather.apply(x.contiguous(), 0) + gathered_weight = AllGather.apply(weight.contiguous(), 0, tp_group) + gathered_gradient = AllGather.apply(output.grad.contiguous(), 0, dp_group) + if parallel_mode == "row": + gathered_gradient = AllGather.apply(gathered_gradient, 0, tp_group) + + log_file = kwargs["log_dir"] + "/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log" + + dist.barrier() + if WORLD_RANK != 0: + return # stats are gathered on node 0 + with open(log_file) as f: + content = f.read() + + def get_stat(tensor, stat): + regex = r".*_{tensor}_{stat}\s+.*iteration=(\d+)\s+.*value=([-+]?\d*\.?\d+)".format( + tensor=tensor, stat=stat + ) + for line in content.splitlines(): + match = re.search(regex, line) + if match: + value = float(match.group(2)) + return value + + rf = lambda x: round(float(x), 4) + stats = [] + tensors = { + "activation": gathered_activation, + "weight": gathered_weight if gather_weight else weight, + "gradient": gathered_gradient, + } + stats = { + "min": torch.min, + "max": torch.max, + "mean": torch.mean, + "std": torch.std, + "l1_norm": lambda x: torch.norm(x, p=1), + "l2_norm": lambda x: torch.norm(x, p=2), + "cur_amax": lambda x: x.abs().max(), + "dynamic_range": _compute_dynamic_range, + } + for stat_key in stats.keys(): + for tensor_key in tensors.keys(): + torch.testing.assert_close( + get_stat(tensor_key, stat_key), + rf(stats[stat_key](tensors[tensor_key])), + atol=0.0001, + rtol=0.0001, + ) + set_weight_tensor_tp_group_reduce(True) # reset + + +@run_debug_test +def test_log_expert_parallel(**kwargs): + """ + This test tests the scenario, when one of the node of data parallel does not invoke the debug layer. + It naturally occurs in the expert parallelism, when one expert doesn't get input on one node, + but gets it on other nodes. If there were all_gather inside forward(), this would result in deadlock. + """ + _prepare_config_test_log_distributed(kwargs["config_file"]) + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS) + debug_api.set_tensor_reduction_group(NCCL_WORLD) + x, weight = _get_tensors( + "row", weight_seed=WORLD_RANK * 1234, data_seed=WORLD_RANK * 1234, tp_size=1, tp_rank=0 + ) # data parallel + model = _init_model(weight, parallel_mode=None, name="linear1") + model1 = _init_model(weight, parallel_mode=None, name="linear2") + with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE): + y1 = model(x) + y2 = model1(x) + y = y1 + y2 + y.sum().backward() + debug_api.step() + with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE): + y = model(x) + if WORLD_RANK != 0: + y = y + model1(x) + + y.sum().backward() + + +@run_debug_test +def test_disable_fp8_gemms(fprop_fp8, dgrad_fp8, wgrad_fp8, parallel_mode, **kwargs): + disable_fp8_gemms_create_config(fprop_fp8, dgrad_fp8, wgrad_fp8, kwargs["config_file"]) + fp8_kwargs = { + "fprop_fp8": fprop_fp8, + "dgrad_fp8": dgrad_fp8, + "wgrad_fp8": wgrad_fp8, + } + + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS) + x, weight = _get_tensors(parallel_mode) + model = _init_model(weight, parallel_mode=parallel_mode) + y = _run_forward_backward(x, model, parallel_mode=parallel_mode) + + output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()} + + x.grad.zero_() + ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode, **fp8_kwargs) + _cmp(ground_truth, output) + + +@run_debug_test +def test_disable_fp8_layer(parallel_mode, **kwargs): + if WORLD_RANK == 0: + kwargs["config_file"].write(DISABLE_FP8_LAYER_CONFIG) + kwargs["config_file"].flush() + dist.barrier() + + x, weight = _get_tensors(parallel_mode) + + ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode) + x.grad.zero_() + + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS) + + model = _init_model(weight, parallel_mode) + y = _run_forward_backward(x, model, parallel_mode) + + output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()} + _cmp(ground_truth, output) + + +@run_debug_test +def test_per_tensor_scaling( + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + parallel_mode, + **kwargs, +): + input_kwargs = { + "fprop_inp": fprop_inp, + "fprop_weight": fprop_weight, + "dgrad_weight": dgrad_weight, + "dgrad_grad": dgrad_grad, + "wgrad_input": wgrad_input, + "wgrad_grad": wgrad_grad, + } + fp8_kwargs = { + "fprop_fp8": True, + "dgrad_fp8": True, + "wgrad_fp8": True, + } + """ + Runs a test to validate per-tensor (current) scaling in FP8 computations. + The function performs warm-up iterations to populate the amax buffer of the model and compute scaling factors based on delayed scaling. + Subsequently, weights and inputs are switched to ensure their current scaling factors differ from those based on delayed scaling; + similarly, the loss is multiplied by a large factor to alter the gradient's magnitude, + creating a discrepancy between the original (delayed) and per-tensor (current) scaling factors. + Finally, a linear pass is emulated, and the results are compared.” + """ + _prepare_per_tensor_scaling_config( + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + kwargs["config_file"], + ) + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS) + + warmup_input, warmup_weight = _get_tensors(parallel_mode=parallel_mode) + model = _init_model(warmup_weight, parallel_mode=parallel_mode) + + # Warmup run to setup amax and scaling factors. + for _ in range(AMAX_HISTORY_LEN): + _run_forward_backward(warmup_input, model, parallel_mode=parallel_mode) + + x, weight = _get_tensors( + parallel_mode=parallel_mode, weight_seed=WORLD_RANK * 2137, data_seed=WORLD_RANK * 2137 + ) + model.weight.data = weight.data + x.retain_grad() + + # delayed scaling factor + # need to be collected before forward pass with test data, + # because this forward pass changes scaling factors + set_scaling_factors(model, input_kwargs, fp8_kwargs) + + LOSS_MULTIPLIER = 100 + + with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + y = model(x) + model.zero_grad() + if parallel_mode == "column": + y = AllGather.apply(y, -1) + y.retain_grad() + + ( + LOSS_MULTIPLIER * y.sum() + ).backward() # Loss multiplication to change gradient's order of magintude + + output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()} + # per tensor - current - scaling factors + # need to be collected after forward pass with test data, + # because gradient(y.grad) cannot be accessed before forward, + # but it needs to be collected. + + set_current_scaling_factors(x, weight, y, input_kwargs, fp8_kwargs) + ground_truth = _emulate_linear_distributed( + x, weight, parallel_mode=parallel_mode, loss_multiplier=LOSS_MULTIPLIER, **fp8_kwargs + ) + + _cmp(ground_truth, output) + + +@run_debug_test +def test_fake_quant_fp8( + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + parallel_mode, + **kwargs, +): + + fp8_kwargs = { + "fprop_input_fake_quant": fprop_inp, + "fprop_weight_fake_quant": fprop_weight, + "dgrad_gradient_fake_quant": dgrad_grad, + "dgrad_weight_fake_quant": dgrad_weight, + "wgrad_gradient_fake_quant": wgrad_grad, + "wgrad_input_fake_quant": wgrad_input, + "fprop_fp8": not (fprop_inp or fprop_weight), + "dgrad_fp8": not (dgrad_weight or dgrad_grad), + "wgrad_fp8": not (wgrad_grad or wgrad_input), + } + if WORLD_RANK == 0: + fake_quant_fp8_create_config( + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + kwargs["config_file"], + ) + dist.barrier() + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS) + + x, weight = _get_tensors(parallel_mode) + model = _init_model(weight, parallel_mode) + y = _run_forward_backward(x, model, parallel_mode) + + output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()} + fp8_kwargs["fprop_input_scale"] = ( + _get_current_scale(x, fprop_inp) if not fp8_kwargs["fprop_fp8"] else None + ) + fp8_kwargs["fprop_weight_scale"] = ( + _get_current_scale(weight, fprop_weight) if not fp8_kwargs["fprop_fp8"] else None + ) + fp8_kwargs["dgrad_gradient_scale"] = ( + _get_current_scale(y.grad, dgrad_grad) if not fp8_kwargs["dgrad_fp8"] else None + ) + fp8_kwargs["dgrad_weight_scale"] = ( + _get_current_scale(weight, dgrad_weight) if not fp8_kwargs["dgrad_fp8"] else None + ) + fp8_kwargs["wgrad_gradient_scale"] = ( + _get_current_scale(y.grad, wgrad_grad) if not fp8_kwargs["wgrad_fp8"] else None + ) + fp8_kwargs["wgrad_input_scale"] = ( + _get_current_scale(x, wgrad_input) if not fp8_kwargs["wgrad_fp8"] else None + ) + ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode, **fp8_kwargs) + _cmp(ground_truth, output) + + +def _init_distributed(): + global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, FP8 + + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + + assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node + assert LOCAL_SIZE <= torch.cuda.device_count() + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + } + dist_init_kwargs["init_method"] = "env://" + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + torch.cuda.set_device(LOCAL_RANK) + dist.init_process_group(**dist_init_kwargs) + + NCCL_WORLD = dist.new_group(backend="nccl") + + WORLD_SIZE = dist.get_world_size() + + +def _run_test_with_combinations( + test_function, values_list, num_repeat, extra_args, sample_size=None +): + combinations = itertools.product(values_list, repeat=num_repeat) + total_combinations = itertools.product(combinations, extra_args) + + if sample_size is not None: + total_combinations = random.sample(list(total_combinations), sample_size) + + for comb, arg in total_combinations: + test_function(*comb, arg) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--feature_dirs", type=str) + args = parser.parse_args() + FEATURE_DIRS = args.feature_dirs + random.seed(SEED) + _init_distributed() + + test_log_expert_parallel() + for parallel_mode in ["column", "row"]: + for gather_weight in [True, False]: + test_log_distributed(parallel_mode, gather_weight) + + if fp8_available: + for parallel_mode in ["row", "column"]: + test_disable_fp8_layer(parallel_mode) + + # test_disable_fp8_gemms + _run_test_with_combinations( + test_disable_fp8_gemms, all_boolean, num_repeat=3, extra_args=["column", "row"] + ) + + # test_fake_quant_fp8 + dtype_options = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None] + _run_test_with_combinations( + test_fake_quant_fp8, + dtype_options, + num_repeat=6, + extra_args=["column", "row"], + sample_size=20, + ) + + _run_test_with_combinations( + test_per_tensor_scaling, + all_boolean, + num_repeat=6, + extra_args=["column"], + sample_size=20, + ) diff --git a/tests/pytorch/debug/test_api_features.py b/tests/pytorch/debug/test_api_features.py new file mode 100644 index 000000000..f9cd234ba --- /dev/null +++ b/tests/pytorch/debug/test_api_features.py @@ -0,0 +1,398 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer + +import nvdlfw_inspect.api as debug_api + +try: + import transformer_engine + import transformer_engine_torch as tex +except (ImportError, ModuleNotFoundError): + print("Could not find TransformerEngine package.") + exit(1) + + +def test_transformer_engine_no_config(feature_dirs): + debug_api.initialize("", feature_dirs=feature_dirs) + try: + + tensor = torch.rand(24, 2046).cuda() + + # FP8 enabled - true by the default + assert debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="fprop", iteration=0 + ) + + # modify_tensor_enabled - False by default + assert not debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0 + ) + + # inspect_tensor_enabled - False by default + assert not debug_api.transformer_engine.inspect_tensor_enabled( + "decoder.1.attn.qkv", tensor_name="activation", iteration=0 + ) + + # inspect_tensor_postquantize - False by default + assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( + "decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0 + ) + + finally: + debug_api.end_debug() + + +def test_disable_fp8_gemm(configs_dir, feature_dirs): + try: + debug_api.initialize(configs_dir + "disable_fp8_gemms.yaml", feature_dirs=feature_dirs) + + assert debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="fprop", iteration=0 + ) + assert not debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="dgrad", iteration=0 + ) + assert not debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="wgrad", iteration=0 + ) + + # caching + assert debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="fprop", iteration=0 + ) + assert not debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="dgrad", iteration=0 + ) + assert not debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="wgrad", iteration=0 + ) + + finally: + debug_api.end_debug() + + +def test_disable_fp8_layer(configs_dir, feature_dirs): + try: + debug_api.initialize(configs_dir + "disable_fp8_layer.yaml", feature_dirs=feature_dirs) + + assert debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.mlp.fc1", gemm="fprop", iteration=0 + ) + assert debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.mlp.fc1", gemm="wgrad", iteration=0 + ) + assert debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.mlp.fc1", gemm="dgrad", iteration=0 + ) + assert not debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="fprop", iteration=0 + ) + assert not debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="wgrad", iteration=0 + ) + assert not debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="dgrad", iteration=0 + ) + + finally: + debug_api.end_debug() + + +def test_per_tensor_scaling(configs_dir, feature_dirs): + try: + + debug_api.initialize(configs_dir + "per_tensor_scaling.yaml", feature_dirs=feature_dirs) + + tensor = torch.rand(24, 2046).cuda() + + # check modify_tensor_enabled + assert debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0 + ) + assert debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc1", gemm="fprop", tensor_name="weight", iteration=0 + ) + assert debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0 + ) + assert not debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="weight", iteration=0 + ) + assert not debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc1", gemm="wgrad", tensor_name="gradient", iteration=0 + ) + assert not debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc1", gemm="wgrad", tensor_name="activation", iteration=0 + ) + + # check modify_tensor + + default_quantizer1 = Float8Quantizer( + scale=torch.tensor([1]).cuda(), + amax=torch.tensor([0]).cuda(), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + default_quantizer2 = Float8Quantizer( + scale=torch.tensor([1]).cuda(), + amax=torch.tensor([0]).cuda(), + fp8_dtype=tex.DType.kFloat8E5M2, + ) + + output1 = debug_api.transformer_engine.modify_tensor( + layer_name="decoder.1.mlp.fc1", + gemm="fprop", + tensor_name="activation", + default_quantizer=default_quantizer1, + iteration=0, + tensor=tensor, + ) + assert type(output1) == Float8Tensor + assert output1._fp8_dtype == tex.DType.kFloat8E4M3 + + output2 = debug_api.transformer_engine.modify_tensor( + "decoder.1.mlp.fc1", + gemm="dgrad", + tensor=tensor, + tensor_name="gradient", + default_quantizer=default_quantizer2, + iteration=0, + ) + assert type(output2) == Float8Tensor + assert output2._fp8_dtype == tex.DType.kFloat8E5M2 + + assert not debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc1", + gemm="wgrad", + tensor_name="gradient", + iteration=0, + ) + + assert not debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc4", + gemm="fprop", + tensor_name="activation", + iteration=0, + ) + finally: + debug_api.end_debug() + + +def test_fake_quant(configs_dir, feature_dirs): + try: + debug_api.initialize( + configs_dir + "fake_quantization_config.yaml", feature_dirs=feature_dirs + ) + + tensor = torch.rand(24, 2046).cuda() + + # modify_tensor_enabled + assert debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0 + ) + + assert debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0 + ) + + # modify_tensor + debug_api.transformer_engine.modify_tensor( + "decoder.1.mlp.fc1", + gemm="fprop", + tensor=tensor, + tensor_name="activation", + iteration=0, + default_quantizer=None, + ) + + debug_api.transformer_engine.modify_tensor( + "decoder.1.mlp.fc1", + gemm="dgrad", + tensor=tensor, + tensor_name="gradient", + iteration=0, + default_quantizer=None, + ) + + assert debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.fc2", gemm="wgrad", iteration=0 + ) + # caching + assert debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.fc2", gemm="wgrad", iteration=0 + ) + finally: + debug_api.end_debug() + + +def test_statistics_collection(configs_dir, feature_dirs): + try: + debug_api.initialize( + config_file=configs_dir + "stats_collection_test_config.yaml", + feature_dirs=feature_dirs, + default_logging_enabled=False, + ) + + tensor = torch.randn((100, 100, 5)).cuda() + tensor_fp8 = Float8Tensor( + data=tensor.to(torch.uint8).cuda(), + fp8_scale_inv=torch.full([1], 1.0).cuda(), + fp8_dtype=tex.DType.kFloat8E4M3, + shape=tensor.shape, + dtype=torch.float32, + ) + + def log(): + from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS + + return STATS_BUFFERS.log_stats() + + def assert_empty(): + stats = log() + assert len(stats) == 0 + + # TE tensor stats -- + debug_api.transformer_engine.inspect_tensor( + "decoder.1.mlp.fc1", + tensor=tensor, + tensor_name="activation", + iteration=200, + tp_group=None, + ) + stats = log() + assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max() + assert not debug_api.transformer_engine.inspect_tensor_enabled( + "decoder.1.mlp.fc1", tensor_name="activation", iteration=201 + ) + assert not debug_api.transformer_engine.inspect_tensor_enabled( + "decoder.2.mlp.fc1", tensor_name="activation", iteration=200 + ) + assert not debug_api.transformer_engine.inspect_tensor_enabled( + "decoder.1.mlp.fc1", tensor_name="gradient", iteration=200 + ) + + expected_underflows = (tensor_fp8._data == 0).sum() * 100 / (100 * 100 * 5) + expected_overflows = (tensor_fp8._data == 126).sum() * 100 / (100 * 100 * 5) + + # TE FP8 tensor stats -- + assert debug_api.transformer_engine.inspect_tensor_postquantize_enabled( + "decoder.1.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200 + ) + debug_api.transformer_engine.inspect_tensor_postquantize( + "decoder.1.mlp.fc1", + tensor=tensor_fp8, + tensor_name="gradient", + iteration=200, + rowwise=True, + tp_group=None, + ) + stats = log() + torch.testing.assert_close( + stats[("decoder.1.mlp.fc1", "gradient", "underflows%", 200)], expected_underflows + ) + + assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( + "decoder.1.mlp.fc1", tensor_name="activation", gemm="fprop", iteration=201 + ) + assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( + "decoder.2.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200 + ) + + # Second config in same yaml + tensor = torch.rand((100, 100, 5)) + debug_api.transformer_engine.inspect_tensor( + "decoder.6.mlp.fc1", + tensor=tensor, + tensor_name="activation", + iteration=200, + tp_group=None, + ) + stats = log() + stats_names = [x[3] for x in stats.keys()] + all(s in stats_names for s in ["cur_amax", "dynamic_range", "mean", "std", "l1_norm"]) + assert stats[("decoder.6.mlp.fc1", "activation", "mean", 200)] == tensor.mean() + + debug_api.transformer_engine.inspect_tensor( + "decoder.7.mlp.fc1", + tensor=tensor, + tensor_name="weight", + iteration=200, + tp_group=None, + ) + stats = log() + stats_names = [x[3] for x in stats.keys()] + all(s in stats_names for s in ["mean", "std", "l1_norm", "min", "max"]) + assert stats[("decoder.7.mlp.fc1", "weight", "max", 200)] == tensor.max() + + assert not debug_api.transformer_engine.inspect_tensor_enabled( + "decoder.7.mlp.fc1", tensor_name="weight", iteration=201 + ) + assert_empty() + + finally: + debug_api.end_debug() + + +def test_statistics_multi_run(configs_dir, feature_dirs): + try: + debug_api.initialize( + config_file=configs_dir + "stats_collection_test_config.yaml", + feature_dirs=feature_dirs, + default_logging_enabled=False, + ) + + def feed(tensor, tensor_fp8): + debug_api.transformer_engine.inspect_tensor( + "decoder.5.mlp.fc1", + tensor=tensor, + tensor_name="activation", + iteration=1, + tp_group=None, + ) + debug_api.transformer_engine.inspect_tensor_postquantize( + "decoder.5.mlp.fc1", + tensor=tensor_fp8, + tensor_name="activation", + iteration=1, + rowwise=True, + tp_group=None, + ) + + def log_stats(): + from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS + + return STATS_BUFFERS.log_stats() + + def fp8_tensor(t): + return Float8Tensor( + data=t.to(torch.uint8).cuda(), + fp8_scale_inv=torch.ones([1]).cuda(), + fp8_dtype=tex.DType.kFloat8E4M3, + shape=t.shape, + dtype=torch.float32, + ) + + shape = [1024, 1024] + tensors = [torch.randn(shape) for _ in range(2)] + tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)] + + feed(tensors[0], tensors_fp8[0]) + feed(tensors[1], tensors_fp8[1]) + stats1 = log_stats() + + tensor2 = torch.cat((tensors[0], tensors[1])).cuda() + fp8tensor2 = fp8_tensor(tensor2) + feed(tensor2, fp8tensor2) + stats2 = log_stats() + + assert len(stats1.keys()) > 0 + for k in stats1.keys(): + torch.testing.assert_close(stats1[k], stats2[k]) + finally: + debug_api.end_debug() + + +if __name__ == "__main__": + pass diff --git a/tests/pytorch/debug/test_config.py b/tests/pytorch/debug/test_config.py new file mode 100644 index 000000000..71715a686 --- /dev/null +++ b/tests/pytorch/debug/test_config.py @@ -0,0 +1,151 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +import pathlib, os + +from nvdlfw_inspect.config_manager import ConfigManager + +import nvdlfw_inspect.api as debug_api + +try: + import transformer_engine + from transformer_engine.debug.features.api import TEConfigAPIMapper +except (ImportError, ModuleNotFoundError): + print("Could not find TransformerEngine debug module.") + exit(1) + + +def test_transformer_engine_config_parsing(feature_dirs): + debug_api.initialize( + config_file=pathlib.Path(__file__).resolve().parent + / "test_configs/tensor_manipulation_transformer_engine.yaml", + feature_dirs=feature_dirs, + log_dir="./log", + ) + + cfg_fc1 = ConfigManager.get_config_for_layer("decoder.1.mlp.fc1")["transformer_engine"] + cfg_fc2 = ConfigManager.get_config_for_layer("decoder.1.mlp.fc2")["transformer_engine"] + assert cfg_fc1 and cfg_fc2 + + gemm_parsing = True + tensor_parsing = True + + # Per tensor scaling set for dgrad, filter based on gemm + ret, _ = TEConfigAPIMapper().parse_config_and_api( + cfg_fc1["PerTensorScaling"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="wgrad", + tensor_name="activation", + ) + assert not ret + + # per tensor scaling set for gradient, filter based on tensor name + ret, _ = TEConfigAPIMapper().parse_config_and_api( + cfg_fc1["PerTensorScaling"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="dgrad", + tensor_name="activation", + ) + assert not ret + + ret, parsed_cfg_fc1 = TEConfigAPIMapper().parse_config_and_api( + cfg_fc1["PerTensorScaling"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="dgrad", + tensor_name="gradient", + ) + assert ret + assert parsed_cfg_fc1 == {"gemm": "dgrad", "tensor": "gradient"} + + # Test tensor struct + ret, parsed_cfg_fc1_act = TEConfigAPIMapper().parse_config_and_api( + cfg_fc1["FakeQuant"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="fprop", + tensor_name="activation", + ) + ret, parsed_cfg_fc1_wei = TEConfigAPIMapper().parse_config_and_api( + cfg_fc1["FakeQuant"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="fprop", + tensor_name="weight", + ) + assert ret + assert parsed_cfg_fc1_act == { + "gemm": "fprop", + "tensor": "activation", + "quant_format": "FP8E4M3", + } + assert parsed_cfg_fc1_wei == { + "gemm": "fprop", + "tensor": "weight", + "quant_format": "FP8E4M3", + } + + # Test gemms struct + ret, parsed_cfg_fc2_grad = TEConfigAPIMapper().parse_config_and_api( + cfg_fc2["FakeQuant"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="dgrad", + tensor_name="gradient", + ) + assert ret + assert parsed_cfg_fc2_grad == {"gemm": "dgrad", "tensor": "gradient", "quant_format": "FP8E5M2"} + ret, parsed_cfg_fc2_wei = TEConfigAPIMapper().parse_config_and_api( + cfg_fc2["FakeQuant"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="dgrad", + tensor_name="weight", + ) + assert ret + assert parsed_cfg_fc2_wei == {"gemm": "dgrad", "tensor": "weight", "quant_format": "FP8E5M2"} + + # Test gemm + tensor struct + ret, parsed_cfg_fc2_fprop_act = TEConfigAPIMapper().parse_config_and_api( + cfg_fc2["PerTensorScaling"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="fprop", + tensor_name="activation", + ) + assert ret + assert parsed_cfg_fc2_fprop_act == {"gemm": "fprop", "tensor": "activation"} + + ret, parsed_cfg_fc2_fprop_wei = TEConfigAPIMapper().parse_config_and_api( + cfg_fc2["PerTensorScaling"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="fprop", + tensor_name="weight", + ) + assert ret + assert parsed_cfg_fc2_fprop_wei == {"gemm": "fprop", "tensor": "weight"} + + ret, parsed_cfg_fc2_wgrad_act = TEConfigAPIMapper().parse_config_and_api( + cfg_fc2["PerTensorScaling"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="wgrad", + tensor_name="activation", + ) + assert ret + assert parsed_cfg_fc2_wgrad_act == {"gemm": "wgrad", "tensor": "activation"} + + ret, parsed_cfg_fc2_wgrad_grad = TEConfigAPIMapper().parse_config_and_api( + cfg_fc2["PerTensorScaling"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="wgrad", + tensor_name="gradient", + ) + assert ret + assert parsed_cfg_fc2_wgrad_grad == {"gemm": "wgrad", "tensor": "gradient"} + + ConfigManager.reset() diff --git a/tests/pytorch/debug/test_configs/disable_fp8_gemms.yaml b/tests/pytorch/debug/test_configs/disable_fp8_gemms.yaml new file mode 100644 index 000000000..b832f26d8 --- /dev/null +++ b/tests/pytorch/debug/test_configs/disable_fp8_gemms.yaml @@ -0,0 +1,8 @@ +test_disable_fp8_gemm_1: + enabled: True + layers: + layer_types: [qkv, fc2] + transformer_engine: + DisableFP8GEMM: + enabled: True + gemms: [dgrad, wgrad] \ No newline at end of file diff --git a/tests/pytorch/debug/test_configs/disable_fp8_layer.yaml b/tests/pytorch/debug/test_configs/disable_fp8_layer.yaml new file mode 100644 index 000000000..39bfc7a25 --- /dev/null +++ b/tests/pytorch/debug/test_configs/disable_fp8_layer.yaml @@ -0,0 +1,7 @@ +test_disable_fp8_layer: + enabled: True + layers: + layer_types: [qkv] + transformer_engine: + DisableFP8Layer: + enabled: True \ No newline at end of file diff --git a/tests/pytorch/debug/test_configs/dummy_feature.yaml b/tests/pytorch/debug/test_configs/dummy_feature.yaml new file mode 100644 index 000000000..540e3ac42 --- /dev/null +++ b/tests/pytorch/debug/test_configs/dummy_feature.yaml @@ -0,0 +1,9 @@ +deummy_feature_everywhere: + enabled: True + layers: + layer_name_regex_pattern: .* + transformer_engine: + TestDummyFeature: + enabled: True + tensors: [weight, activation, gradient, output, wgrad, dgrad] + gemms: [wgrad, dgrad, fprop] \ No newline at end of file diff --git a/tests/pytorch/debug/test_configs/fake_quantization_config.yaml b/tests/pytorch/debug/test_configs/fake_quantization_config.yaml new file mode 100644 index 000000000..62feace6d --- /dev/null +++ b/tests/pytorch/debug/test_configs/fake_quantization_config.yaml @@ -0,0 +1,14 @@ +test_fake_quant_fp8: + enabled: True + layers: + layer_numbers: [1] + layer_types: [fc1, fc2] + transformer_engine: + FakeQuant: + enabled: True + gemms: [fprop, dgrad] + tensors_struct: + - tensor: activation + quant_format: FP8E4M3 + - tensor: gradient + quant_format: FP8E5M2 \ No newline at end of file diff --git a/tests/pytorch/debug/test_configs/per_tensor_scaling.yaml b/tests/pytorch/debug/test_configs/per_tensor_scaling.yaml new file mode 100644 index 000000000..c17f2f7d2 --- /dev/null +++ b/tests/pytorch/debug/test_configs/per_tensor_scaling.yaml @@ -0,0 +1,19 @@ +test_per_tensor_scaling: + enabled: True + layers: + layer_numbers: [1] + layer_types: [fc1, fc2] + transformer_engine: + DisableFP8GEMM: + enabled: True + gemms: [wgrad] + PerTensorScaling: + enabled: True + gemms_struct: + - gemm: fprop + tensors_struct: + - tensor: activation + - tensor: weight + - gemm: dgrad + tensors_struct: + - tensor: gradient \ No newline at end of file diff --git a/tests/pytorch/debug/test_configs/stats_collection_test_config.yaml b/tests/pytorch/debug/test_configs/stats_collection_test_config.yaml new file mode 100644 index 000000000..8f01b2d62 --- /dev/null +++ b/tests/pytorch/debug/test_configs/stats_collection_test_config.yaml @@ -0,0 +1,59 @@ +stat_collection_test_1: + enabled: True + layers: + layer_numbers: [1, 3] + LogTensorStats: + enabled: True + stats: [mean, std, l1_norm, l2_norm] + tensors: [activation] + freq: 1 + start_step: 100 + end_step: 500 + transformer_engine: + LogTensorStats: + enabled: True + stats: [cur_amax, dynamic_range] + tensors: [activation] + freq: 2 + start_step: 100 + end_step: 500 + LogFp8TensorStats: + enabled: True + stats: [underflows%] + tensors: [gradient] + freq: 5 + start_step: 100 + end_step: 500 + +stat_collection_test_2: + enabled: True + layers: + layer_numbers: [6, 7] + transformer_engine: + LogTensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [cur_amax, dynamic_range, mean, std, l1_norm] + freq: 2 + start_step: 100 + end_step: 500 + - tensor: weight + stats: [mean, std, l1_norm, min, max] + freq: 5 + start_step: 100 + end_step: 500 + +stat_collection_test_4: + enabled: True + layers: + layer_numbers: [5] + transformer_engine: + LogTensorStats: + enabled: True + tensors: [activation] + stats: [cur_amax, dynamic_range, mean, std, l1_norm] + LogFp8TensorStats: + enabled: True + stats: [underflows%] + tensors: [activation] \ No newline at end of file diff --git a/tests/pytorch/debug/test_configs/tensor_manipulation_transformer_engine.yaml b/tests/pytorch/debug/test_configs/tensor_manipulation_transformer_engine.yaml new file mode 100644 index 000000000..e86486366 --- /dev/null +++ b/tests/pytorch/debug/test_configs/tensor_manipulation_transformer_engine.yaml @@ -0,0 +1,45 @@ +# This config is used when FP8 training is ON + +transformer_engine_fc1_manipulation: + enabled: True + layers: + layer_name_regex_pattern: .*(fc1) # Select layers if they end in fc1 + transformer_engine: # namespace + DisableFP8GEMM: # Disable FP8 GEMM. FProp run in high precision + enabled: True + gemms: [fprop] + PerTensorScaling: # Scale DGrad gradients using per tensor current scaling and run FP8 GEMM + enabled: True + gemms: [dgrad] + tensors: [gradient] + FakeQuant: # Disable FP8 GEMM for Wgrad. Fake quantize activations to Wgrad and run high precision GEMM + enabled: True + gemms: [fprop] + tensors_struct: + - tensor: activation + quant_format: FP8E4M3 + - tensor: weight + quant_format: FP8E4M3 + +transformer_engine_fc2_manipulation: + enabled: True + layers: + layer_name_regex_pattern: .*(fc2) # Select layers if they end in fc2 + transformer_engine: # namespace + PerTensorScaling: # Scale WGrad and Fprop inputs using per tensor current scaling and run FP8 GEMM + enabled: True + gemms_struct: + - gemm: fprop + tensors_struct: + - tensor: activation + - tensor: weight + - gemm: wgrad + tensors_struct: + - tensor: activation + - tensor: gradient + FakeQuant: # Disable FP8 GEMM for DGrad. Fake quantize weights and gradients to DGrad and run high precision GEMM + enabled: True + gemms_struct: + - gemm: dgrad + tensors: [weight, gradient] + quant_format: FP8E5M2 \ No newline at end of file diff --git a/tests/pytorch/debug/test_distributed.py b/tests/pytorch/debug/test_distributed.py new file mode 100644 index 000000000..ab5b60a13 --- /dev/null +++ b/tests/pytorch/debug/test_distributed.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import subprocess +from pathlib import Path +import pytest +import torch + +""" + Distributed numerics tests + + These tests test the numerical corectness of the TransformerEngine layers. + Tests are parametrized by the layer and fp8 precision. + One test consists of running multiple configurations from file run_numerics.py + Such design is due to the fact the initialization of one test is long + - 2 processes need to start and load torch and TE. Multiple configurations + are run in one test - this reduces the initialization overhead. + +""" + +if torch.cuda.device_count() < 2: + pytest.skip("Distributed training needs at least 2 GPUs.") + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS: int = min(4, torch.cuda.device_count()) +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + + +def test_debug_distributed(feature_dirs): + test_path = TEST_ROOT / "run_distributed.py" + test_cmd = LAUNCH_CMD + [str(test_path), f"--feature_dirs={feature_dirs[0]}"] + + result = subprocess.run(test_cmd, env=os.environ, check=False, text=True) + if result.returncode != 0: + raise AssertionError(f"torchrun exited with {result.returncode}") diff --git a/tests/pytorch/debug/test_numerics.py b/tests/pytorch/debug/test_numerics.py new file mode 100644 index 000000000..749fa16bc --- /dev/null +++ b/tests/pytorch/debug/test_numerics.py @@ -0,0 +1,761 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import functools +import itertools +import os +import random +import tempfile +from string import Template + +import pytest +import torch + +import nvdlfw_inspect.api as debug_api +import transformer_engine.debug +import transformer_engine.pytorch as tepytorch +import transformer_engine_torch as tex +from transformer_engine.common.recipe import DelayedScaling, Format +from transformer_engine.pytorch.fp8 import _default_sf_compute +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.module.base import ( + _2X_ACC_DGRAD, + _2X_ACC_FPROP, + _2X_ACC_WGRAD, +) +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager + +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +all_boolean = [True, False] +FP8_FORMAT = Format.HYBRID +AMAX_HISTORY_LEN = 16 +FP8_RECIPE = DelayedScaling( + fp8_format=FP8_FORMAT, amax_history_len=AMAX_HISTORY_LEN, amax_compute_algo="max" +) +SEED = 1234 +IN_SIZE = 128 +OUT_SIZE = 64 +BATCH_SIZE = 16 +SEQ_LEN = 128 +LOSS_FN = torch.nn.functional.cross_entropy + + +def _cast_to_fp8(tensor, scale, dtype): + tensor = tensor.contiguous() + if type(scale) == torch.Tensor: + amax = scale.abs().max().float() + quantizer = Float8Quantizer(scale, amax, dtype) + else: + quantizer = Float8CurrentScalingQuantizer(scale, device=tensor.device) + + return quantizer(tensor) + + +def _get_current_scale(tensor, fp8_dtype): + if fp8_dtype == tex.DType.kFloat8E4M3: + fp8_max = Format.E4M3.value.max_fwd + else: + fp8_max = Format.E5M2.value.max_fwd + + amax = tensor.abs().max().float() + one = torch.ones(1, device=tensor.device) + + return _default_sf_compute(amax, one, fp8_max, 0).detach() + + +def _fake_cast(tensor, fp8_dtype, scale): + scale = scale or _get_current_scale(tensor, fp8_dtype) + fp8_tensor = _cast_to_fp8(tensor, scale, fp8_dtype) + + return fp8_tensor.dequantize() + + +def _fp8_gemm_kernel(tensor1, scale1, dtype1, tensor2, scale2, dtype2, use_split_accumulator): + fp8_tensor1 = _cast_to_fp8(tensor1, scale1, dtype1) + fp8_tensor2 = _cast_to_fp8(tensor2, scale2, dtype2) + + out, *_ = tepytorch.cpp_extensions.general_gemm( + fp8_tensor1, + fp8_tensor2, + tepytorch.module.base.get_workspace(), + torch.float32, + use_split_accumulator=use_split_accumulator, + ) + out.requires_grad = True + return out.T + + +def _emulate_linear( + input: torch.Tensor, + weight: torch.Tensor, + fprop_fp8: bool = False, + fprop_input_fake_quant: tex.DType = None, + fprop_input_scale: torch.Tensor = None, + fprop_weight_fake_quant: tex.DType = None, + fprop_weight_scale: torch.Tensor = None, + dgrad_fp8: bool = False, + dgrad_gradient_fake_quant: tex.DType = None, + dgrad_gradient_scale: torch.Tensor = None, + dgrad_weight_fake_quant: tex.DType = None, + dgrad_weight_scale: torch.Tensor = None, + wgrad_fp8: bool = False, + wgrad_gradient_fake_quant: tex.DType = None, + wgrad_gradient_scale: torch.Tensor = None, + wgrad_input_fake_quant: tex.DType = None, + wgrad_input_scale: torch.Tensor = None, + loss_multiplier: float = 1.0, + activation_sync=None, + gradient_sync=None, +): + _scalar = lambda x: torch.Tensor([x]).cuda() if type(x) in [float, torch.Tensor] else x + if fprop_fp8: + activation = _fp8_gemm_kernel( + input, + _scalar(fprop_input_scale or 1.0), + tex.DType.kFloat8E4M3, + weight, + _scalar(fprop_weight_scale or 1.0), + tex.DType.kFloat8E4M3, + _2X_ACC_FPROP, + ) + activation = activation.clone().detach().contiguous().requires_grad_(True) + else: + fprop_input = ( + _fake_cast(input, fprop_input_fake_quant, _scalar(fprop_input_scale)) + if fprop_input_fake_quant is not None + else input + ) + fprop_weight = ( + _fake_cast(weight, fprop_weight_fake_quant, _scalar(fprop_weight_scale)) + if fprop_weight_fake_quant is not None + else weight + ) + + activation = (fprop_input @ fprop_weight.T).contiguous() + + if activation_sync: + activation = activation_sync(activation) + + activation.retain_grad() + + (loss_multiplier * activation.sum()).backward(retain_graph=True) + gradient = activation.grad.clone() + + if gradient_sync: + gradient = gradient_sync(gradient) + + if dgrad_fp8: + dgrad = _fp8_gemm_kernel( + weight.T, + _scalar(dgrad_weight_scale or 1.0), + tex.DType.kFloat8E4M3, + gradient, + _scalar(dgrad_gradient_scale or 1.0), + tex.DType.kFloat8E5M2, + _2X_ACC_DGRAD, + ).T + else: + dgrad_gradient = ( + _fake_cast(gradient, dgrad_gradient_fake_quant, _scalar(dgrad_gradient_scale)) + if dgrad_gradient_fake_quant is not None + else gradient + ) + + dgrad_weight = ( + _fake_cast(weight, dgrad_weight_fake_quant, _scalar(dgrad_weight_scale)) + if dgrad_weight_fake_quant is not None + else weight + ) + dgrad = dgrad_gradient @ dgrad_weight + + if wgrad_fp8: + wgrad = _fp8_gemm_kernel( + input.T, + _scalar(wgrad_input_scale or 1.0), + tex.DType.kFloat8E4M3, + gradient.T, + _scalar(wgrad_gradient_scale or 1.0), + tex.DType.kFloat8E5M2, + _2X_ACC_WGRAD, + ).T + else: + wgrad_gradient = ( + _fake_cast(gradient, wgrad_gradient_fake_quant, _scalar(wgrad_gradient_scale)) + if wgrad_gradient_fake_quant is not None + else gradient + ) + wgrad_input = ( + _fake_cast(input, wgrad_input_fake_quant, _scalar(wgrad_input_scale)) + if wgrad_input_fake_quant is not None + else input + ) + wgrad_input = wgrad_input.contiguous() + wgrad_gradient = wgrad_gradient.contiguous() + wgrad, *_ = tepytorch.cpp_extensions.general_gemm( + wgrad_input, + wgrad_gradient, + tepytorch.module.base.get_workspace(), + torch.float32, + layout="NT", + grad=True, + use_split_accumulator=_2X_ACC_WGRAD, + ) + + return {"activation": activation, "wgrad": wgrad, "dgrad": dgrad} + + +def _init_debug(config_name, log_dir, feature_dirs): + debug_api.initialize( + config_file=config_name, + feature_dirs=feature_dirs, + log_dir=log_dir, + default_logging_enabled=True, + ) + + +def create_config_file(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file: + with tempfile.TemporaryDirectory() as temp_dir: + try: + kwargs["config_file"] = temp_file + kwargs["log_dir"] = temp_dir + result = func(*args, **kwargs) + finally: + temp_file_name = temp_file.name + debug_api.end_debug() + os.unlink(temp_file_name) + return result + + return wrapper + + +def _cmp(ground_truth, output): + torch.testing.assert_close(ground_truth["activation"], output["activation"]) + torch.testing.assert_close(ground_truth["wgrad"], output["wgrad"]) + torch.testing.assert_close(ground_truth["dgrad"], output["dgrad"]) + + +def _init_model(weight): + model = transformer_engine.pytorch.Linear(IN_SIZE, OUT_SIZE, name="linear") + with torch.no_grad(): + model.weight.copy_(weight.contiguous()) + return model + + +def _run_forward_backward(x, model, loss_scale=1.0, is_first_microbatch=None, fp8=True): + with tepytorch.fp8_autocast(enabled=fp8, fp8_recipe=FP8_RECIPE): + y = model(x, is_first_microbatch=is_first_microbatch) + (y.sum() * loss_scale).backward() + debug_api.step() + return y + + +def _get_tensors(): + torch.manual_seed(SEED) + x = torch.randn((SEQ_LEN * BATCH_SIZE, IN_SIZE), requires_grad=True).cuda() + x.retain_grad() + weight = torch.randn((OUT_SIZE, IN_SIZE)).cuda() + return x, weight + + +LOGGING_CONFIG = """logging_config: + enabled: True + layers: + layer_types: [linear] + transformer_engine: + LogTensorStats: + enabled: True + tensors: [activation, gradient, weight, output, wgrad, dgrad] + stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range] +""" + + +DISABLE_FP8_CONFIG = Template( + """disable_fp8_config: + enabled: True + layers: + layer_types: [linear] + transformer_engine: + DisableFP8GEMM: + enabled: True + gemms: [$gemms] +""" +) + + +@create_config_file +def run_logging_zero_numel_tensor(feature_dirs, **kwargs): + kwargs["config_file"].write(LOGGING_CONFIG) + kwargs["config_file"].flush() + + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs) + + x, weight = _get_tensors() + x1 = x[:0, :] + model = _init_model(weight) + _ = _run_forward_backward(x1, model, fp8=False) + _ = _run_forward_backward(x, model, fp8=False) + + +def test_logging_zero_numel_tensor(feature_dirs): + run_logging_zero_numel_tensor(feature_dirs) + + +@pytest.mark.parametrize("fprop_fp8", all_boolean) +@pytest.mark.parametrize("dgrad_fp8", all_boolean) +@pytest.mark.parametrize("wgrad_fp8", all_boolean) +def test_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8): + if not fp8_available: + pytest.skip(reason_for_no_fp8) + run_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8) + + +def disable_fp8_gemms_create_config(fprop_fp8, dgrad_fp8, wgrad_fp8, config_file): + gemms = "" + if not fprop_fp8: + gemms += "fprop," + if not dgrad_fp8: + gemms += "dgrad," + if not wgrad_fp8: + gemms += "wgrad," + if len(gemms) > 0: + gemms = gemms[:-1] # remove last ',' + config_file.write(DISABLE_FP8_CONFIG.safe_substitute(gemms=gemms)) + config_file.flush() + + +@create_config_file +def run_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8, **kwargs): + disable_fp8_gemms_create_config(fprop_fp8, dgrad_fp8, wgrad_fp8, kwargs["config_file"]) + fp8_kwargs = { + "fprop_fp8": fprop_fp8, + "dgrad_fp8": dgrad_fp8, + "wgrad_fp8": wgrad_fp8, + } + + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs) + x, weight = _get_tensors() + model = _init_model(weight) + y = _run_forward_backward(x, model) + + output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()} + + x.grad.zero_() + ground_truth = _emulate_linear(x, weight, **fp8_kwargs) + _cmp(ground_truth, output) + + +def test_disable_fp8_layer(feature_dirs): + if not fp8_available: + pytest.skip(reason_for_no_fp8) + run_disable_fp8_layer(feature_dirs) + + +DISABLE_FP8_LAYER_CONFIG = """disable_fp8_config: + enabled: True + layers: + layer_types: [linear] + transformer_engine: + DisableFP8Layer: + enabled: True +""" + + +@create_config_file +def run_disable_fp8_layer(feature_dirs, **kwargs): + kwargs["config_file"].write(DISABLE_FP8_LAYER_CONFIG) + kwargs["config_file"].flush() + + x, weight = _get_tensors() + + ground_truth = _emulate_linear(x, weight) + x.grad.zero_() + + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs) + + model = _init_model(weight) + y = _run_forward_backward(x, model) + + output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()} + _cmp(ground_truth, output) + + +random.seed(1234) + +all_combinations = list(itertools.product(all_boolean, repeat=6)) +subset_combinations = random.sample(all_combinations, 20) + + +@pytest.mark.parametrize( + "fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad", + subset_combinations, +) +def test_per_tensor_scaling( + feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad +): + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if not any([fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad]): + pytest.skip("Skipping test because all parameters are False") + run_per_tensor_scaling( + feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad + ) + + +PER_TENSOR_SCALING_CONFIG = Template( + """per_tensor_scaling_config: + enabled: True + layers: + layer_types: [linear] + transformer_engine: + PerTensorScaling: + enabled: True + gemms_struct: +$gemms +""" +) + + +def _prepare_per_tensor_scaling_config( + fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad, config_file +): + gemms = "" + title = lambda x: f" - gemm: {x}\n tensors: [" + + def add_tensor(if_add, gemm_name): + nonlocal gemms + if if_add: + gemms += gemm_name + "," + + if fprop_inp or fprop_weight: + gemms += title("fprop") + add_tensor(fprop_inp, "activation") + add_tensor(fprop_weight, "weight") + gemms = gemms[:-1] + "]\n" + if dgrad_weight or dgrad_grad: + gemms += title("dgrad") + add_tensor(dgrad_weight, "weight") + add_tensor(dgrad_grad, "gradient") + gemms = gemms[:-1] + "]\n" + if wgrad_input or wgrad_grad: + gemms += title("wgrad") + add_tensor(wgrad_input, "activation") + add_tensor(wgrad_grad, "gradient") + gemms = gemms[:-1] + "]\n" + config_file.write(PER_TENSOR_SCALING_CONFIG.safe_substitute(gemms=gemms)) + config_file.flush() + + +def set_scaling_factors(model, input_kwargs, fp8_kwargs): + # Copy fp8 scaling factors into fp8_kwargs dict if respective flag in input_kwargs is set. + if not input_kwargs["fprop_inp"]: + fp8_kwargs["fprop_input_scale"] = model.fp8_meta["scaling_fwd"].scale[0].clone() + if not input_kwargs["fprop_weight"]: + fp8_kwargs["fprop_weight_scale"] = model.fp8_meta["scaling_fwd"].scale[1].clone() + if not input_kwargs["dgrad_grad"]: + fp8_kwargs["dgrad_gradient_scale"] = model.fp8_meta["scaling_bwd"].scale[0].clone() + if not input_kwargs["dgrad_weight"]: + fp8_kwargs["dgrad_weight_scale"] = model.fp8_meta["scaling_fwd"].scale[1].clone() + if not input_kwargs["wgrad_grad"]: + fp8_kwargs["wgrad_gradient_scale"] = model.fp8_meta["scaling_bwd"].scale[0].clone() + if not input_kwargs["wgrad_input"]: + fp8_kwargs["wgrad_input_scale"] = model.fp8_meta["scaling_fwd"].scale[0].clone() + + +def set_current_scaling_factors(x, weight, y, input_kwargs, fp8_kwargs): + # Compute per tensor scaling factor if respective flag in input_kwargs is set. + if input_kwargs["fprop_inp"]: + fp8_kwargs["fprop_input_scale"] = tex.DType.kFloat8E4M3 + if input_kwargs["fprop_weight"]: + fp8_kwargs["fprop_weight_scale"] = tex.DType.kFloat8E4M3 + if input_kwargs["dgrad_grad"]: + fp8_kwargs["dgrad_gradient_scale"] = tex.DType.kFloat8E5M2 + if input_kwargs["dgrad_weight"]: + fp8_kwargs["dgrad_weight_scale"] = tex.DType.kFloat8E4M3 + if input_kwargs["wgrad_grad"]: + fp8_kwargs["wgrad_gradient_scale"] = tex.DType.kFloat8E5M2 + if input_kwargs["wgrad_input"]: + fp8_kwargs["wgrad_input_scale"] = tex.DType.kFloat8E4M3 + + +@create_config_file +def run_per_tensor_scaling( + feature_dirs, + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + **kwargs, +): + input_kwargs = { + "fprop_inp": fprop_inp, + "fprop_weight": fprop_weight, + "dgrad_weight": dgrad_weight, + "dgrad_grad": dgrad_grad, + "wgrad_input": wgrad_input, + "wgrad_grad": wgrad_grad, + } + fp8_kwargs = { + "fprop_fp8": True, + "dgrad_fp8": True, + "wgrad_fp8": True, + } + """ + Runs a test to validate per-tensor (current) scaling in FP8 computations. + The function performs warm-up iterations to populate the amax buffer of the model and compute scaling factors based on delayed scaling. + Subsequently, weights and inputs are switched to ensure their current scaling factors differ from those based on delayed scaling; + similarly, the loss is multiplied by a large factor to alter the gradient's magnitude, + creating a discrepancy between the original (delayed) and per-tensor (current) scaling factors. + Finally, a linear pass is emulated, and the results are compared.” + """ + _prepare_per_tensor_scaling_config( + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + kwargs["config_file"], + ) + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs) + + warmup_input, warmup_weight = _get_tensors() + model = _init_model(warmup_weight) + + # Warmup run to setup amax and scaling factors. + for _ in range(AMAX_HISTORY_LEN): + _run_forward_backward(warmup_input, model) + + x = torch.randn_like(warmup_input, requires_grad=True).cuda() + weight = torch.randn_like(warmup_weight, requires_grad=True).cuda() + model.weight.data = weight.data + x.retain_grad() + + # delayed scaling factor + # need to be collected before forward pass with test data, + # because this forward pass changes scaling factors + set_scaling_factors(model, input_kwargs, fp8_kwargs) + + LOSS_MULTIPLIER = 100 + + with tepytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + y = model(x, is_first_microbatch=True) + model.zero_grad() + y.retain_grad() + ( + LOSS_MULTIPLIER * y.sum() + ).backward() # Loss multiplication to change gradient's order of magintude + + output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()} + + # per tensor - current - scaling factors + # need to be collected after forward pass with test data, + # because gradient(y.grad) cannot be accessed before forward, + # but it needs to be collected. + set_current_scaling_factors(x, weight, y, input_kwargs, fp8_kwargs) + + ground_truth = _emulate_linear(x, weight, loss_multiplier=LOSS_MULTIPLIER, **fp8_kwargs) + _cmp(ground_truth, output) + + +@pytest.mark.parametrize( + "fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad", + subset_combinations, +) +def test_microbatching_per_tensor_scaling( + feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad +): + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if not any([fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad]): + pytest.skip("Skipping test because all parameters are False") + + @create_config_file + def run_microbatching_test( + feature_dirs, + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + **kwargs, + ): + # Prepare the configuration file + _prepare_per_tensor_scaling_config( + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + kwargs["config_file"], + ) + + # Initialize debug + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs) + + # Get data + x_full, weight = _get_tensors() + microbatch_size = x_full.size(0) // 2 + x_mb1 = x_full[:microbatch_size, ...].clone().detach().requires_grad_(True) + x_mb2 = x_full[microbatch_size:, ...].clone().detach().requires_grad_(True) + + def init_and_warmup(): + model = _init_model(weight) + _run_forward_backward(x_mb1, model, loss_scale=0.5) + _run_forward_backward(x_mb2, model, loss_scale=0.5) + return model + + # Run without is_first_microbatch + + model = init_and_warmup() # running next 2 iters does not change amaxes and scaling factors + y_mb1 = _run_forward_backward(x_mb1, model, loss_scale=0.5) + y_mb2 = _run_forward_backward(x_mb2, model, loss_scale=0.5) + + # Collect outputs + output1 = { + "activation": torch.cat([y_mb1.clone(), y_mb2.clone()], dim=0), + "wgrad": model.weight.grad.clone(), + "dgrad": torch.cat([x_mb1.grad.clone(), x_mb2.grad.clone()], dim=0), + } + + # Run with is_first_microbatch + model = init_and_warmup() # running next 2 iters does not change amaxes and scaling factors + y_mb1 = _run_forward_backward(x_mb1, model, loss_scale=0.5, is_first_microbatch=True) + y_mb2 = _run_forward_backward(x_mb2, model, loss_scale=0.5, is_first_microbatch=False) + + # Collect outputs + output2 = { + "activation": torch.cat([y_mb1.clone(), y_mb2.clone()], dim=0), + "wgrad": model.weight.grad.clone(), + "dgrad": torch.cat([x_mb1.grad.clone(), x_mb2.grad.clone()], dim=0), + } + + # Compare outputs + torch.testing.assert_close(output1["activation"], output2["activation"], atol=1.0, rtol=0.5) + torch.testing.assert_close(output1["dgrad"], output2["dgrad"], atol=1.0, rtol=0.5) + torch.testing.assert_close(output1["wgrad"], output2["wgrad"], atol=1.0, rtol=0.5) + + # Run the test + run_microbatching_test( + feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad + ) + + +all_combinations = list( + itertools.product([tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None], repeat=6) +) +subset_combinations = random.sample(all_combinations, 10) + + +@pytest.mark.parametrize( + "fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad", + subset_combinations, +) +def test_fake_quant_fp8( + feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad +): + if not fp8_available: + pytest.skip(reason_for_no_fp8) + run_fake_quant_fp8( + feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad + ) + + +FAKE_QUANT_CONFIG = Template( + """fake_quant_config: + enabled: True + layers: + layer_types: [linear] + transformer_engine: + FakeQuant: + enabled: True + gemms_struct: +$gemms +""" +) + + +def fake_quant_fp8_create_config( + fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad, config_file +): + format_to_str = {tex.DType.kFloat8E4M3: "FP8E4M3", tex.DType.kFloat8E5M2: "FP8E5M2"} + gemms = "" + + def _add_tensor(quant_format, tensor): + nonlocal gemms + if quant_format: + gemms += " " * 8 + "- tensor: " + tensor + "\n" + gemms += " " * 8 + " quant_format: " + format_to_str[quant_format] + "\n" + + title = lambda x: f" - gemm: {x}\n tensors_struct:\n" + if fprop_inp or fprop_weight: + gemms += title("fprop") + _add_tensor(fprop_inp, "activation") + _add_tensor(fprop_weight, "weight") + gemms = gemms[:-1] + "\n" + if dgrad_weight or dgrad_grad: + gemms += title("dgrad") + _add_tensor(dgrad_weight, "weight") + _add_tensor(dgrad_grad, "gradient") + gemms = gemms[:-1] + "\n" + if wgrad_input or wgrad_grad: + gemms += title("wgrad") + _add_tensor(wgrad_input, "activation") + _add_tensor(wgrad_grad, "gradient") + gemms = gemms[:-1] + "\n" + config = FAKE_QUANT_CONFIG.safe_substitute(gemms=gemms) + config_file.write(config) + config_file.flush() + + +@create_config_file +def run_fake_quant_fp8( + feature_dirs, + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + **kwargs, +): + fp8_kwargs = { + "fprop_input_fake_quant": fprop_inp, + "fprop_weight_fake_quant": fprop_weight, + "dgrad_gradient_fake_quant": dgrad_grad, + "dgrad_weight_fake_quant": dgrad_weight, + "wgrad_gradient_fake_quant": wgrad_grad, + "wgrad_input_fake_quant": wgrad_input, + "fprop_fp8": not (fprop_inp or fprop_weight), + "dgrad_fp8": not (dgrad_weight or dgrad_grad), + "wgrad_fp8": not (wgrad_grad or wgrad_input), + } + fake_quant_fp8_create_config( + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + kwargs["config_file"], + ) + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs) + + x, weight = _get_tensors() + model = _init_model(weight) + y = _run_forward_backward(x, model) + + output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()} + ground_truth = _emulate_linear(x, weight, **fp8_kwargs) + _cmp(ground_truth, output) diff --git a/tests/pytorch/debug/test_sanity.py b/tests/pytorch/debug/test_sanity.py new file mode 100644 index 000000000..e4ce35be6 --- /dev/null +++ b/tests/pytorch/debug/test_sanity.py @@ -0,0 +1,99 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch + +import nvdlfw_inspect.api as debug_api +import transformer_engine.pytorch as te +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager + +from test_numerics import create_config_file + +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +B, S, H, D = 64, 64, 64, 64 + +model_keys = ["linear", "layernorm_linear", "layernorm_mlp", "mha_attention", "transformer_layer"] + +configs = { + "": "", + "log": """log: + layers: + layer_types: [linear] + enabled: + True + transformer_engine: + LogTensorStats: + enabled: True + tensors: [activation, gradient, weight, output, wgrad, dgrad] + stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range] + start_step : 0 + end_step: 1 + LogFp8TensorStats: + enabled: True + tensors: [activation, gradient, weight] + stats: [underflows, overflows] + start_step : 0 + end_step: 1 +""", + "fake_quant": """ +fake_quant_config: + enabled: True + layers: + layer_types: [linear] + transformer_engine: + FakeQuant: + enabled: True + gemms: [fprop, dgrad, wgrad] + quant_format: FP8E5M2 +""", +} + + +def _get_model(model_key): + if model_key == "linear": + return te.Linear(D, D) + if model_key == "layernorm_linear": + return te.LayerNormLinear(D, D) + if model_key == "layernorm_mlp": + return te.LayerNormMLP(D, D, D) + if model_key == "mha_attention": + return te.MultiheadAttention(D, H) + if model_key == "transformer_layer": + return te.TransformerLayer(D, D, H) + + +def _run_forward_backward(model, fp8): + for _ in range(3): + inp = torch.randn((S, B, H)).cuda() + with te.fp8_autocast(enabled=fp8): + out = model(inp) + out.sum().backward() + debug_api.step() + + +@create_config_file +def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir): + try: + if config != "": + config_file.write(config) + config_file.flush() + config_file_name = config_file.name if config != "" else "" + debug_api.initialize(feature_dirs=feature_dirs, config_file=config_file_name) + model = _get_model(model_key) + _run_forward_backward(model, fp8) + except Exception as error: + raise error + finally: + debug_api.end_debug() + + +@pytest.mark.parametrize("model_key", model_keys) +@pytest.mark.parametrize("fp8", [False, True]) +@pytest.mark.parametrize("config_key", configs.keys()) +def test_sanity_debug(model_key, fp8, config_key, feature_dirs): + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + _run_test(model_key, fp8, configs[config_key], feature_dirs) diff --git a/tests/pytorch/debug/utils.py b/tests/pytorch/debug/utils.py new file mode 100644 index 000000000..f03ee56b5 --- /dev/null +++ b/tests/pytorch/debug/utils.py @@ -0,0 +1,22 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os + +LOG_FILE = os.path.join("nvdlfw_inspect_logs", "nvdlfw_inspect_globalrank-0.log") + + +def reset_debug_log(): + if os.path.isfile(LOG_FILE): + # delete all content + with open(LOG_FILE, "w") as f: + pass + + +def check_debug_log(msg): + with open(LOG_FILE, "r") as f: + for line in f.readlines(): + if msg in line: + return True + return False diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index ce936f964..6d9e2f152 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -273,7 +273,9 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") assert dist.is_nccl_available() dist.init_process_group(**dist_init_kwargs) - tp_group = dist.new_group(backend="nccl") + tp_group = dist.new_group( + backend="nccl", pg_options=dist.ProcessGroupNCCL.Options(is_high_priority_stream=True) + ) tp_rank = dist.get_rank(tp_group) tp_size = dist.get_world_size(tp_group) dist_print( diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 48ace31c3..8638c1bce 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -323,6 +323,7 @@ def _train(opts): new_group_kwargs = { "backend": "nccl", "ranks": tp_rank_list, + "pg_options": dist.ProcessGroupNCCL.Options(is_high_priority_stream=True), } else: opts.tp = WORLD_SIZE diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index ac72960c4..8a201b72d 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -40,10 +40,17 @@ LOSS_FN = nn.MSELoss() QUANTIZATION = None - -# Disable TF32 -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False +if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False): + # The numerics of all the layers should work the same, + # when debug=True. I fed them with dummy feature + # to prevent switching off debug, which can happen if + # no feature is active. + import nvdlfw_inspect.api as debug_api + + debug_api.initialize( + os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"], + feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"], + ) # Quantization recipe setup @@ -105,11 +112,15 @@ def main(argv=None, namespace=None): # Quantization scheme QUANTIZATION = args.quantization - if QUANTIZATION in ("fp8", "mxfp8", "fp8_block_scaling"): - global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE + global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE + if QUANTIZATION in ("fp8", "mxfp8"): SEQ_LEN = 32 BATCH_SIZE = 32 HIDDEN_SIZE = 128 + elif QUANTIZATION == "fp8_block_scaling": + SEQ_LEN = 128 + BATCH_SIZE = 128 + HIDDEN_SIZE = 512 test_dict = [ test_quantizer, @@ -167,7 +178,7 @@ def backward(ctx, grad_output): def _constant(tensor): - return nn.init.constant_(tensor, 0.5) + return nn.init.constant_(tensor, 0.05) def dist_print(msg, src=None, end="\n", error=False): @@ -195,7 +206,8 @@ def _get_tolerances(dtype): if dtype == torch.bfloat16: return {"rtol": 1.6e-2, "atol": 1e-5} if dtype == torch.float32: - return {"rtol": 1.3e-6, "atol": 4e-5} + # TF32 has same mantissa bits as FP16 + return {"rtol": 1e-3, "atol": 1e-5} raise ValueError(f"Unsupported dtype ({dtype})") @@ -526,8 +538,11 @@ def test_linear(): {"return_bias": True}, {"params_dtype": torch.float16}, {"delay_wgrad_compute": True}, + {"save_original_input": True}, ] for kwargs in kwargs_list: + if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8": + continue for parallel_mode in ["column", "row"]: for sequence_parallel in [False, True]: _test_linear(parallel_mode, sequence_parallel, **kwargs) @@ -659,7 +674,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs if "return_layernorm_output" in kwargs: output_single_node, norm_s = output_single_node output_distributed, norm_d = output_distributed - if sequence_parallel: + if sequence_parallel and not kwargs.get("return_layernorm_output_gathered", False): norm_d = _gather(norm_d) _check_outputs(norm_s, norm_d) @@ -768,7 +783,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg if "return_layernorm_output" in kwargs: output_single_node, norm_s = output_single_node output_distributed, norm_d = output_distributed - if sequence_parallel: + if sequence_parallel and not kwargs.get("return_layernorm_output_gathered", False): norm_d = _gather(norm_d) _check_outputs(norm_s, norm_d) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index b8911eede..74d1dc69c 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -51,7 +51,7 @@ torch._dynamo.reset() -def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, quantization): +def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization): test_path = TEST_ROOT / "run_gemm_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -78,6 +78,8 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, quantization): if torch.cuda.get_device_properties(0).major != 9: pytest.skip("Atomic GEMM is requires device compute capability 9.x (Hopper).") test_cmd.append("--atomic") + if aggregate: + test_cmd.append("--aggregate") result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) if ( @@ -135,12 +137,13 @@ def _run_layer_with_overlap( @pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8")) -def test_split_all_gather_overlaps(quantization): +@pytest.mark.parametrize("aggregate", (False, True)) +def test_split_all_gather_overlaps(quantization, aggregate): """ Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap("AG", False, True, False, quantization) + _run_gemm_with_overlap("AG", False, True, False, aggregate, quantization) @pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8")) @@ -150,7 +153,7 @@ def test_split_reduce_scatter_overlaps(quantization, p2p): Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap("RS", False, p2p, False, quantization) + _run_gemm_with_overlap("RS", False, p2p, False, False, quantization) @pytest.mark.parametrize( @@ -183,10 +186,10 @@ def test_bulk_overlaps(comm_type, quantization, connections): " 9.0 (HOPPER ARCH)." ) os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" - _run_gemm_with_overlap(comm_type, True, False, False, quantization) + _run_gemm_with_overlap(comm_type, True, False, False, False, quantization) os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" else: - _run_gemm_with_overlap(comm_type, True, False, False, quantization) + _run_gemm_with_overlap(comm_type, True, False, False, False, quantization) @pytest.mark.parametrize( diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index 1e639b06e..6dc17b126 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -11,6 +11,7 @@ import functools import itertools import os +import pathlib import subprocess import sys from typing import Optional @@ -23,18 +24,27 @@ import transformer_engine.pytorch as te from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer import transformer_engine.pytorch.ops as te_ops -from transformer_engine.pytorch.ops._common import is_float8_tensor from transformer_engine.pytorch.utils import is_bf16_compatible, is_fp8_fnuz import transformer_engine_torch as tex -# Check if FP8 is supported +# Import utility functions +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import dtype_tols, make_recipe + + +# Check what quantization schemes are supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() quantization_list: list[Optional[str]] = [None] if fp8_available: - quantization_list.append("fp8") + quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: quantization_list.append("mxfp8") @@ -63,11 +73,12 @@ def reset_rng(seed: int = 1234) -> None: @torch.no_grad() def make_reference_and_test_tensors( shape: int | Iterable[int], + quantization: Optional[str] = None, ref_dtype: torch.dtype = torch.float64, ref_device: torch.device = "cpu", test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", - test_is_fp8: bool = False, + test_is_quantized: bool = False, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -76,78 +87,55 @@ def make_reference_and_test_tensors( operations in high precision. The test tensor is intended for use in Transformer Engine operations. + If a quantization scheme is provided, the tensor values are + quantized so that they are representable. + """ + + # Random reference tensor ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) + + # Construct test tensor from reference tensor test = ref.to(device=test_device, dtype=test_dtype) - if test_is_fp8: + if quantization is None: + if test_is_quantized: + raise ValueError("Quantization scheme not provided") + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + elif quantization in ("fp8", "fp8_delayed_scaling"): quantizer = Float8Quantizer( - scale=torch.ones(1, dtype=torch.float32, device=test_device), + scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), amax=torch.zeros(1, dtype=torch.float32, device=test_device), fp8_dtype=tex.DType.kFloat8E4M3, ) test = quantizer(test) - elif test.data_ptr() == ref.data_ptr(): - test = test.clone() + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device=test_device, + ) + test = quantizer(test) + elif quantization == "mxfp8": + test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + else: + raise ValueError(f"Unsupported quantization scheme ({quantization})") + if isinstance(test, QuantizedTensor) and not test_is_quantized: + test = test.dequantize() + + # Make sure reference and test tensors match each other ref.copy_(test) + ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) return ref, test -def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: - """Estimated numerical error for a datatype - - Based on tolerances for torch.testing.assert_close. - - """ - - # Transformer Engine dtypes - if isinstance(dtype, tex.DType): - if dtype == tex.DType.kFloat8E4M3: - return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 - if dtype == tex.DType.kFloat8E5M2: - return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 - dtype = { - tex.DType.kByte: torch.uint8, - tex.DType.kInt32: torch.int32, - tex.DType.kFloat32: torch.float32, - tex.DType.kFloat16: torch.half, - tex.DType.kBFloat16: torch.bfloat16, - }[dtype] - - # PyTorch dtypes - if dtype == torch.float16: - return dict(rtol=1e-3, atol=1e-5) - if dtype == torch.bfloat16: - return dict(rtol=1.6e-2, atol=1e-5) - if dtype == torch.float32: - return dict(rtol=1.3e-6, atol=1e-5) - if dtype == torch.float64: - return dict(rtol=1e-7, atol=1e-7) - raise ValueError(f"Unsupported dtype ({dtype})") - - -def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: - """Make recipe for quantization scheme""" - if name is None: - return None - if name == "fp8": - return transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - if name == "mxfp8": - return transformer_engine.common.recipe.MXFP8BlockScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - raise ValueError(f"Unsupported quantization scheme ({name})") - - def _test_all_reduce( *, - local_size: int = 17, + local_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8: bool = False, + quantization: Optional[str] = None, ) -> None: # Distributed process group @@ -156,22 +144,25 @@ def _test_all_reduce( world_size = torch.distributed.get_world_size(process_group) # Tensor dimensions - in_shape = [world_size, local_size] - out_shape = [local_size] + in_shape = [world_size, local_size, local_size] + out_shape = [local_size, local_size] # Random data reset_rng() + with_quantization = quantization is not None x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) # Plain PyTorch implementation @@ -199,10 +190,10 @@ def _test_all_reduce( def _test_all_gather( *, - local_size: int = 13, + local_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8: bool = False, + quantization: Optional[str] = None, ) -> None: # Distributed process group @@ -211,26 +202,29 @@ def _test_all_gather( world_size = torch.distributed.get_world_size(process_group) # Tensor dimensions - in_shape = [world_size, local_size] - out_shape = [world_size, world_size * local_size] + in_shape = [world_size, local_size, local_size] + out_shape = [world_size, world_size * local_size, local_size] # Random data reset_rng() + with_quantization = quantization is not None x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) # Plain PyTorch implementation - y_ref = x_ref.tile((world_size, 1)).reshape(out_shape) + y_ref = x_ref.tile((world_size, 1, 1)).reshape(out_shape) y_ref.backward(dy_ref) # Convert to distributed tensors @@ -257,10 +251,10 @@ def _test_all_gather( def _test_reduce_scatter( *, - local_size: int = 11, + local_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8: bool = False, + quantization: Optional[str] = None, ) -> None: # Distributed process group @@ -269,22 +263,25 @@ def _test_reduce_scatter( world_size = torch.distributed.get_world_size(process_group) # Tensor dimensions - in_shape = [world_size, world_size * local_size] - out_shape = [world_size, local_size] + in_shape = [world_size, world_size * local_size, local_size] + out_shape = [world_size, local_size, local_size] # Random data reset_rng() + with_quantization = quantization is not None x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) # Plain PyTorch implementation @@ -324,7 +321,11 @@ def _test_basic_linear( tensor_parallel_mode: str = "column", sequence_parallel: bool = False, ) -> None: + + # Skip invalid configurations quantized_compute = quantization is not None + if not quantized_compute and quantized_weight: + return # Distributed process group process_group = world_group() @@ -348,30 +349,23 @@ def _test_basic_linear( reset_rng() x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) - if isinstance(w_test, QuantizedTensor): - w_test = w_test.dequantize() dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -468,7 +462,11 @@ def _test_linear( tensor_parallel_mode: str = "column", sequence_parallel: bool = False, ) -> None: + + # Skip invalid configurations quantized_compute = quantization is not None + if not quantized_compute and quantized_weight: + return # Distributed process group process_group = world_group() @@ -492,21 +490,16 @@ def _test_linear( reset_rng() x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) - if isinstance(w_test, QuantizedTensor): - w_test = w_test.dequantize() b_ref, b_test = None, None if bias: if tensor_parallel_mode == "row": @@ -520,13 +513,11 @@ def _test_linear( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -773,9 +764,10 @@ def run_parallel_tests() -> None: if rank == 0: print(f"Running _test_all_reduce") _test_all_reduce() - if rank == 0: - print(f"Running _test_all_gather") - _test_all_gather() + for quantization in quantization_list: + if rank == 0: + print(f"Running _test_all_gather with quantization={quantization}") + _test_all_gather(quantization=quantization) if rank == 0: print(f"Running _test_reduce_scatter") _test_reduce_scatter() @@ -829,6 +821,7 @@ def run_parallel_tests() -> None: @pytest.mark.parametrize("world_size", _world_sizes) def test_distributed_fuser_ops(world_size: int) -> None: """Launch parallel job that runs parallel tests""" + #TODO: find out why cannot align the following two lines with NV upstream python_exe = sys.executable current_file = os.path.abspath(__file__) command = [ diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 42070ea0f..37f0e8669 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -21,26 +21,30 @@ import transformer_engine.pytorch.cpp_extensions as tex from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops -from transformer_engine.pytorch.ops._common import is_float8_tensor from transformer_engine.pytorch.ops.fused import ( UserbuffersBackwardLinear, UserbuffersForwardLinear, ) -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.utils import is_bf16_compatible # Import utility functions _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) -from utils import dtype_tols, str_to_dtype +from utils import dtype_tols, make_recipe, str_to_dtype # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() quantization_list: list[Optional[str]] = [None] if fp8_available: - quantization_list.append("fp8") + quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: quantization_list.append("mxfp8") @@ -118,11 +122,12 @@ def reset_rng(seed: int = 1234) -> None: @torch.no_grad() def make_reference_and_test_tensors( shape: int | Iterable[int], + quantization: Optional[str] = None, ref_dtype: torch.dtype = torch.float64, ref_device: torch.device = "cpu", test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", - test_is_fp8: bool = False, + test_is_quantized: bool = False, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -131,47 +136,49 @@ def make_reference_and_test_tensors( operations in high precision. The test tensor is intended for use in Transformer Engine operations. + If a quantization scheme is provided, the tensor values are + quantized so that they are representable. + """ - # Random data + # Random reference tensor ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) - # Make copy of tensor + # Construct test tensor from reference tensor test = ref.to(device=test_device, dtype=test_dtype) - if test_is_fp8: + if quantization is None: + if test_is_quantized: + raise ValueError("Quantization scheme not provided") + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + elif quantization in ("fp8", "fp8_delayed_scaling"): quantizer = Float8Quantizer( - scale=torch.ones(1, dtype=torch.float32, device=test_device), + scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), amax=torch.zeros(1, dtype=torch.float32, device=test_device), fp8_dtype=tex.DType.kFloat8E4M3, ) test = quantizer(test) - elif test.data_ptr() == ref.data_ptr(): - test = test.clone() + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device=test_device, + ) + test = quantizer(test) + elif quantization == "mxfp8": + test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + else: + raise ValueError(f"Unsupported quantization scheme ({quantization})") + if isinstance(test, QuantizedTensor) and not test_is_quantized: + test = test.dequantize() - # Make sure reference and test tensors represent exact same values + # Make sure reference and test tensors match each other ref.copy_(test) - # Return reference and test tensors ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) return ref, test -def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: - """Make recipe for quantization scheme""" - if name is None: - return None - if name == "fp8": - return transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - if name == "mxfp8": - return transformer_engine.common.recipe.MXFP8BlockScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - raise ValueError(f"Unsupported quantization scheme ({name})") - - def _test_linear( *, model_config: ModelConfig, @@ -201,21 +208,16 @@ def _test_linear( reset_rng() x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(w_test, QuantizedTensor): - w_test = w_test.dequantize() b_ref, b_test = None, None if bias: if tensor_parallel_mode == "row": @@ -229,13 +231,11 @@ def _test_linear( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -370,7 +370,7 @@ def _test_linear( if quantized_compute: tols = dtype_tols( model[0].weight._fp8_dtype - if is_float8_tensor(model[0].weight) + if isinstance(model[0].weight, Float8Tensor) else tex.DType.kFloat8E4M3 ) diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index 632f50e90..1ff5aff99 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -56,7 +56,7 @@ def test_distributed(quantization): if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) if quantization == "fp8_cs" and not fp8_available: - pytest.skip(fp8_available) + pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if quantization == "fp8_block_scaling" and not fp8_block_scaling_available: diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 708cb15d7..672950f50 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -93,7 +93,7 @@ def run_dpa_with_cp( # instantiate core attn module core_attn = DotProductAttention( config.num_heads, - config.head_dim_qk, + (config.head_dim_qk, config.head_dim_v), num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, qkv_format=qkv_format, @@ -110,16 +110,22 @@ def run_dpa_with_cp( config.num_heads, config.head_dim_qk, ) - kv_input_shape = ( + k_input_shape = ( config.batch_size, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim_qk, ) + v_input_shape = ( + config.batch_size, + config.max_seqlen_kv, + config.num_gqa_groups, + config.head_dim_v, + ) attn_output_shape = ( config.batch_size, config.max_seqlen_q, - config.num_heads * config.head_dim_qk, + config.num_heads * config.head_dim_v, ) cu_seqlens_q = None cu_seqlens_kv = None @@ -132,16 +138,22 @@ def run_dpa_with_cp( config.num_heads, config.head_dim_qk, ) - kv_input_shape = ( + k_input_shape = ( config.max_seqlen_kv, config.batch_size, config.num_gqa_groups, config.head_dim_qk, ) + v_input_shape = ( + config.max_seqlen_kv, + config.batch_size, + config.num_gqa_groups, + config.head_dim_v, + ) attn_output_shape = ( config.max_seqlen_q, config.batch_size, - config.num_heads * config.head_dim_qk, + config.num_heads * config.head_dim_v, ) cu_seqlens_q = None cu_seqlens_kv = None @@ -153,14 +165,19 @@ def run_dpa_with_cp( config.num_heads, config.head_dim_qk, ) - kv_input_shape = ( + k_input_shape = ( config.batch_size * config.max_seqlen_q, config.num_gqa_groups, config.head_dim_qk, ) + v_input_shape = ( + config.batch_size * config.max_seqlen_q, + config.num_gqa_groups, + config.head_dim_v, + ) attn_output_shape = ( config.batch_size * config.max_seqlen_q, - config.num_heads * config.head_dim_qk, + config.num_heads * config.head_dim_v, ) seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32) seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2) @@ -183,8 +200,8 @@ def run_dpa_with_cp( assert False, f"{qkv_format} is an unsupported qkv_format!" q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda() - k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda() - v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda() + k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda() + v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda() dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda() dout_quantizer = Float8Quantizer( fp8_dtype=tex.DType.kFloat8E5M2, diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index a9e980288..f2bb09125 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -263,13 +263,19 @@ def test(): model_configs_base = { - # test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend - "base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0 - "base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0 - "base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1 - "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1 - "base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference - "base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference + # test: b, h, hg, d, sq, skv, p, mask, bias + "base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), + "base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), + "base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), + "base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), + "base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), + "base_4_0": ModelConfig(8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias"), + "base_4_1": ModelConfig(8, 16, 16, 192, 128, 2048, 0.0, "no_mask", "no_bias"), + "base_5_0": ModelConfig(8, 16, 16, 512, 1, 2048, 0.0, "no_mask", "no_bias"), + "base_5_1": ModelConfig(8, 16, 16, 512, 128, 2048, 0.0, "no_mask", "no_bias"), + "base_6_0": ModelConfig(8, 16, 16, 1024, 1, 2048, 0.0, "no_mask", "no_bias"), + "base_6_1": ModelConfig(8, 16, 16, 1024, 128, 2048, 0.0, "no_mask", "no_bias"), } @@ -297,6 +303,7 @@ def test_dot_product_mem_calc(): qkv_layout=qkv_layout, window_size=config.window_size, pad_between_seqs=pad_between_seqs, + is_training=is_training, ) if FusedAttnBackend["CK"] not in fused_attn_backends: pytest.skip("This test requires the CK fused attention backend.") @@ -350,14 +357,28 @@ def test_dot_product_attention( if config.window_size == (-1, -1) and swa: config.window_size = [2, 2] config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) + + is_training = True available_backends, _, fused_attn_backends = _get_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, window_size=config.window_size, pad_between_seqs=pad_between_seqs, + is_training=is_training, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not fused_attn_supported: + is_training = False + available_backends, _, fused_attn_backends = _get_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + window_size=config.window_size, + pad_between_seqs=pad_between_seqs, + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention # mannually pads and unpads the input and output of FlashAttention for testing purposes @@ -465,6 +486,7 @@ def test_dot_product_attention( share_cu_seqlens_ref, ) + logging.info(f"[test_dot_product_attention]: is_training = {is_training}") if unfused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs flash attn") torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols) @@ -509,18 +531,27 @@ def test_dpa_checkpoint(dtype, model_configs, model): "mla_1_1": ModelConfig( 4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128 ), # cross, 0 + "mla_1_2": ModelConfig( + 4, 16, 16, 192, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128 + ), # cross, 0 "mla_2_0": ModelConfig( 2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64 ), # self , 1 "mla_2_1": ModelConfig( 1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64 ), # cross, 1 + "mla_2_2": ModelConfig( + 1, 24, 24, 192, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=128 + ), # cross, 1 "mla_3_0": ModelConfig( 8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64 ), # inference "mla_3_1": ModelConfig( 8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128 ), # inference + "mla_3_2": ModelConfig( + 8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128 + ), # inference "mla_4_0": ModelConfig( 10, 16, 16, 192, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=128 ), @@ -1183,6 +1214,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layer_number=1, attention_type=config.attn_type, ).to(dtype=dtype, device="cuda") + if not is_training: + block = block.eval() cu_seqlens_q_padded = None cu_seqlens_kv_padded = None @@ -1272,9 +1305,11 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: "te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), "te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), "te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), + "te_1_3": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), "te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), "te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), "te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "te_2_3": ModelConfig(1, 16, 16, 64, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"), "te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"), "te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"), } @@ -1285,7 +1320,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @pytest.mark.parametrize("model", model_configs_te_layer.keys()) @pytest.mark.parametrize("ckpt_attn", [False]) -@pytest.mark.parametrize("qkv_format", ["sbhd"]) +@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd", "thd"]) @pytest.mark.parametrize("fused_qkv_params", [False]) @pytest.mark.parametrize("RoPE", [False]) def test_transformer_layer( @@ -1298,22 +1333,37 @@ def test_transformer_layer( tols = dict(atol=5e-2, rtol=5e-2) workspace_opt = True - qkv_layout="sbh3d" if fused_qkv_params else "sb3hd" - # override the qkv_layout in mqa gqa mode in ROCm TE - if IS_HIP_EXTENSION and model_configs[model].num_gqa_groups != model_configs[model].num_heads: - qkv_layout = "sbhd_sbhd_sbhd" - # Test backend availability + is_training = True available_backends, _, fused_attn_backends = _get_attention_backends( config, qkv_dtype=dtype, - qkv_layout=qkv_layout, + qkv_layout=( + qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd") + ), + is_training=is_training, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not fused_attn_supported: + is_training = False + available_backends, _, fused_attn_backends = _get_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=( + qkv_format.replace("hd", "h3d") + if fused_qkv_params + else qkv_format.replace("hd", "3hd") + ), + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends # Skip if only unfused backend is supported if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: pytest.skip("Less than two backends to compare.") + # Skip if qkv_format = thd and "padding" not in attn_mask_type + if qkv_format == "thd" and "padding" not in config.attn_mask_type: + pytest.skip("THD requires padding mask.") # UnfusedDotProductAttention backend if unfused_attn_supported: @@ -1326,6 +1376,7 @@ def test_transformer_layer( workspace_opt, fused_qkv_params, RoPE, + is_training, ) # FusedAttention backend @@ -1339,6 +1390,7 @@ def test_transformer_layer( workspace_opt, fused_qkv_params, RoPE, + is_training, ) # FlashAttention backend @@ -1352,8 +1404,10 @@ def test_transformer_layer( workspace_opt, fused_qkv_params, RoPE, + is_training, ) + logging.info(f"[test_transformer_layer]: is_training = {is_training}") if unfused_attn_supported and fused_attn_supported: logging.info("[test_transformer_layer]: unfused attn vs fused attn") torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) @@ -1420,6 +1474,7 @@ def _run_transformer_layer( workspace_opt: bool, fused_qkv_params: bool, RoPE: bool, + is_training: bool, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Run TransformerLayer module with one forward pass and one backward pass""" @@ -1434,48 +1489,84 @@ def _run_transformer_layer( _attention_backends["backend_selection_requires_update"] = True # Create input tensor - inp = torch.randn( - config.max_seqlen_q, - config.batch_size, - config.hidden_size, - dtype=dtype, - device="cuda", - requires_grad=True, - ) - # In case the format to be tested is batch-first, need to transpose the - # input tensor. + if qkv_format == "sbhd": + inp = torch.randn( + config.max_seqlen_q, + config.batch_size, + config.hidden_size, + dtype=dtype, + device="cuda", + requires_grad=True, + ) + inp_enc = torch.randn( + config.max_seqlen_kv, + config.batch_size, + config.hidden_size, + dtype=dtype, + device="cuda", + requires_grad=True, + ) if qkv_format == "bshd": - inp = inp.transpose(0, 1) + inp = torch.randn( + config.batch_size, + config.max_seqlen_q, + config.hidden_size, + dtype=dtype, + device="cuda", + requires_grad=True, + ) + inp_enc = torch.randn( + config.batch_size, + config.max_seqlen_kv, + config.hidden_size, + dtype=dtype, + device="cuda", + requires_grad=True, + ) # Create seqlens - if "padding" in config.attn_mask_type: - seqlens_q = torch.randint( - 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" - ) + if "padding" in config.attn_mask_type or qkv_format == "thd": + if config.attn_type == "self": + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + seqlens_kv = seqlens_q + if config.attn_type == "cross": + if config.max_seqlen_q > 1: + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + else: + seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda") + seqlens_kv = torch.randint( + 1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda" + ) else: seqlens_q = torch.full( [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" ) - - # Create attention mask if padding - attention_mask = None - if "padding" in config.attn_mask_type: - attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) - for i in range(config.batch_size): - attention_mask_q = torch.cat( - [ - attention_mask_q, - torch.Tensor( - [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i]) - ) - .to(torch.bool) - .unsqueeze(0) - .unsqueeze(0) - .unsqueeze(0), - ], - dim=0, - ) - attention_mask = attention_mask_q.to(device="cuda") + seqlens_kv = torch.full( + [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" + ) + cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) + cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0) + if qkv_format == "thd": + inp = torch.randn( + cu_seqlens_q[-1], + config.hidden_size, + dtype=dtype, + device="cuda", + requires_grad=True, + ) + inp_enc = torch.randn( + cu_seqlens_kv[-1], + config.hidden_size, + dtype=dtype, + device="cuda", + requires_grad=True, + ) sigma = 0.02 init_method = init_method_normal(sigma) @@ -1527,7 +1618,7 @@ def _run_transformer_layer( sequence_parallel=False, apply_residual_connection_post_layernorm=False, output_layernorm=False, - layer_type="encoder", + layer_type="encoder" if config.attn_type == "self" else "decoder", drop_path_rate=drop_path_rates[layer_number - 1], set_parallel_mode=True, fuse_qkv_params=fused_qkv_params, @@ -1537,6 +1628,8 @@ def _run_transformer_layer( bias=True, attn_input_format=qkv_format, ).to(dtype=dtype, device="cuda") + if not is_training: + block = block.eval() # Create ALiBi slopes alibi_slopes = None @@ -1546,16 +1639,22 @@ def _run_transformer_layer( # Run a forward and backward pass out = block( inp, - attention_mask=attention_mask, self_attn_mask_type=config.attn_mask_type, + encoder_output=inp_enc if config.attn_type == "cross" else None, + enc_dec_attn_mask_type=config.attn_mask_type if config.attn_type == "cross" else None, checkpoint_core_attention=False, rotary_pos_emb=rotary_pos_emb, core_attention_bias_type=config.attn_bias_type, core_attention_bias=bias, alibi_slopes=alibi_slopes, + max_seqlen_q=config.max_seqlen_q, + max_seqlen_kv=config.max_seqlen_kv, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, ) - loss = out.sum() - loss.backward() + if is_training: + loss = out.sum() + loss.backward() return out, inp.grad diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 78be0e952..edf518d6b 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -108,6 +108,18 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_2_4": ModelConfig( 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) ), # GQA + "cp_3_0": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64 + ), # MLA + "cp_3_1": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=64 + ), # MLA + "cp_3_2": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias", head_dim_v=64 + ), # MLA + "cp_3_3": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias", head_dim_v=64 + ), # MLA } @@ -162,6 +174,10 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ) if dtype != "fp8" and fp8_mha: pytest.skip("Only fp8 works with fp8_mha=True!") + if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: + pytest.skip("MLA CP currently only support KV P2P!") + if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: + pytest.skip("MLA CP currently does not support FP8 attention!") subprocess.run( get_bash_arguments( diff --git a/tests/pytorch/fused_attn/test_kv_cache.py b/tests/pytorch/fused_attn/test_kv_cache.py index eb3838ff1..967309459 100644 --- a/tests/pytorch/fused_attn/test_kv_cache.py +++ b/tests/pytorch/fused_attn/test_kv_cache.py @@ -52,7 +52,7 @@ 4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16 ), "infer_1": ModelConfig( - 2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16 + 2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16 ), } @@ -370,12 +370,24 @@ def generate_args( ] -def get_tols(module, backend, dtype): +def get_tols(config, module, backend, dtype): if module == "TransformerLayer": - tols = { - torch.half: (5e-3, 5e-3), - torch.bfloat16: (3.5e-2, 3.5e-2), - } + if config.head_dim_qk <= 128: + tols = { + torch.half: (5e-3, 5e-3), + torch.bfloat16: (3.5e-2, 3.5e-2), + } + else: + if backend == "UnfusedAttention": + tols = { + torch.half: (1.6e-2, 1.6e-2), + torch.bfloat16: (1.2e-1, 1e-1), + } + else: + tols = { + torch.half: (1e-2, 1e-2), + torch.bfloat16: (8e-2, 7e-2), + } if module == "DotProductAttention": tols = { torch.half: (1e-3, 1e-3), @@ -662,7 +674,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g incremental_output = incremental_output[0] # compare results - atol, rtol = get_tols(module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn) + atol, rtol = get_tols( + config, module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn + ) for i, seq in enumerate(sim.t_seq_ids): token_index = sim.step_lens[i] - 1 if qkv_format == "bshd": diff --git a/tests/pytorch/references/blockwise_quantizer_reference.py b/tests/pytorch/references/blockwise_quantizer_reference.py index f5c9dc0e9..1ce7d3e42 100644 --- a/tests/pytorch/references/blockwise_quantizer_reference.py +++ b/tests/pytorch/references/blockwise_quantizer_reference.py @@ -260,6 +260,7 @@ def quantize( eps: float = 0.0, pow_2_scales: bool = False, quant_tile_shape: Tuple[int, int] = (128, 128), + munge_scale_shapes: bool = True, ) -> QuantizeResult: # sanity checks assert x.dim() == 2 @@ -277,27 +278,33 @@ def quantize( assert quant_tile_shape in ((1, 128), (128, 128)) if quant_tile_shape[0] == 1: # Quantize row-wise - return self.scale_munger.munge_scale_shapes_for_backend( - self._quantize_vector_tiling( - x, - quant_dtype, - tile_len=quant_tile_shape[1], - return_transpose=return_transpose, - pow_2_scales=pow_2_scales, - eps=eps, - ), - quant_tile_shape, + result = self._quantize_vector_tiling( + x, + quant_dtype, + tile_len=quant_tile_shape[1], + return_transpose=return_transpose, + pow_2_scales=pow_2_scales, + eps=eps, ) + if munge_scale_shapes: + result = self.scale_munger.munge_scale_shapes_for_backend( + result, + quant_tile_shape, + ) + return result else: # Quantize block-wise - return self.scale_munger.munge_scale_shapes_for_backend( - self._quantize_square_block_tiling( - x, - quant_dtype, - tile_len=quant_tile_shape[0], - return_transpose=return_transpose, - pow_2_scales=pow_2_scales, - eps=eps, - ), - quant_tile_shape, + result = self._quantize_square_block_tiling( + x, + quant_dtype, + tile_len=quant_tile_shape[0], + return_transpose=return_transpose, + pow_2_scales=pow_2_scales, + eps=eps, ) + if munge_scale_shapes: + result = self.scale_munger.munge_scale_shapes_for_backend( + result, + quant_tile_shape, + ) + return result diff --git a/tests/pytorch/test_checkpoint.py b/tests/pytorch/test_checkpoint.py new file mode 100644 index 000000000..16e7feb1b --- /dev/null +++ b/tests/pytorch/test_checkpoint.py @@ -0,0 +1,175 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +import argparse +import functools +import os +import pathlib + +import pytest +import torch + +import transformer_engine.pytorch as te + +from utils import make_recipe + +# Check supported quantization schemes +fp8_available, reason_for_no_fp8 = te.fp8.FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = te.fp8.FP8GlobalStateManager.is_mxfp8_available() + + +# Test cases for loading checkpoint files +_TestLoadCheckpoint_name_list: tuple[str, ...] = ( + "linear", + "layernorm_linear", + "layernorm_mlp", + "layernorm", + "rmsnorm", + "transformer_layer", + "ops_linear", + "linear.fp8", + "ops_linear.fp8", + "linear.mxfp8", + "ops_linear.mxfp8", +) + + +class TestLoadCheckpoint: + """Tests for loading checkpoint files + + Tests assume that checkpoint files have already been created. In + order to regenerate checkpoint files, e.g. after a breaking change + in the checkpoint format, run this file directly as a Python + script: `python3 test_checkpoint.py --save-checkpoint all`. + + """ + + @staticmethod + def _make_module(name: str) -> torch.nn.Module: + """Construct a module""" + if name == "linear": + return te.Linear(1, 1) + if name == "layernorm_linear": + return te.LayerNormLinear(1, 1) + if name == "layernorm_mlp": + return te.LayerNormMLP(1, 1) + if name == "layernorm": + return te.LayerNorm(1) + if name == "rmsnorm": + return te.RMSNorm(1) + if name == "transformer_layer": + return te.TransformerLayer(1, 1, 1) + if name == "ops_linear": + return te.ops.Linear(1, 1) + if name == "linear.fp8": + with te.fp8_model_init(recipe=make_recipe("fp8")): + return te.Linear(16, 16) + if name == "ops_linear.fp8": + with te.fp8_model_init(recipe=make_recipe("fp8")): + return te.ops.Linear(16, 16) + if name == "linear.mxfp8": + with te.fp8_model_init(recipe=make_recipe("mxfp8")): + return te.Linear(32, 32) + if name == "ops_linear.mxfp8": + with te.fp8_model_init(recipe=make_recipe("mxfp8")): + return te.ops.Linear(32, 32) + raise ValueError(f"Unrecognized module name ({name})") + + @staticmethod + @functools.lru_cache(maxsize=None) + def _checkpoint_dir() -> pathlib.Path: + """Path to directory with checkpoint files""" + + # Check environment variable + path = os.getenv("NVTE_TEST_CHECKPOINT_ARTIFACT_PATH") + if path: + return pathlib.Path(path).resolve() + + # Fallback to path in root dir + root_dir = pathlib.Path(__file__).resolve().parent.parent.parent + return root_dir / "artifacts" / "tests" / "pytorch" / "test_checkpoint" + + @staticmethod + def _save_checkpoint(name: str, checkpoint_dir: Optional[pathlib.Path] = None) -> None: + """Save a module's checkpoint file""" + + # Path to save checkpoint + if checkpoint_dir is None: + checkpoint_dir = TestLoadCheckpoint._checkpoint_dir() + checkpoint_dir.mkdir(exist_ok=True) + checkpoint_file = checkpoint_dir / f"{name}.pt" + + # Create module and save checkpoint + module = TestLoadCheckpoint._make_module(name) + torch.save(module.state_dict(), checkpoint_file) + print(f"Saved checkpoint for {name} at {checkpoint_file}") + + @pytest.mark.parametrize("name", _TestLoadCheckpoint_name_list) + def test_module(self, name: str) -> None: + """Test for loading a module's checkpoint file""" + + # Skip if quantization is not supported + quantization = None + if "." in name: + quantization = name.split(".")[1] + if quantization == "fp8" and not fp8_available: + pytest.skip(reason_for_no_fp8) + if quantization == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + # Construct module + module = self._make_module(name) + + # Load checkpoint from file + checkpoint_file = self._checkpoint_dir() / f"{name}.pt" + if not checkpoint_file.is_file(): + raise FileNotFoundError(f"Could not find checkpoint file at {checkpoint_file}") + state_dict = torch.load(checkpoint_file, weights_only=False) + + # Update module from checkpoint + module.load_state_dict(state_dict, strict=True) + + +def main() -> None: + """Main function + + Typically used to generate checkpoint files. + + """ + + # Parse command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument( + "--save-checkpoint", + type=str, + default=None, + help="Save checkpoint file for a module", + ) + parser.add_argument( + "--checkpoint-dir", + type=str, + default=None, + help="Directory to save checkpoint file in", + ) + args = parser.parse_args() + + # Save checkpoint files if needed + if args.save_checkpoint is not None: + checkpoint_dir = args.checkpoint_dir + if checkpoint_dir is not None: + checkpoint_dir = pathlib.Path(checkpoint_dir).resolve() + if args.save_checkpoint == "all": + for name in _TestLoadCheckpoint_name_list: + TestLoadCheckpoint._save_checkpoint(name, checkpoint_dir=checkpoint_dir) + else: + TestLoadCheckpoint._save_checkpoint( + args.save_checkpoint, + checkpoint_dir=checkpoint_dir, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 816df12f6..59383f21b 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -97,6 +97,8 @@ def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload max_mem_used = torch.cuda.memory_allocated() / (1024**2) torch.cuda.synchronize() + tensor.sum().backward() + return max_mem_used @@ -115,6 +117,9 @@ def test_cpu_offload(fp8_recipe, model_key) -> None: the difference being the size of the FP8 cache that is not offloaded to the CPU. We also expect this memory consumption to be smaller than in scenario (1). """ + import gc + + gc.collect() model_cls = model_types[model_key] models_list = [model_cls() for _ in range(NUM_LAYERS)] diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index 0baee4975..858ce73b6 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -88,6 +88,126 @@ def initialize_for_many_scales( return result +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (303, 300), + (305, 256), + # Some larger tiles. + (2000, 2000), + (2048, 2000), + (2000, 1024), + (2048, 1024), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"]) +def test_quantization_1D_block_tiling_with_compact_data_and_scales( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + pow_2_scales: bool, +) -> None: + te_dtype = TE_DType[quant_dtype] + tile_size = (1, 128) + # This test runs a comparison of the ref class versus the class using + # CUDA kernels to quantize. They should quantize identically for pixels + # that are not DC values in the scale factor shape. + ref_quantizer = BlockwiseQuantizerReference() + sut_quantizer = Float8BlockQuantizer( + fp8_dtype=te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=eps, + force_pow_2_scales=pow_2_scales, + block_scaling_dim=1, + all_gather_usage=True, + ) + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input + x = initialize_for_many_scales((M, N), tile_size, dtype=x_dtype, device=device) + + x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False) + x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut) + x_fp8_sut_cpp_alloc = sut_quantizer(x) + + assert x_fp8_sut._rowwise_data is not None + qx: torch.Tensor = x_fp8_sut._rowwise_data.view(dtype=quant_dtype) + assert x_fp8_sut._rowwise_scale_inv is not None + sx: torch.Tensor = x_fp8_sut._rowwise_scale_inv + qx_t = x_fp8_sut._columnwise_data + sx_t = x_fp8_sut._columnwise_scale_inv + + qresult_ref = ref_quantizer.quantize( + x, + quant_dtype=quant_dtype, + return_transpose=True, + eps=eps, + pow_2_scales=pow_2_scales, + quant_tile_shape=tile_size, + munge_scale_shapes=False, + ) + qx_ref, sx_ref, qx_t_ref, sx_t_ref = ( + qresult_ref.data, + qresult_ref.scale, + qresult_ref.data_t, + qresult_ref.scale_t, + ) + + # match the reference quantize transpose output with the columnwise non-transpose method + qx_t_ref = qx_t_ref.transpose(-1, -2).contiguous() + sx_t_ref = sx_t_ref.transpose(-1, -2).contiguous() + + # Check + torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0) + torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0) + assert qx_t is not None + qx_t = qx_t.view(dtype=quant_dtype) + assert qx_t_ref is not None + assert sx_t is not None + assert sx_t_ref is not None + torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0) + torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0) + + # check that the C++ and Python allocators are equivalent + torch.testing.assert_close( + x_fp8_sut._rowwise_data, x_fp8_sut_cpp_alloc._rowwise_data, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close( + x_fp8_sut._rowwise_scale_inv, x_fp8_sut_cpp_alloc._rowwise_scale_inv, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close( + x_fp8_sut._columnwise_data, x_fp8_sut_cpp_alloc._columnwise_data, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close( + x_fp8_sut._columnwise_scale_inv, + x_fp8_sut_cpp_alloc._columnwise_scale_inv, + atol=0.0, + rtol=0.0, + ) + + # check if the fp8 output between C++ and Python are the same + assert x_fp8_sut._data_format == x_fp8_sut_cpp_alloc._data_format + + def check_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index b63e949e4..281fc67a5 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -385,7 +385,7 @@ def compare_recipe( ) # recipe1 - using_fp8_recipe = recipe1 != GetRecipes.none + using_fp8_recipe = recipe1() != GetRecipes.none() if using_fp8_recipe: with fp8_autocast(enabled=True, fp8_recipe=recipe1()): y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) @@ -393,7 +393,7 @@ def compare_recipe( y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) # recipe2 - using_fp8_recipe = recipe2 != GetRecipes.none + using_fp8_recipe = recipe2() != GetRecipes.none() if using_fp8_recipe: with fp8_autocast(enabled=True, fp8_recipe=recipe2()): y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient) @@ -608,7 +608,7 @@ def compare_recipe( ) # recipe1 - using_fp8_recipe = recipe1 != GetRecipes.none + using_fp8_recipe = recipe1() != GetRecipes.none() if using_fp8_recipe: with fp8_autocast(enabled=True, fp8_recipe=recipe1()): y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear( @@ -630,7 +630,7 @@ def compare_recipe( ) # recipe2 - using_fp8_recipe = recipe2 != GetRecipes.none + using_fp8_recipe = recipe2() != GetRecipes.none() if using_fp8_recipe: with fp8_autocast(enabled=True, fp8_recipe=recipe2()): y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear( diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index 64aa1e4d2..63833b564 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -179,7 +179,40 @@ def test_quantize_dequantize_columnwise_only( ) @pytest.mark.parametrize("block_scaling_dim", [1, 2]) @pytest.mark.parametrize("dq_columnwise", [True, False]) + @pytest.mark.parametrize("all_gather_usage", [True, False]) def test_quantize_dequantize_dims( + self, + dims: DimsType, + block_scaling_dim: int, + dq_columnwise: bool, + all_gather_usage: bool, + ) -> None: + if all_gather_usage and block_scaling_dim != 1: + pytest.skip("all_gather_usage only implemented for 1D block quantization.") + atol = _tols[tex.DType.kFloat8E4M3]["atol"] + rtol = _tols[tex.DType.kFloat8E4M3]["rtol"] + quantizer = Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=dq_columnwise, + block_scaling_dim=block_scaling_dim, + all_gather_usage=all_gather_usage, + ) + self._test_quantize_dequantize( + quantizer=quantizer, + dims=dims, + atol=atol, + rtol=rtol, + dequant_columnwise=dq_columnwise, + ) + + @pytest.mark.parametrize( + "dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]] + ) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + @pytest.mark.parametrize("dq_columnwise", [True, False]) + @pytest.mark.xfail(raises=NotImplementedError) + def test_quantize_dequantize_compact_format( self, dims: DimsType, block_scaling_dim: int, dq_columnwise: bool ) -> None: atol = _tols[tex.DType.kFloat8E4M3]["atol"] @@ -189,6 +222,7 @@ def test_quantize_dequantize_dims( rowwise=True, columnwise=dq_columnwise, block_scaling_dim=block_scaling_dim, + all_gather_usage=True, ) self._test_quantize_dequantize( quantizer=quantizer, @@ -253,8 +287,13 @@ def test_data_accessors(self, dims: DimsType, block_scaling_dim: int) -> None: @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) @pytest.mark.parametrize("block_scaling_dim", [1, 2]) - def test_serialization(self, dims: DimsType, block_scaling_dim: int) -> None: + @pytest.mark.parametrize("all_gather_usage", [True, False]) + def test_serialization( + self, dims: DimsType, block_scaling_dim: int, all_gather_usage: bool + ) -> None: """Test serialization of Float8BlockwiseQTensor""" + if all_gather_usage and block_scaling_dim != 1: + pytest.skip("all_gather_usage only implemented for 1D block quantization.") device = "cuda" dtype = torch.bfloat16 x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) @@ -263,6 +302,7 @@ def test_serialization(self, dims: DimsType, block_scaling_dim: int) -> None: rowwise=True, columnwise=True, block_scaling_dim=block_scaling_dim, + all_gather_usage=all_gather_usage, ) # Create FP8 tensor @@ -286,6 +326,7 @@ def test_serialization(self, dims: DimsType, block_scaling_dim: int) -> None: assert x_fp8_loaded._is_2D_scaled == x_fp8._is_2D_scaled assert x_fp8_loaded.dtype == x_fp8.dtype assert x_fp8_loaded._fp8_dtype == x_fp8._fp8_dtype + assert x_fp8_loaded._data_format == x_fp8._data_format # Test that dequantized values match x_fp8_dequant = x_fp8.dequantize() diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py new file mode 100644 index 000000000..d2cb85dd3 --- /dev/null +++ b/tests/pytorch/test_fused_router.py @@ -0,0 +1,394 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +import torch +import math +from typing import Optional, Dict +from transformer_engine.pytorch.router import ( + fused_topk_with_score_function, + fused_compute_score_for_moe_aux_loss, + fused_moe_aux_loss, +) +import pytest +from copy import deepcopy + +seed = 42 +torch.manual_seed(seed) +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + +# Pytorch-based group topk +def group_limited_topk( + scores: torch.Tensor, + topk: int, + num_tokens: int, + num_experts: int, + num_groups: int, + group_topk: int, +): + group_scores = ( + scores.view(num_tokens, num_groups, -1).topk(topk // group_topk, dim=-1)[0].sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=group_topk, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + + # Mask the experts based on selection groups + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_tokens, num_groups, num_experts // num_groups) + .reshape(num_tokens, -1) + ) + + masked_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) + probs, top_indices = torch.topk(masked_scores, k=topk, dim=-1) + + return probs, top_indices + + +# Pytorch-based topk softmax/sigmoid +def topk_softmax_sigmoid_pytorch( + logits: torch.Tensor, + topk: int, + use_pre_softmax: bool = False, + num_groups: Optional[int] = None, + group_topk: Optional[int] = None, + scaling_factor: Optional[float] = None, + score_function: str = "softmax", + expert_bias: Optional[torch.Tensor] = None, +): + num_tokens, num_experts = logits.shape + + def compute_topk(scores, topk, num_groups=None, group_topk=None): + if group_topk: + return group_limited_topk( + scores=scores, + topk=topk, + num_tokens=num_tokens, + num_experts=num_experts, + num_groups=num_groups, + group_topk=group_topk, + ) + else: + return torch.topk(scores, k=topk, dim=1) + + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) + probs, top_indices = compute_topk(scores, topk, num_groups, group_topk) + else: + scores, top_indices = compute_topk(logits, topk, num_groups, group_topk) + probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits) + elif score_function == "sigmoid": + scores = torch.sigmoid(logits.float()).type_as(logits) + if expert_bias is not None: + scores_for_routing = scores + expert_bias + _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk) + scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits) + else: + scores, top_indices = compute_topk(scores, topk, num_groups, group_topk) + probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + else: + raise ValueError(f"Invalid score_function: {score_function}") + + if scaling_factor: + probs = probs * scaling_factor + + topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs) + topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool() + + return topk_masked_gates, topk_map + + +# Pytorch-based compute routing scores for aux loss +def compute_scores_for_aux_loss_pytorch( + logits: torch.Tensor, topk: int, score_function: str +) -> torch.Tensor: + if score_function == "softmax": + scores = torch.softmax(logits, dim=-1, dtype=torch.float32) + elif score_function == "sigmoid": + scores = torch.sigmoid(logits) + scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + else: + raise ValueError(f"Invalid score_function: {score_function}") + + _, top_indices = torch.topk(scores, k=topk, dim=1) + routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool() + return routing_map, scores + + +# Pytorch-based aux loss +def aux_loss_pytorch( + probs: torch.Tensor, + tokens_per_expert: torch.Tensor, + total_num_tokens: int, + topk: int, + num_experts: int, + moe_aux_loss_coeff: float, +): + aggregated_probs_per_expert = probs.sum(dim=0) + aux_loss = torch.sum(aggregated_probs_per_expert * tokens_per_expert) * ( + num_experts * moe_aux_loss_coeff / (topk * total_num_tokens * total_num_tokens) + ) + return aux_loss + + +def run_comparison( + dtype, + num_tokens, + num_experts, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + enable_bias, +): + # Set some parameters + if score_function == "sigmoid": + # Construct the special logits to avoid inf in the sigmoid function + offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 + logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 + logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) + else: + logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4 + logits = logits.view(num_tokens, num_experts) + logits.requires_grad = True + if enable_bias and score_function == "sigmoid": + expert_bias = torch.arange(num_experts, device="cuda") * 0.1 + expert_bias = torch.flip(expert_bias, dims=[0]) + expert_bias.requires_grad = True + else: + expert_bias = None + + # Clone the input tensor + logits_clone = deepcopy(logits) + logits_clone.requires_grad = True + if expert_bias is not None: + expert_bias_clone = deepcopy(expert_bias) + expert_bias_clone.requires_grad = True + else: + expert_bias_clone = None + + # Run the original implementation + # We do not support the capacity factor case + probs, routing_map = topk_softmax_sigmoid_pytorch( + logits=logits, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=expert_bias, + ) + + # Run the fused implementation + probs_fused, routing_map_fused = fused_topk_with_score_function( + logits=logits_clone, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=expert_bias_clone, + ) + + torch.testing.assert_close(probs, probs_fused) + torch.testing.assert_close(routing_map, routing_map_fused) + + # Fake the loss + loss = torch.sum(probs) + loss_fused = torch.sum(probs_fused) + + # Backward the loss + loss.backward() + loss_fused.backward() + + # Check the gradient + torch.testing.assert_close(logits.grad, logits_clone.grad) + + +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("num_tokens", [2048, 7168, 8992]) +@pytest.mark.parametrize("num_experts", [128, 32]) +@pytest.mark.parametrize("topk", [4, 8]) +@pytest.mark.parametrize("group_topk", [None, 4]) +@pytest.mark.parametrize("scaling_factor", [None, 1.2]) +@pytest.mark.parametrize("enable_bias", [True, False]) +def test_topk_sigmoid( + dtype, + num_tokens, + num_experts, + topk, + group_topk, + scaling_factor, + enable_bias, +): + num_groups = 8 if group_topk else None + run_comparison( + dtype=dtype, + num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + use_pre_softmax=False, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function="sigmoid", + enable_bias=enable_bias, + ) + + +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) +@pytest.mark.parametrize("num_experts", [128, 32]) +@pytest.mark.parametrize("topk", [4, 8]) +@pytest.mark.parametrize("use_pre_softmax", [True, False]) +@pytest.mark.parametrize("group_topk", [None, 4]) +@pytest.mark.parametrize("scaling_factor", [None, 1.2]) +def test_topk_softmax( + dtype, + num_tokens, + num_experts, + topk, + use_pre_softmax, + group_topk, + scaling_factor, +): + num_groups = 8 if group_topk else None + run_comparison( + dtype=dtype, + num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function="softmax", + enable_bias=False, + ) + + +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) +@pytest.mark.parametrize("num_experts", [256, 128, 32]) +@pytest.mark.parametrize("topk", [4, 8]) +@pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) +def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): + if score_function == "sigmoid": + # Construct the special logits to avoid inf in the sigmoid function + offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 + logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 + logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) + else: + logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4 + logits = logits.view(num_tokens, num_experts) + logits.requires_grad = True + + logits_clone = deepcopy(logits) + logits_clone.requires_grad = True + + routing_map, scores = compute_scores_for_aux_loss_pytorch( + logits=logits, + topk=topk, + score_function=score_function, + ) + + routing_map_fused, scores_fused = fused_compute_score_for_moe_aux_loss( + logits=logits_clone, + topk=topk, + score_function=score_function, + ) + + torch.testing.assert_close(scores, scores_fused) + torch.testing.assert_close(routing_map, routing_map_fused) + + loss = torch.sum(scores) + loss.backward() + loss_fused = torch.sum(scores_fused) + loss_fused.backward() + + torch.testing.assert_close(logits.grad, logits_clone.grad) + + +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) +@pytest.mark.parametrize("num_experts", [256, 128, 32]) +@pytest.mark.parametrize("topk", [4]) +def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): + # Construct the special probs to avoid inf in the sigmoid function + offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 + probs = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 + probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) + probs = probs.view(num_tokens, num_experts) + probs.requires_grad = True + + tokens_per_expert = torch.randint(1, 1000, (num_experts,), device="cuda", dtype=torch.int32) + coeff = 0.01 + + probs_clone = deepcopy(probs) + probs_clone.requires_grad = True + + aux_loss = aux_loss_pytorch( + probs=probs, + tokens_per_expert=tokens_per_expert, + total_num_tokens=num_tokens, + topk=topk, + num_experts=num_experts, + moe_aux_loss_coeff=coeff, + ) + + aux_loss_fused = fused_moe_aux_loss( + probs=probs_clone, + tokens_per_expert=tokens_per_expert, + total_num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + coeff=coeff, + ) + + torch.testing.assert_close(aux_loss, aux_loss_fused) + + # Backward + aux_loss.backward() + aux_loss_fused.backward() + + torch.testing.assert_close(probs.grad, probs_clone.grad) + + +def profile_topk_softmax( + dtype, + num_tokens, + num_experts, + topk, + enable_bias, + use_pre_softmax, +): + group_topk = 4 + scaling_factor = 1.2 + test_topk_sigmoid( + torch.float32, num_tokens, num_experts, topk, group_topk, scaling_factor, enable_bias + ) + test_topk_softmax( + torch.float32, num_tokens, num_experts, topk, use_pre_softmax, group_topk, scaling_factor + ) + + +if __name__ == "__main__": + test_fused_scores_for_aux_loss( + dtype=torch.float32, num_tokens=2, num_experts=32, topk=8, score_function="softmax" + ) + test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=32, topk=4) + test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=128, topk=4) + test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=256, topk=4) + test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=32, topk=4) + test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=128, topk=4) + test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=256, topk=4) + test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=32, topk=4) + test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=128, topk=4) + test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=256, topk=4) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index dd94f2435..78894d97d 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -9,6 +9,8 @@ from collections.abc import Iterable import io import math +import pathlib +import sys from typing import Optional import pytest @@ -20,16 +22,26 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops.fused import ( + BackwardBiasActivation, BackwardLinearAdd, ForwardLinearBiasActivation, ForwardLinearBiasAdd, ) from transformer_engine.pytorch.tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Tensor, + Float8CurrentScalingQuantizer, + Float8Quantizer, +) from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex +# Import utility functions +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent)) +from utils import dtype_tols, make_recipe + # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() @@ -42,6 +54,13 @@ # Supported devices _devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")] +# Supported quantization recipes +_quantization_list: list[Optional[str]] = [None] +if fp8_available: + _quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) +if mxfp8_available: + _quantization_list.append("mxfp8") + def maybe_skip_quantization( quantization: Optional[str], @@ -49,13 +68,14 @@ def maybe_skip_quantization( dims: Optional[Iterable[int] | int] = None, device: Optional[torch.device | str] = None, ) -> None: + """Skip test case if a quantization scheme is not supported""" # Don't skip if there is no quantization if quantization is None: return # Check if quantization scheme is supported - if quantization == "fp8" and not fp8_available: + if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) @@ -63,7 +83,7 @@ def maybe_skip_quantization( if dims is not None: if not isinstance(dims, Iterable): dims = (dims,) - if quantization == "fp8": + if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"): if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: pytest.skip("FP8 GEMMs require dims that are divisible by 16") elif quantization == "mxfp8": @@ -75,47 +95,15 @@ def maybe_skip_quantization( pytest.skip("Quantization is only supported on CUDA devices") -def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: - """Estimated numerical error for a datatype - - Based on tolerances for torch.testing.assert_close. - - """ - - # Transformer Engine dtypes - if isinstance(dtype, tex.DType): - if dtype == tex.DType.kFloat8E4M3: - return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 - if dtype == tex.DType.kFloat8E5M2: - return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 - dtype = { - tex.DType.kByte: torch.uint8, - tex.DType.kInt32: torch.int32, - tex.DType.kFloat32: torch.float32, - tex.DType.kFloat16: torch.half, - tex.DType.kBFloat16: torch.bfloat16, - }[dtype] - - # PyTorch dtypes - if dtype == torch.float16: - return dict(rtol=1e-3, atol=1e-5) - if dtype == torch.bfloat16: - return dict(rtol=1.6e-2, atol=1e-5) - if dtype == torch.float32: - return dict(rtol=1.3e-6, atol=1e-5) - if dtype == torch.float64: - return dict(rtol=1e-7, atol=1e-7) - raise ValueError(f"Unsupported dtype ({dtype})") - - @torch.no_grad() def make_reference_and_test_tensors( shape: int | Iterable[int], + quantization: Optional[str] = None, ref_dtype: torch.dtype = torch.float64, ref_device: torch.device = "cpu", test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", - test_is_fp8: bool = False, + test_is_quantized: bool = False, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -124,40 +112,50 @@ def make_reference_and_test_tensors( operations in high precision. The test tensor is intended for use in Transformer Engine operations. + If a quantization scheme is provided, the tensor values are + quantized so that they are representable. + """ + + # Random reference tensor ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) + + # Construct test tensor from reference tensor test = ref.to(device=test_device, dtype=test_dtype) - if test_is_fp8: + if quantization is None: + if test_is_quantized: + raise ValueError("Quantization scheme not provided") + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + elif quantization in ("fp8", "fp8_delayed_scaling"): quantizer = Float8Quantizer( scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), amax=torch.zeros(1, dtype=torch.float32, device=test_device), fp8_dtype=tex.DType.kFloat8E4M3, ) test = quantizer(test) - elif test.data_ptr() == ref.data_ptr(): - test = test.clone() + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device=test_device, + ) + test = quantizer(test) + elif quantization == "mxfp8": + test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + else: + raise ValueError(f"Unsupported quantization scheme ({quantization})") + if isinstance(test, QuantizedTensor) and not test_is_quantized: + test = test.dequantize() + + # Make sure reference and test tensors match each other ref.copy_(test) + ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) return ref, test -def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: - """Make recipe for quantization scheme""" - if name is None: - return None - if name == "fp8": - return transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - if name == "mxfp8": - return transformer_engine.common.recipe.MXFP8BlockScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - raise ValueError(f"Unsupported quantization scheme ({name})") - - -class TestSequential: +class TestSequentialContainer: """Tests for sequential container""" def test_modules(self) -> None: @@ -366,7 +364,7 @@ def test_fp8_scale_update( @pytest.mark.parametrize("init_dtype", _dtypes) @pytest.mark.parametrize("final_dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) def test_dtype_cast( self, *, @@ -379,8 +377,9 @@ def test_dtype_cast( """Check dtype cast functions""" # Skip invalid configurations - maybe_skip_quantization(quantization, device=device) + in_shape = (size, size) with_quantization = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data dtype = torch.float32 @@ -390,9 +389,9 @@ def test_dtype_cast( dtype = torch.bfloat16 w_ref, w_test = make_reference_and_test_tensors( (size, size), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=with_quantization, ) # Construct operation @@ -414,11 +413,11 @@ def test_dtype_cast( assert isinstance(op.weight, QuantizedTensor) == with_quantization assert op.weight.dtype == final_dtype w_test = op.weight.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(w_test, w_ref, rtol=0, atol=0) + torch.testing.assert_close(w_test, w_ref, **dtype_tols(dtype)) # Check forward and backward pass x = torch.zeros( - (size, size), + in_shape, dtype=init_dtype, device=device, requires_grad=True, @@ -431,7 +430,7 @@ def test_dtype_cast( @pytest.mark.parametrize("model_dtype", _dtypes) @pytest.mark.parametrize("autocast_dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) def test_pyt_autocast( self, *, @@ -446,8 +445,9 @@ def test_pyt_autocast( device = torch.device(device) # Skip invalid configurations + in_shape = (size, size) quantized_compute = quantization is not None - maybe_skip_quantization(quantization) + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Construct operation recipe = make_recipe(quantization) @@ -456,7 +456,7 @@ def test_pyt_autocast( # Check forward and backward pass x = torch.zeros( - (size, size), + in_shape, dtype=model_dtype, device=device, requires_grad=True, @@ -494,33 +494,34 @@ def setup_class(cls) -> None: @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", _quantization_list) def test_identity( self, *, - in_shape: Iterable[int] = (1,), + in_shape: Iterable[int] = (32, 32), dtype: torch.dtype, device: torch.device, - fp8: bool, + quantization: Optional[str], ) -> None: # Skip invalid configurations - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + with_quantization = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -556,7 +557,7 @@ def test_identity( ), ) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8_current_scaling")) def test_reshape( self, *, @@ -564,31 +565,32 @@ def test_reshape( dtype: torch.dtype, device: torch.device = "cuda", memory_format: torch.memory_format = torch.contiguous_format, - fp8: bool, + quantization: Optional[str], ) -> None: in_shape, out_shape = shapes # Skip invalid configurations if memory_format == torch.channels_last and len(in_shape) != 4: pytest.skip("torch.channels_last only supports 4D tensors") - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + maybe_skip_quantization(quantization, device=device) + with_quantization = quantization is not None # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) x_test = x_test.contiguous(memory_format=memory_format) x_test = x_test.detach().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( x_ref.reshape(out_shape).size(), + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -617,10 +619,10 @@ def test_reshape( torch.testing.assert_close(dx_test, x_ref.grad, **tols) @pytest.mark.parametrize("size", (1, 7, 32)) - @pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (2, 3, 4, -1))) + @pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (4, 3, 8, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", _devices) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", _quantization_list) def test_bias( self, *, @@ -628,24 +630,23 @@ def test_bias( in_shape: Iterable[int], dtype: torch.dtype, device: torch.device, - fp8: bool, + quantization: Optional[str], ) -> None: # Make input and bias shapes consistent in_shape = list(in_shape)[:-1] + [size] # Skip invalid configurations - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + with_quantization = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) b_ref, b_test = make_reference_and_test_tensors( size, @@ -654,8 +655,10 @@ def test_bias( ) dy_ref, dy_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -680,7 +683,7 @@ def test_bias( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols) - @pytest.mark.parametrize("quantization", ("fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("cast_forward", (False, True)) @pytest.mark.parametrize("cast_backward", (False, True)) def test_quantize( @@ -696,25 +699,26 @@ def test_quantize( """Quantize""" # Skip invalid configurations - maybe_skip_quantization(quantization) + with_quantization = quantization is not None + maybe_skip_quantization(quantization, device=device) + if quantization == "mxfp8": + maybe_skip_quantization(quantization, dims=in_shape) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - requires_grad=False, - test_is_fp8=True, + requires_grad=True, ) - x_test = x_test.dequantize().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, - test_is_fp8=True, ) - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = x_ref @@ -723,13 +727,14 @@ def test_quantize( # Implementation with fusible operation op = te_ops.Quantize(forward=cast_forward, backward=cast_backward) recipe = make_recipe(quantization) - with te.fp8_autocast(fp8_recipe=recipe): + with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) # Check tensor types - assert isinstance(y_test, QuantizedTensor) == cast_forward - assert isinstance(x_test.grad, QuantizedTensor) == cast_backward + if with_quantization: + assert isinstance(y_test, QuantizedTensor) == cast_forward + assert isinstance(x_test.grad, QuantizedTensor) == cast_backward # Check values tols = dict(rtol=0, atol=0) @@ -764,10 +769,25 @@ def _test_basic_linear( # Skip invalid configurations maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=out_shape) - if quantization == "fp8" and quantized_output and not quantized_compute: - pytest.skip("FP8 output is only supported with FP8 GEMMs") - if quantization == "fp8" and quantized_grad_input and not quantized_compute: - pytest.skip("FP8 grad input is only supported with FP8 GEMMs") + quantization_needed = any( + ( + quantized_compute, + quantized_input, + quantized_weight, + quantized_output, + quantized_grad_output, + quantized_grad_input, + ) + ) + if quantization is None and quantization_needed: + pytest.skip("Quantization scheme is not specified") + if quantization is not None and not quantization_needed: + pytest.skip("Quantization scheme is not used") + if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"): + if quantized_output and not quantized_compute: + pytest.skip("FP8 output is only supported with FP8 GEMMs") + if quantized_grad_input and not quantized_compute: + pytest.skip("FP8 grad input is only supported with FP8 GEMMs") if quantization == "mxfp8" and quantized_output: pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs") if quantization == "mxfp8" and quantized_grad_input: @@ -776,28 +796,25 @@ def _test_basic_linear( # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_input), + test_is_quantized=quantized_input, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_grad_output), + test_is_quantized=quantized_grad_output, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -860,7 +877,7 @@ def _test_basic_linear( @pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5))) @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) def test_basic_linear( self, @@ -882,7 +899,7 @@ def test_basic_linear( ) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - @pytest.mark.parametrize("quantization", ("fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantized_compute", (False, True)) @pytest.mark.parametrize("quantized_input", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True)) @@ -901,6 +918,8 @@ def test_basic_linear_quantized( quantized_grad_input: bool, ) -> None: """GEMM with FP8 inputs and outputs""" + if quantization is None: + pytest.skip("Skipping case without quantization") self._test_basic_linear( dtype=torch.bfloat16, quantization=quantization, @@ -913,8 +932,11 @@ def test_basic_linear_quantized( ) @pytest.mark.parametrize("bias", (False, True)) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantized_compute", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True)) + @pytest.mark.parametrize("input_requires_grad", (False, True)) + @pytest.mark.parametrize("weight_requires_grad", (False, True)) def test_linear( self, *, @@ -924,7 +946,10 @@ def test_linear( dtype: torch.dtype = torch.float32, device: torch.device = "cuda", quantization: Optional[str], + quantized_compute: bool, quantized_weight: bool, + input_requires_grad: bool, + weight_requires_grad: bool, ) -> None: """GEMM + bias""" @@ -934,25 +959,25 @@ def test_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - quantized_compute = quantization is not None maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=out_shape) + if quantization is None and (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not specified") + if quantization is not None and not (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not used") # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) b_ref, b_test = None, None if bias: @@ -963,6 +988,7 @@ def test_linear( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, @@ -988,9 +1014,12 @@ def test_linear( op.bias.copy_(b_test) del w_test del b_test + for param in op.parameters(): + param.requires_grad_(requires_grad=weight_requires_grad) with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = op(x_test) - y_test.backward(dy_test) + if input_requires_grad or weight_requires_grad: + y_test.backward(dy_test) # Expected numerical error tols = dtype_tols(dtype) @@ -1001,20 +1030,22 @@ def test_linear( # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") - dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") - dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(y_test, y_ref, **tols) - torch.testing.assert_close(dx_test, x_ref.grad, **tols) - torch.testing.assert_close(dw_test, w_ref.grad, **tols) - if bias: - db_test = op.bias.grad.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(db_test, b_ref.grad, **tols) + if input_requires_grad: + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + if weight_requires_grad: + dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + if bias: + db_test = op.bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db_test, b_ref.grad, **tols) @pytest.mark.parametrize("weight_shape", ((7, 2), (32,))) @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("zero_centered_gamma", (False, True)) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) def test_layer_norm( self, *, @@ -1184,7 +1215,7 @@ def test_layer_norm_autocast( @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("zero_centered_gamma", (False, True)) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) def test_rmsnorm( self, *, @@ -1261,6 +1292,7 @@ def test_rmsnorm( tols = dtype_tols(y_test._quantizer.dtype) expected_tensor_cls = { Float8Quantizer:Float8Tensor, + Float8CurrentScalingQuantizer:Float8Tensor, MXFP8Quantizer:MXFP8Tensor }[type(y_test._quantizer)] assert isinstance(y_test, expected_tensor_cls) @@ -1274,16 +1306,68 @@ def test_rmsnorm( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dw_test, w_ref.grad, **tols) + @pytest.mark.parametrize("in_shape", ((32,), (6, 16, 64), (32, 64))) + @pytest.mark.parametrize("dtype", _dtypes) + def test_l2normalization( + self, + *, + in_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + eps: float = 1e-6, + ) -> None: + """L2 Normalization""" + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + # L2 norm: x / ||x||_2 = x / sqrt(sum(x^2) + eps) + l2_norm_squared = x_ref.pow(2).sum(dim=-1, keepdim=True) + rsqrt_norm = torch.rsqrt(l2_norm_squared + eps) + y_ref = x_ref * rsqrt_norm + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = te_ops.L2Normalization( + eps=eps, + ) + y_test = op(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + + torch.testing.assert_close(y_test, y_ref, **tols) + # L2Norm backward pass requires slightly looser atol for bfloat16 + if dtype == torch.bfloat16: + tols["atol"] = 2e-3 + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", _quantization_list) def test_add_in_place( self, *, - in_shape: Iterable[int] = (1,), + in_shape: Iterable[int] = (32, 32), dtype: torch.dtype, device: torch.device, - fp8: bool, + quantization: Optional[str], ) -> None: """Add two tensors @@ -1292,28 +1376,30 @@ def test_add_in_place( """ # Skip invalid configurations - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + with_quantization = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x1_ref, x1_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) x2_ref, x2_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -1330,7 +1416,7 @@ def test_add_in_place( # Check results tols = dtype_tols(dtype) - if fp8: + if with_quantization: tols = dtype_tols(x1_test._fp8_dtype) y_test = y_test.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") @@ -1341,14 +1427,14 @@ def test_add_in_place( @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", _quantization_list) def test_make_extra_output( self, *, - in_shape: Iterable[int] = (1,), + in_shape: Iterable[int] = (32, 32), dtype: torch.dtype, device: torch.device, - fp8: bool, + quantization: Optional[str], ) -> None: """Output tensor twice @@ -1357,28 +1443,31 @@ def test_make_extra_output( """ # Skip invalid configurations - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + with_quantization = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy1_ref, dy1_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) dy2_ref, dy2_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -1404,7 +1493,7 @@ def test_make_extra_output( @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu")) @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("cache_quantized_input", (False, True)) def test_activation( self, @@ -1427,26 +1516,21 @@ def test_activation( quantized_compute = quantization is not None maybe_skip_quantization(quantization, dims=in_shape, device=device) if cache_quantized_input: - maybe_skip_quantization("fp8", device=device) + maybe_skip_quantization("fp8_current_scaling", device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization="fp8_current_scaling" if cache_quantized_input else None, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, requires_grad=False, ) - if quantized_compute: - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref: torch.Tensor @@ -1489,8 +1573,6 @@ def test_activation( tols = dtype_tols(dtype) if quantized_compute or cache_quantized_input: tols = dtype_tols(tex.DType.kFloat8E4M3) - if activation == "relu" and not cache_quantized_input: - tols = {"atol": 0, "rtol": 0} # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1499,7 +1581,7 @@ def test_activation( torch.testing.assert_close(dx_test, x_ref.grad, **tols) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantize_forward", (False, True)) @pytest.mark.parametrize("quantize_backward", (False, True)) def test_swiglu( @@ -1577,7 +1659,7 @@ def setup_class(cls) -> None: @pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5))) @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantized_weight", (False, True)) def test_forward_linear_bias_activation( self, @@ -1609,18 +1691,15 @@ def test_forward_linear_bias_activation( # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if quantized_compute: - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) b_ref, b_test = None, None if bias: @@ -1631,6 +1710,7 @@ def test_forward_linear_bias_activation( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, @@ -1687,7 +1767,7 @@ def test_forward_linear_bias_activation( @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) def test_forward_linear_bias_add( self, *, @@ -1716,18 +1796,15 @@ def test_forward_linear_bias_add( # Random data x1_ref, x1_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x1_test, QuantizedTensor): - with torch.no_grad(): - x1_test = x1_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) b_ref, b_test = None, None if bias: @@ -1743,6 +1820,7 @@ def test_forward_linear_bias_add( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, @@ -1800,8 +1878,100 @@ def test_forward_linear_bias_add( db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(db_test, b_ref.grad, **tols) + @pytest.mark.parametrize("activation", ("relu", "gelu")) + @pytest.mark.parametrize("out_shape", ((32, 32), (32, 1, 32), (8, 2, 2, 32))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) + def test_backward_bias_activation( + self, + *, + activation: str, + out_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + quantization: Optional[str], + ) -> None: + """Backward dbias + dact + quantize""" + + # Tensor dimensions + in_shape = list(out_shape) + hidden_size = in_shape[-1] + + # Skip invalid configurations + with_quantization = quantization is not None + maybe_skip_quantization(quantization, device=device) + if quantization == "mxfp8" and (len(in_shape) < 2 or in_shape[-1] % 32 != 0): + pytest.skip("Unsupported tensor size for MXFP8") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + ) + b_ref, b_test = make_reference_and_test_tensors( + hidden_size, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = x_ref + b_ref.reshape([1] * (len(in_shape) - 1) + [hidden_size]) + if activation == "gelu": + y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh") + elif activation == "relu": + y_ref = torch.nn.functional.relu(y_ref) + else: + raise ValueError(f"Unexpected activation function ({activation})") + y_ref.backward(dy_ref) + + # Implementation with fusible operations + recipe = make_recipe(quantization) + act_type = te_ops.GELU if activation == "gelu" else te_ops.ReLU + model = te_ops.Sequential( + te_ops.Quantize(forward=False, backward=True), + te_ops.Bias(hidden_size, device=device, dtype=dtype), + act_type(), + ) + with torch.no_grad(): + model[1].bias.copy_(b_test) + del b_test + with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe): + y_test = model(x_test) + y_test.backward(dy_test) + + # Check that backward operations have been fused + backward_ops = model._module_groups[0]._backward_ops + if with_quantization and quantization in ["fp8_delayed_scaling", "mxfp8"]: + assert len(backward_ops) == 2 + assert isinstance(backward_ops[0][0], BackwardBiasActivation) + assert isinstance(backward_ops[1][0], te_ops.Quantize) + else: + assert len(backward_ops) == 3 + assert isinstance(backward_ops[0][0], act_type) + assert isinstance(backward_ops[1][0], te_ops.Bias) + assert isinstance(backward_ops[2][0], te_ops.Quantize) + + # Expected numerical error + tols = dtype_tols(dtype) + if with_quantization: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + db_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(db_test, b_ref.grad, **tols) + @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) def test_backward_linear_add( self, *, @@ -1829,27 +1999,26 @@ def test_backward_linear_add( # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) dy1_ref, dy1_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, ) dy2_ref, dy2_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, @@ -1913,7 +2082,7 @@ def setup_class(cls) -> None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantized_weight", (False, True)) def test_linear( self, @@ -2016,3 +2185,109 @@ def test_linear( torch.testing.assert_close(y_load, y_save, **tols) for x_load, x_save in zip(xs_load, xs_save): torch.testing.assert_close(x_load.grad, x_save.grad, **tols) + + +class TestSequentialModules: + """Test for larger Sequentials with modules commonly used together""" + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("normalization", ("LayerNorm", "RMSNorm")) + @pytest.mark.parametrize("quantized_compute", (False, True)) + @pytest.mark.parametrize("quantized_weight", (False, True)) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) + def test_layernorm_mlp( + self, + *, + bias: bool, + normalization: str, + quantized_compute: bool, + quantized_weight: bool, + dtype: torch.dtype, + quantization: Optional[str], + device: torch.device = "cuda", + hidden_size: int = 32, + sequence_length: int = 512, + batch_size: int = 4, + ffn_hidden_size: int = 64, + layernorm_epsilon: float = 1e-5, + ) -> None: + """ + LayerNorm/RMSNorm + Linear + GELU + Linear + + Note that this test checks only if the module runs + as when chaining multiple modules it is hard to validate + numerical accuracy. + """ + + # Make input shape + in_shape = (sequence_length, batch_size, hidden_size) + ffn_shape = in_shape[:-1] + (ffn_hidden_size,) + + # Skip invalid configurations + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=ffn_shape, device=device) + quantization_needed = quantized_compute or quantized_weight + if quantization is None and quantization_needed: + pytest.skip("Quantization scheme is not specified") + if quantization is not None and not quantization_needed: + pytest.skip("Quantization scheme is not used") + + # Random data + _, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + _, dy_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Implementation with fusible operations + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + if normalization == "LayerNorm": + norm = te_ops.LayerNorm( + hidden_size, + eps=layernorm_epsilon, + device=device, + dtype=dtype, + ) + else: + norm = te_ops.RMSNorm( + hidden_size, + eps=layernorm_epsilon, + device=device, + dtype=dtype, + ) + ffn1 = te_ops.Linear( + hidden_size, + ffn_hidden_size, + bias=bias, + device=device, + dtype=dtype, + ) + act = te_ops.GELU() + ffn2 = te_ops.Linear( + ffn_hidden_size, + hidden_size, + bias=bias, + device=device, + dtype=dtype, + ) + forward = te_ops.Sequential(norm, ffn1, act, ffn2) + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y_test = forward(x_test) + y_test.backward(dy_test) diff --git a/tests/pytorch/test_hf_integration.py b/tests/pytorch/test_hf_integration.py new file mode 100644 index 000000000..0b2468510 --- /dev/null +++ b/tests/pytorch/test_hf_integration.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_utils import PreTrainedModel + +from transformer_engine.pytorch.transformer import TransformerLayer +from transformer_engine.pytorch.utils import is_bf16_compatible + + +class SimpleTEModel(PreTrainedModel): + config_class = PretrainedConfig + + def __init__(self, config: PretrainedConfig): + super().__init__(config) + self.my_layer = TransformerLayer( + hidden_size=320, + num_attention_heads=16, + ffn_hidden_size=1024, + layer_number=None, + ) + + def forward(self, hidden_states, attention_mask): + return self.my_layer(hidden_states, attention_mask) + + +def test_save_hf_model(tmp_path): + model = SimpleTEModel(PretrainedConfig()) + model.save_pretrained(tmp_path / "simple_te_model") + + +@pytest.mark.xfail(reason="This test is failing until huggingface/transformers#38155 is merged.") +def test_save_and_load_hf_model(tmp_path): + model = SimpleTEModel(PretrainedConfig()) + model.save_pretrained(tmp_path / "simple_te_model") + del model + model = SimpleTEModel.from_pretrained(tmp_path / "simple_te_model") + assert model is not None diff --git a/tests/pytorch/test_jit.py b/tests/pytorch/test_jit.py index a697cc048..e670070bc 100644 --- a/tests/pytorch/test_jit.py +++ b/tests/pytorch/test_jit.py @@ -63,3 +63,62 @@ def test_lazy_compile(): from transformer_engine.pytorch.jit import dgelu_fused_ dgelu_fused_(torch.randn(10, 10), torch.randn(10, 10)) + + +def test_l2normalization_fused(): + """Smoke test for L2Normalization fusion functions.""" + from transformer_engine.pytorch.jit import ( + l2normalization_fused, + l2normalization_fwd_fused, + l2normalization_backward_fused, + ) + + # Basic smoke test like other JIT functions + x = torch.randn(10, 128, device="cuda", dtype=torch.float32) + eps = 1e-6 + + # Test inference version + output_inf = l2normalization_fused(x, eps) + + # Test training version with backward + x_train = torch.randn(10, 128, device="cuda", dtype=torch.float32, requires_grad=True) + output_train, rsqrt_norm = l2normalization_fwd_fused(x_train, eps) + grad_output = torch.randn_like(output_train) + grad_input = l2normalization_backward_fused(grad_output, x_train, rsqrt_norm, eps) + + +def test_l2normalization_fused_correctness(): + """Simple verification that L2Normalization fusion matches reference implementation.""" + from transformer_engine.pytorch.jit import ( + l2normalization_fwd_fused, + l2normalization_backward_fused, + ) + + device = "cuda" if torch.cuda.is_available() else "cpu" + x = torch.randn(16, 64, device=device, dtype=torch.float32, requires_grad=True) + eps = 1e-6 + + # Test fused forward + output_fused, rsqrt_norm = l2normalization_fwd_fused(x, eps) + + # Reference implementation + x_ref = x.clone().detach().requires_grad_(True) + x_squared = x_ref.pow(2) + l2_norm_squared = x_squared.sum(dim=-1, keepdim=True) + rsqrt_norm_ref = torch.rsqrt(l2_norm_squared + eps) + output_ref = x_ref * rsqrt_norm_ref + + # Check forward pass matches + torch.testing.assert_close(output_fused, output_ref, atol=1e-6, rtol=1e-5) + torch.testing.assert_close(rsqrt_norm, rsqrt_norm_ref, atol=1e-6, rtol=1e-5) + + # Test fused backward + grad_output = torch.randn_like(output_fused) + grad_input_fused = l2normalization_backward_fused(grad_output, x, rsqrt_norm, eps) + + # Reference backward + output_ref.backward(grad_output) + grad_input_ref = x_ref.grad + + # Check backward pass matches + torch.testing.assert_close(grad_input_fused, grad_input_ref, atol=1e-5, rtol=1e-4) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 73c35eace..1787ab191 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -120,6 +120,20 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq mask_types = ["causal", "no_mask"] +NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0")) + +if NVTE_TEST_NVINSPECT_ENABLED: + # The numerics of all the layers should work the same, + # when debug=True. I fed them with dummy feature + # to prevent switching off debug, which can happen if + # no feature is active. + import nvdlfw_inspect.api as debug_api + + debug_api.initialize( + os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"], + feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"], + ) + fp8_recipes = [ recipe.MXFP8BlockScaling(), recipe.DelayedScaling(), @@ -621,6 +635,8 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") if recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) @@ -737,6 +753,8 @@ def test_gpt_full_activation_recompute( pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") if recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) @@ -1104,8 +1122,11 @@ def test_mha_accuracy(dtype, bs, model, mask_type): assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) -def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False): +def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False, recipe=None): reset_rng_states() + fp8 = recipe is not None + if fp8: + FP8GlobalStateManager.reset() inp_hidden_states = torch.randn( (config.seq_len, bs, config.hidden_size), @@ -1115,9 +1136,10 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False) ) inp_hidden_states.retain_grad() - out = block(inp_hidden_states) - if isinstance(out, (List, Tuple)): - out = out[0] + with fp8_autocast(enabled=fp8, fp8_recipe=recipe): + out = block(inp_hidden_states) + if isinstance(out, (List, Tuple)): + out = out[0] loss = out.sum() loss.backward() if delay_wgrad_compute: @@ -1394,6 +1416,64 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ torch.testing.assert_close(o, o_ref, rtol=0, atol=0) +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("model", ["small"]) +@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +def test_linear_accuracy_save_original_input(dtype, model, recipe): + bs = 1 + fuse_wgrad_accumulation = True + fp8_model_params = False + fp8 = recipe is not None + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8 and recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) + if fp8 and recipe.delayed(): + pytest.skip("DelayedScaling recipe is not supported with save_original_input") + + config = model_configs[model] + if config.seq_len % 16 != 0 and fp8: + pytest.skip("FP8 requires sequence length to be divisible by 16.") + + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + te_linear_ref = Linear( + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + save_original_input=False, + ).eval() + + te_linear = Linear( + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + save_original_input=True, + ).eval() + + # Share params + with torch.no_grad(): + te_linear_ref.weight = Parameter(te_linear.weight.clone()) + if fuse_wgrad_accumulation: + weight = getattr(te_linear, f"weight") + weight.main_grad = torch.rand_like(weight, dtype=torch.float32) + te_linear_ref.weight.main_grad = weight.main_grad.clone() + + te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe) + te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe) + + # Shoule be bit-wise match + for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @@ -1959,6 +2039,8 @@ def test_grouped_linear_accuracy( pytest.skip(reason_for_no_fp8) if fp8 and recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) @@ -1977,6 +2059,111 @@ def test_grouped_linear_accuracy( device="cuda", fuse_wgrad_accumulation=fuse_wgrad_accumulation, delay_wgrad_compute=delay_wgrad_compute, + save_original_input=False, + ).eval() + sequential_linear = torch.nn.ModuleList( + [ + Linear( + config.hidden_size, + 4 * config.hidden_size, + bias=bias, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + for _ in range(num_gemms) + ] + ) + + # Share params + with torch.no_grad(): + for i in range(num_gemms): + sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) + if bias: + sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) + if fuse_wgrad_accumulation: + weight_i = getattr(grouped_linear, f"weight{i}") + weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) + sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() + + outputs_ref = _test_grouped_linear_accuracy( + sequential_linear, + num_gemms, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + delay_wgrad_compute, + ) + outputs = _test_grouped_linear_accuracy( + grouped_linear, + num_gemms, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + delay_wgrad_compute, + ) + + # Shoule be bit-wise match + for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + +@pytest.mark.parametrize("dtype", param_types, ids=str) +@pytest.mark.parametrize("num_gemms", [3]) +@pytest.mark.parametrize("bs", [1]) +@pytest.mark.parametrize("model", ["126m"]) +@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("fp8_model_params", [False]) +@pytest.mark.parametrize("fuse_wgrad_accumulation", [True]) +@pytest.mark.parametrize("bias", [False]) +@pytest.mark.parametrize("delay_wgrad_compute", [True]) +def test_grouped_linear_accuracy_save_original_input( + dtype, + num_gemms, + bs, + model, + recipe, + fp8_model_params, + fuse_wgrad_accumulation, + bias, + delay_wgrad_compute, + parallel_mode=None, +): + fp8 = recipe is not None + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8 and recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") + if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) + if fp8 and recipe.delayed(): + pytest.skip("DelayedScaling recipe is not supported with save_original_input") + + config = model_configs[model] + if config.seq_len % 16 != 0 and fp8: + pytest.skip("FP8 requires sequence length to be divisible by 16.") + + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + grouped_linear = GroupedLinear( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=bias, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + delay_wgrad_compute=delay_wgrad_compute, + save_original_input=True, ).eval() sequential_linear = torch.nn.ModuleList( [ @@ -2151,14 +2338,100 @@ def _generate_random_numbers(n, total_sum): @pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_padding_grouped_linear_accuracy( - dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None + dtype, + num_gemms, + bs, + model, + fp8, + recipe, + fp8_model_params, + parallel_mode=None, +): + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + if recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") + if recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) + + config = model_configs[model] + if config.seq_len % 16 != 0 and fp8: + pytest.skip("FP8 requires sequence length to be divisible by 16.") + + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + grouped_linear = TorchGroupedLinearWithPadding( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + parallel_mode=parallel_mode, + fp8=fp8, + ).eval() + + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + ref_grouped_linear = GroupedLinear( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + save_original_input=False, + ).eval() + + # Share params + with torch.no_grad(): + inner_grouped_linear = grouped_linear.linear_fn + for i in range(num_gemms): + setattr( + ref_grouped_linear, + f"weight{i}", + Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), + ) + + outputs = _test_padding_grouped_linear_accuracy( + grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 + ) + outputs_ref = _test_padding_grouped_linear_accuracy( + ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 + ) + + # Shoule be bit-wise match + for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("num_gemms", [3]) +@pytest.mark.parametrize("bs", [1]) +@pytest.mark.parametrize("model", ["126m"]) +@pytest.mark.parametrize("fp8", [True]) +@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_model_params", [False]) +def test_padding_grouped_linear_accuracy_save_original_input( + dtype, + num_gemms, + bs, + model, + fp8, + recipe, + fp8_model_params, + parallel_mode=None, ): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") if recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) + if fp8 and recipe.delayed(): + pytest.skip("DelayedScaling recipe is not supported with save_original_input") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: @@ -2184,6 +2457,7 @@ def test_padding_grouped_linear_accuracy( params_dtype=dtype, parallel_mode=parallel_mode, device="cuda", + save_original_input=True, ).eval() # Share params @@ -2278,6 +2552,8 @@ def test_gpt_cuda_graph(dtype, bs, model): if use_fa: pytest.skip(f"ROCm flash attention does not support cuda graph with {dtype}") + if NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("Cuda Graphs are not supported in debug mode.") config = model_configs[model] sigma = 0.023 @@ -2375,6 +2651,8 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe): pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") if recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) @@ -2567,9 +2845,9 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, not IS_HIP_EXTENSION and backend == "FusedAttention" and get_device_compute_capability() == (8, 9) - and get_cudnn_version() < (9, 11, 0) + and get_cudnn_version() < (9, 12, 0) ): - pytest.skip("Skip KV cache for sm89 and cuDNN < 9.11") + pytest.skip("Skip KV cache for sm89 and cuDNN < 9.12") os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py new file mode 100644 index 000000000..ea9c85e37 --- /dev/null +++ b/tests/pytorch/test_onnx_export.py @@ -0,0 +1,1154 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +This file contains tests for exporting TransformerEngine models to ONNX. + +The purpose of these tests is validation that TE models are converted to their correct ONNX +representation. Toward this end, each test captures the output of a TE module forward pass, +converts the TE module to ONNX, and uses ONNX Runtime (ORT) to execute the ONNX graph and +validate the output against TE's output. + +Until FP8 is introduced to the ONNX standard, FP8 QuantizeLinear/DequantizeLinear is implemented +using custom ORT operations. + +To run many repetitive tests use pytest-loop: + $ python3 -m pip install pytest-loop + $ pytest --loop 1000 tests/pytorch/test_onnx_export.py::test_export_layernorm + +For reproducibility use: torch.manual_seed(0) +""" + +import os +import tempfile +import pytest +import warnings +import numpy as np +import onnxruntime as ort +import torch +import random +from torch import nn as nn +from typing import Optional, Union, Tuple, List +from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op +import transformer_engine.pytorch as te +from transformer_engine.common import recipe +import transformer_engine_torch as tex +from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.utils import get_default_init_method + +# Global test configuration knobs. + +# Enable this to serialize test inputs and outputs to file (as a Polygraphy RunResults instance). +SAVE_TEST_IO = bool(int(os.getenv("NVTE_ONNX_EXPORT_SAVE_TEST_IO", "0"))) + +if SAVE_TEST_IO: + from polygraphy.json import save_json + from polygraphy.comparator import RunResults + +# The directory where generated ONNX test models are stored. +NVTE_TEST_ARTIFACTS_DIR = os.environ.get("NVTE_TEST_ARTIFACTS_DIR") +NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join( + tempfile.gettempdir(), "./gen_onnx_models" +) + + +# The directory where this file is stored. +TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) + +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +skip_MXFP8 = pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + +fp8_recipes = [ + None, + recipe.DelayedScaling(), + recipe.MXFP8BlockScaling(), +] + +supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] + +all_normalizations = ["LayerNorm", "RMSNorm"] + + +@onnx_op( + op_type="trt::TRT_FP8QuantizeLinear", + domain="trt", + inputs=[ + PyCustomOpDef.dt_float, + PyCustomOpDef.dt_float, + ], + outputs=[PyCustomOpDef.dt_uint8], +) +def trt_fp8_quantize(t, scale): + """FP8 quantization extension for ONNX Runtime.""" + x = torch.from_numpy(t).cuda() + q = te.tensor.float8_tensor.Float8Quantizer( + scale=1 / torch.from_numpy(scale).cuda(), + amax=torch.zeros([1]).cuda(), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + return q(x)._data.cpu().numpy() + + +@onnx_op( + op_type="trt::TRT_FP8DequantizeLinear", + domain="trt", + inputs=[ + PyCustomOpDef.dt_uint8, + PyCustomOpDef.dt_float, + ], + outputs=[PyCustomOpDef.dt_float], +) +def trt_fp8_dequantize(t, scale): + """FP8 dequantization extension for ONNX Runtime.""" + x = torch.from_numpy(t).cuda() + q = te.tensor.float8_tensor.Float8Quantizer( + scale=1 / torch.from_numpy(scale).cuda(), + amax=torch.zeros([1]).cuda(), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + quantizer_tensor = q.create_tensor_from_data(x, fake_dtype=torch.float32) + return quantizer_tensor.dequantize().cpu().numpy() + + +@onnx_op( + op_type="trt::TRT_MXFP8QuantizeLinear", + domain="trt", + inputs=[ + PyCustomOpDef.dt_float, + ], + outputs=[PyCustomOpDef.dt_uint8, PyCustomOpDef.dt_uint8], +) +def trt_mxfp8_quantize(t): + """MXFP8 quantization extension for ONNX Runtime.""" + x = torch.from_numpy(t).cuda() + q = te.tensor.mxfp8_tensor.MXFP8Quantizer(tex.DType.kFloat8E4M3) + return q(x)._rowwise_data.cpu().numpy(), q(x)._rowwise_scale_inv.cpu().numpy() + + +@onnx_op( + op_type="trt::TRT_MXFP8DequantizeLinear", + domain="trt", + inputs=[ + PyCustomOpDef.dt_uint8, + PyCustomOpDef.dt_uint8, + ], + outputs=[PyCustomOpDef.dt_float], +) +def trt_mxfp8_dequantize(t, scale_inv): + """MXFP8 dequantization extension for ONNX Runtime.""" + x = torch.from_numpy(t).cuda() + scale_inv_tensor = torch.from_numpy(scale_inv).cuda() + q = te.tensor.mxfp8_tensor.MXFP8Quantizer(tex.DType.kFloat8E4M3) + quantizer_tensor = q.create_tensor_from_data(x, scale_inv_tensor, fake_dtype=torch.float32) + return quantizer_tensor.dequantize().cpu().numpy() + + +@pytest.fixture() +def seed_default_rng(): + """Reseed the PRNG for test reproducibility""" + torch.manual_seed(1234) + + +@pytest.fixture() +def set_max_seq_len(max_seq_len=128): + """Set the maximum sequence length that can be used for attention masking""" + os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{max_seq_len}" + + +@pytest.fixture(autouse=True) +def reset_global_fp8_state(): + yield + FP8GlobalStateManager.reset() + + +def do_export( + model: torch.nn.Module, + inp: torch.Tensor, + fname: str, + fp8_recipe: recipe.Recipe, + input_names: List[str] = None, + output_names: List[str] = None, + dynamic_shapes: List[str] = None, +): + """Export to ONNX""" + input_names = input_names or ["input"] + output_names = output_names or ["output"] + + with torch.inference_mode(), te.fp8_autocast( + enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe + ), warnings.catch_warnings(): + warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning, module=r".*") + + model.cuda().eval() + os.makedirs(NVTE_TEST_ARTIFACTS_DIR, exist_ok=True) + fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) + + inps = inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,) + assert len(inps) == len(input_names) + inds_to_del = [i for i in range(len(inps)) if inps[i] is None] + input_names = [input_names[i] for i in range(len(inps)) if i not in inds_to_del] + + model(*inps) # warm-up run + with te.export.onnx_export(True): + model(*inps) + with te.export.onnx_export(True): + torch.onnx.export( + model, + inps, + fname, + dynamo=True, + custom_translation_table=te_translation_table, + verbose=True, + dynamic_shapes=dynamic_shapes, + input_names=input_names, + output_names=output_names, + optimize=inps[0].dtype + != torch.bfloat16, # optimizer does not work with bfloat16 yet - will need to change that after onnxscript supports bfloat16 + ) + + +def to_numpy(tensor): + if isinstance(tensor, torch.Tensor): + if tensor.dtype == torch.bfloat16: + tensor = tensor.type(torch.float32) + tensor = tensor.detach().cpu().numpy() + return tensor + + +def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int): + """Initialize the FP8 quantization scales in module""" + module.init_fp8_metadata(num_gemms) + for quantizer in module.quantizers["scaling_fwd"]: + quantizer.scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale + + +def te_infer( + model: torch.nn.Module, + inps: Union[Tuple[torch.Tensor], torch.Tensor], + is_fp8: bool, + fp8_recipe: recipe.Recipe, +): + """Transformer Engine forward propagation.""" + with torch.inference_mode(), te.fp8_autocast( + enabled=is_fp8, fp8_recipe=fp8_recipe + ), warnings.catch_warnings(): + te_outputs = model(*inps if isinstance(inps, tuple) else (inps,)) + if not isinstance(te_outputs, tuple): + te_outputs = (te_outputs,) + return te_outputs + + +def compare_outputs( + onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname +): + """Compare ORT and TE outputs.""" + assert len(onnx_outputs) == len(te_outputs) + # Compare ORT and PyTorch outputs. + for onnx_output, te_output in zip(onnx_outputs, te_outputs): + # np.isclose: abs(a - b) <= (atol + rtol * abs(b)) + te_output = to_numpy(te_output) + onnx_output = to_numpy(onnx_output) + ac = ~np.isclose(onnx_output, te_output, atol=atol, rtol=rtol) + mismatches = ac.nonzero() + mismatched_ids = [loc for loc in zip(*mismatches)] + if mismatched_ids: + # Log some information in case of error. + print("*" * 100) + nb_errors = len(mismatched_ids) + nb_vals = min(nb_errors, max_errors_printed) + print(f"Detected {nb_errors} diverging values (output shape={onnx_output.shape})") + print(f"Showing first {nb_vals} errors (ONNX -- TE):") + abs_err = np.abs(onnx_output - te_output) + errors = abs_err[mismatches] + for loc in mismatched_ids[:nb_vals]: + ref = te_output[loc] + print( + f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} >" + f" {atol + rtol * abs(ref)}" + ) + print(f"Max error: {np.max(errors)}") + if nb_errors > allow_cnt_errors: + raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors") + + +def serialize_inputs_outputs( + fname: str, + inputs: Union[Tuple[torch.Tensor], torch.Tensor], + te_outputs: List[torch.Tensor], + input_names: Optional[List[str]] = None, + output_names: Optional[List[str]] = None, +): + if not SAVE_TEST_IO: + return + + fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) + + input_names = input_names or ["input"] + output_names = output_names or ["output"] + inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,) + named_inputs = zip(input_names, inputs) + input_data = [{k: v.cpu() for k, v in named_inputs if v is not None}] + json_fname = fname[: -len(".onnx")] + "_inputs.json" + save_json(input_data, json_fname, description="custom input data") + + json_fname = fname[: -len(".onnx")] + "_output.json" + named_outputs = zip(output_names, te_outputs) + output_data = {k: v.detach().cpu() for k, v in named_outputs if v is not None} + custom_outputs = RunResults() + custom_outputs.add([output_data], runner_name="custom_runner") + custom_outputs.save(json_fname) + + +def validate_result( + fname: str, + inps: Union[Tuple[torch.Tensor], torch.Tensor], + model: torch.nn.Module, + atol: float = 1.0e-8, # np.isclose default atol + rtol: float = 1.0e-5, # np.isclose default rtol + max_errors_printed: int = 10, + is_fp8: bool = False, + allow_cnt_errors: int = 0, + input_names: List[str] = None, + output_names: List[str] = None, + te_outputs: List[torch.Tensor] = None, +): + """Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX + representation using ONNX Runtime (ORT) and ensure they are close. + + The purpose of the output comparison is to validate that TE models are converted to + their correct ONNX representation by testing that TE and ORT outputs match within some + small threshold (allowing for finite precision errors). + + Argument `allow_cnt_errors` reduces test failure noise due to spurious errors by ignoring, + a very small number (0-3) of outliers. This is fine to do because these outliers are due to + small kernel implementation differences between TE and ORT and do not imply an incorrect ONNX + representation (the tests assume both ORT or TE kernels are correct). + + Argument `te_outputs` can be used to provide pre-computed TE outputs. + """ + + def create_ort_session(fname: str, is_fp8: bool): + def load_custom_ops(session_opts: ort.SessionOptions): + """For FP8 validation with ORT we need to load our custom FP8 Q/DQ extension.""" + session_opts.register_custom_ops_library(get_library_path()) + print("registered custom FP8 Q/DQ ops!") + + """Create an ONNX Runtime session for validation.""" + kwargs = {"providers": ["CUDAExecutionProvider", "CPUExecutionProvider"]} + if is_fp8: + sess_options = ort.SessionOptions() + load_custom_ops(sess_options) + kwargs["sess_options"] = sess_options + + s = ort.InferenceSession(fname, **kwargs) + return s + + def create_ort_input_dict(session, inputs): + inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,) + input_names = [x.name for x in session.get_inputs()] + inps = [to_numpy(x) for x in inputs if x is not None] + inp_dict = dict(zip(input_names, inps)) + return inp_dict + + input_names = input_names or ["input"] + output_names = output_names or ["output"] + + # Run ORT session and TE model. + fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) + if not te_outputs: + te_outputs = te_infer(model, inps, is_fp8) + ort_s = create_ort_session(fname, is_fp8) + input_feed = create_ort_input_dict(ort_s, inps) + onnx_outputs = ort_s.run(None, input_feed=input_feed) + compare_outputs( + onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname + ) + + +def create_meta(scale_factor: float, size: int = 1): + meta = tex.FP8TensorMeta() + meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") + meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor + meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor + return meta + + +def dtype2str(dtype: torch.dtype, fake_bf16_io=False): + if fake_bf16_io: + assert dtype == torch.bfloat16 + return "_fake_bf16" + return { + torch.float32: "_fp32", + torch.float16: "_fp16", + torch.bfloat16: "_bf16", + }[dtype] + + +def as_te_type(dtype: torch.dtype): + return { + torch.float32: tex.DType.kFloat32, + torch.float16: tex.DType.kFloat16, + torch.bfloat16: tex.DType.kBFloat16, + }[dtype] + + +def get_attn_mask_str(use_mask, attn_mask_type): + # See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names. + if attn_mask_type is None: + return "_mask" if use_mask else "_no-mask" + attn_mask_str = "_arbitrary-no-mask" + attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str + attn_mask_str = ( + "_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str + ) + return attn_mask_str + + +""" +Test cases begin here. +""" + + +@pytest.mark.parametrize("scale_factor", [112]) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +# Returning the bias is a TE fusion optimization we don't care about. +@pytest.mark.parametrize("return_bias", [True, False]) +@pytest.mark.parametrize( + "precision, use_bias", + [ + (torch.float32, False), + (torch.float32, True), + (torch.float16, False), + (torch.float16, True), + # Todo: cannot configure BF16 when bias is disabled (ORT issue?) + (torch.bfloat16, False), + # Todo: cannot configure BF16 when bias is enabled (ORT issue?) + (torch.bfloat16, True), + ], +) +def test_export_linear( + seed_default_rng, + scale_factor: float, + fp8_recipe: recipe.Recipe, + use_bias: bool, + return_bias: bool, + precision: torch.dtype, +): + # Skip FP8 tests on non-hopper devices + if fp8_recipe is not None and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if return_bias and not use_bias: + pytest.skip("Cannot return bias when bias is disabled") + + # Set dimensions (these are arbitrary). + batch_size = 4 + in_features = 64 + out_features = 64 + hidden_size = 64 + + class Test_Linear(nn.Module): + def __init__(self, in_features, out_features, use_bias, return_bias, precision): + super().__init__() + self.linear = te.Linear( + in_features, + out_features, + bias=use_bias, + return_bias=return_bias, + params_dtype=precision, + ) + + def forward(self, inp): + ret = self.linear(inp) + return ret + + inp = torch.randn(batch_size, hidden_size, in_features, device="cuda", dtype=precision) + fp8_str = "_fp8" if fp8_recipe is not None else "" + bias_str = "_bias" if use_bias else "" + high_prec_str = dtype2str(precision) + fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx" + with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + model = Test_Linear(in_features, out_features, use_bias, return_bias, precision).to( + device="cuda" + ) + # dynamic shape + bs = torch.export.Dim("bs", min=2, max=1256) + do_export( + model, + inp, + fname, + fp8_recipe, + dynamic_shapes={"inp": {0: bs}}, + ) + te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe) + serialize_inputs_outputs(fname, inp, te_outputs) + + if precision in (torch.bfloat16,): + return + if fp8_recipe is None: + validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs) + else: + validate_result( + fname, inp, model, atol=1e-2, is_fp8=fp8_recipe is not None, te_outputs=te_outputs + ) + + +@pytest.mark.parametrize("scale_factor", [112]) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize( + "precision", + [ + torch.float32, + torch.float16, + torch.bfloat16, + ], +) +@pytest.mark.parametrize("zero_centered_gamma", [False, True]) +@pytest.mark.parametrize("normalization", all_normalizations) +def test_export_layernorm( + seed_default_rng, + scale_factor: float, + fp8_recipe: recipe.Recipe, + precision: torch.dtype, + zero_centered_gamma: bool, + normalization: str, +): + # Skip FP8 tests on non-hopper devices + if fp8_recipe is not None and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + # Set dimensions (these are arbitrary). + batch_size = 4 + in_features = 64 + out_features = 256 + hidden_size = 256 + + inp = torch.ones(batch_size, in_features, out_features, device="cuda", dtype=precision) + fp8_str = "_fp8" if fp8_recipe is not None else "" + high_prec_str = dtype2str(precision) + fname = f"te.layernorm_linear{fp8_str}{high_prec_str}.onnx" + + with torch.no_grad(): + with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + layernorm_cls = te.LayerNorm if normalization == "LayerNorm" else te.RMSNorm + model = layernorm_cls( + hidden_size, + params_dtype=precision, + zero_centered_gamma=zero_centered_gamma, + ).to(device="cuda") + + # dynamic shape + bs = torch.export.Dim("bs", min=2, max=1256) + do_export(model, inp, fname, fp8_recipe, dynamic_shapes={"input": {0: bs}}) + te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe) + serialize_inputs_outputs(fname, inp, te_outputs) + if precision in (torch.bfloat16,): + return + if fp8_recipe is None: + validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs) + elif precision != torch.bfloat16: + validate_result( + fname, + inp, + model, + atol=1e-3, + is_fp8=fp8_recipe is not None, + te_outputs=te_outputs, + ) + + +@pytest.mark.parametrize("scale_factor", [112]) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("return_bias", [True, False]) +@pytest.mark.parametrize("return_layernorm_output", [True, False]) +@pytest.mark.parametrize( + "precision, use_bias", + [ + (torch.float32, False), + (torch.float32, True), + (torch.float16, True), + (torch.float16, False), + (torch.bfloat16, True), + (torch.bfloat16, False), + ], +) +@pytest.mark.parametrize("zero_centered_gamma", [False, True]) +@pytest.mark.parametrize("normalization", all_normalizations) +def test_export_layernorm_linear( + seed_default_rng, + scale_factor: float, + fp8_recipe: recipe.Recipe, + use_bias: bool, + return_bias: bool, + return_layernorm_output: bool, + precision: torch.dtype, + zero_centered_gamma: bool, + normalization: str, +): + # Skip FP8 tests on non-hopper devices + if fp8_recipe is not None and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if return_bias and not use_bias: + pytest.skip("Cannot return bias when bias is disabled") + + # Set dimensions (these are arbitrary). + in_features = 64 + out_features = 256 + hidden_size = 256 + + inp = torch.randn(in_features, out_features, device="cuda", dtype=precision) + fp8_str = "_fp8" if fp8_recipe is not None else "" + bias_str = "_bias" if use_bias else "" + high_prec_str = dtype2str(precision) + fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx" + + with torch.no_grad(): + with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + model = te.LayerNormLinear( + hidden_size, + 3 * hidden_size, + bias=use_bias, + return_bias=return_bias, + return_layernorm_output=return_layernorm_output, + params_dtype=precision, + zero_centered_gamma=zero_centered_gamma, + normalization=normalization, + ).to(device="cuda") + if fp8_recipe is not None: + set_layer_scale(model, scale_factor, num_gemms=2) + do_export(model, inp, fname, fp8_recipe) + + te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe) + serialize_inputs_outputs(fname, inp, te_outputs) + if precision in (torch.bfloat16,): + return + if fp8_recipe is None: + validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs) + elif precision != torch.bfloat16: + validate_result( + fname, + inp, + model, + atol=1e-3, + is_fp8=fp8_recipe is not None, + te_outputs=te_outputs, + ) + + +@pytest.mark.parametrize("scale_factor", [112]) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("return_bias", [True, False]) +@pytest.mark.parametrize("return_layernorm_output", [True, False]) +@pytest.mark.parametrize( + "precision, use_bias", + [ + (torch.float32, False), + (torch.float32, True), + (torch.float16, True), + (torch.float16, False), + (torch.bfloat16, True), + (torch.bfloat16, False), + ], +) +@pytest.mark.parametrize("zero_centered_gamma", [False, True]) +@pytest.mark.parametrize("activation", supported_activations) +@pytest.mark.parametrize("normalization", all_normalizations) +def test_export_layernorm_mlp( + seed_default_rng, + scale_factor: float, + fp8_recipe: recipe.Recipe, + use_bias: bool, + return_bias: bool, + return_layernorm_output: bool, + precision: torch.dtype, + zero_centered_gamma: bool, + activation: str, + normalization: str, +): + # Skip FP8 tests on non-hopper devices + if fp8_recipe is not None and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if return_bias and not use_bias: + pytest.skip("Cannot return bias when bias is disabled") + + # Set dimensions (these are arbitrary). + in_features = 64 + out_features = 256 + hidden_size = 256 + ffn_hidden_size = 256 + + inp = torch.randn(in_features, out_features, device="cuda", dtype=precision) + fp8_str = "_fp8" if fp8_recipe is not None else "" + bias_str = "_bias" if use_bias else "" + high_prec_str = dtype2str(precision) + fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}_{activation}.onnx" + with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + model = te.LayerNormMLP( + hidden_size, + ffn_hidden_size, + bias=use_bias, + return_bias=return_bias, + return_layernorm_output=return_layernorm_output, + params_dtype=precision, + zero_centered_gamma=zero_centered_gamma, + activation=activation, + normalization=normalization, + ).to(device="cuda") + if fp8_recipe is not None: + set_layer_scale(model, scale_factor, num_gemms=2) + do_export(model, inp, fname, fp8_recipe) + te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe) + serialize_inputs_outputs(fname, inp, te_outputs) + if precision in (torch.bfloat16,): + return + atol = ( + 2e-2 if fp8_recipe is not None else (5e-1 if activation == "swiglu" else 1e-3) + ) # TODO(pgadzinski) - check 2e-2 + validate_result( + fname, inp, model, atol=atol, is_fp8=fp8_recipe is not None, te_outputs=te_outputs + ) + + +@pytest.mark.parametrize( + "precision, use_mask, attn_mask_type", + [ + (torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) + (torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask) + (torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) + (torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) + (torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) + (torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) + (torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) + (torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) + ], +) +def test_export_core_attention( + seed_default_rng, + set_max_seq_len, + precision: torch.dtype, + use_mask: bool, + attn_mask_type: str, +): + # Set dimensions (these are arbitrary). + seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64) + qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels) + qkv_format = "sbhd" + + query_layer = torch.randn(qkv_size, dtype=precision, device="cuda") + key_layer = torch.randn(qkv_size, dtype=precision, device="cuda") + value_layer = torch.randn(qkv_size, dtype=precision, device="cuda") + input_names = ["query", "key", "value", "attention_mask"] + attention_mask = None + if use_mask: + # Generate a random mask with 50% probability for 0 or 1. + probs = 0.5 * torch.ones(batch_size, 1, 1, seq_len, device="cuda", dtype=precision) + attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) + inp = (query_layer, key_layer, value_layer, attention_mask) + + mask_str = get_attn_mask_str(use_mask, attn_mask_type) + high_prec_str = dtype2str(precision) + fname = f"te.core_attention{mask_str}{high_prec_str}.onnx" + + model = te.attention.DotProductAttention( + num_attention_heads=num_attention_heads, + kv_channels=kv_channels, + attention_dropout=0.5, + qkv_format=qkv_format, + attn_mask_type=attn_mask_type, + ).to(device="cuda") + do_export(model, inp, fname, input_names=input_names, fp8_recipe=None) + te_outputs = te_infer(model, inp, is_fp8=False, fp8_recipe=None) + serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) + if precision in (torch.bfloat16,): + return + validate_result( + fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs + ) + + +test_configs_multihead_attention = [ + # "use_mask, attn_mask_type" + (False, "no_mask"), # calls ScaledSoftmax + (True, "arbitrary"), # calls ScaledMaskedSoftmax +] +test_configs_attention_type = [ + # "input_layernorm, attention_type, fuse_qkv_params" + (True, "self", True), + (False, "self", True), + (True, "self", False), + (False, "self", False), + (True, "cross", True), + (False, "cross", True), + (True, "cross", False), + (False, "cross", False), +] + + +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) +@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("return_layernorm_output", [False]) +@pytest.mark.parametrize( + "input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type +) +def test_export_multihead_attention( + seed_default_rng, + set_max_seq_len, + fp8_recipe: recipe.Recipe, + use_mask: bool, + attn_mask_type: str, + precision: torch.dtype, + return_layernorm_output: bool, + input_layernorm: bool, + attention_type: str, + fuse_qkv_params: bool, +): + # Skip FP8 tests on non-hopper devices + if fp8_recipe is not None and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + hidden_size = 256 + sequence_length = 128 + batch_size = 4 + num_attention_heads = 32 + kv_channels = 8 + attention_dropout = 0.1 + layernorm_epsilon = 1e-5 + init_method = output_layer_init_method = get_default_init_method() + attention_args = ( + hidden_size, + num_attention_heads, + kv_channels, + attention_dropout, + layernorm_epsilon, + init_method, + output_layer_init_method, + ) + + hidden_states_context = torch.randn( + sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" + ) + attention_mask = None + if use_mask and attn_mask_type != "causal": + # Generate a random mask with 50% probability for 0 or 1. + probs = 0.5 * torch.ones( + batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision + ) + attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) + + encoder_output = None + + if attention_type == "cross": + encoder_output = torch.randn( + sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" + ) + + fp8_str = "_fp8" if fp8_recipe is not None else "" + dtype_str = dtype2str(precision) + attn_type_str = "_self-attention" if attention_type == "self" else "_cross-attention" + fuse_qkv_str = "_fused-qkv" if fuse_qkv_params else "" + attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) + input_ln_str = "_input-ln" if input_layernorm else "" + fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx" + + model = te.MultiheadAttention( + *attention_args, + attn_mask_type=attn_mask_type, + params_dtype=precision, + return_layernorm_output=return_layernorm_output, + input_layernorm=input_layernorm, + attention_type=attention_type, + fuse_qkv_params=fuse_qkv_params, + return_bias=True, + ).to(device="cuda") + + inp_context = (hidden_states_context, attention_mask, encoder_output) + input_names = ["hidden_states", "attention_mask", "encoder_output"] + output_names = ["attention_output", "attention_bias"] + seq = torch.export.Dim("seq", min=2, max=1256) + bs = torch.export.Dim("bs", min=2, max=1256) + do_export( + model, + inp_context, + fname, + fp8_recipe, + input_names=input_names, + output_names=output_names, + dynamic_shapes={ + "hidden_states": {0: seq, 1: bs}, + "attention_mask": {2: seq, 0: bs} if use_mask else None, + "encoder_output": {0: seq, 1: bs} if attention_type == "cross" else None, + }, + ) + te_outputs = te_infer(model, inp_context, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe) + serialize_inputs_outputs( + fname, inp_context, te_outputs, input_names=input_names, output_names=output_names + ) + if precision in (torch.bfloat16,): + return + + if fp8_recipe is None: + validate_result( + fname, + inp_context, + model, + atol=1e-3, + input_names=input_names, + output_names=output_names, + te_outputs=te_outputs, + ) + else: + validate_result( + fname, + inp_context, + model, + atol=1e-2, + is_fp8=fp8_recipe is not None, + input_names=input_names, + output_names=output_names, + allow_cnt_errors=3, + te_outputs=te_outputs, + ) + + # In GPT generative phase (inference) the input sequence is smaller than the maximum + # allowed sequence length and we want to test this condition. + # Pretend that we're in generative phase when it makes sense (causal mask and self-attention). + is_generative_phase = attn_mask_type == "causal" and attention_type == "self" + if is_generative_phase: + seq_len_offset = 8 + hidden_states_generative = torch.randn( + sequence_length - seq_len_offset, + batch_size, + hidden_size, + dtype=precision, + device="cuda", + ) + inp_generative = (hidden_states_generative, attention_mask, encoder_output) + if fp8_recipe is None: + validate_result( + fname, + inp_generative, + model, + atol=1e-3, + input_names=input_names, + output_names=output_names, + ) + else: + validate_result( + fname, + inp_generative, + model, + atol=1e-2, + is_fp8=fp8_recipe is not None, + input_names=input_names, + output_names=output_names, + allow_cnt_errors=3, + ) + + +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) +@pytest.mark.parametrize("output_layernorm", [True, False]) +@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("fuse_qkv_params", [False, True]) +@pytest.mark.parametrize("zero_centered_gamma", [False, True]) +@pytest.mark.parametrize("activation", supported_activations) +def test_export_transformer_layer( + seed_default_rng, + set_max_seq_len, + fp8_recipe: recipe.Recipe, + use_mask: bool, + attn_mask_type: str, + output_layernorm: bool, + precision: torch.dtype, + fuse_qkv_params: bool, + zero_centered_gamma: bool, + activation: str, +): + # Skip FP8 tests on non-hopper devices + if fp8_recipe is not None and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + # Layer configuration + hidden_size = 64 + sequence_length = 128 + batch_size = 1 + ffn_hidden_size = 256 + num_attention_heads = 4 + + input_tensor = torch.rand( + sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" + ) + input_names = ["input", "attention_mask"] + attention_mask = None + if use_mask and attn_mask_type != "causal": + # Generate a random mask with 50% probability for 0 or 1. + probs = 0.5 * torch.ones( + batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision + ) + attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) + inp = (input_tensor, attention_mask) + + fp8_str = "_fp8" if fp8_recipe is not None else "" + fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else "" + high_prec_str = dtype2str(precision) + attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) + fname = f"te.transformer_layer{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}_{activation}.onnx" + + model = te.TransformerLayer( + hidden_size, + ffn_hidden_size, + num_attention_heads, + self_attn_mask_type=attn_mask_type, + output_layernorm=output_layernorm, + params_dtype=precision, + fuse_qkv_params=fuse_qkv_params, + zero_centered_gamma=zero_centered_gamma, + activation=activation, + ).to(device="cuda") + do_export(model, inp, fname, fp8_recipe, input_names=input_names) + te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe) + serialize_inputs_outputs( + fname, + inp, + te_outputs, + input_names=input_names, + ) + if precision in (torch.bfloat16,): + return + atol = 5e-1 if fp8_recipe is not None else (5e-1 if activation == "swiglu" else 5e-3) + validate_result( + fname, + inp, + model, + atol=atol, + is_fp8=fp8_recipe is not None, + input_names=input_names, + te_outputs=te_outputs, + ) + + +@skip_FP8 +@skip_MXFP8 +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("zero_centered_gamma", [True]) +def test_export_gpt_generation( + seed_default_rng, + set_max_seq_len, + fp8_recipe: recipe.Recipe, + precision: torch.dtype, + zero_centered_gamma: bool, +): + """Test that the ONNX model can correctly handle inputs with different shapes and that + the attention mask is adjusted on-the-fly to different sequence lengths. + """ + + # Skip FP8 tests on non-hopper devices + if fp8_recipe is not None and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + # Layer configuration + hidden_size = 64 + sequence_length = 128 + batch_size = 4 + ffn_hidden_size = 256 + num_attention_heads = 4 + attention_mask = None + use_mask = True + attn_mask_type = "causal" + fuse_qkv_params = True + output_layernorm = False + + fp8_str = "_fp8" if fp8_recipe is not None else "" + fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else "" + high_prec_str = dtype2str(precision) + attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) + fname = f"te.transformer_layer_generative{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}.onnx" + + model = te.TransformerLayer( + hidden_size, + ffn_hidden_size, + num_attention_heads, + self_attn_mask_type=attn_mask_type, + output_layernorm=output_layernorm, + params_dtype=precision, + fuse_qkv_params=fuse_qkv_params, + zero_centered_gamma=zero_centered_gamma, + ).to(device="cuda") + + # "Context phase": use full input sequence length + input_names = ["input"] + output_names = ["output"] + input_tensor = torch.rand( + sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" + ) + inp = (input_tensor,) + # dynamic shape + seq = torch.export.Dim("seq", min=2, max=1256) + bs = torch.export.Dim("bs", min=2, max=1256) + do_export( + model, + inp, + fname, + fp8_recipe, + dynamic_shapes={"hidden_states": {0: seq, 1: bs}}, + ) + te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe) + serialize_inputs_outputs( + fname, inp, te_outputs, input_names=input_names, output_names=output_names + ) + if precision not in (torch.bfloat16,): + validate_result( + fname, + inp, + model, + atol=1e-2, + is_fp8=fp8_recipe is not None, + input_names=input_names, + te_outputs=te_outputs, + ) + + # "Generative phase": use a single input (sequence len=1). For FP8 we need to pad the sequence to mult of 8 and for MXFP8 we need to pad to mult of 32. + sequence_length = 1 if fp8_recipe is None else 32 + input_tensor = torch.rand( + sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" + ) + inp = (input_tensor, attention_mask) + te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe) + serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) + if precision not in (torch.bfloat16,): + validate_result( + fname, + inp, + model, + atol=1e-2, + is_fp8=fp8_recipe is not None, + input_names=input_names, + te_outputs=te_outputs, + ) + + +@pytest.mark.parametrize("enabled", [True, False]) +def test_export_ctx_manager(enabled): + assert is_in_onnx_export_mode() == False + with te.onnx_export(enabled): + assert is_in_onnx_export_mode() == enabled + assert is_in_onnx_export_mode() == False diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index fdb9b7f0b..dd6c6a3b0 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -61,22 +61,26 @@ def one_iteration_test( test_loss = self.test_loss_func( self.input_test, self.tar_test, label_smoothing, reduce_loss, None ) - if reduce_loss: - test_loss.backward() ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref) + + # Handle backward pass based on the test scenario if reduce_loss: + test_loss.backward() ref_loss.backward() + else: + test_loss.sum().backward() + ref_loss.sum().backward() test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss - torch.testing.assert_close(test_loss, ref_loss, check_dtype=False) if ignore_idx: print(test_loss, ref_loss) - if reduce_loss: - torch.testing.assert_close( - torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad - ) + + # Compare gradients when backward pass was called + torch.testing.assert_close( + torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad + ) self.input_test = None self.input_ref = None diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 40e964d9c..5202155e2 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -328,33 +328,37 @@ def _test_permutation_index_map( te_unpermute_output_ = te_unpermute_output.float() te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() - torch.testing.assert_close( - pytorch_permute_output.float(), - te_permute_output_, - msg=f"Mismatch in te_permute fwd", - ) - torch.testing.assert_close( - pytorch_permute_fwd_input.grad.float(), - te_permute_fwd_input_grad, - msg=f"Mismatch in te_permute bwd", - **tols, - ) - torch.testing.assert_close( - pytorch_unpermute_output.float(), - te_unpermute_output_, - msg=f"Mismatch in te_unpermute fwd", - **tols, - ) - torch.testing.assert_close( - pytorch_unpermute_fwd_input.grad.float(), - te_unpermute_fwd_input_grad, - msg=f"Mismatch in te_unpermute bwd", - **tols, - ) - if with_probs: + if not BENCHMARK: torch.testing.assert_close( - probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols + pytorch_permute_output.float(), + te_permute_output_, + msg=f"Mismatch in te_permute fwd", ) + torch.testing.assert_close( + pytorch_permute_fwd_input.grad.float(), + te_permute_fwd_input_grad, + msg=f"Mismatch in te_permute bwd", + **tols, + ) + torch.testing.assert_close( + pytorch_unpermute_output.float(), + te_unpermute_output_, + msg=f"Mismatch in te_unpermute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_unpermute_fwd_input.grad.float(), + te_unpermute_fwd_input_grad, + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) + if with_probs: + torch.testing.assert_close( + probs.grad.float(), + te_probs.grad.float(), + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) if not pytorch_permute_fwd_input.numel(): print("Empty pytorch_permute_fwd_input activation test passed.") @@ -544,34 +548,38 @@ def _test_permutation_mask_map( te_unpermute_output_ = te_unpermute_output.float() te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() - torch.testing.assert_close( - pytorch_permute_output.float(), - te_permute_output_, - msg=f"Mismatch in te_permute fwd", - **tols, - ) - torch.testing.assert_close( - pytorch_permute_fwd_input.grad.float(), - te_permute_fwd_input_grad, - msg=f"Mismatch in te_permute bwd", - **tols, - ) - torch.testing.assert_close( - pytorch_unpermute_output.float(), - te_unpermute_output_, - msg=f"Mismatch in te_unpermute fwd", - **tols, - ) - torch.testing.assert_close( - pytorch_unpermute_fwd_input.grad.float(), - te_unpermute_fwd_input_grad, - msg=f"Mismatch in te_unpermute bwd", - **tols, - ) - if with_probs: + if not BENCHMARK: + torch.testing.assert_close( + pytorch_permute_output.float(), + te_permute_output_, + msg=f"Mismatch in te_permute fwd", + **tols, + ) torch.testing.assert_close( - probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols + pytorch_permute_fwd_input.grad.float(), + te_permute_fwd_input_grad, + msg=f"Mismatch in te_permute bwd", + **tols, ) + torch.testing.assert_close( + pytorch_unpermute_output.float(), + te_unpermute_output_, + msg=f"Mismatch in te_unpermute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_unpermute_fwd_input.grad.float(), + te_unpermute_fwd_input_grad, + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) + if with_probs: + torch.testing.assert_close( + probs.grad.float(), + te_probs.grad.float(), + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) if not pytorch_permute_fwd_input.numel(): print("Empty pytorch_permute_fwd_input activation test passed.") @@ -833,18 +841,19 @@ def _test_moe_chunk_sort( te_output_ = te_output.float() te_fwd_input_grad = te_fwd_input.grad.float() - torch.testing.assert_close( - pytorch_output.float(), - te_output_, - msg=f"Mismatch in te_permute fwd", - **tols, - ) - torch.testing.assert_close( - pytorch_fwd_input.grad.float(), - te_fwd_input_grad, - msg=f"Mismatch in te_permute bwd", - **tols, - ) + if not BENCHMARK: + torch.testing.assert_close( + pytorch_output.float(), + te_output_, + msg=f"Mismatch in te_permute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_fwd_input.grad.float(), + te_fwd_input_grad, + msg=f"Mismatch in te_permute bwd", + **tols, + ) if not pytorch_fwd_input.numel(): print("Empty pytorch_fwd_input activation test passed.") @@ -893,6 +902,7 @@ def _test_permutation_mask_map_alongside_probs( topK, num_out_tokens, tp_size, + BENCHMARK=False, ): if topK > num_expert: pytest.skip("topK should be smaller than the number of experts.") @@ -1022,21 +1032,73 @@ def _test_permutation_mask_map_alongside_probs( te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() te_unpermute_output_ = te_unpermute_output.float() - torch.testing.assert_close( - pytorch_unpermute_output.float(), - te_unpermute_output_, - msg=f"Mismatch in fused_unpermute fwd", - **tols, - ) - torch.testing.assert_close( - pytorch_permute_fwd_input.grad.float(), - te_permute_fwd_input_grad, - msg=f"Mismatch in fused_permute bwd", - **tols, - ) - torch.testing.assert_close( - probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in prob grad", **tols - ) + if not BENCHMARK: + torch.testing.assert_close( + pytorch_unpermute_output.float(), + te_unpermute_output_, + msg=f"Mismatch in fused_unpermute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_permute_fwd_input.grad.float(), + te_permute_fwd_input_grad, + msg=f"Mismatch in fused_permute bwd", + **tols, + ) + torch.testing.assert_close( + probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in prob grad", **tols + ) + + if BENCHMARK: + t1 = perf_test_cuda_kernel( + lambda: te_permute_with_probs( + te_permute_fwd_input, te_probs, routing_map, num_out_tokens=num_out_tokens + ) + ) + print(f"permute\t\tfwd: TE: {t1:.3f} ms") + + te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs( + te_permute_fwd_input, + te_probs, + routing_map, + num_out_tokens=num_out_tokens, + ) + te_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + te_permute_output, + te_permute_bwd_input, + forward_input=[te_permute_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"permute\t\tbwd: TE: {t2:.3f} ms") + + chunk_sort_fwd_input = te_permute_output.detach() + chunk_sort_fwd_input.requires_grad_(True) + chunk_sort_fwd_probs = te_permuted_probs.detach() + chunk_sort_fwd_probs.requires_grad_(True) + t1 = perf_test_cuda_kernel( + lambda: te_sort_chunks_by_index_with_probs( + chunk_sort_fwd_input, chunk_sort_fwd_probs, split_sizes_cuda, sorted_idxs_cuda + ) + ) + print(f"chunk sort\t\tfwd: TE: {t1:.3f} ms") + + chunk_sort_output, _ = te_sort_chunks_by_index_with_probs( + chunk_sort_fwd_input, chunk_sort_fwd_probs, split_sizes_cuda, sorted_idxs_cuda + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + chunk_sort_output, + te_permute_bwd_input, + forward_input=[chunk_sort_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"chunk sort\t\tbwd: TE: {t2:.3f} ms") def perf_test_cuda_kernel(cuda_kernel_fn): @@ -1069,7 +1131,7 @@ def perf_test_cuda_kernel(cuda_kernel_fn): @pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) -@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) @@ -1098,7 +1160,7 @@ def test_permutation_index_map( @pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) -@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) @@ -1144,7 +1206,7 @@ def test_permutation_mask_map_empty_input(te_dtype): @pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) -@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) @@ -1199,7 +1261,7 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) @pytest.mark.parametrize("num_tokens", [2048]) -@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) @@ -1231,7 +1293,7 @@ def test_permutation_mask_map_fp8( @pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) -@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) def test_permutation_index_map_topk1_no_probs( te_dtype, @@ -1258,7 +1320,7 @@ def test_permutation_index_map_topk1_no_probs( @pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) -@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) def test_permutation_mask_map_topk1_no_probs( te_dtype, @@ -1285,7 +1347,7 @@ def test_permutation_mask_map_topk1_no_probs( @pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) -@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("tp_size", [1, 2, 8]) @pytest.mark.parametrize("hidden_size", [4096]) def test_chunk_permutation( @@ -1378,5 +1440,108 @@ def test_permutation_single_case(): ) +def benchmark_single_case( + te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size +): + torch.cuda.nvtx.range_push( + f"{num_tokens}-{num_expert}-{hidden_size}-{topK}-{ep_size}-{tp_size}" + ) + + torch.cuda.nvtx.range_push("permutation_index_map_with_probs") + _test_permutation_index_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=True, + BENCHMARK=True, + ) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("permutation_mask_map_with_probs") + _test_permutation_mask_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=True, + BENCHMARK=True, + ) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("permutation_mask_map_without_probs") + _test_permutation_mask_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=False, + BENCHMARK=True, + ) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("permutation_mask_map_alongside_probs") + _test_permutation_mask_map_alongside_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + tp_size=tp_size, + BENCHMARK=True, + ) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_pop() + + +def benchmark_multiple_cases(): + print("GPU:", torch.cuda.get_device_name(0)) + + # te_dtype = tex.DType.kFloat32 + # te_dtype = tex.DType.kFloat16 + te_dtype = tex.DType.kBFloat16 + + ep_size = 64 + tp_size = 2 + num_tokens = 4096 + num_expert = 256 + hidden_size = 7168 + topK = 8 + num_out_tokens = num_tokens * topK + benchmark_single_case( + te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size + ) + + ep_size = 8 + tp_size = 1 + num_tokens = 8192 * 2 + num_expert = 128 + hidden_size = 4096 + topK = 6 + num_out_tokens = num_tokens * topK + benchmark_single_case( + te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size + ) + + ep_size = 64 + tp_size = 2 + num_tokens = 16384 + num_expert = 4 + hidden_size = 7168 + topK = 1 + num_out_tokens = num_tokens * topK + benchmark_single_case( + te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size + ) + + if __name__ == "__main__": - test_permutation_single_case() + benchmark_multiple_cases() diff --git a/tests/pytorch/test_qk_norm.py b/tests/pytorch/test_qk_norm.py new file mode 100644 index 000000000..6f4e62f81 --- /dev/null +++ b/tests/pytorch/test_qk_norm.py @@ -0,0 +1,242 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from transformer_engine.pytorch import MultiheadAttention + +import pytest +import torch + + +@pytest.mark.parametrize("use_qk_norm", [False, True]) +@pytest.mark.parametrize("attention_type", ["self", "cross"]) +@pytest.mark.parametrize("qk_norm_eps", [1e-6, 1e-5]) +def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None: + """Test QK normalization functionality, module structure, and numerical behavior.""" + hidden_size = 256 + num_attention_heads = 8 + seq_len = 128 + + # Create MultiheadAttention module + mha = MultiheadAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_type=attention_type, + use_qk_norm=use_qk_norm, + qk_norm_eps=qk_norm_eps, + bias=False, + device="cuda", + ).cuda() + + # Check module structure based on use_qk_norm parameter + if use_qk_norm: + assert hasattr(mha, "qk_norm"), "Should have qk_norm module when use_qk_norm=True" + assert not hasattr(mha, "q_l2norm"), "Should not have separate q_l2norm module" + assert not hasattr(mha, "k_l2norm"), "Should not have separate k_l2norm module" + # Check that the module is L2Norm type + from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization + + assert isinstance( + mha.qk_norm, L2Normalization + ), "qk_norm should be an L2Normalization module" + else: + assert not hasattr(mha, "qk_norm"), "Should not have qk_norm module when use_qk_norm=False" + + # Create input tensors + batch_size = 2 # Use a fixed batch size for testing + hidden_states = torch.randn( + seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32 + ) + + if attention_type == "cross": + encoder_output = torch.randn( + seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32 + ) + else: + encoder_output = None + + # Test forward pass + with torch.no_grad(): + if attention_type == "cross": + output = mha(hidden_states, encoder_output=encoder_output) + else: + output = mha(hidden_states) + + # Check output shape and numerical properties + assert output.shape == ( + seq_len, + batch_size, + hidden_size, + ), f"Output shape mismatch: {output.shape}" + assert not torch.isnan(output).any(), "Output contains NaN" + assert not torch.isinf(output).any(), "Output contains Inf" + + # Test with RoPE (if self-attention) + if attention_type == "self": + head_dim = hidden_size // num_attention_heads + rotary_dim = head_dim // 2 + rotary_pos_emb = torch.randn(seq_len, 1, 1, rotary_dim, device="cuda", dtype=torch.float32) + + with torch.no_grad(): + output_with_rope = mha(hidden_states, rotary_pos_emb=rotary_pos_emb) + + assert output_with_rope.shape == ( + seq_len, + batch_size, + hidden_size, + ), "Output shape with RoPE mismatch" + assert not torch.isnan(output_with_rope).any(), "RoPE output contains NaN" + assert not torch.isinf(output_with_rope).any(), "RoPE output contains Inf" + + +def test_qk_norm_output_difference() -> None: + """Test that QK normalization actually changes the output compared to no normalization.""" + hidden_size = 256 + num_attention_heads = 8 + seq_len = 128 + batch_size = 2 + + # Use same random seed to ensure identical weight initialization + current_rng_state = torch.get_rng_state() + current_cuda_rng_state = torch.cuda.get_rng_state() + + # Reset to a known seed for reproducible initialization + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # Create model with QK normalization + mha_with_norm = MultiheadAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_qk_norm=True, + bias=False, + device="cuda", + ).cuda() + + # Reset to same seed for identical initialization + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # Create identical model without QK normalization + mha_no_norm = MultiheadAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_qk_norm=False, + bias=False, + device="cuda", + ).cuda() + + # Create input tensors + hidden_states = torch.randn( + seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32 + ) + + # Compare outputs with identical weights but different QK norm settings + with torch.no_grad(): + output_with_norm = mha_with_norm(hidden_states) + output_no_norm = mha_no_norm(hidden_states) + + # Outputs should be different when QK normalization is enabled + assert not torch.allclose( + output_with_norm, output_no_norm, atol=1e-6 + ), "QK normalization should change the output, but outputs are identical" + + +def test_qk_norm_with_fused_qkv() -> None: + """Test QK normalization works with fused QKV parameters.""" + hidden_size = 256 + num_attention_heads = 8 + seq_len = 64 + + mha = MultiheadAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + fuse_qkv_params=True, + use_qk_norm=True, + bias=False, + device="cuda", + ).cuda() + + # Create input and test forward pass + batch_size = 2 # Use a fixed batch size for testing + hidden_states = torch.randn( + seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32 + ) + + with torch.no_grad(): + output = mha(hidden_states) + + assert output.shape == ( + seq_len, + batch_size, + hidden_size, + ), f"Output shape mismatch: {output.shape}" + + +def test_qk_norm_transformer_layer_output_difference() -> None: + """Test that QK normalization actually changes TransformerLayer output compared to no normalization.""" + from transformer_engine.pytorch import TransformerLayer + + hidden_size = 256 + ffn_hidden_size = 1024 + num_attention_heads = 8 + seq_len = 128 + batch_size = 2 + + # Use same random seed to ensure identical weight initialization + current_rng_state = torch.get_rng_state() + current_cuda_rng_state = torch.cuda.get_rng_state() + + # Reset to a known seed for reproducible initialization + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # Create TransformerLayer with QK normalization + transformer_with_norm = TransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + use_qk_norm=True, + bias=False, + device="cuda", + ).cuda() + + # Reset to same seed for identical initialization + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # Create identical TransformerLayer without QK normalization + transformer_no_norm = TransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + use_qk_norm=False, + bias=False, + device="cuda", + ).cuda() + + # Create input tensors + hidden_states = torch.randn( + seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32 + ) + + # Compare outputs with identical weights but different QK norm settings + with torch.no_grad(): + output_with_norm = transformer_with_norm(hidden_states) + output_no_norm = transformer_no_norm(hidden_states) + + # Outputs should be different when QK normalization is enabled + assert not torch.allclose( + output_with_norm, output_no_norm, atol=1e-6 + ), "QK normalization should change the TransformerLayer output, but outputs are identical" + + # Check that outputs have expected shapes and properties + assert output_with_norm.shape == ( + seq_len, + batch_size, + hidden_size, + ), f"Output shape mismatch: {output_with_norm.shape}" + assert not torch.isnan(output_with_norm).any(), "Output with QK norm contains NaN" + assert not torch.isinf(output_with_norm).any(), "Output with QK norm contains Inf" + assert not torch.isnan(output_no_norm).any(), "Output without QK norm contains NaN" + assert not torch.isinf(output_no_norm).any(), "Output without QK norm contains Inf" diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 02ff9367a..5aa91de52 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -8,22 +8,32 @@ import pytest import torch +import warnings import transformer_engine.common.recipe import transformer_engine.pytorch as te +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer import transformer_engine_torch as tex from transformer_engine.pytorch.fp8 import ( FP8GlobalStateManager, _amax_and_scale_update, - get_default_fp8_recipe, + fp8_model_init, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.utils import is_fp8_fnuz +from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear +from transformer_engine.pytorch.distributed import fp8_autocast +from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling import transformer_engine_torch as tex # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) # FP8 per tensor delayed scaling @@ -370,3 +380,127 @@ def setup_fp8_meta(): ) torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale) + + @pytest.mark.parametrize( + "model_init_recipe", + [ + pytest.param( + MXFP8BlockScaling(), + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + ), + pytest.param( + Float8BlockScaling(), + marks=pytest.mark.skipif( + not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling + ), + ), + ], + ) + def test_check_for_weight_tensor_and_recipe_correspondence(self, model_init_recipe): + with fp8_model_init(enabled=True, recipe=model_init_recipe): + linear = Linear(32, 32).cuda() + + x = torch.randn(32, 32, device="cuda") + with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling()): + with pytest.raises(RuntimeError) as excinfo: + _ = linear(x) + assert "Recipe mismatch for " in str(excinfo.value) + + @pytest.mark.parametrize( + "target_recipe_class, expected_quantizer_type, available_flag, reason", + [ + pytest.param( + MXFP8BlockScaling, + MXFP8Quantizer, + mxfp8_available, + reason_for_no_mxfp8, + id="DelayedScaling->MXFP8BlockScaling", + ), + pytest.param( + Float8BlockScaling, + Float8BlockQuantizer, + fp8_block_scaling_available, + reason_for_no_fp8_block_scaling, + id="DelayedScaling->Float8BlockScaling", + ), + ], + ) + def test_dynamic_recipe_update( + self, target_recipe_class, expected_quantizer_type, available_flag, reason + ): + if not available_flag: + pytest.skip(reason) + + in_features = 32 + out_features = 32 + batch_size = 32 + linear = Linear(in_features, out_features).cuda() + initial_recipe = DelayedScaling() + + # Run initial iterations with DelayedScaling + for _ in range(3): + x = torch.randn(batch_size, in_features, device="cuda") + with fp8_autocast(enabled=True, fp8_recipe=initial_recipe): + y = linear(x) + loss = y.mean() + loss.backward() + + for quantizer in linear.quantizers["scaling_fwd"]: + assert isinstance(quantizer, Float8Quantizer) + + # Change recipe + target_recipe = target_recipe_class() + + # Run subsequent iterations with the target recipe + for i in range(3): + x = torch.randn(batch_size, in_features, device="cuda") + if i == 0: + # Expect a warning on the first iteration with the new recipe + with pytest.warns(UserWarning, match="Recipe type changed"): + with fp8_autocast(enabled=True, fp8_recipe=target_recipe): + y = linear(x) + for quantizer in linear.quantizers["scaling_fwd"]: + assert isinstance(quantizer, expected_quantizer_type) + else: + # No warning expected on subsequent iterations + with warnings.catch_warnings(): + warnings.simplefilter("error") # Raise error if unexpected warning occurs + with fp8_autocast(enabled=True, fp8_recipe=target_recipe): + y = linear(x) + loss = y.mean() + loss.backward() + + # Final check + for quantizer in linear.quantizers["scaling_fwd"]: + assert isinstance(quantizer, expected_quantizer_type) + + @pytest.mark.parametrize( + "module_class", + [ + Linear, + LayerNormLinear, + LayerNormMLP, + GroupedLinear, + ], + ) + def test_quantizer_update(self, module_class): + in_features = 32 + out_features = 32 + batch_size = 32 + + recipe = DelayedScaling(amax_history_len=1024) + with fp8_model_init(recipe=recipe): + if module_class == GroupedLinear: + module = module_class(1, in_features, out_features).cuda() + else: + module = module_class(in_features, out_features).cuda() + + x = torch.randn(batch_size, in_features, device="cuda") + recipe = DelayedScaling(amax_history_len=1) + with fp8_autocast(enabled=True, fp8_recipe=recipe): + warn_msg = "Quantizer is being updated, this may affect model behavior" + with pytest.warns(UserWarning, match=warn_msg): + if module_class == GroupedLinear: + y = module(x, [batch_size]) + else: + y = module(x) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index c8e02ca49..9fbadd4b9 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -14,6 +14,7 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION import os +import transformer_engine.pytorch from transformer_engine.pytorch.fp8 import ( fp8_autocast, FP8GlobalStateManager, @@ -42,12 +43,14 @@ from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import ( - Float8Quantizer, Float8CurrentScalingQuantizer, + Float8Quantizer, + Float8Tensor, ) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.distributed import checkpoint -from test_numerics import reset_rng_states, dtype_tols +from utils import dtype_tols # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -56,6 +59,28 @@ ) mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +# Record initial RNG state from script run. +seed = 1234 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) +_cpu_rng_state = torch.get_rng_state() +_cuda_rng_state = torch.cuda.get_rng_state() + +NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0")) + + +if NVTE_TEST_NVINSPECT_ENABLED: + # The sanity tests should work the same, + # when debug=True. I fed them with dummy feature + # to prevent switching off debug, which can happen if + # no feature is active. + import nvdlfw_inspect.api as debug_api + + debug_api.initialize( + os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"], + feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"], + ) + def create_meta(scale_factor: float, size: int = 1): meta = tex.FP8TensorMeta() @@ -84,6 +109,13 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor: return torch.min(amax_history, dim=0).values +def reset_rng_states() -> None: + """revert back to initial RNG state.""" + global _cpu_rng_state, _cuda_rng_state + torch.set_rng_state(_cpu_rng_state) + torch.cuda.set_rng_state(_cuda_rng_state) + + @dataclass class ModelConfig: """Transformer model configuration""" @@ -524,6 +556,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias): + if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params: + pytest.skip("Quantized model parameters are not supported in debug mode.") config = model_configs[model] ffn_hidden_size = 4 * config.hidden_size num_tokens = bs * config.seq_len @@ -565,6 +599,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ def test_sanity_grouped_linear( dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split ): + if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params: + pytest.skip("FP8 model parameters are not supported in debug mode.") config = model_configs[model] ffn_hidden_size = 4 * config.hidden_size # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527. @@ -680,6 +716,8 @@ def test_sanity_gpt( if IS_HIP_EXTENSION and cpu_offload: pytest.skip("cpu_offloading not supported in rocm TE") + if cpu_offload and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("CPU offload is not supported in debug mode.") config = model_configs[model] if fp8_recipe is not None: @@ -1348,3 +1386,82 @@ def backward(ctx, grad_output): # Assert that gradients are the same torch.testing.assert_close(grad_checkpoint, grad_standard) + + +@pytest.mark.parametrize( + "module_name", + ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"), +) +@pytest.mark.parametrize( + "quantization", + (None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"), +) +def test_inference_mode( + module_name: str, + quantization: Optional[str], +) -> None: + """Test heuristics for initializing quantized weights""" + if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None: + pytest.skip("Quantized model parameters are not supported in debug mode.") + + # Tensor dimensions + sequence_length = 32 + hidden_size = 32 + + # Skip invalid configurations + if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: + pytest.skip(reason_for_no_fp8) + if quantization == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + # Construct quantization recipe + with_quantization = quantization not in (None, "None") + quantization_recipe = None + if quantization == "fp8_delayed_scaling": + quantization_recipe = recipe.DelayedScaling() + elif quantization == "fp8_current_scaling": + quantization_recipe = recipe.Float8CurrentScaling() + elif quantization == "mxfp8": + quantization_recipe = recipe.MXFP8BlockScaling() + + # Construct module + module = None + with torch.no_grad(): + with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe): + if module_name == "Linear": + module = Linear(hidden_size, hidden_size) + elif module_name == "LayerNormLinear": + module = LayerNormLinear(hidden_size, hidden_size) + elif module_name == "LayerNormMLP": + module = LayerNormMLP(hidden_size, hidden_size) + elif module_name == "GroupedLinear": + module = GroupedLinear(1, hidden_size, hidden_size) + elif module_name == "ops.Linear": + module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size) + + def check_weights(): + """Helper function to check that weight parameters have expected data""" + for param in module.parameters(): + if isinstance(param, Float8Tensor): + assert param._data is not None, "Missing FP8 data" + assert ( + param._transpose is None and param._transpose_invalid + ), "FP8 transpose is not expected for inference" + if isinstance(param, MXFP8Tensor): + assert param._rowwise_data is not None, "Missing row-wise MXFP8 data" + assert ( + param._columnwise_data is None + ), "Column-wise MXFP8 data is not expected for inference" + + # Check that modules have expected weights after initialization + check_weights() + + # Check that modules have expected weights after forward pass + with torch.inference_mode(): + x = torch.zeros(sequence_length, hidden_size, device="cuda") + kwargs = {} + if module_name == "GroupedLinear": + kwargs["m_splits"] = [sequence_length] + with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe): + y = module(x, **kwargs) + check_weights() diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 30b6737c0..0c50592bd 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -8,6 +8,8 @@ import torch +import transformer_engine +import transformer_engine.common.recipe import transformer_engine.pytorch as te from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type import transformer_engine_torch as tex @@ -87,3 +89,25 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: if dtype == torch.float8_e5m2 or dtype == torch.float8_e5m2fnuz: return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 raise ValueError(f"Unsupported dtype ({dtype})") + + +def make_recipe(name: Optional[str]) -> Optional[Recipe]: + """Make recipe for quantization scheme""" + if name is None: + return None + if name in ("fp8", "fp8_delayed_scaling"): + return transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + amax_history_len=8, + ) + if name == "fp8_current_scaling": + return transformer_engine.common.recipe.Float8CurrentScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + if name == "mxfp8": + return transformer_engine.common.recipe.MXFP8BlockScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + if name == "fp8_block_scaling": + return transformer_engine.common.recipe.Float8BlockScaling() + raise ValueError(f"Unsupported quantization scheme ({name})") diff --git a/transformer_engine/__init__.py b/transformer_engine/__init__.py index fa59a79b2..050abc8f7 100644 --- a/transformer_engine/__init__.py +++ b/transformer_engine/__init__.py @@ -8,6 +8,7 @@ # pylint: disable=unused-import +import os from importlib import metadata import os import transformer_engine.common @@ -30,17 +31,56 @@ try: if _use_pytorch: from . import pytorch -except (ImportError, FileNotFoundError): +except ImportError: pass +except FileNotFoundError as e: + if "Could not find shared object file" not in str(e): + raise e # Unexpected error + else: + if os.getenv("NVTE_FRAMEWORK"): + frameworks = os.getenv("NVTE_FRAMEWORK").split(",") + if "pytorch" in frameworks or "all" in frameworks: + raise e + else: + # If we got here, we could import `torch` but could not load the framework extension. + # This can happen when a user wants to work only with `transformer_engine.jax` on a system that + # also has a PyTorch installation. In order to enable that use case, we issue a warning here + # about the missing PyTorch extension in case the user hasn't set NVTE_FRAMEWORK. + import warnings -try: - if _use_jax: from . import jax -except (ImportError, FileNotFoundError): - pass + warnings.warn( + "Detected a PyTorch installation but could not find the shared object file for the " + "Transformer Engine PyTorch extension library. If this is not intentional, please " + "reinstall Transformer Engine with `pip install transformer_engine[pytorch]` or " + "build from source with `NVTE_FRAMEWORK=pytorch`.", + category=RuntimeWarning, + ) try: - if _use_jax: import transformer_engine_jax + if _use_jax: from . import jax except ImportError: pass +except FileNotFoundError as e: + if "Could not find shared object file" not in str(e): + raise e # Unexpected error + else: + if os.getenv("NVTE_FRAMEWORK"): + frameworks = os.getenv("NVTE_FRAMEWORK").split(",") + if "jax" in frameworks or "all" in frameworks: + raise e + else: + # If we got here, we could import `jax` but could not load the framework extension. + # This can happen when a user wants to work only with `transformer_engine.pytorch` on a system + # that also has a Jax installation. In order to enable that use case, we issue a warning here + # about the missing Jax extension in case the user hasn't set NVTE_FRAMEWORK. + import warnings + + warnings.warn( + "Detected a Jax installation but could not find the shared object file for the " + "Transformer Engine Jax extension library. If this is not intentional, please " + "reinstall Transformer Engine with `pip install transformer_engine[jax]` or " + "build from source with `NVTE_FRAMEWORK=jax`.", + category=RuntimeWarning, + ) __version__ = str(metadata.version("transformer_engine")) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 5b0f1981d..55fbdb996 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -34,43 +34,46 @@ endif() # Language options if(USE_CUDA) - if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) - set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120) - else () - set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) - endif() - endif() - set(CMAKE_CXX_STANDARD 17) - set(CMAKE_CUDA_STANDARD 17) - set(CMAKE_CUDA_STANDARD_REQUIRED ON) - if (CMAKE_BUILD_TYPE STREQUAL "Debug") - set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G") +# Removed indent to minimize code diff with NV upstream +if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0) + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) + elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) + set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120) + else () + set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) endif() +endif() +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) +if (CMAKE_BUILD_TYPE STREQUAL "Debug") + set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G") +endif() - # Hide non-necessary symbols in shared object. - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version") - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version") - - # Transformer Engine library - project(transformer_engine LANGUAGES CUDA CXX) - - # CUDA Toolkit - find_package(CUDAToolkit REQUIRED) - if (CUDAToolkit_VERSION VERSION_LESS 12.0) - message(FATAL_ERROR "CUDA 12.0+ is required, but found CUDA ${CUDAToolkit_VERSION}") - endif() +# Hide non-necessary symbols in shared object. +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version") - # cuDNN frontend API - set(CUDNN_FRONTEND_INCLUDE_DIR - "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include") - if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}") - message(FATAL_ERROR - "Could not find cuDNN frontend API at ${CUDNN_FRONTEND_INCLUDE_DIR}. " - "Try running 'git submodule update --init --recursive' " - "within the Transformer Engine source.") - endif() - include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) +# Transformer Engine library +project(transformer_engine LANGUAGES CUDA CXX) + +# CUDA Toolkit +find_package(CUDAToolkit REQUIRED) +if (CUDAToolkit_VERSION VERSION_LESS 12.0) + message(FATAL_ERROR "CUDA 12.0+ is required, but found CUDA ${CUDAToolkit_VERSION}") +endif() + +# cuDNN frontend API +set(CUDNN_FRONTEND_INCLUDE_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include") +if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}") + message(FATAL_ERROR + "Could not find cuDNN frontend API at ${CUDNN_FRONTEND_INCLUDE_DIR}. " + "Try running 'git submodule update --init --recursive' " + "within the Transformer Engine source.") +endif() +include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) else() set(CMAKE_CXX_STANDARD 17) project(transformer_engine LANGUAGES HIP CXX) @@ -114,110 +117,78 @@ find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) # Configure Transformer Engine library include_directories(${PROJECT_SOURCE_DIR}/..) set(transformer_engine_SOURCES) +# Source files in both cuda and rocm +list(APPEND transformer_engine_SOURCES + transformer_engine.cpp + common.cu + multi_tensor/adam.cu + multi_tensor/compute_scale.cu + multi_tensor/l2norm.cu + multi_tensor/scale.cu + multi_tensor/sgd.cu + transpose/cast_transpose.cu + transpose/transpose.cu + transpose/cast_transpose_fusion.cu + transpose/transpose_fusion.cu + transpose/multi_cast_transpose.cu + activation/gelu.cu + fused_attn/flash_attn.cu + fused_attn/context_parallel.cu + fused_attn/kv_cache.cu + activation/relu.cu + activation/swiglu.cu + gemm/cublaslt_gemm.cu + normalization/common.cpp + normalization/layernorm/ln_api.cpp + normalization/layernorm/ln_bwd_semi_cuda_kernel.cu + normalization/layernorm/ln_fwd_cuda_kernel.cu + normalization/rmsnorm/rmsnorm_api.cpp + normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu + normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu + permutation/permutation.cu + util/cast.cu + util/padding.cu + util/cuda_driver.cpp + util/cuda_runtime.cpp + util/multi_stream.cpp + util/rtc.cpp + swizzle/swizzle.cu + fused_softmax/scaled_masked_softmax.cu + fused_softmax/scaled_upper_triang_masked_softmax.cu + fused_softmax/scaled_aligned_causal_masked_softmax.cu + fused_rope/fused_rope.cu + fused_router/fused_moe_aux_loss.cu + fused_router/fused_score_for_moe_aux_loss.cu + fused_router/fused_topk_with_score_function.cu + recipe/current_scaling.cu + recipe/delayed_scaling.cu + recipe/fp8_block_scaling.cu) if(USE_CUDA) - list(APPEND transformer_engine_SOURCES - cudnn_utils.cpp - transformer_engine.cpp - common.cu - multi_tensor/adam.cu - multi_tensor/compute_scale.cu - multi_tensor/l2norm.cu - multi_tensor/scale.cu - multi_tensor/sgd.cu - transpose/cast_transpose.cu - transpose/transpose.cu - transpose/cast_transpose_fusion.cu - transpose/transpose_fusion.cu - transpose/multi_cast_transpose.cu - transpose/quantize_transpose_square_blockwise.cu - transpose/quantize_transpose_vector_blockwise.cu - activation/gelu.cu - fused_attn/flash_attn.cu - fused_attn/context_parallel.cu - fused_attn/kv_cache.cu - fused_attn/fused_attn_f16_max512_seqlen.cu - fused_attn/fused_attn_f16_arbitrary_seqlen.cu - activation/relu.cu - activation/swiglu.cu - fused_attn/fused_attn_fp8.cu - fused_attn/fused_attn.cpp - fused_attn/utils.cu - gemm/cublaslt_gemm.cu - normalization/common.cpp - normalization/layernorm/ln_api.cpp - normalization/layernorm/ln_bwd_semi_cuda_kernel.cu - normalization/layernorm/ln_fwd_cuda_kernel.cu - normalization/rmsnorm/rmsnorm_api.cpp - normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu - normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu - permutation/permutation.cu - util/cast.cu - util/padding.cu - util/cuda_driver.cpp - util/cuda_nvml.cpp - util/cuda_runtime.cpp - util/rtc.cpp - swizzle/swizzle.cu - fused_softmax/scaled_masked_softmax.cu - fused_softmax/scaled_upper_triang_masked_softmax.cu - fused_softmax/scaled_aligned_causal_masked_softmax.cu - fused_rope/fused_rope.cu - recipe/current_scaling.cu - recipe/delayed_scaling.cu - recipe/fp8_block_scaling.cu - comm_gemm_overlap/userbuffers/ipcsocket.cc - comm_gemm_overlap/userbuffers/userbuffers-host.cpp - comm_gemm_overlap/userbuffers/userbuffers.cu - comm_gemm_overlap/comm_gemm_overlap.cpp) - add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) +# Removed indent to minimize code diff with NV upstream +# Files unique in cuda building +list(APPEND transformer_engine_SOURCES + cudnn_utils.cpp + transpose/quantize_transpose_square_blockwise.cu + transpose/quantize_transpose_vector_blockwise.cu + fused_attn/fused_attn_f16_max512_seqlen.cu + fused_attn/fused_attn_f16_arbitrary_seqlen.cu + fused_attn/fused_attn_fp8.cu + fused_attn/fused_attn.cpp + fused_attn/utils.cu + util/cuda_nvml.cpp + comm_gemm_overlap/userbuffers/ipcsocket.cc + comm_gemm_overlap/userbuffers/userbuffers-host.cpp + comm_gemm_overlap/userbuffers/userbuffers.cu + comm_gemm_overlap/comm_gemm_overlap.cpp) +add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) else() list(APPEND transformer_engine_SOURCES - transformer_engine.cpp - common.cu - multi_tensor/adam.cu - multi_tensor/compute_scale.cu - multi_tensor/l2norm.cu - multi_tensor/scale.cu - multi_tensor/sgd.cu - transpose/cast_transpose.cu - transpose/transpose.cu - transpose/cast_transpose_fusion.cu - transpose/transpose_fusion.cu - transpose/multi_cast_transpose.cu - activation/gelu.cu - activation/relu.cu - activation/swiglu.cu - fused_attn/flash_attn.cu - fused_attn/context_parallel.cu - fused_attn/kv_cache.cu fused_attn_rocm/fused_attn.cpp fused_attn_rocm/fused_attn_aotriton.cpp fused_attn_rocm/fused_attn_ck.cpp fused_attn_rocm/utils.cpp - gemm/cublaslt_gemm.cu gemm/rocm_gemm.cu - normalization/common.cpp - normalization/layernorm/ln_api.cpp - normalization/layernorm/ln_bwd_semi_cuda_kernel.cu - normalization/layernorm/ln_fwd_cuda_kernel.cu - normalization/rmsnorm/rmsnorm_api.cpp - normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu - normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu - permutation/permutation.cu - util/cast.cu - util/padding.cu - util/cuda_driver.cpp - util/cuda_runtime.cpp - util/rtc.cpp - amd_detail/system.cpp - swizzle/swizzle.cu - fused_softmax/scaled_masked_softmax.cu - fused_softmax/scaled_upper_triang_masked_softmax.cu - fused_softmax/scaled_aligned_causal_masked_softmax.cu - fused_rope/fused_rope.cu - recipe/current_scaling.cu - recipe/delayed_scaling.cu - recipe/fp8_block_scaling.cu) + amd_detail/system.cpp) # process source code files set(TE ${CMAKE_CURRENT_SOURCE_DIR}/../..) @@ -238,6 +209,8 @@ else() IGNORES "*/amd_detail/*" IGNORES "*/aotriton/*" IGNORES "*/ck_fused_attn/*" + IGNORES "*/pytorch/csrc/*" + IGNORES "*/jax/csrc/*" CUSTOM_MAP_FILE "${TE}/hipify_custom_map.json" NO_MATH_REPLACE ) @@ -255,26 +228,33 @@ target_include_directories(transformer_engine PUBLIC # Configure dependencies if (USE_CUDA) - target_link_libraries(transformer_engine PUBLIC +target_link_libraries(transformer_engine PUBLIC CUDA::cublas CUDA::cudart CUDNN::cudnn_all) - target_include_directories(transformer_engine PRIVATE +target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) - target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") - - # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI - # Changed - option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) - if (NVTE_UB_WITH_MPI) - find_package(MPI REQUIRED) - target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX) - target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES}) - target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI) - endif() - - # Hack to enable dynamic loading in cuDNN frontend - target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) +target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") + +# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI +option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) +if (NVTE_UB_WITH_MPI) + find_package(MPI REQUIRED) + target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX) + target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES}) + target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI) +endif() + +option(NVTE_ENABLE_NVSHMEM "Compile with NVSHMEM library" OFF) +if (NVTE_ENABLE_NVSHMEM) + add_subdirectory(nvshmem_api) + target_link_libraries(transformer_engine PUBLIC nvshmemapi) + target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR}) +endif() + +# Hack to enable dynamic loading in cuDNN frontend +target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) + else() set(GPU_TARGETS ${CMAKE_HIP_ARCHITECTURES}) @@ -372,18 +352,6 @@ else() target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS}) endif() -if(USE_CUDA) - option(NVTE_ENABLE_NVSHMEM "Compile with NVSHMEM library" OFF) - if (NVTE_ENABLE_NVSHMEM) - add_subdirectory(nvshmem_api) - target_link_libraries(transformer_engine PUBLIC nvshmemapi) - target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR}) - endif() - -# Hack to enable dynamic loading in cuDNN frontend -target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) -endif() - # Helper functions to make header files with C++ strings function(make_string_header STRING STRING_NAME) configure_file(util/string_header.h.in @@ -398,18 +366,18 @@ function(make_string_header_from_file file_ STRING_NAME) endfunction() if(USE_CUDA) - # Header files with C++ strings - list(GET CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES 0 cuda_include_path) - make_string_header("${cuda_include_path}" - string_path_cuda_include) - make_string_header_from_file(transpose/rtc/cast_transpose_fusion.cu - string_code_transpose_rtc_cast_transpose_fusion_cu) - make_string_header_from_file(transpose/rtc/cast_transpose.cu - string_code_transpose_rtc_cast_transpose_cu) - make_string_header_from_file(transpose/rtc/transpose.cu - string_code_transpose_rtc_transpose_cu) - make_string_header_from_file(utils.cuh - string_code_utils_cuh) +# Header files with C++ strings +list(GET CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES 0 cuda_include_path) +make_string_header("${cuda_include_path}" + string_path_cuda_include) +make_string_header_from_file(transpose/rtc/cast_transpose_fusion.cu + string_code_transpose_rtc_cast_transpose_fusion_cu) +make_string_header_from_file(transpose/rtc/cast_transpose.cu + string_code_transpose_rtc_cast_transpose_cu) +make_string_header_from_file(transpose/rtc/transpose.cu + string_code_transpose_rtc_transpose_cu) +make_string_header_from_file(utils.cuh + string_code_utils_cuh) else() make_string_header_from_file(utils_hip.cuh string_code_utils_cuh) @@ -444,34 +412,34 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu PROPERTIES COMPILE_OPTIONS "--use_fast_math") if(USE_CUDA) - option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) - if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) - set_source_files_properties(activation/gelu.cu - activation/relu.cu - activation/swiglu.cu - PROPERTIES - COMPILE_OPTIONS "--use_fast_math") - endif() - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") - - # Number of parallel build jobs - if(ENV{MAX_JOBS}) - set(BUILD_JOBS_STR "$ENV{MAX_JOBS}") - elseif(ENV{NVTE_BUILD_MAX_JOBS}) - set(BUILD_JOBS_STR "$ENV{NVTE_BUILD_MAX_JOBS}") - else() - set(BUILD_JOBS_STR "max") - endif() - message(STATUS "Parallel build jobs: ${BUILD_JOBS_STR}") - - # Number of threads per parallel build job - set(BUILD_THREADS_PER_JOB $ENV{NVTE_BUILD_THREADS_PER_JOB}) - if (NOT BUILD_THREADS_PER_JOB) - set(BUILD_THREADS_PER_JOB 1) - endif() - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --threads ${BUILD_THREADS_PER_JOB}") - message(STATUS "Threads per parallel build job: ${BUILD_THREADS_PER_JOB}") +option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) +if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) + set_source_files_properties(activation/gelu.cu + activation/relu.cu + activation/swiglu.cu + PROPERTIES + COMPILE_OPTIONS "--use_fast_math") +endif() +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") + +# Number of parallel build jobs +if(ENV{MAX_JOBS}) + set(BUILD_JOBS_STR "$ENV{MAX_JOBS}") +elseif(ENV{NVTE_BUILD_MAX_JOBS}) + set(BUILD_JOBS_STR "$ENV{NVTE_BUILD_MAX_JOBS}") +else() + set(BUILD_JOBS_STR "max") +endif() +message(STATUS "Parallel build jobs: ${BUILD_JOBS_STR}") + +# Number of threads per parallel build job +set(BUILD_THREADS_PER_JOB $ENV{NVTE_BUILD_THREADS_PER_JOB}) +if (NOT BUILD_THREADS_PER_JOB) + set(BUILD_THREADS_PER_JOB 1) +endif() +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --threads ${BUILD_THREADS_PER_JOB}") +message(STATUS "Threads per parallel build job: ${BUILD_THREADS_PER_JOB}") else() set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -O3") set(HIP_HCC_FLAGS "${CMAKE_HIP_FLAGS} -mavx2 -mf16c -mfma -std=c++17") @@ -485,9 +453,8 @@ endif() # Install library install(TARGETS transformer_engine DESTINATION .) -set_target_properties(transformer_engine PROPERTIES INSTALL_RPATH "$ORIGIN/lib;$ORIGIN/transformer_engine/lib") - if (USE_ROCM) + set_target_properties(transformer_engine PROPERTIES INSTALL_RPATH "$ORIGIN/lib;$ORIGIN/transformer_engine/lib") if("$ENV{ROCM_PATH}" STREQUAL "") set(ROCM_PATH "/opt/rocm") else() diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index c72eca543..8a73138e3 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -6,28 +6,27 @@ """FW agnostic user-end APIs""" +import ctypes import functools -import sys import glob -import sysconfig -import subprocess -import ctypes +import importlib +from importlib.metadata import version, metadata, PackageNotFoundError import logging import os -import platform -import importlib -import functools from pathlib import Path -from importlib.metadata import version, metadata, PackageNotFoundError +import platform +import subprocess +import sys +import sysconfig +from typing import Optional import transformer_engine - _logger = logging.getLogger(__name__) @functools.lru_cache(maxsize=None) -def _is_pip_package_installed(package): +def _is_pip_package_installed(package) -> bool: """Check if the given package is installed via pip.""" # This is needed because we only want to return true @@ -42,37 +41,37 @@ def _is_pip_package_installed(package): @functools.lru_cache(maxsize=None) -def _find_shared_object_in_te_dir(te_path: Path, prefix: str): +def _find_shared_object_in_te_dir(te_path: Path, prefix: str) -> Optional[Path]: """ - Find a shared object file of given prefix in the top level TE directory. - Only the following locations are searched to avoid stray SOs and build - artifacts: - 1. The given top level directory (editable install). - 2. `transformer_engine` named directories (source install). - 3. `wheel_lib` named directories (PyPI install). + Find a shared object file with the given prefix within the top level TE directory. + + The following locations are searched: + 1. Top level directory (editable install). + 2. `transformer_engine` directory (source install). + 3. `wheel_lib` directory (PyPI install). Returns None if no shared object files are found. Raises an error if multiple shared object files are found. """ - # Ensure top level dir exists and has the module. before searching. - if not te_path.exists() or not (te_path / "transformer_engine").exists(): + # Ensure top level dir exists and has the module before searching. + if not te_path.is_dir() or not (te_path / "transformer_engine").exists(): return None files = [] search_paths = ( - te_path, - te_path / "transformer_engine", - te_path / "transformer_engine/wheel_lib", - te_path / "wheel_lib", + te_path, # Editable build. + te_path / "transformer_engine", # Regular source build. + te_path / "transformer_engine/wheel_lib", # PyPI. ) # Search. - for dirname, _, names in os.walk(te_path): - if Path(dirname) in search_paths: - for name in names: - if name.startswith(prefix) and name.endswith(f".{_get_sys_extension()}"): - files.append(Path(dirname, name)) + for dir_path in search_paths: + if not dir_path.is_dir(): + continue + for file_path in dir_path.iterdir(): + if file_path.name.startswith(prefix) and file_path.suffix == _get_sys_extension(): + files.append(file_path) if len(files) == 0: return None @@ -84,16 +83,12 @@ def _find_shared_object_in_te_dir(te_path: Path, prefix: str): @functools.lru_cache(maxsize=None) def _get_shared_object_file(library: str) -> Path: """ - Return the path of the shared object file for the given TE - library, one of 'core', 'torch', or 'jax'. - - Several factors affect finding the correct location of the shared object: - 1. System and environment. - 2. If the installation is from source or via PyPI. - - Source installed .sos are placed in top level dir - - Wheel/PyPI installed .sos are placed in 'wheel_lib' dir to avoid conflicts. - 3. For source installations, is the install editable/inplace? - 4. The user directory from where TE is being imported. + Path to shared object file for a Transformer Engine library. + + TE libraries are 'core', 'torch', or 'jax'. This function first + searches in the imported TE directory, and then in the + site-packages directory. + """ # Check provided input and determine the correct prefix for .so. @@ -103,46 +98,25 @@ def _get_shared_object_file(library: str) -> Path: else: so_prefix = f"transformer_engine_{library}" - # Check TE install location (will be local if TE is available in current dir for import). - te_install_dir = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent - so_path_in_install_dir = _find_shared_object_in_te_dir(te_install_dir, so_prefix) - - # Check default python package install location in system. - site_packages_dir = Path(sysconfig.get_paths()["purelib"]) - so_path_in_default_dir = _find_shared_object_in_te_dir(site_packages_dir, so_prefix) - - # Case 1: Typical user workflow: Both locations are the same, return any result. - if te_install_dir == site_packages_dir: - if so_path_in_install_dir is not None: - return so_path_in_install_dir - raise FileNotFoundError(f"Could not find shared object file for Transformer Engine {library} lib.") - - # Case 2: ERR! Both locations are different but returned a valid result. - # NOTE: Unlike for source installations, pip does not wipe out artifacts from - # editable builds. In case developers are executing inside a TE directory via - # an inplace build, and then move to a regular build, the local shared object - # file will be incorrectly picked up without the following logic. - if so_path_in_install_dir is not None and so_path_in_default_dir is not None: - raise RuntimeError( - f"Found multiple shared object files: {so_path_in_install_dir} and" - f" {so_path_in_default_dir}. Remove local shared objects installed" - f" here {so_path_in_install_dir} or change the working directory to" - "execute from outside TE." - ) - - # Case 3: Typical dev workflow: Editable install - if so_path_in_install_dir is not None: - return so_path_in_install_dir + # Search for shared lib in imported directory + te_path = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent + so_path = _find_shared_object_in_te_dir(te_path, so_prefix) + if so_path is not None: + return so_path - # Case 4: Executing from inside a TE directory without an inplace build available. - if so_path_in_default_dir is not None: - return so_path_in_default_dir + # Search for shared lib in site-packages directory + te_path = Path(sysconfig.get_paths()["purelib"]) + so_path = _find_shared_object_in_te_dir(te_path, so_prefix) + if so_path is not None: + return so_path - raise FileNotFoundError(f"Could not find shared object file for Transformer Engine {library} lib.") + raise FileNotFoundError( + f"Could not find shared object file for Transformer Engine {library} lib." + ) @functools.lru_cache(maxsize=None) -def load_framework_extension(framework: str): +def load_framework_extension(framework: str) -> None: """ Load shared library with Transformer Engine framework bindings and check verify correctness if installed via PyPI. @@ -164,7 +138,6 @@ def load_framework_extension(framework: str): # If the framework extension pip package is installed, it means that TE is installed via # PyPI. For this case we need to make sure that the metapackage, the core lib, and framework # extension are all installed via PyPI and have matching version. - ''' if _is_pip_package_installed(module_name): assert _is_pip_package_installed( "transformer_engine" @@ -183,7 +156,6 @@ def load_framework_extension(framework: str): f" v{version(f'transformer-engine-{te_cuda_vers}')}. Install transformer-engine using " f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'" ) - ''' # If the core package is installed via PyPI, log if # the framework extension is not found from PyPI. @@ -204,18 +176,17 @@ def load_framework_extension(framework: str): @functools.lru_cache(maxsize=None) -def _get_sys_extension(): +def _get_sys_extension() -> str: + """File extension for shared objects.""" system = platform.system() - if system == "Linux": - extension = "so" - elif system == "Darwin": - extension = "dylib" - elif system == "Windows": - extension = "dll" - else: - raise RuntimeError(f"Unsupported operating system ({system})") - return extension + if system == "Linux": + return ".so" + if system == "Darwin": + return ".dylib" + if system == "Windows": + return ".dll" + raise RuntimeError(f"Unsupported operating system ({system})") @functools.lru_cache(maxsize=None) @@ -229,7 +200,7 @@ def _load_nvidia_cuda_library(lib_name: str): so_paths = glob.glob( os.path.join( sysconfig.get_path("purelib"), - f"nvidia/{lib_name}/lib/lib*.{_get_sys_extension()}.*[0-9]", + f"nvidia/{lib_name}/lib/lib*{_get_sys_extension()}.*[0-9]", ) ) @@ -244,7 +215,7 @@ def _load_nvidia_cuda_library(lib_name: str): @functools.lru_cache(maxsize=None) -def _nvidia_cudart_include_dir(): +def _nvidia_cudart_include_dir() -> str: """Returns the include directory for cuda_runtime.h if exists in python environment.""" try: @@ -263,14 +234,14 @@ def _load_cudnn(): # Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH") if cudnn_home: - libs = glob.glob(f"{cudnn_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True) + libs = glob.glob(f"{cudnn_home}/**/libcudnn{_get_sys_extension()}*", recursive=True) libs.sort(reverse=True, key=os.path.basename) if libs: return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) # Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" - libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True) + libs = glob.glob(f"{cuda_home}/**/libcudnn{_get_sys_extension()}*", recursive=True) libs.sort(reverse=True, key=os.path.basename) if libs: return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) @@ -281,7 +252,7 @@ def _load_cudnn(): return handle # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise - return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) + return ctypes.CDLL(f"libcudnn{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) @functools.lru_cache(maxsize=None) @@ -289,7 +260,7 @@ def _load_nvrtc(): """Load NVRTC shared library.""" # Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" - libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True) + libs = glob.glob(f"{cuda_home}/**/libnvrtc{_get_sys_extension()}*", recursive=True) libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs)) libs.sort(reverse=True, key=os.path.basename) if libs: @@ -313,7 +284,7 @@ def _load_nvrtc(): return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise - return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) + return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) te_rocm_build = False diff --git a/transformer_engine/common/amd_detail/hip_float8.h b/transformer_engine/common/amd_detail/hip_float8.h index 4ae6b09ce..eea337489 100644 --- a/transformer_engine/common/amd_detail/hip_float8.h +++ b/transformer_engine/common/amd_detail/hip_float8.h @@ -4,7 +4,6 @@ * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ #pragma once - #include #if !defined(__HIP_DEVICE_COMPILE__) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 98a970c98..40595ea98 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -196,7 +196,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz if (param_type == NVTETensorParam::kNVTERowwiseData || param_type == NVTETensorParam::kNVTEColumnwiseData) { // Offset data pointer - param_dptr += chunk_offset * typeToSize(param_dtype); + param_dptr += get_buffer_size_bytes(chunk_offset, param_dtype); param_shape = chunk_shape; if (param_type == NVTETensorParam::kNVTEColumnwiseData && @@ -217,7 +217,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz } else { chunk_scale_height /= 32; } - param_dptr += (chunk_offset / 32) * typeToSize(param_dtype); + param_dptr += get_buffer_size_bytes(chunk_offset / 32, param_dtype); param_shape = {chunk_scale_height, chunk_scale_width}; } @@ -236,7 +236,7 @@ TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source auto chunk = get_tensor_chunk(source, chunk_offset, chunk_shape); // Update chunk with offset data pointers from the communication buffer - auto ubuf_ptr = reinterpret_cast(_ubuf.dptr()) + (chunk_offset * _ubuf.element_size()); + auto ubuf_ptr = reinterpret_cast(_ubuf.dptr()) + chunk_offset * _ubuf.element_size(); if (chunk.dptr() != nullptr) { chunk.set_rowwise_data(reinterpret_cast(ubuf_ptr), chunk.dtype(), chunk.shape()); } @@ -269,7 +269,7 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType "or 2 (multi-atomic)."); NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); - size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype); + size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); void *buffer_ptr; _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); @@ -306,7 +306,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_comm, 0)); // Communication: AG and RS - int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size + int comm_elements = _ubuf.bytes() / 2; // UBUF uses 2Byte element size if (comm_type == CommOverlapType::AG) { allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, (cudaEvent_t)_comm_launch_event); @@ -606,7 +606,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, // Create workspace tensor with userbuffer NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); - size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype); + size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); int buffer_chunk_bytes = buffer_bytes / tp_size; _num_ubuf_chunks = tp_size; if (_is_reduce_scatter) { @@ -704,7 +704,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( assert(pre_gelu_out.numel() == 0); // Get communication and GEMM output chunk sizes - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const int comm_bytes = _ubufs[0].bytes(); // Create an GEMM output buffer with N+1 chunks in a contiguous memory void *D_buffer_ptr; @@ -762,21 +762,20 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( if (B_copy.numel() > 0) { assert(B_copy.numel() == _ubufs[_self_chunk_id].numel()); assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size()); - NVTE_CHECK_CUDA( - cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(), - _ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(), - cudaMemcpyDeviceToDevice, _stream_send[0])); + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(), + _ubufs[_self_chunk_id].bytes(), cudaMemcpyDeviceToDevice, + _stream_send[0])); NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); } // Copy the first GEMM output chunk to the end chunk position of D_buffer char *src_ptr = reinterpret_cast(D_buffer.dptr()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, D_chunk_bytes, + NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + D.bytes(), src_ptr, D_chunk_bytes, cudaMemcpyDeviceToDevice, stream_main)); // Return the last N rows of D_buffer - NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.numel() * D.element_size(), + NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.bytes(), cudaMemcpyDeviceToDevice, stream_main)); // Clean up buffer allocation @@ -806,7 +805,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, const size_t n_chunk = _ubufs[0].size(0); // Get communication and GEMM output chunk sizes - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const int comm_bytes = _ubufs[0].bytes(); const bool do_gelu = pre_gelu_out.numel() > 0; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); @@ -822,7 +821,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, // Chunk dims std::vector input_b_chunk_shape = (transb ? std::vector{k, 2 * n_chunk} : std::vector{2 * n_chunk, k}); - std::vector output_chunk_shape = {2 * n_chunk, k}; + std::vector output_chunk_shape = {2 * n_chunk, m}; size_t input_b_chunk_size = 2 * n_chunk * k; size_t output_chunk_size = 2 * n_chunk * m; @@ -853,13 +852,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, // GEMM auto input_b_chunk = - get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape); + get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id / 2, input_b_chunk_shape); auto output_chunk = - get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape); - auto aux_chunk = - (do_gelu) - ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k}) - : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + get_tensor_chunk(D, output_chunk_size * send_chunk_id / 2, output_chunk_shape); + auto aux_chunk = (do_gelu) + ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id / 2, + {n_chunk * 2, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); auto workspace_chunk = get_tensor_chunk( workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); @@ -882,8 +881,8 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, assert(B_copy.numel() == _ubufs[_tp_id].numel()); assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, _stream_send[0])); + _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, + _stream_send[0])); } } } else { @@ -935,8 +934,8 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, assert(B_copy.numel() == _ubufs[_tp_id].numel()); assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, _stream_send[0])); + _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, + _stream_send[0])); } } } @@ -966,7 +965,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs( _ub_comm->cga_size = _cga_size; // Get communication and GEMM input chunk sizes - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const int comm_bytes = _ubufs[0].bytes(); // Reset counters int *counter_ptr = reinterpret_cast(_counter.dptr()); @@ -1033,7 +1032,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, size_t m = transa ? A.size(0) : A.size(1); size_t k = transa ? A.size(1) : A.size(0); size_t n_chunk = _ubufs[0].size(0); - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const int comm_bytes = _ubufs[0].bytes(); // Get input and workspace data pointers size_t input_chunk_size = n_chunk * k; diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index e52cdd8a1..65da58d5f 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -100,6 +100,16 @@ bool has_mnnvl_fabric(int device_id) { } return false; #else + // Check run-time CUDA version + if (transformer_engine::cuda::cudart_version() < 12040) { + if (getenv("NVTE_UBDEBUG")) { + printf( + "TransformerEngine does not support multi-node NVLINK " + "since it is not being run with CUDA version >= 12.4.\n"); + } + return false; + } + bool mnnvl_fabric_support = false; CUdevice dev; NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &dev, device_id); @@ -248,7 +258,8 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, CUmemFabricHandle *tmphndl = reinterpret_cast(malloc(sizeof(CUmemFabricHandle))); CUmemFabricHandle *exphndls; - NVTE_CHECK_CUDA(cudaMallocHost(&exphndls, (*comm)->nvsize * sizeof(CUmemFabricHandle))); + NVTE_CHECK_CUDA(cudaMallocHost(reinterpret_cast(&exphndls), + (*comm)->nvsize * sizeof(CUmemFabricHandle))); if ((*comm)->ar2_nvrank == 0) NVTE_CALL_CHECK_CUDA_DRIVER(cuMemExportToShareableHandle, static_cast(tmphndl), (*comm)->mc_handle, CU_MEM_HANDLE_TYPE_FABRIC, 0); @@ -345,8 +356,10 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, NVTE_CHECK_CUDA(cudaDeviceSynchronize()); register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, true); - NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int))); - NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); + NVTE_CHECK_CUDA( + cudaMalloc(reinterpret_cast(&(*comm)->send_id), (*comm)->nranks * sizeof(int))); + NVTE_CHECK_CUDA(cudaMalloc(reinterpret_cast(&(*comm)->recv_id), + NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int))); NVTE_CHECK_CUDA( cudaMemset((*comm)->recv_id, 0, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); @@ -358,10 +371,11 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, #define GPU_PAGE_OFFSET (GPU_PAGE_SIZE - 1) #define GPU_PAGE_MASK (~GPU_PAGE_OFFSET) - NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE)); - NVTE_CHECK_CUDA(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE)); - (*comm)->flags = - reinterpret_cast(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); + NVTE_CHECK_CUDA( + cudaMalloc(reinterpret_cast(&(*comm)->flags_baseptr), 2 * GPU_PAGE_SIZE)); + NVTE_CHECK_CUDA(cudaMemset((*comm)->flags_baseptr, 0, 2 * GPU_PAGE_SIZE)); + (*comm)->flags = reinterpret_cast( + ((CUdeviceptr)(*comm)->flags_baseptr + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); using namespace std; @@ -438,20 +452,31 @@ int create_communicator_mpi(communicator **comm) { } void destroy_communicator(communicator *comm) { - for (int hndl = 0; hndl < comm->free_region; hndl++) { + // Clear memory allocated in register_user_buffer_collective calls + for (int hndl = comm->free_region - 1; hndl >= 0; hndl--) { if (comm->use_mc && comm->mem_dealloc[hndl]) { + // Unbind the local device buffer from the Multicast handle + CUdevice dev; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &dev, comm->mydev); + NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastUnbind, comm->mc_handle, dev, comm->uc_offsets[hndl], + comm->mem_size[hndl]); + + // Unmap memory addresses and release handles for both peer and own buffers for (int rank = 0; rank < comm->nvsize; rank++) { - if (rank == comm->nvrank) { - NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->uchandles[hndl][rank]); - } else { - comm->uchandles[hndl][rank] = 0; - } + NVTE_CALL_CHECK_CUDA_DRIVER(cuMemUnmap, + reinterpret_cast(comm->peer_ptr[hndl][rank]), + comm->mem_size[hndl]); + NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->uchandles[hndl][rank]); } free(reinterpret_cast(comm->uchandles[hndl])); + + // Free memory reserved for buffer allocations + NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressFree, comm->ucbase_ptr[hndl], + static_cast(comm->mem_size[hndl] * comm->nvsize)); } else { for (int rank = 0; rank < comm->nvsize; rank++) { if (rank != comm->nvrank) { - cudaIpcCloseMemHandle(comm->peer_ptr[hndl][rank]); + NVTE_CHECK_CUDA(cudaIpcCloseMemHandle(comm->peer_ptr[hndl][rank])); } else if (comm->mem_dealloc[hndl]) { NVTE_CHECK_CUDA(cudaFree(comm->peer_ptr[hndl][rank])); } else { @@ -460,11 +485,16 @@ void destroy_communicator(communicator *comm) { } } free(comm->peer_ptr[hndl]); - comm->mem_ptr[hndl] = nullptr; + comm->mem_ptr[hndl] = nullptr; // this points to already cleaned up local device buffer } - cudaFree(reinterpret_cast(comm->recv_id)); - cudaFree(reinterpret_cast(comm->send_id)); + // Clear memory allocated in the communicator constructor + NVTE_CHECK_CUDA(cudaFree(reinterpret_cast(comm->recv_id))); + NVTE_CHECK_CUDA(cudaFree(reinterpret_cast(comm->send_id))); + NVTE_CHECK_CUDA(cudaFree(reinterpret_cast(comm->flags_baseptr))); if (comm->use_mc) { + NVTE_CALL_CHECK_CUDA_DRIVER(cuMemUnmap, reinterpret_cast(comm->mc_baseptr), + comm->mc_maxsize); + NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressFree, comm->mc_baseptr, comm->mc_maxsize); NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->mc_handle); } delete comm; @@ -531,7 +561,8 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * CUmemFabricHandle myhndl; NVTE_CALL_CHECK_CUDA_DRIVER(cuMemExportToShareableHandle, &myhndl, comm->uchandles[hndl][myrank], CU_MEM_HANDLE_TYPE_FABRIC, 0); - NVTE_CHECK_CUDA(cudaMallocHost(&exphndl, comm->nvsize * sizeof(CUmemFabricHandle))); + NVTE_CHECK_CUDA(cudaMallocHost(reinterpret_cast(&exphndl), + comm->nvsize * sizeof(CUmemFabricHandle))); comm->_allgather(reinterpret_cast(exphndl), comm->nvsize * sizeof(CUmemFabricHandle), reinterpret_cast(&myhndl), sizeof(CUmemFabricHandle), comm->comm_intra); @@ -615,6 +646,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * aligned_size, (uint64_t)0); comm->memflags[hndl] |= NVTE_UB_MEM_MC_CREATED; comm->mc_ptr[hndl] = reinterpret_cast(comm->mc_baseptr) + comm->mc_offset; + comm->uc_offsets[hndl] = comm->mc_offset; comm->mc_offset += aligned_size; } else if (!comm->myrank) { printf("UB: warning region %d size %ld MB registered without MC access\n", hndl, diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 84defcdb2..03e45b978 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -107,6 +107,7 @@ struct communicator { CUmemGenericAllocationHandle *uchandles[NVTE_MAX_REGIONS]; void *ucbase_ptr[NVTE_MAX_REGIONS]; // only for cuMem allocated memory + size_t uc_offsets[NVTE_MAX_REGIONS]; size_t mem_size[NVTE_MAX_REGIONS]; bool mem_dealloc[NVTE_MAX_REGIONS]; @@ -125,7 +126,7 @@ struct communicator { // max value for running block counters in hostflags int basecounter[userbuffers_op_types]; // NOLINT(*) - int *flags, *map_flags; + int *flags_baseptr, *flags, *map_flags; void *mem_mr[NVTE_MAX_REGIONS]; diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 70ec47ba7..483444751 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -119,13 +119,20 @@ void checkCuDriverContext(CUstream stream) { #ifndef __HIP_PLATFORM_AMD__ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) { - static const std::unordered_map dtypeMapping = { - {DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}, - {DType::kFloat32, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32}, - {DType::kFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16}, - {DType::kBFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16}, - {DType::kFloat8E4M3, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}, - {DType::kFloat8E5M2, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}}; + static const std::unordered_map dtypeMapping = []() { + std::unordered_map typeMapping = { + {DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}, + {DType::kFloat32, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32}, + {DType::kFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16}, + {DType::kBFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16}, + {DType::kFloat8E4M3, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}, + {DType::kFloat8E5M2, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}}; +#if FP4_TYPE_SUPPORTED + typeMapping.insert( + {DType::kFloat4E2M1, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B}); +#endif + return typeMapping; + }(); return dtypeMapping.at(dtype); } @@ -133,18 +140,19 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) { void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, const uint32_t stride_elems, - const uint32_t offset_elems, const size_t type_size) { + const uint32_t offset_elems, const size_t type_num_bits) { // Get a function pointer to the cuTensorMapEncodeTiled driver API - static PFN_cuTensorMapEncodeTiled cuDriverTensorMapEncodeTiled = []() { + // Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13 + static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { void *driver_ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled"); - return reinterpret_cast(driver_ptr); + return reinterpret_cast(driver_ptr); }(); // rank is the number of dimensions of the array constexpr uint32_t rank = 2; uint64_t size[rank] = {globalX, globalY}; // The stride is the number of bytes to traverse from the first element of one row to the next - uint64_t stride[rank - 1] = {stride_elems * type_size}; + uint64_t stride[rank - 1] = {(stride_elems * type_num_bits) / 8}; // The boxSize is the size of the shared memory buffer that is used as the // source/destination of a TMA transfer @@ -154,15 +162,15 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, uint32_t elemStride[rank] = {1, 1}; const CUtensorMapDataType tensorDataType = get_CUtensorMapDataType(tensor.dtype); - void *dataPtr = - reinterpret_cast(reinterpret_cast(tensor.dptr) + offset_elems * type_size); + void *dataPtr = reinterpret_cast(reinterpret_cast(tensor.dptr) + + (offset_elems * type_num_bits) / 8); NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment), "Tensor data pointer must be 16B aligned"); - const int TMA_needed_size = TMA_gmem_alignment / type_size; - NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_size, - "-byte data type, expected multiple of ", TMA_needed_size, ", got ", globalX); + const int TMA_needed_size = (TMA_gmem_alignment * 8) / type_num_bits; + NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits, + "-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX); // Create the tensor descriptor. NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled( @@ -203,10 +211,24 @@ std::vector> convert_tensor_array(NVTETensor **nvte_tensor for (size_t i = 0; i < outer_size; ++i) { ret.emplace_back(); for (size_t j = 0; j < inner_size; ++j) { - ret.back().push_back(reinterpret_cast(nvte_tensors[i][j])); + ret.back().push_back(convertNVTETensor(nvte_tensors[i][j])); } } return ret; } +size_t get_buffer_size_bytes(const size_t elements_num, const DType buffer_dtype) { + return (elements_num * typeToNumBits(buffer_dtype)) / 8; +} + +size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last, + const DType buffer_dtype) { + if (buffer_dtype == DType::kFloat4E2M1) { + NVTE_CHECK(dim_last % 2 == 0, + "Last dimension of a tensor with FP4 type of data must be an even number!"); + } + const size_t elements_num = dim_first * dim_last; + return get_buffer_size_bytes(elements_num, buffer_dtype); +} + } // namespace transformer_engine diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 94652af44..39038724a 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -11,10 +11,18 @@ #ifndef __HIP_PLATFORM_AMD__ #include +#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080) +#else +#define FP4_TYPE_SUPPORTED false #endif //#ifndef __HIP_PLATFORM_AMD__ + #include #include #include +#if FP4_TYPE_SUPPORTED +#include +#endif + #include #include @@ -93,9 +101,16 @@ struct SimpleTensor { } return acc; } + + void clear() { + dptr = nullptr; + shape.resize(0); + dtype = DType::kFloat32; + } }; struct Tensor { + public: SimpleTensor data; SimpleTensor columnwise_data; SimpleTensor amax; @@ -103,8 +118,8 @@ struct Tensor { SimpleTensor scale_inv; SimpleTensor columnwise_scale_inv; - public: NVTEScalingMode scaling_mode; + NVTETensor nvte_tensor; Tensor() : data(), @@ -113,7 +128,20 @@ struct Tensor { scale(nullptr, {1}, DType::kFloat32), scale_inv(nullptr, {1}, DType::kFloat32), columnwise_scale_inv(nullptr, {1}, DType::kFloat32), - scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {} + scaling_mode(NVTE_DELAYED_TENSOR_SCALING), + nvte_tensor(0) {} + + void clear() { + data.clear(); + columnwise_data.clear(); + amax.clear(); + scale.clear(); + scale_inv.clear(); + columnwise_scale_inv.clear(); + scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + } + + explicit operator NVTETensor() const noexcept { return nvte_tensor; } size_t numel() const { size_t acc = 1; @@ -167,6 +195,7 @@ struct Tensor { } break; case NVTE_MXFP8_1D_SCALING: + case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING: if (!has_data() && has_columnwise_data()) { return columnwise_data.shape; } else { @@ -236,11 +265,14 @@ struct QuantizationConfig { bool force_pow_2_scales = false; float amax_epsilon = 0.0f; NVTETensor noop_tensor = nullptr; + Float8BlockScaleTensorFormat float8_block_scale_tensor_format = + Float8BlockScaleTensorFormat::GEMM_READY; static constexpr size_t attr_sizes[] = { - sizeof(bool), // force_pow_2_scales - sizeof(float), // amax_epsilon - sizeof(NVTETensor) // noop_tensor + sizeof(bool), // force_pow_2_scales + sizeof(float), // amax_epsilon + sizeof(NVTETensor), // noop_tensor + sizeof(Float8BlockScaleTensorFormat) // float8_block_scale_tensor_format }; }; @@ -249,6 +281,13 @@ constexpr T DIVUP(const T &x, const T &y) { return (((x) + ((y)-1)) / (y)); } +template +constexpr __device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(const T1 &N, const T2 &M) { + static_assert(std::is_integral::value && std::is_integral::value, + "Integral type required."); + return DIVUP(static_cast(N), static_cast(M)) * M; +} + using byte = uint8_t; using int16 = int16_t; using int32 = int32_t; @@ -262,11 +301,15 @@ using fp8e5m2 = __nv_fp8_e5m2; #if CUDA_VERSION >= 12080 using fp8e8m0 = __nv_fp8_e8m0; #endif // CUDA_VERSION >= 12080 +#if FP4_TYPE_SUPPORTED +using fp4e2m1 = __nv_fp4_e2m1; +#endif //FP4_TYPE_SUPPORTED #else using bf16 = hip_bfloat16; using fp8e4m3 = te_hip_fp8_e4m3; using fp8e5m2 = te_hip_fp8_e5m2; #endif //__HIP_PLATFORM_AMD__ + using e8m0_t = uint8_t; namespace detail { @@ -295,12 +338,22 @@ TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2) #if CUDA_VERSION >= 12080 TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0) #endif // CUDA_VERSION >= 12080 +#if FP4_TYPE_SUPPORTED +TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp4_e2m1) +#endif #endif // #ifdef __HIP_PLATFORM_AMD__ #undef TRANSFORMER_ENGINE_TYPE_NAME template struct TypeExtrema; +#if FP4_TYPE_SUPPORTED +template <> +struct TypeExtrema { + static constexpr float max = 6.0f; +}; +#endif + template <> struct TypeExtrema { #ifndef __HIP_PLATFORM_AMD__ @@ -336,9 +389,28 @@ struct TypeExtrema { } // namespace detail +template +struct BitsNumber; + +#if FP4_TYPE_SUPPORTED +template <> +struct BitsNumber { + static constexpr size_t num_bits = 4; +}; +#endif + +template +struct BitsNumber { + static constexpr size_t num_bits = 8 * sizeof(T); +}; + template struct TypeInfo { +#if FP4_TYPE_SUPPORTED + using types = std::tuple; +#else using types = std::tuple; +#endif template struct Helper { @@ -363,11 +435,21 @@ struct TypeInfo { } constexpr static DType dtype = getType(); - constexpr static size_t size = sizeof(T); + constexpr static size_t size = BitsNumber::num_bits; constexpr static float max_finite_value = detail::TypeExtrema::max; constexpr static const char *name = detail::type_name(); }; +#if FP4_TYPE_SUPPORTED +#define SWITCH_FP4_TYPE_HANDLE(type, ...) \ + case DType::kFloat4E2M1: { \ + using type = fp4e2m1; \ + { __VA_ARGS__ } \ + } break; +#else +#define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing +#endif + #define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ @@ -411,6 +493,7 @@ struct TypeInfo { using type = byte; \ { __VA_ARGS__ } \ } break; \ + SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ default: \ NVTE_ERROR("Invalid type."); \ } @@ -522,6 +605,9 @@ struct TypeInfo { case DType::kFloat8E4M3: { \ NVTE_ERROR("FP8 type not instantiated for input."); \ } break; \ + case DType::kFloat4E2M1: { \ + NVTE_ERROR("FP4 type not instantiated for input."); \ + } break; \ default: \ NVTE_ERROR("Invalid type."); \ } @@ -592,6 +678,14 @@ struct is_fp8 : std::true_type {}; template <> struct is_fp8 : std::true_type {}; +template +struct is_fp4 : std::false_type {}; + +#if FP4_TYPE_SUPPORTED +template <> +struct is_fp4 : std::true_type {}; +#endif + // [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_rowwise = 128; @@ -610,13 +704,16 @@ inline bool is_aligned_tensor_data(const Tensor &t, size_t alignment) { } size_t typeToSize(const DType type); +size_t typeToNumBits(const DType type); + +size_t get_buffer_size_bytes(const size_t N, const DType buffer_dtype); +size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last, + const DType buffer_dtype); void CheckNoopTensor(const Tensor &t, const std::string &name); void CheckInputTensor(const Tensor &t, const std::string &name); void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false); -bool is_fp8_dtype(const DType t); - /*! \brief Update a tensor's FP8 scale-inverse * * The FP8 scale-inverse (dequantization scaling factor) is updated @@ -636,7 +733,7 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype); void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, const uint32_t stride_elems, - const uint32_t offset_elems, const size_t type_size); + const uint32_t offset_elems, const size_t type_num_bits); #endif //#ifndef __HIP_PLATFORM_AMD__ bool is_supported_by_CC_100(); @@ -644,6 +741,8 @@ bool is_supported_by_CC_100(); std::vector> convert_tensor_array(NVTETensor **nvte_tensors, size_t outer_size, size_t inner_size); +Tensor *convertNVTETensor(const NVTETensor tensor); +Tensor *convertNVTETensorCheck(const NVTETensor tensor); } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_ diff --git a/transformer_engine/common/fused_attn/context_parallel.cu b/transformer_engine/common/fused_attn/context_parallel.cu index 721cdb230..15708d2d5 100644 --- a/transformer_engine/common/fused_attn/context_parallel.cu +++ b/transformer_engine/common/fused_attn/context_parallel.cu @@ -325,7 +325,7 @@ void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor int batch = cu_seqlens_shape[0] - 1; int num_heads = tensor_shape[seq_dim + 1]; int dim_per_head = tensor_shape[seq_dim + 2]; - int hidden_size_in_bytes = num_heads * dim_per_head * typeToSize(tensor.dtype()); + int hidden_size_in_bytes = (num_heads * dim_per_head * typeToNumBits(tensor.dtype())) / 8; // For 128-bits load/store NVTE_CHECK(hidden_size_in_bytes % 16 == 0); @@ -582,7 +582,7 @@ static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step, NVTE_CHECK(grad_per_step_shape[seq_dim + 2] == dim_per_head); size_t hidden_size = num_heads * dim_per_head; - NVTE_CHECK((hidden_size * typeToSize(grad.dtype())) % 16 == 0); + NVTE_CHECK(((hidden_size * typeToNumBits(grad.dtype())) / 8) % 16 == 0); constexpr unsigned int block = 256; unsigned int grid_x; @@ -677,9 +677,9 @@ void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu NVTE_API_CALL(nvte_thd_read_half_tensor); using namespace transformer_engine; - context_parallel::thd_read_half_tensor(*reinterpret_cast(tensor), - *reinterpret_cast(cu_seqlens), - *reinterpret_cast(half), half_idx, stream); + context_parallel::thd_read_half_tensor(*convertNVTETensorCheck(tensor), + *convertNVTETensorCheck(cu_seqlens), + *convertNVTETensorCheck(half), half_idx, stream); } void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step, @@ -689,8 +689,8 @@ void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &ls using namespace transformer_engine; context_parallel::thd_second_half_lse_correction( - *reinterpret_cast(lse), *reinterpret_cast(lse_per_step), - *reinterpret_cast(cu_seqlens), lse_packed, stream); + *convertNVTETensorCheck(lse), *convertNVTETensorCheck(lse_per_step), + *convertNVTETensorCheck(cu_seqlens), lse_packed, stream); } void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens, @@ -700,8 +700,8 @@ void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &c using namespace transformer_engine; context_parallel::thd_read_second_half_lse( - *reinterpret_cast(lse), *reinterpret_cast(cu_seqlens), - *reinterpret_cast(half_lse), lse_packed, second_half_lse_seqlen, stream); + *convertNVTETensorCheck(lse), *convertNVTETensorCheck(cu_seqlens), + *convertNVTETensorCheck(half_lse), lse_packed, second_half_lse_seqlen, stream); } void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step, @@ -712,9 +712,9 @@ void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step, using namespace transformer_engine; context_parallel::thd_out_correction( - *reinterpret_cast(out), *reinterpret_cast(out_per_step), - *reinterpret_cast(lse), *reinterpret_cast(lse_per_step), - *reinterpret_cast(cu_seqlens), only_second_half, lse_packed, stream); + *convertNVTETensorCheck(out), *convertNVTETensorCheck(out_per_step), + *convertNVTETensorCheck(lse), *convertNVTETensorCheck(lse_per_step), + *convertNVTETensorCheck(cu_seqlens), only_second_half, lse_packed, stream); } void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step, @@ -727,8 +727,8 @@ void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_ste std::string second_half_str(second_half); context_parallel::thd_grad_correction( - *reinterpret_cast(grad), *reinterpret_cast(grad_per_step), - *reinterpret_cast(cu_seqlens), first_half_str, second_half_str, stream); + *convertNVTETensorCheck(grad), *convertNVTETensorCheck(grad_per_step), + *convertNVTETensorCheck(cu_seqlens), first_half_str, second_half_str, stream); } void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output, @@ -737,7 +737,7 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso NVTE_API_CALL(nvte_thd_get_partitioned_indices); using namespace transformer_engine; - context_parallel::thd_get_partitioned_indices(*reinterpret_cast(cu_seqlens), - *reinterpret_cast(output), total_tokens, + context_parallel::thd_get_partitioned_indices(*convertNVTETensorCheck(cu_seqlens), + *convertNVTETensorCheck(output), total_tokens, world_size, rank, stream); } diff --git a/transformer_engine/common/fused_attn/flash_attn.cu b/transformer_engine/common/fused_attn/flash_attn.cu index 97074724b..0c261d0fa 100644 --- a/transformer_engine/common/fused_attn/flash_attn.cu +++ b/transformer_engine/common/fused_attn/flash_attn.cu @@ -138,8 +138,8 @@ void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t s NVTE_API_CALL(nvte_prepare_flash_attn_fwd); using namespace transformer_engine; - flash_attention::prepare_flash_attn_fwd(*reinterpret_cast(qkvi), - *reinterpret_cast(qkv), stream); + flash_attention::prepare_flash_attn_fwd(*convertNVTETensorCheck(qkvi), + *convertNVTETensorCheck(qkv), stream); } void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv, @@ -147,7 +147,7 @@ void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTET NVTE_API_CALL(nvte_prepare_flash_attn_bwd); using namespace transformer_engine; - flash_attention::prepare_flash_attn_bwd( - *reinterpret_cast(q), *reinterpret_cast(k), - *reinterpret_cast(v), *reinterpret_cast(qkv), stream); + flash_attention::prepare_flash_attn_bwd(*convertNVTETensorCheck(q), *convertNVTETensorCheck(k), + *convertNVTETensorCheck(v), *convertNVTETensorCheck(qkv), + stream); } diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 25340dd87..9d4701730 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -134,10 +134,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right) { + bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -216,12 +216,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } if ( // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging - // special conditions for blackwell - // TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7 - !(sm_arch_ >= 100 && (head_dim_qk > 128 || head_dim_v > 128)) && // architecture - ((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || - (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && + ((cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)) || + (cudnn_runtime_version >= 8903 && sm_arch_ >= 80 && sm_arch_ < 100) || + (cudnn_runtime_version >= 90700 && sm_arch_ >= 80)) && // sequence length ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) || (cudnn_runtime_version >= 90000)) && @@ -229,11 +227,32 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || (cudnn_runtime_version >= 8907)) && // head dimension - ((head_dim_qk <= 128 && head_dim_qk % 8 == 0 && head_dim_v <= 128 && head_dim_v % 8 == 0) || - // TODO (cyang): add is_training to nvte_get_fused_attn_backend - // d=256 only supported for forward - (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim_qk <= 256 && - head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) && + // multiples of 8 + (head_dim_qk % 8 == 0 && head_dim_v % 8 == 0 && + // <= 128 + ((head_dim_qk <= 128 && head_dim_v <= 128) || + // 9.1: <= 256 + Hopper + fprop + // 9.5: <= 256 + Hopper + bprop + (head_dim_qk <= 256 && head_dim_v <= 256 && + ((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) || + (is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) || + // 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1 + (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 && + layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) || + // 9.10: any head_dim + any arch + fprop + paged + // 9.10: any head_dim + any arch + fprop + non_paged + sq > 1 + // 9.10: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM} + (!is_training && cudnn_runtime_version >= 91000 && + (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 || + (max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) || + // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged + (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && + cudnn_runtime_version >= 91100)) && + // 9.11 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA + (!(cudnn_runtime_version == 91100 && is_training && sm_arch_ == 90 && head_dim_qk >= 128 && + head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) && + head_dim_qk != head_dim_v))) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || (cudnn_runtime_version >= 8906 && @@ -392,14 +411,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); using namespace transformer_engine; - const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens); - const Tensor *input_cu_seqlens_padded = reinterpret_cast(cu_seqlens_padded); - const Tensor *input_rng_state = reinterpret_cast(rng_state); - const Tensor *input_QKV = reinterpret_cast(QKV); - const Tensor *input_Bias = reinterpret_cast(Bias); - Tensor *input_output_S = reinterpret_cast(S); - Tensor *output_O = reinterpret_cast(O); - Tensor *wkspace = reinterpret_cast(workspace); + const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); + const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); + const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); + const Tensor *input_QKV = convertNVTETensorCheck(QKV); + const Tensor *input_Bias = convertNVTETensorCheck(Bias); + Tensor *input_output_S = convertNVTETensorCheck(S); + Tensor *output_O = convertNVTETensorCheck(O); + Tensor *wkspace = convertNVTETensor(workspace); auto ndim = input_QKV->data.shape.size(); size_t b = input_cu_seqlens->data.shape[0] - 1; @@ -423,8 +442,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, d, window_size_left, window_size_right); + is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -472,16 +491,16 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; - const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens); - const Tensor *input_cu_seqlens_padded = reinterpret_cast(cu_seqlens_padded); - const Tensor *input_QKV = reinterpret_cast(QKV); - const Tensor *input_O = reinterpret_cast(O); - const Tensor *input_dO = reinterpret_cast(dO); - const Tensor *input_S = reinterpret_cast(S); - Tensor *input_output_dP = reinterpret_cast(dP); - Tensor *output_dQKV = reinterpret_cast(dQKV); - Tensor *output_dBias = reinterpret_cast(dBias); - Tensor *wkspace = reinterpret_cast(workspace); + const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); + const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); + const Tensor *input_QKV = convertNVTETensorCheck(QKV); + const Tensor *input_O = convertNVTETensorCheck(O); + const Tensor *input_dO = convertNVTETensorCheck(dO); + const Tensor *input_S = convertNVTETensorCheck(S); + Tensor *input_output_dP = convertNVTETensorCheck(dP); + Tensor *output_dQKV = convertNVTETensorCheck(dQKV); + Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *wkspace = convertNVTETensor(workspace); auto ndim = input_QKV->data.shape.size(); size_t b = input_cu_seqlens->data.shape[0] - 1; @@ -505,12 +524,12 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, + true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); fused_attn_max_512_bwd_qkvpacked( b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV, input_dO, output_S, output_dQKV, output_dBias, input_cu_seqlens, wkspace, stream, handle); @@ -519,13 +538,13 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); Tensor *input_Bias, *input_rng_state; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - input_Bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); } else { - input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); } fused_attn_arbitrary_seqlen_bwd_qkvpacked( b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, @@ -540,9 +559,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); fused_attn_fp8_bwd_qkvpacked(b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQKV, input_cu_seqlens, @@ -566,19 +585,19 @@ void nvte_fused_attn_fwd_kvpacked( cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = reinterpret_cast(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast(cu_seqlens_kv_padded); - const Tensor *input_page_table_k = reinterpret_cast(page_table_k); - const Tensor *input_page_table_v = reinterpret_cast(page_table_v); - const Tensor *input_rng_state = reinterpret_cast(rng_state); - const Tensor *input_Q = reinterpret_cast(Q); - const Tensor *input_KV = reinterpret_cast(KV); - const Tensor *input_Bias = reinterpret_cast(Bias); - Tensor *input_output_S = reinterpret_cast(S); - Tensor *output_O = reinterpret_cast(O); - Tensor *wkspace = reinterpret_cast(workspace); + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k); + const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v); + const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_KV = convertNVTETensorCheck(KV); + const Tensor *input_Bias = convertNVTETensorCheck(Bias); + Tensor *input_output_S = convertNVTETensorCheck(S); + Tensor *output_O = convertNVTETensorCheck(O); + Tensor *wkspace = convertNVTETensor(workspace); size_t b = input_cu_seqlens_q->data.shape[0] - 1; auto ndim = input_Q->data.shape.size(); @@ -636,8 +655,8 @@ void nvte_fused_attn_fwd_kvpacked( const NVTEDType KV_type = static_cast(input_KV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, d, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -686,20 +705,20 @@ void nvte_fused_attn_bwd_kvpacked( cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = reinterpret_cast(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast(cu_seqlens_kv_padded); - const Tensor *input_Q = reinterpret_cast(Q); - const Tensor *input_KV = reinterpret_cast(KV); - const Tensor *input_O = reinterpret_cast(O); - const Tensor *input_dO = reinterpret_cast(dO); - const Tensor *input_S = reinterpret_cast(S); - Tensor *input_output_dP = reinterpret_cast(dP); - Tensor *output_dQ = reinterpret_cast(dQ); - Tensor *output_dKV = reinterpret_cast(dKV); - Tensor *output_dBias = reinterpret_cast(dBias); - Tensor *wkspace = reinterpret_cast(workspace); + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_KV = convertNVTETensorCheck(KV); + const Tensor *input_O = convertNVTETensorCheck(O); + const Tensor *input_dO = convertNVTETensorCheck(dO); + const Tensor *input_S = convertNVTETensorCheck(S); + Tensor *input_output_dP = convertNVTETensorCheck(dP); + Tensor *output_dQ = convertNVTETensorCheck(dQ); + Tensor *output_dKV = convertNVTETensorCheck(dKV); + Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *wkspace = convertNVTETensor(workspace); size_t b = input_cu_seqlens_q->data.shape[0] - 1; auto ndim = input_Q->data.shape.size(); @@ -731,12 +750,12 @@ void nvte_fused_attn_bwd_kvpacked( const NVTEDType KV_type = static_cast(input_KV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, d, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); fused_attn_max_512_bwd_kvpacked( b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_dO, output_S, output_dQ, output_dKV, output_dBias, @@ -746,13 +765,13 @@ void nvte_fused_attn_bwd_kvpacked( #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); Tensor *input_Bias, *input_rng_state; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - input_Bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); } else { - input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); } fused_attn_arbitrary_seqlen_bwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout, @@ -768,9 +787,9 @@ void nvte_fused_attn_bwd_kvpacked( #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); fused_attn_fp8_bwd_kvpacked(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, @@ -797,20 +816,20 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = reinterpret_cast(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast(cu_seqlens_kv_padded); - const Tensor *input_page_table_k = reinterpret_cast(page_table_k); - const Tensor *input_page_table_v = reinterpret_cast(page_table_v); - const Tensor *input_rng_state = reinterpret_cast(rng_state); - const Tensor *input_Q = reinterpret_cast(Q); - const Tensor *input_K = reinterpret_cast(K); - const Tensor *input_V = reinterpret_cast(V); - const Tensor *input_Bias = reinterpret_cast(Bias); - Tensor *input_output_S = reinterpret_cast(S); - Tensor *output_O = reinterpret_cast(O); - Tensor *wkspace = reinterpret_cast(workspace); + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k); + const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v); + const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_K = convertNVTETensorCheck(K); + const Tensor *input_V = convertNVTETensorCheck(V); + const Tensor *input_Bias = convertNVTETensorCheck(Bias); + Tensor *input_output_S = convertNVTETensorCheck(S); + Tensor *output_O = convertNVTETensorCheck(O); + Tensor *wkspace = convertNVTETensor(workspace); auto ndim = input_Q->data.shape.size(); auto ndim_kv = input_K->data.shape.size(); @@ -862,8 +881,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEDType KV_type = static_cast(input_K->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -914,22 +933,22 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = reinterpret_cast(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast(cu_seqlens_kv_padded); - const Tensor *input_Q = reinterpret_cast(Q); - const Tensor *input_K = reinterpret_cast(K); - const Tensor *input_V = reinterpret_cast(V); - const Tensor *input_O = reinterpret_cast(O); - const Tensor *input_dO = reinterpret_cast(dO); - const Tensor *input_S = reinterpret_cast(S); - Tensor *input_output_dP = reinterpret_cast(dP); - Tensor *output_dQ = reinterpret_cast(dQ); - Tensor *output_dK = reinterpret_cast(dK); - Tensor *output_dV = reinterpret_cast(dV); - Tensor *output_dBias = reinterpret_cast(dBias); - Tensor *wkspace = reinterpret_cast(workspace); + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_K = convertNVTETensorCheck(K); + const Tensor *input_V = convertNVTETensorCheck(V); + const Tensor *input_O = convertNVTETensorCheck(O); + const Tensor *input_dO = convertNVTETensorCheck(dO); + const Tensor *input_S = convertNVTETensorCheck(S); + Tensor *input_output_dP = convertNVTETensorCheck(dP); + Tensor *output_dQ = convertNVTETensorCheck(dQ); + Tensor *output_dK = convertNVTETensorCheck(dK); + Tensor *output_dV = convertNVTETensorCheck(dV); + Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *wkspace = convertNVTETensor(workspace); auto ndim = input_Q->data.shape.size(); auto ndim_kv = input_K->data.shape.size(); @@ -954,12 +973,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEDType KV_type = static_cast(input_K->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias, @@ -969,13 +988,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); Tensor *input_Bias, *input_rng_state; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - input_Bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); } else { - input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); } fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, @@ -991,9 +1010,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 2ce93f196..0932b2cf8 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -377,7 +377,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t)); const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0; const size_t num_bytes_per_ragged_offset = - alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); + alignTo<16>(((b + 1) * typeToNumBits(ragged_offset_type)) / 8); size_t seqlen_offsets_workspace_size = 0; if (is_ragged_q || is_ragged_kv) { size_t count = 2 * (static_cast(is_ragged_q) + static_cast(is_ragged_kv)); @@ -831,7 +831,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t)); const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0; const size_t num_bytes_per_ragged_offset = - alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); + alignTo<16>(((b + 1) * typeToNumBits(ragged_offset_type)) / 8); size_t seqlen_offsets_workspace_size = 0; if (is_ragged_q || is_ragged_kv) { size_t count = 2 * (static_cast(is_ragged_q) + static_cast(is_ragged_kv)); @@ -957,9 +957,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = typeToSize(QKV_type) * num_attn_heads * head_dim; + stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = typeToSize(QKV_type) * head_dim; + stride = (typeToNumBits(QKV_type) * head_dim) / 8; } void *devPtrQ = static_cast(devPtrQKV); void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); @@ -990,7 +990,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( const auto cudnn_runtime_version = cudnnGetVersion(); if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { output_S->data.shape = {max_tokens, num_attn_heads, 1}; @@ -998,17 +998,17 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; } output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); output_bias->data.dptr = nullptr; output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen}; output_bias->data.dtype = QKV_type; } else { Aux_CTX_Tensors->size = 2; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { output_S->data.shape = {max_tokens, num_attn_heads, 1}; @@ -1016,22 +1016,22 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; } output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; } } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = rng_state->data.dptr; } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); output_bias->data.dptr = devPtrBias; } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); @@ -1082,9 +1082,9 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = typeToSize(QKV_type) * num_attn_heads * head_dim; + stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = typeToSize(QKV_type) * head_dim; + stride = (typeToNumBits(QKV_type) * head_dim) / 8; } void *devPtrQ = devPtrQKV; void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); @@ -1173,9 +1173,9 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; + stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = typeToSize(QKV_type) * head_dim; + stride = (typeToNumBits(QKV_type) * head_dim) / 8; } void *devPtrK = devPtrKV; void *devPtrV = static_cast(static_cast(devPtrKV) + stride); @@ -1216,7 +1216,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( const auto cudnn_runtime_version = cudnnGetVersion(); if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; @@ -1224,17 +1224,17 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; } output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); output_bias->data.dptr = nullptr; output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; output_bias->data.dtype = QKV_type; } else { Aux_CTX_Tensors->size = 2; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; @@ -1242,22 +1242,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; } output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; } } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = rng_state->data.dptr; } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); output_bias->data.dptr = devPtrBias; } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); @@ -1313,9 +1313,9 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; + stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = typeToSize(QKV_type) * head_dim; + stride = (typeToNumBits(QKV_type) * head_dim) / 8; } void *devPtrK = devPtrKV; void *devPtrV = static_cast(static_cast(devPtrKV) + stride); @@ -1446,7 +1446,7 @@ void fused_attn_arbitrary_seqlen_fwd( const auto cudnn_runtime_version = cudnnGetVersion(); if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; @@ -1454,17 +1454,17 @@ void fused_attn_arbitrary_seqlen_fwd( output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; } output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); output_bias->data.dptr = nullptr; output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; output_bias->data.dtype = QKV_type; } else { Aux_CTX_Tensors->size = 2; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; @@ -1472,22 +1472,22 @@ void fused_attn_arbitrary_seqlen_fwd( output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; } output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; } } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = rng_state->data.dptr; } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); output_bias->data.dptr = devPtrBias; } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu index 08e0642b2..89528fa3c 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu @@ -1239,12 +1239,12 @@ void fused_attn_max_512_fwd_qkvpacked( if (Aux_CTX_Tensors->size == 0) { Aux_CTX_Tensors->size = 1; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen}; output_S->data.dtype = input_QKV->data.dtype; } else if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); @@ -1317,12 +1317,12 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max if (Aux_CTX_Tensors->size == 0) { Aux_CTX_Tensors->size = 1; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; output_S->data.dtype = q_type; } else if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); @@ -1386,12 +1386,12 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, if (Aux_CTX_Tensors->size == 0) { Aux_CTX_Tensors->size = 1; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; output_S->data.dtype = q_type; } else if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index eacd8b53b..3e38a5066 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2364,9 +2364,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = typeToSize(QKV_type) * num_attn_heads * head_dim; + stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = typeToSize(QKV_type) * head_dim; + stride = (typeToNumBits(QKV_type) * head_dim) / 8; } void* devPtrQ = static_cast(devPtrQKV); void* devPtrK = static_cast(static_cast(devPtrQKV) + stride); @@ -2383,9 +2383,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { Aux_CTX_Tensors->size = 3; - Tensor* output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); output_M->data.dptr = nullptr; output_M->data.shape = {batch, num_attn_heads, max_seqlen, 1}; output_M->data.dtype = DType::kFloat32; @@ -2396,9 +2396,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; } else if (Aux_CTX_Tensors->size == 3) { - Tensor* output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); devPtrM = output_M->data.dptr; devPtrZInv = output_ZInv->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr; @@ -2466,9 +2466,9 @@ void fused_attn_fp8_bwd_qkvpacked( NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = typeToSize(QKV_type) * num_attn_heads * head_dim; + stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = typeToSize(QKV_type) * head_dim; + stride = (typeToNumBits(QKV_type) * head_dim) / 8; } void* devPtrQ = devPtrQKV; void* devPtrK = static_cast(static_cast(devPtrQKV) + stride); @@ -2564,9 +2564,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; + stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = typeToSize(QKV_type) * head_dim; + stride = (typeToNumBits(QKV_type) * head_dim) / 8; } void* devPtrK = devPtrKV; void* devPtrV = static_cast(static_cast(devPtrKV) + stride); @@ -2582,9 +2582,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { Aux_CTX_Tensors->size = 3; - Tensor* output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); output_M->data.dptr = nullptr; output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_M->data.dtype = DType::kFloat32; @@ -2595,9 +2595,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; } else if (Aux_CTX_Tensors->size == 3) { - Tensor* output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); devPtrM = output_M->data.dptr; devPtrZInv = output_ZInv->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr; @@ -2671,9 +2671,9 @@ void fused_attn_fp8_bwd_kvpacked( NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; + stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = typeToSize(QKV_type) * head_dim; + stride = (typeToNumBits(QKV_type) * head_dim) / 8; } void* devPtrK = devPtrKV; void* devPtrV = static_cast(static_cast(devPtrKV) + stride); @@ -2779,9 +2779,9 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { Aux_CTX_Tensors->size = 3; - Tensor* output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); output_M->data.dptr = nullptr; output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_M->data.dtype = DType::kFloat32; @@ -2792,9 +2792,9 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; } else if (Aux_CTX_Tensors->size == 3) { - Tensor* output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); devPtrM = output_M->data.dptr; devPtrZInv = output_ZInv->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr; diff --git a/transformer_engine/common/fused_attn/kv_cache.cu b/transformer_engine/common/fused_attn/kv_cache.cu index 0ad5ab01b..9bdc41e9e 100644 --- a/transformer_engine/common/fused_attn/kv_cache.cu +++ b/transformer_engine/common/fused_attn/kv_cache.cu @@ -10,6 +10,8 @@ namespace transformer_engine { namespace kv_cache { +constexpr int block_size = 1024; + template __global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *batch_indices, int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k, @@ -22,21 +24,29 @@ __global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *bat actual_b = i + 1; } } + bool flag = (batch_indices[0] != 0); for (int batch_idx = 0; batch_idx < actual_b; batch_idx++) { - int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; - int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; - for (int token_idx = blockIdx.x; token_idx < cached_len - new_len; token_idx += gridDim.x) { + if (flag || ((batch_indices[batch_idx] - batch_indices[0]) != batch_idx)) { + int num_tokens = (cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]) - + (cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]); int num_elts_k = h_kv * d_k; int num_elts_v = h_kv * d_v; - int k_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_k; - int k_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_k; - int v_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_v; - int v_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_v; - for (int i = threadIdx.x; i < num_elts_k; i += blockDim.x) { - *(k_cache + k_cache_des_offset + i) = *(k_cache + k_cache_src_offset + i); - } - for (int i = threadIdx.x; i < num_elts_v; i += blockDim.x) { - *(v_cache + v_cache_des_offset + i) = *(v_cache + v_cache_src_offset + i); + int num_elts = max(num_elts_k, num_elts_v); + for (int token_idx = blockIdx.x; token_idx < num_tokens; token_idx += gridDim.x) { + int src_offset = batch_indices[batch_idx] * max_seq_len + token_idx; + int des_offset = batch_idx * max_seq_len + token_idx; + dtype *k_cache_src_offset = k_cache + src_offset * num_elts_k; + dtype *k_cache_des_offset = k_cache + des_offset * num_elts_k; + dtype *v_cache_src_offset = v_cache + src_offset * num_elts_v; + dtype *v_cache_des_offset = v_cache + des_offset * num_elts_v; + for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { + if (i < num_elts_k) { + *(k_cache_des_offset + i) = *(k_cache_src_offset + i); + } + if (i < num_elts_v) { + *(v_cache_des_offset + i) = *(v_cache_src_offset + i); + } + } } } } @@ -55,19 +65,26 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq; - int new_token_offset = batch_idx * max_ctx_len; int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; - for (int i = threadIdx.x; i < new_len; i += blockDim.x) { + int num_elts_k = h_kv * d_k; + int num_elts_v = h_kv * d_v; + int hd = h_kv * max(d_k, d_v); + for (int i = blockIdx.y; i < new_len; i += gridDim.y) { int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size]; - int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; - for (int j = 0; j < h_kv * d_k; j++) { - *(k_cache + token_idx * h_kv * d_k + j) = - *(new_k + (new_token_offset + i) * h_kv * d_k + j); - } - for (int j = 0; j < h_kv * d_v; j++) { - *(v_cache + token_idx * h_kv * d_v + j) = - *(new_v + (new_token_offset + i) * h_kv * d_v + j); + dtype *new_token_id_k = new_k + (batch_idx * max_ctx_len + i) * num_elts_k; + dtype *new_token_id_v = new_v + (batch_idx * max_ctx_len + i) * num_elts_v; + dtype *token_id_k = + k_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_k; + dtype *token_id_v = + v_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_v; + for (int j = threadIdx.x; j < hd; j += blockDim.x) { + if (j < num_elts_k) { + *(token_id_k + j) = *(new_token_id_k + j); + } + if (j < num_elts_v) { + *(token_id_v + j) = *(new_token_id_v + j); + } } } } @@ -76,14 +93,24 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq; int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; - for (int i = threadIdx.x; i < new_len; i += blockDim.x) { + int num_elts_k = h_kv * d_k; + int num_elts_v = h_kv * d_v; + int hd = h_kv * max(d_k, d_v); + for (int i = blockIdx.y; i < new_len; i += gridDim.y) { int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size]; - int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; - for (int j = 0; j < h_kv * d_k; j++) { - *(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k + j); - } - for (int j = 0; j < h_kv * d_v; j++) { - *(v_cache + token_idx * h_kv * d_v + j) = *(new_v + (i * b + batch_idx) * h_kv * d_v + j); + dtype *new_token_id_k = new_k + (i * b + batch_idx) * num_elts_k; + dtype *new_token_id_v = new_v + (i * b + batch_idx) * num_elts_v; + dtype *token_id_k = + k_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_k; + dtype *token_id_v = + v_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_v; + for (int j = threadIdx.x; j < hd; j += blockDim.x) { + if (j < num_elts_k) { + *(token_id_k + j) = *(new_token_id_k + j); + } + if (j < num_elts_v) { + *(token_id_v + j) = *(new_token_id_v + j); + } } } } @@ -92,16 +119,24 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq; int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; - for (int i = threadIdx.x; i < new_len; i += blockDim.x) { + int num_elts_k = h_kv * d_k; + int num_elts_v = h_kv * d_v; + int hd = h_kv * max(d_k, d_v); + for (int i = blockIdx.y; i < new_len; i += gridDim.y) { int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size]; - int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; - for (int j = 0; j < h_kv * d_k; j++) { - *(k_cache + token_idx * h_kv * d_k + j) = - *(new_k + (cu_new_lens[batch_idx] + i) * h_kv * d_k + j); - } - for (int j = 0; j < h_kv * d_v; j++) { - *(v_cache + token_idx * h_kv * d_v + j) = - *(new_v + (cu_new_lens[batch_idx] + i) * h_kv * d_v + j); + dtype *new_token_id_k = new_k + (cu_new_lens[batch_idx] + i) * num_elts_k; + dtype *new_token_id_v = new_v + (cu_new_lens[batch_idx] + i) * num_elts_v; + dtype *token_id_k = + k_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_k; + dtype *token_id_v = + v_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_v; + for (int j = threadIdx.x; j < hd; j += blockDim.x) { + if (j < num_elts_k) { + *(token_id_k + j) = *(new_token_id_k + j); + } + if (j < num_elts_v) { + *(token_id_v + j) = *(new_token_id_v + j); + } } } } @@ -116,14 +151,15 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso bool is_non_paged, cudaStream_t stream) { if (new_k.has_data() && new_v.has_data() && k_cache.has_data() && v_cache.has_data()) { if (is_non_paged) { - reindex_kv_cache_kernel<<<16, 256, 0, stream>>>( + reindex_kv_cache_kernel<<>>( reinterpret_cast(k_cache.data.dptr), reinterpret_cast(v_cache.data.dptr), reinterpret_cast(page_table.data.dptr), reinterpret_cast(cu_new_lens.data.dptr), reinterpret_cast(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len); } - copy_to_kv_cache_kernel<<<16, 256, 0, stream>>>( + dim3 grid_size(b, max_ctx_len); + copy_to_kv_cache_kernel<<>>( reinterpret_cast(new_k.data.dptr), reinterpret_cast(new_v.data.dptr), reinterpret_cast(k_cache.data.dptr), reinterpret_cast(v_cache.data.dptr), reinterpret_cast(page_table.data.dptr), @@ -260,12 +296,12 @@ void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cach NVTE_API_CALL(nvte_copy_to_kv_cache); using namespace transformer_engine; - kv_cache::copy_to_kv_cache( - *reinterpret_cast(new_k), *reinterpret_cast(new_v), - *reinterpret_cast(k_cache), *reinterpret_cast(v_cache), - *reinterpret_cast(page_table), *reinterpret_cast(cu_new_lens), - *reinterpret_cast(cu_cached_lens), qkv_format, b, max_ctx_len, max_seq_len, - max_pages_per_seq, is_non_paged, stream); + kv_cache::copy_to_kv_cache(*convertNVTETensorCheck(new_k), *convertNVTETensorCheck(new_v), + *convertNVTETensorCheck(k_cache), *convertNVTETensorCheck(v_cache), + *convertNVTETensorCheck(page_table), + *convertNVTETensorCheck(cu_new_lens), + *convertNVTETensorCheck(cu_cached_lens), qkv_format, b, max_ctx_len, + max_seq_len, max_pages_per_seq, is_non_paged, stream); } /*************************************************************************************************** @@ -277,9 +313,9 @@ void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens NVTE_API_CALL(nvte_convert_thd_to_bshd); using namespace transformer_engine; - kv_cache::convert_thd_to_bshd(*reinterpret_cast(tensor), - *reinterpret_cast(cu_seqlens), - *reinterpret_cast(new_tensor), b, max_seq_len, stream); + kv_cache::convert_thd_to_bshd(*convertNVTETensorCheck(tensor), + *convertNVTETensorCheck(cu_seqlens), + *convertNVTETensorCheck(new_tensor), b, max_seq_len, stream); } /*************************************************************************************************** @@ -291,7 +327,7 @@ void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens NVTE_API_CALL(nvte_convert_bshd_to_thd); using namespace transformer_engine; - kv_cache::convert_bshd_to_thd(*reinterpret_cast(tensor), - *reinterpret_cast(cu_seqlens), - *reinterpret_cast(new_tensor), t, stream); + kv_cache::convert_bshd_to_thd(*convertNVTETensorCheck(tensor), + *convertNVTETensorCheck(cu_seqlens), + *convertNVTETensorCheck(new_tensor), t, stream); } diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index 66fa72c0c..94c9e4993 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -275,10 +275,10 @@ void log_fused_attn_config( // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right) { + bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { using namespace transformer_engine; // by default, fused attn is enabled @@ -350,14 +350,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); using namespace transformer_engine; - const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens); - const Tensor *input_cu_seqlens_padded = reinterpret_cast(cu_seqlens_padded); - const Tensor *input_rng_state = reinterpret_cast(rng_state); - const Tensor *input_QKV = reinterpret_cast(QKV); - const Tensor *input_Bias = reinterpret_cast(Bias); - Tensor *input_output_S = reinterpret_cast(S); - Tensor *output_O = reinterpret_cast(O); - Tensor *wkspace = reinterpret_cast(workspace); + const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); + const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); + const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); + const Tensor *input_QKV = convertNVTETensorCheck(QKV); + const Tensor *input_Bias = convertNVTETensorCheck(Bias); + Tensor *input_output_S = convertNVTETensorCheck(S); + Tensor *output_O = convertNVTETensorCheck(O); + Tensor *wkspace = convertNVTETensorCheck(workspace); auto ndim = input_QKV->data.shape.size(); size_t b = input_cu_seqlens->data.shape[0] - 1; @@ -384,8 +384,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, d, window_size_left, window_size_right); + is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd_qkvpacked( @@ -426,19 +426,19 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; - const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens); - const Tensor *input_cu_seqlens_padded = reinterpret_cast(cu_seqlens_padded); - const Tensor *input_QKV = reinterpret_cast(QKV); - const Tensor *input_O = reinterpret_cast(O); - const Tensor *input_dO = reinterpret_cast(dO); - Tensor *output_dQKV = reinterpret_cast(dQKV); - Tensor *output_dBias = reinterpret_cast(dBias); - Tensor *wkspace = reinterpret_cast(workspace); + const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); + const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); + const Tensor *input_QKV = convertNVTETensorCheck(QKV); + const Tensor *input_O = convertNVTETensorCheck(O); + const Tensor *input_dO = convertNVTETensorCheck(dO); + Tensor *output_dQKV = convertNVTETensorCheck(dQKV); + Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *wkspace = convertNVTETensorCheck(workspace); // auxiliary tensors - const Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); //softmax lse + const Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); //softmax lse //extract the saved rng state from aux_ctx_tensor - const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); Tensor *input_Bias = nullptr; auto ndim = input_QKV->data.shape.size(); @@ -466,12 +466,12 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, + true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)){ - input_Bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); } fused_attn_ck_bwd_qkvpacked( b, h, max_seqlen, d, @@ -513,18 +513,18 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = reinterpret_cast(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast(cu_seqlens_kv_padded); - const Tensor *input_page_table_k = reinterpret_cast(page_table_k); - const Tensor *input_page_table_v = reinterpret_cast(page_table_v); - const Tensor *input_rng_state = reinterpret_cast(rng_state); - const Tensor *input_Q = reinterpret_cast(Q); - const Tensor *input_KV = reinterpret_cast(KV); - const Tensor *input_Bias = reinterpret_cast(Bias); - Tensor *output_O = reinterpret_cast(O); - Tensor *wkspace = reinterpret_cast(workspace); + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k); + const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v); + const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_KV = convertNVTETensorCheck(KV); + const Tensor *input_Bias = convertNVTETensorCheck(Bias); + Tensor *output_O = convertNVTETensorCheck(O); + Tensor *wkspace = convertNVTETensorCheck(workspace); size_t b = input_cu_seqlens_q->data.shape[0] - 1; auto ndim = input_Q->data.shape.size(); @@ -554,8 +554,8 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, d, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd_kvpacked( @@ -601,22 +601,22 @@ void nvte_fused_attn_bwd_kvpacked( cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = reinterpret_cast(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast(cu_seqlens_kv_padded); - const Tensor *input_Q = reinterpret_cast(Q); - const Tensor *input_KV = reinterpret_cast(KV); - const Tensor *input_O = reinterpret_cast(O); - const Tensor *input_dO = reinterpret_cast(dO); - Tensor *output_dQ = reinterpret_cast(dQ); - Tensor *output_dKV = reinterpret_cast(dKV); - Tensor *wkspace = reinterpret_cast(workspace); - Tensor *output_dBias = reinterpret_cast(dBias); + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_KV = convertNVTETensorCheck(KV); + const Tensor *input_O = convertNVTETensorCheck(O); + const Tensor *input_dO = convertNVTETensorCheck(dO); + Tensor *output_dQ = convertNVTETensorCheck(dQ); + Tensor *output_dKV = convertNVTETensorCheck(dKV); + Tensor *wkspace = convertNVTETensorCheck(workspace); + Tensor *output_dBias = convertNVTETensorCheck(dBias); // auxiliary tensors (to be propagated to the backward pass later) - const Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); //softmax lse - const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + const Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); //softmax lse + const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); Tensor *input_Bias = nullptr; size_t b = input_cu_seqlens_q->data.shape[0] - 1; @@ -647,12 +647,12 @@ void nvte_fused_attn_bwd_kvpacked( std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_Bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); } fused_attn_ck_bwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, @@ -703,19 +703,19 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = reinterpret_cast(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast(cu_seqlens_kv_padded); - const Tensor *input_page_table_k = reinterpret_cast(page_table_k); - const Tensor *input_page_table_v = reinterpret_cast(page_table_v); - const Tensor *input_rng_state = reinterpret_cast(rng_state); - const Tensor *input_Q = reinterpret_cast(Q); - const Tensor *input_K = reinterpret_cast(K); - const Tensor *input_V = reinterpret_cast(V); - const Tensor *input_Bias = reinterpret_cast(Bias); - Tensor *output_O = reinterpret_cast(O); - Tensor *wkspace = reinterpret_cast(workspace); + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k); + const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v); + const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_K = convertNVTETensorCheck(K); + const Tensor *input_V = convertNVTETensorCheck(V); + const Tensor *input_Bias = convertNVTETensorCheck(Bias); + Tensor *output_O = convertNVTETensorCheck(O); + Tensor *wkspace = convertNVTETensorCheck(workspace); auto ndim = input_Q->data.shape.size(); size_t b = input_cu_seqlens_q->data.shape[0] - 1; @@ -737,8 +737,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd( @@ -786,24 +786,24 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = reinterpret_cast(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast(cu_seqlens_kv_padded); - const Tensor *input_Q = reinterpret_cast(Q); - const Tensor *input_K = reinterpret_cast(K); - const Tensor *input_V = reinterpret_cast(V); - const Tensor *input_O = reinterpret_cast(O); - const Tensor *input_dO = reinterpret_cast(dO); - - Tensor *output_dQ = reinterpret_cast(dQ); - Tensor *output_dK = reinterpret_cast(dK); - Tensor *output_dV = reinterpret_cast(dV); - Tensor *output_dBias = reinterpret_cast(dBias); - Tensor *wkspace = reinterpret_cast(workspace); - - const Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); //softmax lse - const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_K = convertNVTETensorCheck(K); + const Tensor *input_V = convertNVTETensorCheck(V); + const Tensor *input_O = convertNVTETensorCheck(O); + const Tensor *input_dO = convertNVTETensorCheck(dO); + + Tensor *output_dQ = convertNVTETensorCheck(dQ); + Tensor *output_dK = convertNVTETensorCheck(dK); + Tensor *output_dV = convertNVTETensorCheck(dV); + Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *wkspace = convertNVTETensorCheck(workspace); + + const Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); //softmax lse + const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); Tensor *input_Bias = nullptr; auto ndim = input_Q->data.shape.size(); @@ -826,12 +826,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_Bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); } fused_attn_ck_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, @@ -874,10 +874,10 @@ uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspac return transformer_engine::fused_attn_rocm::GetRuntimeNumSegments(cu_seqlen, workspace, len, stream); } -void nvte_populate_rng_state_async(void *rng_state_dst, const void *const seed, - size_t batch_size, size_t num_heads, size_t q_max_seqlen, - size_t kv_max_seqlen, cudaStream_t stream) { +void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor seed, + size_t q_max_seqlen, size_t kv_max_seqlen, + NVTE_Fused_Attn_Backend backend, cudaStream_t stream) { NVTE_API_CALL(nvte_populate_rng_state_async); - transformer_engine::fused_attn_rocm::PopulateRngStateAsync(rng_state_dst, seed, batch_size, - num_heads, q_max_seqlen, kv_max_seqlen, stream); + using namespace transformer_engine::fused_attn_rocm; + PopulateRngStateAsync(rng_state_dst, seed, q_max_seqlen, kv_max_seqlen, backend, stream); } diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 3fe7ec854..3a36227cd 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -58,6 +58,10 @@ bool is_aotriton_backend_supported( return false; } + if(head_dim_qk >= 512 || head_dim_v >= 512){ + return false; + } + //TODO: release after TE integrates swa into AOTriton bool is_no_mask_window_size= window_size_left == -1 && window_size_right == -1; bool is_causal_mask_window_size = window_size_left ==-1 && window_size_right ==0; @@ -372,18 +376,18 @@ void fused_attn_aotriton_fwd_qkvpacked( if (Aux_CTX_Tensors->size == 0) { Aux_CTX_Tensors->size = 2; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; output_S->data.shape = {b, h, max_seqlen, 1}; output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = rng_state->data.dptr; } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); @@ -527,18 +531,18 @@ void fused_attn_aotriton_fwd_kvpacked( if (Aux_CTX_Tensors->size == 0) { Aux_CTX_Tensors->size = 2; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; output_S->data.shape = {b, h_q, max_seqlen_q, 1}; output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = rng_state->data.dptr; } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); @@ -676,18 +680,18 @@ void fused_attn_aotriton_fwd( if (Aux_CTX_Tensors->size == 0) { Aux_CTX_Tensors->size = 2; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; output_S->data.shape = {b, h_q, max_seqlen_q, 1}; output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = rng_state->data.dptr; } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index dff1a7626..572a3b5df 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -99,6 +99,15 @@ bool is_ck_backend_supported( return false; } + // filter based on head_dim + // AITER/ck does not support hdim>=512 + if(head_dim_qk >= 512 || head_dim_v >= 512){ + if(nvte_log_ck_config){ + std::cout<<"AITER/CK fused attn does not support head dim >=512 yet"<size == 0) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; if(is_ragged){ output_S->data.shape = {max_tokens, h, 1}; @@ -1276,17 +1285,17 @@ void fused_attn_ck_fwd_qkvpacked( output_S->data.shape = {b, h, max_seqlen, 1}; } output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); output_bias->data.dptr = nullptr; output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen}; output_bias->data.dtype = QKV_type; } else { Aux_CTX_Tensors->size = 2; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; if(is_ragged){ output_S->data.shape = {max_tokens, h, 1}; @@ -1294,22 +1303,22 @@ void fused_attn_ck_fwd_qkvpacked( output_S->data.shape = {b, h, max_seqlen, 1}; } output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; } } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = rng_state->data.dptr; } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); output_bias->data.dptr = devPtrBias; } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); @@ -1523,7 +1532,7 @@ void fused_attn_ck_fwd_kvpacked( if (Aux_CTX_Tensors->size == 0) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; if(is_ragged){ output_S->data.shape = {max_tokens_q, h_q, 1}; @@ -1531,17 +1540,17 @@ void fused_attn_ck_fwd_kvpacked( output_S->data.shape = {b, h_q, max_seqlen_q, 1}; } output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); output_bias->data.dptr = nullptr; output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; output_bias->data.dtype = QKV_type; } else { Aux_CTX_Tensors->size = 2; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; if(is_ragged){ output_S->data.shape = {max_tokens_q, h_q, 1}; @@ -1549,22 +1558,22 @@ void fused_attn_ck_fwd_kvpacked( output_S->data.shape = {b, h_q, max_seqlen_q, 1}; } output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; } } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = rng_state->data.dptr; } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); output_bias->data.dptr = devPtrBias; } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); @@ -1766,7 +1775,7 @@ void fused_attn_ck_fwd( if (Aux_CTX_Tensors->size == 0) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; if(is_ragged){ output_S->data.shape = {max_tokens_q, h_q, 1}; @@ -1774,17 +1783,17 @@ void fused_attn_ck_fwd( output_S->data.shape = {b, h_q, max_seqlen_q, 1}; } output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); output_bias->data.dptr = nullptr; output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; output_bias->data.dtype = QKV_type; } else { Aux_CTX_Tensors->size = 2; - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; if(is_ragged){ output_S->data.shape = {max_tokens_q, h_q, 1}; @@ -1792,22 +1801,22 @@ void fused_attn_ck_fwd( output_S->data.shape = {b, h_q, max_seqlen_q, 1}; } output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; } } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = rng_state->data.dptr; } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); output_bias->data.dptr = devPtrBias; } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); diff --git a/transformer_engine/common/fused_attn_rocm/utils.cpp b/transformer_engine/common/fused_attn_rocm/utils.cpp index 5e9b0b67f..166e22bbf 100644 --- a/transformer_engine/common/fused_attn_rocm/utils.cpp +++ b/transformer_engine/common/fused_attn_rocm/utils.cpp @@ -239,18 +239,15 @@ __global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, } } -void PopulateRngStateAsync(void *rng_state_dst, - const void *const seed, - size_t batch_size, - size_t num_heads, - size_t q_max_seqlen, - size_t kv_max_seqlen, +void PopulateRngStateAsync(void *rng_state_dst, const void *seed, size_t q_max_seqlen, + size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend, cudaStream_t stream) { - size_t increment = batch_size*num_heads*q_max_seqlen*kv_max_seqlen; - auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment); - populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast(rng_state_dst), - reinterpret_cast(seed), offset); - NVTE_CHECK_CUDA(cudaGetLastError()); + //both aiter and aotriton now follows flash-attn rng design + size_t increment = 16; + auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment); + populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast(rng_state_dst), + reinterpret_cast(seed), offset); + NVTE_CHECK_CUDA(cudaGetLastError()); } uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream) { diff --git a/transformer_engine/common/fused_attn_rocm/utils.h b/transformer_engine/common/fused_attn_rocm/utils.h index 2d03e715c..5d65bc8ee 100644 --- a/transformer_engine/common/fused_attn_rocm/utils.h +++ b/transformer_engine/common/fused_attn_rocm/utils.h @@ -59,10 +59,9 @@ class FusedAttnOffsetManager { void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, - size_t batch_size, - size_t num_heads, size_t q_max_seqlen, size_t kv_max_seqlen, + NVTE_Fused_Attn_Backend backend, cudaStream_t stream); uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream); diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index 42dac53e4..df9ea6ee5 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -308,11 +308,10 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens const int stride_d, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_forward); using namespace transformer_engine; - fused_rope_forward( - *reinterpret_cast(input), *reinterpret_cast(cu_seqlens), - *reinterpret_cast(freqs), *reinterpret_cast(start_positions), - reinterpret_cast(output), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, - stride_s_or_t, stride_b, stride_h, stride_d, stream); + fused_rope_forward(*convertNVTETensorCheck(input), *convertNVTETensorCheck(cu_seqlens), + *convertNVTETensorCheck(freqs), *convertNVTETensorCheck(start_positions), + convertNVTETensorCheck(output), qkv_format, interleaved, cp_size, cp_rank, s, + b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream); } void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, @@ -324,9 +323,8 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_backward); using namespace transformer_engine; - fused_rope_backward(*reinterpret_cast(output_grads), - *reinterpret_cast(cu_seqlens), - *reinterpret_cast(freqs), - reinterpret_cast(input_grads), qkv_format, interleaved, cp_size, - cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream); + fused_rope_backward(*convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(cu_seqlens), + *convertNVTETensorCheck(freqs), convertNVTETensorCheck(input_grads), + qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, + stride_b, stride_h, stride_d, stream); } diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu new file mode 100644 index 000000000..3af7e42c2 --- /dev/null +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -0,0 +1,297 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "../utils.cuh" +#include "common/util/cuda_runtime.h" +#include "utils.h" + +namespace transformer_engine { + +// Using Double to hanld all the calculations +using CompType = double; + +template +__global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, + const IndexType* tokens_per_expert, + int total_num_tokens, int num_experts, + int num_rows, int num_cols, int topk, float coeff, + DataType* aux_loss, float* Const_buf) { +#if __CUDA_ARCH__ >= 900 + // Using cooperative_groups to manage the cluster + namespace cg = cooperative_groups; + cg::cluster_group cluster = cg::this_cluster(); + int thread_id = cg::this_grid().thread_rank(); + int lane_id = thread_id % kThreadsPerWarp; + int warp_id = thread_id / kThreadsPerWarp; + int warp_num = blockDim.x * gridDim.x / kThreadsPerWarp; + // Only 1 block in the cluster + int block_id = cluster.block_rank(); + int block_num = cluster.dim_blocks().x; + int cluster_id = blockIdx.x / block_num; + if (cluster_id > 0) return; // Only use the cluster 0 + + extern __shared__ float shmem_aux_loss[]; + CompType* aggregated_probs_per_expert = reinterpret_cast(shmem_aux_loss); + // Clear the shmem + for (int i = threadIdx.x; i < num_cols; i += blockDim.x) { + aggregated_probs_per_expert[i] = CompType(0); + } + __syncthreads(); + + /** + * Section: Reduce the probs to the aggregated_probs_per_expert + * 1. reduce on the block + * 2. reduce on the cluster + */ + // Loop: for all positions in each row + for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) { + CompType tmp = CompType(0); + // Loop: for all rows that this warp is responsible for + for (int j = warp_id; j < num_rows; j += warp_num) { + tmp += CompType(probs[j * num_cols + i]); + } + atomicAdd(&aggregated_probs_per_expert[i], tmp); + } + cluster.sync(); + // The block 0 will reduce the results of all blocks + if (block_id == 0) { + for (int i = 1; i < block_num; i++) { + // Map the shared memory of the block i to the current block + CompType* dst_smem = reinterpret_cast(cluster.map_shared_rank(shmem_aux_loss, i)); + for (int j = threadIdx.x; j < num_cols; j += blockDim.x) { + atomicAdd(&aggregated_probs_per_expert[j], dst_smem[j]); + } + } + } + cluster.sync(); + + /** + * Section: aggregated_probs_per_expert * tokens_per_expert + * In-place update on shmem + */ + if (block_id == 0) { + for (int i = threadIdx.x; i < num_cols; i += blockDim.x) { + aggregated_probs_per_expert[i] *= CompType(tokens_per_expert[i]); + } + __syncthreads(); + + if (warp_id == 0) { + /** + * Section: Reduce to get the sum of aggregated_probs_per_expert + */ + CompType intermediate_result = + warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id); + __syncwarp(); + + if (lane_id == 0) { + /** + * Section: Compute the aux_loss + */ + float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens; + aux_loss[0] = static_cast(static_cast(intermediate_result) * C_coeff); + Const_buf[0] = C_coeff; + } + } + } +#else + // Use Only 1 block/1024 threads to avoid the grid sync + if (blockIdx.x > 0) return; + int warp_num = blockDim.x / kThreadsPerWarp; + int warp_id = threadIdx.x / kThreadsPerWarp; + int lane_id = threadIdx.x % kThreadsPerWarp; + extern __shared__ float shmem_aux_loss[]; + CompType* aggregated_probs_per_expert = reinterpret_cast(shmem_aux_loss); + + // Clear the shmem + for (int i = threadIdx.x; i < num_cols; i += blockDim.x) { + aggregated_probs_per_expert[i] = CompType(0); + } + __syncthreads(); + + /** + * Section: Reduce the probs to the aggregated_probs_per_expert + */ + // Loop: for all positions in each row + for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) { + CompType tmp = CompType(0); + // Loop: for all rows that this warp is responsible for + for (int j = warp_id; j < num_rows; j += warp_num) { + tmp += CompType(probs[j * num_cols + i]); + } + atomicAdd(&aggregated_probs_per_expert[i], tmp); + } + __syncthreads(); + + /** + * Section: aggregated_probs_per_expert * tokens_per_expert + * In-place update on shmem + */ + for (int i = threadIdx.x; i < num_cols; i += blockDim.x) { + aggregated_probs_per_expert[i] *= CompType(tokens_per_expert[i]); + } + __syncthreads(); + + if (warp_id == 0) { + /** + * Section: Reduce to get the sum of aggregated_probs_per_expert + */ + CompType intermediate_result = + warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id); + __syncwarp(); + + if (lane_id == 0) { + /** + * Section: Compute the aux_loss + */ + float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens; + aux_loss[0] = static_cast(static_cast(intermediate_result) * C_coeff); + Const_buf[0] = C_coeff; + } + } +#endif +} + +template +void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, + const IndexType* tokens_per_expert, + int total_num_tokens, int num_experts, int num_rows, + int num_cols, int topk, float coeff, + DataType* aux_loss, float* Const_buf, + cudaStream_t stream) { +// TODO: unblock after rocm support thread block cluster +#ifndef __HIP_PLATFORM_AMD__ + if (cuda::sm_arch(cuda::current_device()) >= 90) { + cudaLaunchConfig_t config = {0}; + int cluster_size = 8; + config.gridDim = cluster_size; + config.blockDim = 1024; + config.dynamicSmemBytes = sizeof(CompType) * num_cols; + config.stream = stream; + + // Update the max cluster size based on the device + cudaOccupancyMaxPotentialClusterSize( + &cluster_size, + reinterpret_cast(fused_moe_aux_loss_forward_kernel), &config); + + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeClusterDimension; + attribute[0].val.clusterDim.x = cluster_size; + attribute[0].val.clusterDim.y = 1; + attribute[0].val.clusterDim.z = 1; + config.numAttrs = 1; + config.attrs = attribute; + + cudaLaunchKernelEx(&config, fused_moe_aux_loss_forward_kernel, probs, + tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk, + coeff, aux_loss, Const_buf); + } else { +#endif + size_t smem_size = sizeof(CompType) * num_cols; + fused_moe_aux_loss_forward_kernel + <<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_experts, + num_rows, num_cols, topk, coeff, aux_loss, Const_buf); +#ifndef __HIP_PLATFORM_AMD__ + } +#endif +} + +void fused_moe_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_expert, + int total_num_tokens, int num_experts, int num_rows, int num_cols, + int topk, float coeff, Tensor& aux_loss, Tensor& Const_buf, + cudaStream_t stream) { + TE_ROUTER_PROBS_TYPE_SWITCH_ALL( + probs.data.dtype, DataType, + TE_ROUTER_INDEX_TYPE_SWITCH_ALL( + tokens_per_expert.data.dtype, IndexType, + fused_moe_aux_loss_forward_kernel_launcher( + reinterpret_cast(probs.data.dptr), + reinterpret_cast(tokens_per_expert.data.dptr), total_num_tokens, + num_experts, num_rows, num_cols, topk, coeff, + reinterpret_cast(aux_loss.data.dptr), + reinterpret_cast(Const_buf.data.dptr), stream););); +} + +template +__global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf, + const IndexType* tokens_per_expert, int num_rows, + int num_cols, DataType* grad_aux_loss, + DataType* grad_probs) { + int global_warp_num = gridDim.x * blockDim.x / kThreadsPerWarp; + int global_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kThreadsPerWarp; + int lane_id = threadIdx.x % kThreadsPerWarp; + + // Loop: for all positions in each row + for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) { + float C_coeff = Const_buf[0]; + IndexType tokens_per_expert_i = tokens_per_expert[i]; + double grad_aux_loss_value = static_cast(grad_aux_loss[0]); + // Loop: for all rows + for (int j = global_warp_id; j < num_rows; j += global_warp_num) { + grad_probs[j * num_cols + i] = C_coeff * tokens_per_expert_i * grad_aux_loss_value; + } + } +} + +template +void fused_moe_aux_loss_backward_kernel_launcher(const float* Const_buf, + const IndexType* tokens_per_expert, int num_rows, + int num_cols, DataType* grad_aux_loss, + DataType* grad_probs, cudaStream_t stream) { + // Meta data for the kernel + int block_size = 256; + int grid_size = (num_rows + block_size - 1) / block_size; + fused_moe_aux_loss_backward_kernel<<>>( + Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss, grad_probs); +} + +void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_per_expert, + int num_rows, int num_cols, Tensor& grad_aux_loss, + Tensor& grad_probs, cudaStream_t stream) { + TE_ROUTER_PROBS_TYPE_SWITCH_ALL( + grad_aux_loss.data.dtype, DataType, + TE_ROUTER_INDEX_TYPE_SWITCH_ALL( + tokens_per_expert.data.dtype, IndexType, + fused_moe_aux_loss_backward_kernel_launcher( + reinterpret_cast(Const_buf.data.dptr), + reinterpret_cast(tokens_per_expert.data.dptr), num_rows, num_cols, + reinterpret_cast(grad_aux_loss.data.dptr), + reinterpret_cast(grad_probs.data.dptr), stream););); +} + +} // namespace transformer_engine + +void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert, + int total_num_tokens, int num_experts, int num_rows, + int num_cols, int topk, float coeff, NVTETensor aux_loss, + NVTETensor Const_buf, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_moe_aux_loss_forward); + using namespace transformer_engine; + fused_moe_aux_loss_forward( + *convertNVTETensorCheck(probs), *convertNVTETensorCheck(tokens_per_expert), total_num_tokens, + num_experts, num_rows, num_cols, topk, coeff, *convertNVTETensorCheck(aux_loss), + *convertNVTETensorCheck(Const_buf), stream); +} + +void nvte_fused_moe_aux_loss_backward(const NVTETensor Const_buf, + const NVTETensor tokens_per_expert, int num_rows, + int num_cols, NVTETensor grad_aux_loss, NVTETensor grad_probs, + cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_moe_aux_loss_backward); + using namespace transformer_engine; + fused_moe_aux_loss_backward(*convertNVTETensorCheck(Const_buf), + *convertNVTETensorCheck(tokens_per_expert), num_rows, num_cols, + *convertNVTETensorCheck(grad_aux_loss), + *convertNVTETensorCheck(grad_probs), stream); +} diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu new file mode 100644 index 000000000..91a4bbb53 --- /dev/null +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -0,0 +1,324 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "../utils.cuh" +#include "utils.h" + +namespace transformer_engine { + +template +__global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logits, int num_tokens, + int num_experts, int topk, + int score_function, DataType *scores, + bool *routing_map, + DataType *intermediate_output) { + /*** + * Section: Global Variables/Addresses init + * - Assume the sizeof(DataType) >= sizeof(int), + * So DataType address is assigned firstly to avoid the alignment issue + * - Each warp is responsible for one token, and has own shared memory buffer. + * Then __syncwarp() is used instead of __syncthreads() + */ + // Used variables/addresses init + int num_token_per_block = blockDim.x / kThreadsPerWarp; + int warp_id = threadIdx.x / kThreadsPerWarp; + int lane_id = threadIdx.x % kThreadsPerWarp; + extern __shared__ float shmem_scores_for_aux_loss[]; + DataType *logits_buf = reinterpret_cast(shmem_scores_for_aux_loss); + DataType *topk_logits_buf = + reinterpret_cast(logits_buf + num_experts * num_token_per_block); + int *topk_indices_buf = reinterpret_cast(topk_logits_buf + topk * num_token_per_block); + // The address of buffers on the current warp + DataType *local_logits = logits_buf + warp_id * num_experts; + DataType *topk_logits = topk_logits_buf + warp_id * topk; + int *topk_indices = topk_indices_buf + warp_id * topk; + + /*** + * Section: Main Loop + * - Each warp is responsible for one token + */ + int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; + for (int round = blockIdx.x; round < total_round; round += gridDim.x) { + int token_offset_cur_warp = round * num_token_per_block + warp_id; + // Each warp is responsible for one token + if (token_offset_cur_warp >= num_tokens) break; + + /*** + * Section: Init buffer + * - Clear the global buffer which will accept the result of this round + * - Clear/Init the shmem buffer used by current warp this round + * - Load the logits to shmem + */ + int pos_offset = token_offset_cur_warp * num_experts; + // Clear the routing_map (num_experts) + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + routing_map[pos_offset + i] = false; + if (score_function == 1) { + intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); + } + } + // Load the logits to shmem + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_logits[i] = logits[pos_offset + i]; + } + __threadfence_block(); + __syncwarp(); + + /*** + * Section: Preprocess + * Possible preprocess the scores before the topk operation + * - Pre-softmax + * - Sigmoid + * - Sigmoid post-processing when topk > 1 + * This is in-place scores update + */ + // score_function == 1 means softmax + if (score_function == 1) { + // Apply softmax to the logits before the topk + apply_softmax_on_float(local_logits, num_experts, lane_id); + __syncwarp(); + // Save the softmax output for backward + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = local_logits[i]; + } + } + + // score_function == 0 means sigmoid + if (score_function == 0) { + // Apply sigmoid to the logits + apply_sigmoid_on_float(local_logits, num_experts, lane_id); + __syncwarp(); + // Save the sigmoid output for backward + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = local_logits[i]; + } + } + + __syncwarp(); //Confirm the scores is written to the softmax/sigmoid output + + if (score_function == 0) { + if (topk > 1) { + auto sum_logits = warp_reduce_on_shmem(local_logits, num_experts, sum, lane_id); + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_logits[i] = static_cast(static_cast(local_logits[i]) / + (static_cast(sum_logits) + epsilon)); + } + } + __syncwarp(); + } + + /*** + * Section: Topk + * Get the topk indices + */ + naive_topk_and_mask(local_logits, num_experts, topk, topk_indices, topk_logits, lane_id); + __syncwarp(); + + // Write the routing_map to the output tensor + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + routing_map[pos_offset + topk_indices[i]] = true; + } + // Write the scores to the output tensor + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + scores[pos_offset + i] = local_logits[i]; + } + __threadfence_block(); + __syncwarp(); + } +} + +template +void fused_score_for_moe_aux_loss_forward_kernel_launcher( + const DataType *logits, int num_tokens, int num_experts, int topk, int score_function, + DataType *scores, bool *routing_map, DataType *intermediate_output, cudaStream_t stream) { + // Meta data for the kernel + size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; + size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; + size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // logits + + topk * num_token_per_block * sizeof(DataType) // topk_logits + + topk * num_token_per_block * sizeof(int); // topk_indices + fused_score_for_moe_aux_loss_forward_kernel + <<>>( + logits, num_tokens, num_experts, topk, score_function, scores, routing_map, + intermediate_output); +} + +void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, int num_experts, + int topk, int score_function, Tensor &scores, + Tensor &routing_map, Tensor &intermediate_output, + cudaStream_t stream) { + TE_ROUTER_PROBS_TYPE_SWITCH_ALL( + logits.data.dtype, DataType, + fused_score_for_moe_aux_loss_forward_kernel_launcher( + reinterpret_cast(logits.data.dptr), num_tokens, num_experts, topk, + score_function, reinterpret_cast(scores.data.dptr), + reinterpret_cast(routing_map.data.dptr), + reinterpret_cast(intermediate_output.data.dptr), stream);); +} + +template +__global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *intermediate_output, + const DataType *grad_scores, + int num_tokens, int num_experts, + int topk, int score_function, + DataType *grad_logits) { + /*** + * Section: Global Variables/Addresses init + * - Assume the sizeof(DataType) >= sizeof(int), + * - Each warp is responsible for one token, and has own shared memory buffer. + * Then __syncwarp() is used instead of __syncthreads() + */ + // Used variables/addresses init + int num_token_per_block = blockDim.x / kThreadsPerWarp; + int warp_id = threadIdx.x / kThreadsPerWarp; + int lane_id = threadIdx.x % kThreadsPerWarp; + extern __shared__ float shmem[]; + DataType *grad_scores_buf = reinterpret_cast(shmem); + // To store the output of softmax/sigmoid from the fwd + DataType *act_from_fwd_buf = + reinterpret_cast(grad_scores_buf + num_experts * num_token_per_block); + DataType *comp_buf = + reinterpret_cast(act_from_fwd_buf + num_experts * num_token_per_block); + // The address of buffers on the current warp + DataType *local_grad = grad_scores_buf + warp_id * num_experts; + DataType *local_act_from_fwd = act_from_fwd_buf + warp_id * num_experts; + DataType *local_comp_buf = comp_buf + warp_id * num_experts; + + /*** + * Section: Main Loop + * - Each warp is responsible for one token + */ + int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; + for (int round = blockIdx.x; round < total_round; round += gridDim.x) { + int token_offset_cur_warp = round * num_token_per_block + warp_id; + // Each warp is responsible for one token + if (token_offset_cur_warp >= num_tokens) break; + + /*** + * Section: Init buffer + * - Clear the global buffer which will accept the result of this round + * - Clear/Init the shmem buffer used by current warp this round + * - Load the dgrad/output_from_fwd to shmem + */ + int pos_offset = token_offset_cur_warp * num_experts; + // Clear the logits_grad in global mem + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + grad_logits[pos_offset + i] = 0.0f; + } + // Load the dgrad/output_from_fwd to shmem + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_grad[i] = grad_scores[pos_offset + i]; + local_act_from_fwd[i] = intermediate_output[pos_offset + i]; + } + __threadfence_block(); + __syncwarp(); + + /*** + * Section: Backward of ops before the topk + * - Pre-softmax bwd + * - Sigmoid Post-processing bwd when topk > 1 + * - Sigmoid bwd + * - Write the grad_logits to the global mem + */ + // Sigmoid Post-processing bwd when topk > 1 + if (topk > 1 && score_function == 0) { + auto sum_fwd_input = warp_reduce_on_shmem(local_act_from_fwd, num_experts, sum, lane_id); + // Put the result of output * grad to the comp_buf + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_comp_buf[i] = local_grad[i] * local_act_from_fwd[i]; + } + __syncwarp(); + auto sum_Output_x_Grad = warp_reduce_on_shmem(local_comp_buf, num_experts, sum, lane_id); + // In-place update + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_grad[i] = + static_cast(local_grad[i]) / (static_cast(sum_fwd_input) + epsilon) - + static_cast(sum_Output_x_Grad) / + ((static_cast(sum_fwd_input) + epsilon) * + (static_cast(sum_fwd_input) + epsilon)); + } + } + __syncwarp(); + + // Pre-softmax bwd + if (score_function == 1) { + apply_softmax_bwd_on_float(local_grad, local_act_from_fwd, local_comp_buf, nullptr, + num_experts, lane_id); + __syncwarp(); + } + // Sigmoid bwd + if (score_function == 0) { + apply_sigmoid_bwd_on_float(local_grad, local_act_from_fwd, num_experts, lane_id); + __syncwarp(); + } + // Write the grad_logits to the global mem + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + grad_logits[pos_offset + i] = local_grad[i]; + } + __syncwarp(); + } +} + +template +void fused_score_for_moe_aux_loss_backward_kernel_launcher( + const DataType *intermediate_output, const DataType *grad_scores, int num_tokens, + int num_experts, int topk, int score_function, DataType *grad_logits, cudaStream_t stream) { + // Meta data for the kernel + size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; + size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; + size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // grad_scores + + + num_experts * num_token_per_block * sizeof(DataType) // act_from_fwd + + num_experts * num_token_per_block * sizeof(DataType); // comp_buf + fused_score_for_moe_aux_loss_backward_kernel + <<>>( + intermediate_output, grad_scores, num_tokens, num_experts, topk, score_function, + grad_logits); +} + +void fused_score_for_moe_aux_loss_backward(const Tensor &intermediate_output, + const Tensor &grad_scores, int num_tokens, + int num_experts, int topk, int score_function, + Tensor &grad_logits, cudaStream_t stream) { + TE_ROUTER_PROBS_TYPE_SWITCH_ALL( + grad_scores.data.dtype, DataType, + fused_score_for_moe_aux_loss_backward_kernel_launcher( + reinterpret_cast(intermediate_output.data.dptr), + reinterpret_cast(grad_scores.data.dptr), num_tokens, num_experts, topk, + score_function, reinterpret_cast(grad_logits.data.dptr), stream);); +} + +} // namespace transformer_engine + +void nvte_fused_score_for_moe_aux_loss_forward(const NVTETensor logits, int num_tokens, + int num_experts, int topk, int score_function, + NVTETensor scores, const NVTETensor routing_map, + const NVTETensor intermediate_output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_score_for_moe_aux_loss_forward); + using namespace transformer_engine; + fused_score_for_moe_aux_loss_forward(*convertNVTETensorCheck(logits), num_tokens, num_experts, + topk, score_function, *convertNVTETensorCheck(scores), + *convertNVTETensorCheck(routing_map), + *convertNVTETensorCheck(intermediate_output), stream); +} + +void nvte_fused_score_for_moe_aux_loss_backward(const NVTETensor intermediate_output, + const NVTETensor grad_scores, int num_tokens, + int num_experts, int topk, int score_function, + NVTETensor grad_logits, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_score_for_moe_aux_loss_backward); + using namespace transformer_engine; + fused_score_for_moe_aux_loss_backward( + *convertNVTETensorCheck(intermediate_output), *convertNVTETensorCheck(grad_scores), + num_tokens, num_experts, topk, score_function, *convertNVTETensorCheck(grad_logits), stream); +} diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu new file mode 100644 index 000000000..443931aa3 --- /dev/null +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -0,0 +1,504 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "../utils.cuh" +#include "utils.h" + +namespace transformer_engine { + +template +__global__ void fused_topk_with_score_function_forward_kernel( + const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, + int num_groups, int group_topk, float scaling_factor, int score_function, + const BiasType *expert_bias, DataType *probs, bool *routing_map, + DataType *intermediate_output) { + /*** + * Section: Global Variables/Addresses init + * - Assume the sizeof(DataType) >= sizeof(int), + * So DataType address is assigned firstly to avoid the alignment issue + * - Each warp is responsible for one token, and has own shared memory buffer. + * Then __syncwarp() is used instead of __syncthreads() + */ + // Used variables/addresses init + int num_token_per_block = blockDim.x / kThreadsPerWarp; + int warp_id = threadIdx.x / kThreadsPerWarp; + int lane_id = threadIdx.x % kThreadsPerWarp; + extern __shared__ float shmem[]; + DataType *scores_buf = reinterpret_cast(shmem); + DataType *topk_scores_buf = + reinterpret_cast(scores_buf + num_experts * num_token_per_block); + DataType *group_scores_buf = nullptr, *masked_scores_buf = nullptr; + int *topk_indices_buf = nullptr; + if (group_topk > 0) { + masked_scores_buf = reinterpret_cast(topk_scores_buf + topk * num_token_per_block); + group_scores_buf = + reinterpret_cast(masked_scores_buf + num_experts * num_token_per_block); + topk_indices_buf = reinterpret_cast(group_scores_buf + num_groups * num_token_per_block); + } else { + topk_indices_buf = reinterpret_cast(topk_scores_buf + topk * num_token_per_block); + } + // The address of buffers on the current warp + DataType *scores = scores_buf + warp_id * num_experts; + DataType *topk_scores = topk_scores_buf + warp_id * topk; + DataType *masked_scores = masked_scores_buf + warp_id * num_experts; + DataType *group_scores = group_scores_buf + warp_id * num_groups; + int *topk_indices = topk_indices_buf + warp_id * topk; + + /*** + * Section: Main Loop + * - Each warp is responsible for one token + */ + int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; + for (int round = blockIdx.x; round < total_round; round += gridDim.x) { + int token_offset_cur_warp = round * num_token_per_block + warp_id; + // Each warp is responsible for one token + if (token_offset_cur_warp >= num_tokens) break; + + /*** + * Section: Init buffer + * - Clear the global buffer which will accept the result of this round + * - Clear/Init the shmem buffer used by current warp this round + * - Load the logits to shmem + */ + int pos_offset = token_offset_cur_warp * num_experts; + // Clear the probs/routing_map (num_experts) + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + probs[pos_offset + i] = 0.0f; + routing_map[pos_offset + i] = false; + if (score_function == 1) { + intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); + } + } + // Load the logits to shmem + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + scores[i] = logits[pos_offset + i]; + } + // If group_topk > 0, init the masked_scores to -inf + if (group_topk > 0) { + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + masked_scores[i] = -std::numeric_limits::infinity(); + } + } + __threadfence_block(); + __syncwarp(); + + /*** + * Section: Preprocess + * Possible preprocess the scores before the topk operation + * - Pre-softmax + * - Sigmoid + * - Expert bias + * This is in-place scores update + */ + // score_function == 1 means softmax + if (use_pre_softmax && score_function == 1) { + // Apply softmax to the logits before the topk + apply_softmax_on_float(scores, num_experts, lane_id); + __syncwarp(); + // Save the softmax output for backward + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = scores[i]; + } + } + + // score_function == 0 means sigmoid + if (score_function == 0) { + // Apply sigmoid to the logits + apply_sigmoid_on_float(scores, num_experts, lane_id); + __syncwarp(); + // Save the sigmoid output for backward + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = scores[i]; + } + } + + __syncwarp(); //Confirm the scores is written to the softmax/sigmoid output + + // Expert bias is only used at the sigmoid case + if (expert_bias && score_function == 0) { + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + scores[i] = static_cast(static_cast(scores[i]) + + static_cast(expert_bias[i])); + } + } + __syncwarp(); + + /*** + * Section: Topk + * Get the topk indices + * - group_topk + * - naive topk + * - topk with expert bias + */ + // Topk on the scores + // The bias is not empty only happens at the sigmod case + if (group_topk > 0) { + int group_size = num_experts / num_groups; + // Top2 + for (int i = 0; i < num_groups; i++) { + naive_topk_and_mask( + /*scores ptr = */ scores + i * group_size, + /*data size = */ group_size, + /*topk = */ topk / group_topk, + /*topk indices ptr = */ topk_indices, + /*topk scores ptr = */ topk_scores, + /*lane id = */ lane_id); + __syncwarp(); + // Compute the group score + if (lane_id == 0) { +//TODO: release after /opt/rocm/include/hip/amd_detail/amd_hip_bfloat16.h remove explict constructor restriction +#ifdef __HIP_PLATFORM_AMD__ + DataType tmp(0.0f); +#else + DataType tmp = 0.0f; +#endif + for (int j = 0; j < topk / group_topk; j++) { + tmp = tmp + topk_scores[j]; + } + group_scores[i] = tmp; + } + __syncwarp(); + } + + // select the topk groups + naive_topk_and_mask( + /*scores ptr = */ group_scores, + /*data size = */ num_groups, + /*topk = */ group_topk, + /*topk indices ptr = */ topk_indices, + /*topk scores ptr = */ topk_scores, + /*lane id = */ lane_id); + __syncwarp(); + // Copy the unmasked scores to the buffer + for (int i = 0; i < group_topk; i++) { + int st = topk_indices[i] * group_size; + int ed = st + group_size; + for (int j = st + lane_id; j < ed; j += kThreadsPerWarp) { + masked_scores[j] = scores[j]; + } + } + __syncwarp(); + naive_topk_and_mask(masked_scores, num_experts, topk, topk_indices, topk_scores, lane_id); + + } else { + naive_topk_and_mask(scores, num_experts, topk, topk_indices, topk_scores, lane_id); + } + __syncwarp(); + + /*** + * Section: Postprocess + * Possible postprocess the scores after the topk operation + * - Revert Expert bias + * - Softmax + * - Sigmoid post-processing when topk > 1 + * - Write the result with scaling_factor + */ + // Revert Expert bias from the topk scores + if (expert_bias && score_function == 0) { + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + topk_scores[i] = + static_cast(topk_scores[i]) - static_cast(expert_bias[topk_indices[i]]); + } + } + __syncwarp(); + + // score_function == 1 means softmax + if (!use_pre_softmax && score_function == 1) { + // Apply softmax to the topk logits + apply_softmax_on_float(topk_scores, topk, lane_id); + __syncwarp(); + // Save the softmax output for backward + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + intermediate_output[pos_offset + topk_indices[i]] = topk_scores[i]; + } + } + + // score_function == 0 means sigmoid + if (score_function == 0) { + if (topk > 1) { + double sum_scores = warp_reduce_on_shmem(topk_scores, topk, sum, lane_id); + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + topk_scores[i] = static_cast(topk_scores[i]) / (sum_scores + epsilon); + } + } + __syncwarp(); + } + + // Write the probs/routing_map to the output tensor + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + routing_map[pos_offset + topk_indices[i]] = true; + probs[pos_offset + topk_indices[i]] = scaling_factor * static_cast(topk_scores[i]); + } + __threadfence_block(); + __syncwarp(); + } +} + +template +void fused_topk_with_score_function_forward_kernel_launcher( + const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, + int num_groups, int group_topk, float scaling_factor, int score_function, + const BiasType *expert_bias, DataType *probs, bool *routing_map, DataType *intermediate_output, + cudaStream_t stream) { + size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; + size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; + size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // scores + + topk * num_token_per_block * sizeof(DataType) // topk_scores + + topk * num_token_per_block * sizeof(int); // topk_indices + if (group_topk > 0) { + shared_memory_size += num_groups * num_token_per_block * sizeof(DataType); // group_scores + shared_memory_size += num_experts * num_token_per_block * sizeof(DataType); // maksed_scores + } + fused_topk_with_score_function_forward_kernel + <<>>( + logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, + scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); +} + +void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, int num_experts, + int topk, bool use_pre_softmax, int num_groups, + int group_topk, float scaling_factor, + int score_function, const Tensor expert_bias, + Tensor probs, Tensor routing_map, + Tensor intermediate_output, cudaStream_t stream) { + TE_ROUTER_PROBS_TYPE_SWITCH_ALL( + logits.data.dtype, DataType, + TE_ROUTER_PROBS_TYPE_SWITCH_ALL( + expert_bias.data.dtype, BiasType, + fused_topk_with_score_function_forward_kernel_launcher( + reinterpret_cast(logits.data.dptr), num_tokens, num_experts, topk, + use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, + reinterpret_cast(expert_bias.data.dptr), + reinterpret_cast(probs.data.dptr), + reinterpret_cast(routing_map.data.dptr), + reinterpret_cast(intermediate_output.data.dptr), stream););); +} + +template +__global__ void fused_topk_with_score_function_backward_kernel( + // Inputs tensor + const bool *routing_map, const DataType *intermediate_output, const DataType *grad_probs, + // Other parameters + int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, + int score_function, + // Output tensor + DataType *grad_logits) { + /*** + * Section: Global Variables/Addresses init + * - Assume the sizeof(DataType) >= sizeof(int), + * - Each warp is responsible for one token, and has own shared memory buffer. + * Then __syncwarp() is used instead of __syncthreads() + */ + // Used variables/addresses init + int num_token_per_block = blockDim.x / kThreadsPerWarp; + int warp_id = threadIdx.x / kThreadsPerWarp; + int lane_id = threadIdx.x % kThreadsPerWarp; + extern __shared__ float shmem[]; + DataType *grad_probs_buf = reinterpret_cast(shmem); + // To store the output of softmax/sigmoid from the fwd + DataType *act_from_fwd_buf = + reinterpret_cast(grad_probs_buf + num_experts * num_token_per_block); + DataType *comp_buf = + reinterpret_cast(act_from_fwd_buf + num_experts * num_token_per_block); + // To store the routing_map from the fwd + bool *routing_map_buf = reinterpret_cast(comp_buf + num_experts * num_token_per_block); + // The address of buffers on the current warp + DataType *local_grad = grad_probs_buf + warp_id * num_experts; + DataType *local_act_from_fwd = act_from_fwd_buf + warp_id * num_experts; + DataType *local_comp_buf = comp_buf + warp_id * num_experts; + bool *local_routing_map = routing_map_buf + warp_id * num_experts; + + /*** + * Section: Main Loop + * - Each warp is responsible for one token + */ + int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; + for (int round = blockIdx.x; round < total_round; round += gridDim.x) { + int token_offset_cur_warp = round * num_token_per_block + warp_id; + // Each warp is responsible for one token + if (token_offset_cur_warp >= num_tokens) break; + + /*** + * Section: Init buffer + * - Clear the global buffer which will accept the result of this round + * - Clear/Init the shmem buffer used by current warp this round + * - Load the dgrad/output_from_fwd to shmem + */ + int pos_offset = token_offset_cur_warp * num_experts; + // Clear the logits_grad in global mem + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + grad_logits[pos_offset + i] = 0.0f; + } + // Load the dgrad/output_from_fwd to shmem + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_grad[i] = grad_probs[pos_offset + i]; + local_act_from_fwd[i] = intermediate_output[pos_offset + i]; + local_routing_map[i] = routing_map[pos_offset + i]; + } + __threadfence_block(); + __syncwarp(); + + /*** + * Section: Backward of ops after the topk + * - Backward of the used scaling_factor + * - Sigmoid Post-processing bwd when topk > 1 + * - Softmax bwd if use_pre_softmax is false + */ + // Backward of the used scaling_factor + // In-place update + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + if (local_routing_map[i]) { + local_grad[i] = static_cast(local_grad[i]) * scaling_factor; + } + } + __syncwarp(); + // Sigmoid Post-processing bwd when topk > 1 + if (topk > 1 && score_function == 0) { + double sum_fwd_input = masked_warp_reduce_on_shmem( + /*data ptr = */ local_act_from_fwd, + /*mask ptr = */ local_routing_map, + /*data size = */ num_experts, + /*reduce func = */ sum, lane_id); + // Put the result of output * grad to the comp_buf + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_comp_buf[i] = (local_routing_map[i] ? static_cast(local_grad[i]) * + static_cast(local_act_from_fwd[i]) + : 0.0f); + } + __syncwarp(); + double sum_Output_x_Grad = masked_warp_reduce_on_shmem( + /*data ptr = */ local_comp_buf, + /*mask ptr = */ local_routing_map, + /*data size = */ num_experts, + /*reduce func = */ sum, lane_id); + // In-place update + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + if (local_routing_map[i]) { + local_grad[i] = + static_cast(local_grad[i]) / (sum_fwd_input + epsilon) - + sum_Output_x_Grad / ((sum_fwd_input + epsilon) * (sum_fwd_input + epsilon)); + } else { + local_grad[i] = 0.0f; + } + } + } + __syncwarp(); + // Softmax bwd if use_pre_softmax is false + if (!use_pre_softmax && score_function == 1) { + apply_softmax_bwd_on_float(local_grad, local_act_from_fwd, local_comp_buf, local_routing_map, + num_experts, lane_id); + __syncwarp(); + } + + /*** + * Section: Backward of topk + * mask the unselected position in the grad + */ + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + if (!local_routing_map[i]) { + local_grad[i] = 0.0f; + } + } + __syncwarp(); + + /*** + * Section: Backward of ops before the topk + * - Pre-softmax bwd + * - Sigmoid bwd + * - Write the grad_logits to the global mem + */ + // Pre-softmax bwd + if (score_function == 1 && use_pre_softmax) { + apply_softmax_bwd_on_float(local_grad, local_act_from_fwd, local_comp_buf, nullptr, + num_experts, lane_id); + __syncwarp(); + } + // Sigmoid bwd + if (score_function == 0) { + apply_sigmoid_bwd_on_float(local_grad, local_act_from_fwd, num_experts, lane_id); + __syncwarp(); + } + // Write the grad_logits to the global mem + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + grad_logits[pos_offset + i] = local_grad[i]; + } + __syncwarp(); + } +} + +template +void fused_topk_with_score_function_backward_kernel_launcher( + const bool *routing_map, const DataType *intermediate_output, const DataType *grad_probs, + int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, + int score_function, DataType *grad_logits, cudaStream_t stream) { + // Meta data for the kernel + size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; + size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; + size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // grad_probs + + + num_experts * num_token_per_block * sizeof(DataType) // act_from_fwd + + num_experts * num_token_per_block * sizeof(DataType) // comp_buf + + num_experts * num_token_per_block * sizeof(bool); // routing_map + fused_topk_with_score_function_backward_kernel + <<>>( + routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk, + use_pre_softmax, scaling_factor, score_function, grad_logits); +} + +void fused_topk_with_score_function_backward(const Tensor &routing_map, + const Tensor &intermediate_output, + const Tensor &grad_probs, int num_tokens, + int num_experts, int topk, bool use_pre_softmax, + float scaling_factor, int score_function, + Tensor &grad_logits, cudaStream_t stream) { + TE_ROUTER_PROBS_TYPE_SWITCH_ALL( + grad_logits.data.dtype, DataType, + fused_topk_with_score_function_backward_kernel_launcher( + reinterpret_cast(routing_map.data.dptr), + reinterpret_cast(intermediate_output.data.dptr), + reinterpret_cast(grad_probs.data.dptr), num_tokens, num_experts, topk, + use_pre_softmax, scaling_factor, score_function, + reinterpret_cast(grad_logits.data.dptr), stream);); +} + +} // namespace transformer_engine + +void nvte_fused_topk_with_score_function_forward( + const NVTETensor logits, int num_tokens, int num_experts, int topk, int use_pre_softmax, + int num_groups, int group_topk, float scaling_factor, int score_function, + const NVTETensor expert_bias, NVTETensor probs, NVTETensor routing_map, + NVTETensor intermediate_output, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_topk_with_score_function_forward); + using namespace transformer_engine; + fused_topk_with_score_function_forward( + *convertNVTETensorCheck(logits), num_tokens, num_experts, topk, + static_cast(use_pre_softmax), num_groups, group_topk, scaling_factor, score_function, + *convertNVTETensorCheck(expert_bias), *convertNVTETensorCheck(probs), + *convertNVTETensorCheck(routing_map), *convertNVTETensorCheck(intermediate_output), stream); +} + +void nvte_fused_topk_with_score_function_backward(const NVTETensor routing_map, + const NVTETensor intermediate_output, + const NVTETensor grad_probs, int num_tokens, + int num_experts, int topk, int use_pre_softmax, + float scaling_factor, int score_function, + NVTETensor grad_logits, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_topk_with_score_function_backward); + using namespace transformer_engine; + fused_topk_with_score_function_backward( + *convertNVTETensorCheck(routing_map), *convertNVTETensorCheck(intermediate_output), + *convertNVTETensorCheck(grad_probs), num_tokens, num_experts, topk, + static_cast(use_pre_softmax), scaling_factor, score_function, + *convertNVTETensorCheck(grad_logits), stream); +} diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h new file mode 100644 index 000000000..4d2462815 --- /dev/null +++ b/transformer_engine/common/fused_router/utils.h @@ -0,0 +1,265 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_ +#define TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_ + +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { + +#ifdef __HIP_PLATFORM_AMD__ +// TODO: remove after rocm supports NV __syncwarp equivalent +__device__ inline void __syncwarp() +{ + __builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront"); + __builtin_amdgcn_wave_barrier(); + __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront"); + +} + +#endif +constexpr size_t kThreadsPerWarp = 32; +constexpr int kThreadsPerBlock = + 128; // Using 4 warps in 1 CTA, Each warp is responsible for 1 token. +constexpr float epsilon = 1e-20; + +template +__device__ inline T max(T a, T b) { + return a > b ? a : b; +} + +template +__device__ inline T sum(T a, T b) { + return a + b; +} + +template +__device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, T (*reduce_func)(T, T), + int lane_id) { + // Some value is hanlded in local thread + // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... + // Reduce the value in local thread + volatile double val = + lane_id < data_size ? static_cast(data_ptr[lane_id]) : static_cast(0); + for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { +//TODO: release after /opt/rocm/include/hip/amd_detail/amd_hip_bfloat16.h provide bf16 constructor from double +#ifdef __HIP_PLATFORM_AMD__ + val = reduce_func(static_cast(val), data_ptr[i]); +#else + val = reduce_func(val, data_ptr[i]); +#endif + } + + // Warp shuffle between threads +#ifdef __HIP_PLATFORM_AMD__ + val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 16, kThreadsPerWarp))); + val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 8, kThreadsPerWarp))); + val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 4, kThreadsPerWarp))); + val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 2, kThreadsPerWarp))); + val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 1, kThreadsPerWarp))); +#else + val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16)); + val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8)); + val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 4)); + val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 2)); + val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 1)); +#endif + __syncwarp(); + return T(val); +} + +template +__device__ inline void apply_sigmoid_on_float(DataType *scores, int data_size, int lane_id) { + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + scores[i] = static_cast(1.0f / (1.0f + exp(-static_cast(scores[i])))); + } +} + +template +__device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int data_size, + T (*reduce_func)(T, T), int lane_id) { + // Some value is hanlded in local thread + // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... + // Reduce the value in local thread + volatile double val = lane_id < data_size && mask[lane_id] + ? static_cast(data_ptr[lane_id]) + : static_cast(0); + for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { + if (mask[i]) { + val = reduce_func(static_cast(val), data_ptr[i]); + } + } + + // Warp shuffle between threads +#ifdef __HIP_PLATFORM_AMD__ + val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 16, kThreadsPerWarp))); + val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 8, kThreadsPerWarp))); + val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 4, kThreadsPerWarp))); + val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 2, kThreadsPerWarp))); + val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 1, kThreadsPerWarp))); +#else + val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16)); + val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8)); + val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 4)); + val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 2)); + val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 1)); +#endif + __syncwarp(); + return T(val); +} + +template +__device__ inline void apply_sigmoid_bwd_on_float(DataType *grad, DataType *fwd_output, + int data_size, int lane_id) { + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + grad[i] = static_cast(grad[i]) * static_cast(fwd_output[i]) * + (1 - static_cast(fwd_output[i])); + } +} + +template +__device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_output, + DataType *comp_buf, bool *mask, int data_size, + int lane_id) { + // Put the result of output * grad to the comp_buf + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + if (mask) { + if (mask[i]) + comp_buf[i] = static_cast(grad[i]) * static_cast(fwd_output[i]); + else + comp_buf[i] = 0.0f; + } else { + comp_buf[i] = static_cast(grad[i]) * static_cast(fwd_output[i]); + } + } + __syncwarp(); + float sum_Output_x_Grad = warp_reduce_on_shmem( + /*data ptr = */ comp_buf, + /*data size = */ data_size, + /*reduce func = */ sum, lane_id); + // In-place update + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + if (mask) { + if (mask[i]) + grad[i] = + static_cast(fwd_output[i]) * (static_cast(grad[i]) - sum_Output_x_Grad); + else + grad[i] = 0.0f; + } else { + grad[i] = + static_cast(fwd_output[i]) * (static_cast(grad[i]) - sum_Output_x_Grad); + } + } +} + +template +__device__ inline void apply_softmax_on_float(DataType *scores, int data_size, int lane_id) { + // 1. compute the max of value + float max_val = static_cast(warp_reduce_on_shmem(scores, data_size, max, lane_id)); + // 2. value -> exp_value + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + scores[i] = static_cast(exp(static_cast(scores[i]) - max_val)); + } + __syncwarp(); + // 3. compute the sum of exp_value + float sum_val = static_cast(warp_reduce_on_shmem(scores, data_size, sum, lane_id)); + // 4. update the softmax value + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + scores[i] = static_cast(scores[i]) / sum_val; + } + __syncwarp(); +} + +template +__device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, int *topk_indices, + T *topk_scores, int lane_id) { + // Topk Times: Find the max value and its index + // Then mask it, and record the index in the topk_indices + // After looping topk times, the topk_indices will be the topk indices + for (int k = 0; k < topk; k++) { + // Find the max value and its index + volatile double val = + (lane_id < data_size) ? static_cast(scores[lane_id]) : static_cast(0); + volatile int index = (lane_id < data_size) ? lane_id : 0; + // Some value is hanlded in local thread + // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... + // Reduce the value in local thread + for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { + volatile double cur_val = scores[i]; + if (cur_val > val) { + val = cur_val; + index = i; + } + } + // Warp shuffle between threads + for (int s = 16; s > 0; s /= 2) { +#ifdef __HIP_PLATFORM_AMD__ + volatile auto shuffled_val = __shfl_xor(val, s, kThreadsPerWarp); + volatile auto shuffled_index = __shfl_xor(index, kThreadsPerWarp); +#else + volatile auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s); + volatile auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s); +#endif + if (shuffled_val > val) { + val = shuffled_val; + index = shuffled_index; + } + } + if (lane_id == 0) { + topk_indices[k] = index; + topk_scores[k] = val; + scores[index] = + static_cast(-1.0) - val; // make the selected experts using val = - 1 - val + } + __syncwarp(); + } + + // Reset the scores to the original value + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + scores[topk_indices[i]] = + static_cast(-1.0) - static_cast(scores[topk_indices[i]]); + } +} + +// Current TE only support float32/bf16/fp16, float64 probs should be considered in the future +#define TE_ROUTER_PROBS_TYPE_SWITCH_ALL(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + +#define TE_ROUTER_INDEX_TYPE_SWITCH_ALL(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kInt32: { \ + using type = int32_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt64: { \ + using type = int64_t; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } +} // namespace transformer_engine +#endif diff --git a/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu index 80348029a..a06c8493a 100644 --- a/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu @@ -557,8 +557,8 @@ void nvte_scaled_aligned_causal_masked_softmax_forward(const NVTETensor input, float scale_factor, cudaStream_t stream) { NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_forward); using namespace transformer_engine; - scaled_aligned_causal_masked_softmax_forward(*reinterpret_cast(input), - reinterpret_cast(softmax_results), + scaled_aligned_causal_masked_softmax_forward(*convertNVTETensorCheck(input), + convertNVTETensorCheck(softmax_results), scale_factor, stream); } @@ -569,6 +569,6 @@ void nvte_scaled_aligned_causal_masked_softmax_backward(const NVTETensor incomin NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_backward); using namespace transformer_engine; scaled_aligned_causal_masked_softmax_backward( - *reinterpret_cast(output_grads), *reinterpret_cast(incoming_grads), - *reinterpret_cast(softmax_results), scale_factor, stream); + *convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(incoming_grads), + *convertNVTETensorCheck(softmax_results), scale_factor, stream); } diff --git a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu index ae3baacae..5eca4947f 100644 --- a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu @@ -821,8 +821,8 @@ void nvte_scaled_softmax_forward(const NVTETensor input, NVTETensor softmax_resu float scale_factor, cudaStream_t stream) { NVTE_API_CALL(nvte_scaled_softmax_forward); using namespace transformer_engine; - scaled_softmax_forward(*reinterpret_cast(input), - reinterpret_cast(softmax_results), scale_factor, stream); + scaled_softmax_forward(*convertNVTETensorCheck(input), convertNVTETensorCheck(softmax_results), + scale_factor, stream); } void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETensor softmax_results, @@ -830,9 +830,9 @@ void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETen cudaStream_t stream) { NVTE_API_CALL(nvte_scaled_softmax_backward); using namespace transformer_engine; - scaled_softmax_backward(*reinterpret_cast(output_grads), - *reinterpret_cast(incoming_grads), - *reinterpret_cast(softmax_results), scale_factor, stream); + scaled_softmax_backward(*convertNVTETensorCheck(output_grads), + *convertNVTETensorCheck(incoming_grads), + *convertNVTETensorCheck(softmax_results), scale_factor, stream); } void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor mask, @@ -840,9 +840,8 @@ void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor cudaStream_t stream) { NVTE_API_CALL(nvte_scaled_masked_softmax_forward); using namespace transformer_engine; - scaled_masked_softmax_forward(*reinterpret_cast(input), - *reinterpret_cast(mask), - reinterpret_cast(softmax_results), scale_factor, stream); + scaled_masked_softmax_forward(*convertNVTETensorCheck(input), *convertNVTETensorCheck(mask), + convertNVTETensorCheck(softmax_results), scale_factor, stream); } void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads, @@ -850,7 +849,7 @@ void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads, float scale_factor, cudaStream_t stream) { NVTE_API_CALL(nvte_scaled_masked_softmax_backward); using namespace transformer_engine; - scaled_masked_softmax_backward( - *reinterpret_cast(output_grads), *reinterpret_cast(incoming_grads), - *reinterpret_cast(softmax_results), scale_factor, stream); + scaled_masked_softmax_backward(*convertNVTETensorCheck(output_grads), + *convertNVTETensorCheck(incoming_grads), + *convertNVTETensorCheck(softmax_results), scale_factor, stream); } diff --git a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu index 00efee873..af3027039 100644 --- a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu @@ -605,9 +605,9 @@ void nvte_scaled_upper_triang_masked_softmax_forward(const NVTETensor input, NVTETensor softmax_results, float scale_factor, cudaStream_t stream) { using namespace transformer_engine; - scaled_upper_triang_masked_softmax_forward(*reinterpret_cast(input), - reinterpret_cast(softmax_results), - scale_factor, stream); + scaled_upper_triang_masked_softmax_forward(*convertNVTETensorCheck(input), + convertNVTETensorCheck(softmax_results), scale_factor, + stream); } void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_grads, @@ -616,6 +616,6 @@ void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_ cudaStream_t stream) { using namespace transformer_engine; scaled_upper_triang_masked_softmax_backward( - *reinterpret_cast(output_grads), *reinterpret_cast(incoming_grads), - *reinterpret_cast(softmax_results), scale_factor, stream); + *convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(incoming_grads), + *convertNVTETensorCheck(softmax_results), scale_factor, stream); } diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 36cbcd330..cca994299 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -13,15 +13,16 @@ #endif // #ifndef __HIP_PLATFORM_AMD__ #include #include +#include #include #include #include #include "../common.h" -#include "../util/vectorized_pointwise.h" #include "../util/handle_manager.h" #include "../util/logging.h" +#include "../util/multi_stream.h" #include "common/util/cuda_runtime.h" #ifndef __HIP_PLATFORM_AMD__ @@ -94,7 +95,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla A.scaling_mode == B.scaling_mode || (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) || (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D), - "Inputs A and B to GEMM need to have compatible scaling modes!"); + "Inputs A and B to GEMM need to have compatible scaling modes, but got A.scaling_mode = " + + to_string(A.scaling_mode) + ", B.scaling_mode = " + to_string(B.scaling_mode)); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); GemmParam ret; @@ -226,6 +228,13 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla return ret; } +/* cuBLAS version number at run-time */ +size_t cublas_version() { + // Cache version to avoid cuBLAS logging overhead + static size_t version = cublasLtGetVersion(); + return version; +} + } // namespace #endif // __HIP_PLATFORM_AMD__ @@ -357,10 +366,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, &fastAccuMode, sizeof(fastAccuMode))); // Scaling factors. -#if CUDA_VERSION >= 12080 +#if CUBLAS_VERSION >= 120800 cublasLtMatmulMatrixScale_t scaling_mode_a; cublasLtMatmulMatrixScale_t scaling_mode_b; -#endif +#endif // CUBLAS_VERSION >= 120800 if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) { void *A_scale_inverse = param.A_scale_inv; void *B_scale_inverse = param.B_scale_inv; @@ -370,10 +379,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale_inverse, sizeof(B_scale_inverse))); -#if CUDA_VERSION >= 12080 +#if CUBLAS_VERSION >= 120800 scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; +#endif // CUBLAS_VERSION >= 120800 } else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) { +#if CUBLAS_VERSION >= 120800 + NVTE_CHECK(cublas_version() >= 120800, + "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); fp8e8m0 *A_scale_inverse = reinterpret_cast(param.A_scale_inv); fp8e8m0 *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -386,17 +399,24 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling. // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set. - if (cublasLtGetVersion() <= 120803) { + if (cublas_version() <= 120803) { const int64_t dummy_a_vec_stride = 1; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, sizeof(dummy_a_vec_stride))); } +#else + NVTE_ERROR("MXFP8 requires cuBLAS 12.8+, but compile-time cuBLAS version is ", + CUBLAS_VERSION); +#endif // CUBLAS_VERSION >= 120800 } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { -#if CUDA_VERSION >= 12090 +#if CUBLAS_VERSION >= 120900 + NVTE_CHECK(cublas_version() >= 120900, + "FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ", + cublas_version()); float *A_scale_inverse = reinterpret_cast(param.A_scale_inv); float *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -415,20 +435,24 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; #else - NVTE_ERROR("FP8 block scaling requires CUDA 12.9+"); -#endif // CUDA_VERSION >= 12090 -#endif // CUDA_VERSION >= 12080 + NVTE_ERROR("FP8 block scaling requires cuBLAS 12.9+, but compile-time cuBLAS version is ", + CUBLAS_VERSION); +#endif // CUBLAS_VERSION >= 120900 } else { NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " + to_string(inputB->scaling_mode) + "."); } -#if CUDA_VERSION >= 12080 - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode_a, sizeof(scaling_mode_a))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode_b, sizeof(scaling_mode_b))); -#endif +#if CUBLAS_VERSION >= 120800 + if (cublas_version() >= 120800) { + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_MODE, + &scaling_mode_a, sizeof(scaling_mode_a))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_MODE, + &scaling_mode_b, sizeof(scaling_mode_b))); + } +#endif // CUBLAS_VERSION >= 120800 if (is_fp8_dtype(outputD->data.dtype)) { // Accumulation mode not supported for FP8 output C = nullptr; @@ -436,13 +460,15 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax))); -#if CUDA_VERSION >= 12080 - // NOTE: In all current cases where FP8 output is supported, the input is - // scaled identically to the output. - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, - CUBLASLT_MATMUL_DESC_D_SCALE_MODE, - &scaling_mode_a, sizeof(scaling_mode_a))); -#endif +#if CUBLAS_VERSION >= 120800 + if (cublas_version() >= 120800) { + // NOTE: In all current cases where FP8 output is supported, the input is + // scaled identically to the output. + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_D_SCALE_MODE, + &scaling_mode_a, sizeof(scaling_mode_a))); + } +#endif // CUBLAS_VERSION >= 120800 // For FP8 output, cuBLAS requires C_type to match bias_type and // be FP16/BF16 const cudaDataType_t C_type = bias ? bias_type : CUDA_R_16BF; @@ -510,8 +536,24 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); -#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 if (counter != nullptr) { +#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000) + NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ", + CUDA_VERSION); +#endif +#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) + NVTE_ERROR( + "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ", + CUBLAS_VERSION); +#endif +#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \ + CUBLAS_VERSION < 130000 + NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, + "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA verson is ", + cuda::cudart_version()); + NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000, + "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ", + cublas_version()); if (m_split == 0) m_split = 1; if (n_split == 0) n_split = 1; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( @@ -529,8 +571,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER, &counter, sizeof(counter))); } - } #endif + } NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( @@ -539,6 +581,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const auto B_alignment = _getAlignment(reinterpret_cast(param.B)); const auto C_alignment = _getAlignment(reinterpret_cast(C)); const auto D_alignment = _getAlignment(reinterpret_cast(D)); + const auto workspace_alignment = _getAlignment(reinterpret_cast(workspace)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( @@ -547,6 +590,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment))); + NVTE_CHECK(workspace_alignment % 256 == 0, + "cuBLAS workspace pointer must be aligned to 256 bytes, got ", workspace_alignment); const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, @@ -585,18 +630,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } #endif // __HIP_PLATFORM_AMD__ -static std::once_flag init_flag; -static cudaStream_t compute_streams[num_streams]; -static cudaEvent_t cublas_event[num_streams]; - -// Warning: only call once per device! -static void init_streams_and_events() { - for (int i = 0; i < num_streams; i++) { - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, -1)); - NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event[i])); - } -} - } // namespace transformer_engine // compute_stream_offset = -1 means the stream from outer rather than compute_streams @@ -605,12 +638,12 @@ static void cublas_gemm_ex(const NVTETensor A, const NVTETensor B, NVTETensor D, NVTETensor workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream, int compute_stream_offset = -1) { using namespace transformer_engine; - const Tensor *inputA = reinterpret_cast(A); - const Tensor *inputB = reinterpret_cast(B); - Tensor *outputD = reinterpret_cast(D); - const Tensor *biasTensor = reinterpret_cast(bias); - Tensor *outputGelu = reinterpret_cast(pre_gelu_out); - Tensor *wspace = reinterpret_cast(workspace); + const Tensor *inputA = convertNVTETensorCheck(A); + const Tensor *inputB = convertNVTETensorCheck(B); + Tensor *outputD = convertNVTETensorCheck(D); + const Tensor *biasTensor = convertNVTETensorCheck(bias); + Tensor *outputGelu = convertNVTETensorCheck(pre_gelu_out); + Tensor *wspace = convertNVTETensorCheck(workspace); cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, #ifdef __HIP_PLATFORM_AMD__ @@ -632,12 +665,12 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons int math_sm_count, cudaStream_t stream) { NVTE_API_CALL(nvte_cublas_gemm); using namespace transformer_engine; - const Tensor *inputA = reinterpret_cast(A); - const Tensor *inputB = reinterpret_cast(B); - Tensor *outputD = reinterpret_cast(D); - const Tensor *biasTensor = reinterpret_cast(bias); - Tensor *outputGelu = reinterpret_cast(pre_gelu_out); - Tensor *wspace = reinterpret_cast(workspace); + const Tensor *inputA = convertNVTETensorCheck(A); + const Tensor *inputB = convertNVTETensorCheck(B); + Tensor *outputD = convertNVTETensor(D); + const Tensor *biasTensor = convertNVTETensor(bias); + Tensor *outputGelu = convertNVTETensor(pre_gelu_out); + Tensor *wspace = convertNVTETensor(workspace); cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, #ifdef __HIP_PLATFORM_AMD__ @@ -657,22 +690,34 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor int n_split, bool gemm_producer, const NVTETensor counter, cudaStream_t stream) { NVTE_API_CALL(nvte_cublas_atomic_gemm); + using namespace transformer_engine; #ifndef __HIP_PLATFORM_AMD__ - int cudart_version; - NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version)); - NVTE_CHECK(cudart_version >= 12020, "Cuda version 12.2 is required for atomic gemm."); - NVTE_CHECK(cublasLtGetVersion() >= 120205, "Cublas version 12.2.5 is required for atomic gemm."); + // Check CUDA and cuBLAS versions +#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000) + NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ", + CUDA_VERSION); #endif +#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) + NVTE_ERROR("Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ", + CUBLAS_VERSION); +#endif + NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, + "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA verson is ", + cuda::cudart_version()); + NVTE_CHECK( + cublas_version() >= 120205 && cublas_version() < 130000, + "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ", + cublas_version()); +#endif //__HIP_PLATFORM_AMD__ - using namespace transformer_engine; - const Tensor *inputA = reinterpret_cast(A); - const Tensor *inputB = reinterpret_cast(B); - Tensor *outputD = reinterpret_cast(D); - const Tensor *biasTensor = reinterpret_cast(bias); - Tensor *outputGelu = reinterpret_cast(pre_gelu_out); - const Tensor *inputCounter = reinterpret_cast(counter); - Tensor *wspace = reinterpret_cast(workspace); + const Tensor *inputA = convertNVTETensorCheck(A); + const Tensor *inputB = convertNVTETensorCheck(B); + Tensor *outputD = convertNVTETensor(D); + const Tensor *biasTensor = convertNVTETensor(bias); + Tensor *outputGelu = convertNVTETensor(pre_gelu_out); + const Tensor *inputCounter = convertNVTETensor(counter); + Tensor *wspace = convertNVTETensor(workspace); NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && is_delayed_tensor_scaling(inputB->scaling_mode), @@ -696,29 +741,31 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT cudaStream_t stream) { NVTE_API_CALL(nvte_multi_stream_cublas_gemm); using namespace transformer_engine; - // Inits streams and events (once, globally) - std::call_once(init_flag, init_streams_and_events); + + int num_streams = nvte_get_num_compute_streams(); int num_stream_used = std::min(num_streams, num_gemms); // wait for current stream to finish - NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[0], stream)); + NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream)); for (int s = 0; s < num_stream_used; s++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0])); + NVTE_CHECK_CUDA( + cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0))); } for (int i = 0; i < num_gemms; i++) { cublas_gemm_ex(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, - compute_streams[i % num_streams], i % num_streams); + detail::get_compute_stream(i % num_streams), i % num_streams); } // record events on compute streams for (int s = 0; s < num_stream_used; s++) { - NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[s], compute_streams[s])); + NVTE_CHECK_CUDA( + cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s))); } // wait for all compute streams to finish for (int s = 0; s < num_stream_used; s++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s))); } } diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 9de4cfad7..50710ee1b 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -5,6 +5,7 @@ ************************************************************************/ #include #include +#include #include #include #include @@ -189,7 +190,7 @@ static hipDataType get_hipblaslt_dtype(const transformer_engine::DType t) { } } -//TODO: unified with cublaslt_gemm.cu +//TODO: merge duplicated logics with cublaslt_gemm.cu struct GemmParam { void *A = nullptr; void *B = nullptr; @@ -933,7 +934,7 @@ static inline int getIntEnv(const char *name, int defval, int minval) */ static void init_hipblaslt_handles(hipblasLtHandle_t* hipblaslt_handles) { NVTE_CHECK(hipblaslt_handles != nullptr); - for (int i = 0; i < num_streams; i++) { + for (int i = 0; i < nvte_get_num_compute_streams(); i++) { NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&hipblaslt_handles[i])); } } @@ -1550,14 +1551,15 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, bool use_service_stream = (math_sm_count != 0) ? get_service_stream(math_sm_count, stream, ss_ctl) : false; + int num_streams = nvte_get_num_compute_streams(); NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < num_streams); hipblasLtHandle_t handle = nullptr; if (compute_stream_offset != -1) { // Init hipblaslt handles (once, globally) static std::once_flag init_flag; - static hipblasLtHandle_t hipblaslt_handles[num_streams]; - std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles); + static std::vector hipblaslt_handles(num_streams); + std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles.data()); handle = hipblaslt_handles[compute_stream_offset]; } diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 64136b2c4..a3235e84f 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -259,6 +259,17 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp */ void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Casts multiple input tensors to quantized output tensors. + * + * \param[in] inputs List of input tensors to be cast. + * \param[in,out] outputs List of output quantized tensors. + * \param[in] quant_config (Optional) Quantization configurations. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, + const NVTEQuantizationConfig quant_config, const size_t num_tensors, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h index 678ffe919..649b5ced5 100644 --- a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h +++ b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h @@ -17,23 +17,21 @@ extern "C" { #endif -/*! \brief Transposes the input, providing the option to immediately exit the kernel - * based on the value of the 'noop' tensor. +/*! \brief Transposes the input. * - * \param[in] input Input tensor. - * \param[in] noop Noop tensor. + * \param[in] input Input tensor to be cast. + * \param[in] noop If this single element tensor has non-zero value, kernel will exit immediately. * \param[in,out] output Output tensor. * \param[in] stream CUDA stream used for the operation. */ void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, cudaStream_t stream); -/*! \brief Casts and transposes the input, providing the option to immediately exit the kernel - * based on the value of the 'noop' tensor. +/*! \brief Casts and transposes the input. * - * \param[in] input Input tensor. - * \param[in] noop Noop tensor. - * \param[in,out] output Output tensor. + * \param[in] input Input tensor to be cast. + * \param[in] noop If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] output Output quantized tensor. * \param[in] stream CUDA stream used for the operation. */ void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 3400eaaeb..851032e04 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -185,6 +185,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); /*! \brief Get fused attention backend based on input parameters. * + * \param[in] is_training Whether the model is in training mode. * \param[in] q_dtype The data type of Tensor Q. * \param[in] kv_dtype The data type of Tensors K, V. * \param[in] qkv_layout The layout of Tensors Q, K, V. @@ -201,10 +202,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] window_size_right Sliding window size (the right half). */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right); + bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); /*! \brief Compute dot product attention with packed QKV input. * @@ -632,8 +633,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream); -#ifndef __HIP_PLATFORM_AMD__ /*! \brief Update the RNG state with the seed and calculated offset. + * + * \warning This API is **experimental** and subject to change. * * \param[in] rng_state_dst RNG state to store seed and offset. * \param[in] seed Seed for RNG state. @@ -647,25 +649,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor seed, size_t q_max_seqlen, size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend, cudaStream_t stream); -#else -/*! \brief Update the RNG state with the seed and calculated offset. - * - * \param[in] rng_state_dst RNG state to store seed and offset. - * \param[in] seed Seed for RNG state. - * \param[in] batch_size Batch size. - * \param[in] num_heads # of attention heads. - * \param[in] q_max_seqlen Max sequence length used for computing for Q. - * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. - * \param[in] kv_max_seqlen Max sequence length used for computing for K and V. - * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. - * \param[in] stream CUDA stream used for this operation. - */ -void nvte_populate_rng_state_async(void *rng_state_dst, const void *const seed, - size_t batch_size, size_t num_heads, size_t q_max_seqlen, - size_t kv_max_seqlen, cudaStream_t stream); -#endif /*! \brief Get KV format for a given QKV layout. + * + * \warning This API is **experimental** and subject to change. * * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. * \param[in] workspace Workspace tensor. @@ -675,48 +662,187 @@ void nvte_populate_rng_state_async(void *rng_state_dst, const void *const seed, uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len, cudaStream_t stream); +/*! \brief Set the seed and offset for RNG state. + * + * \warning This API is **experimental** and subject to change. + * + * \param[out] rng_state_ptr A size 2 array storing the RNG's seed and offset respectively. + * \param[in] captured Whether a CUDA graph is being captured. + * \param[in] seed_ptr Seed pointer. + * \param[in] seed_val Seed value. + * \param[in] offset_ptr Offset pointer. + * \param[in] offset_val Offset value. + * \param[in] offset_intragraph Intragraph offset in RNG states. For use with CUDA Graphs. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t *seed_ptr, uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val, uint32_t offset_intragraph, cudaStream_t stream); +/*! \brief Copy keys and values into the KV cache. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] new_k Key tensor. + * \param[in] new_v Value tensor. + * \param[out] k_cache Key cache. + * \param[out] v_cache Value cache. + * \param[in] page_table Page table for K cache, [batch_size, max_pages_per_seq]. + * \param[in] cu_new_lens Cumulative sequence lengths. + * \param[in] cu_cached_lens Cached cumulative sequence lengths. + * \param[in] qkv_format QKV format, e.g. sbhd. + * \param[in] b Batch size. + * \param[in] max_ctx_len Maximum context length. + * \param[in] max_seq_len Maximum sequence length. + * \param[in] max_pages_per_seq Maximum number of pages per sequence. + * \param[in] is_non_paged Whether the cache is paged or not. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cache, NVTETensor v_cache, NVTETensor page_table, NVTETensor cu_new_lens, NVTETensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b, int max_ctx_len, int max_seq_len, int max_pages_per_seq, int is_non_paged, cudaStream_t stream); +/*! \brief Extract the first half (half_idx=0) or second half (half_idx=1) of a THD tensor. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] tensor Input tensor. + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[out] half Output tensor. + * \param[in] half_idx Whether to read first or second half of input tensor. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu_seqlens, NVTETensor half, int half_idx, cudaStream_t stream); +/*! \brief Correct the second half of the softmax LSE (LogSumExp) for context parallelism. + * + * \warning This API is **experimental** and subject to change. + * + * \param[out] lse Output tensor. + * \param[in] lse_per_step Input tensor. + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[in] lse_packed Whether or not lse_per_step is packed. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step, const NVTETensor &cu_seqlens, int lse_packed, cudaStream_t stream); +/*! \brief Read the second half of the softmax LSE (LogSumExp) for context parallelism. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] lse Input tensor. + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[out] half_lse Output tensor. + * \param[in] lse_packed Whether or the softmax LSE is in packed format. + * \param[in] second_half_lse_seqlen Sequence length. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens, NVTETensor half_lse, int lse_packed, int second_half_lse_seqlen, cudaStream_t stream); +/*! \brief Correct the THD format output of context parallelism in forward pass. + * + * \warning This API is **experimental** and subject to change. + * + * \param[out] out Output tensor. + * \param[in] out_per_step THD format output of context parallelism in forward pass. + * \param[in] lse Softmax LSE. + * \param[in] lse_per_step Softmax LSE per step. + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[in] only_second_half Whether or not to correct only second half. + * \param[in] lse_packed Whether or the softmax LSE is in packed format. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step, const NVTETensor &lse, const NVTETensor &lse_per_step, const NVTETensor &cu_seqlens, int only_second_half, int lse_packed, cudaStream_t stream); +/*! \brief Correct the THD format output of context parallelism in forward pass. + * + * \warning This API is **experimental** and subject to change. + * + * \param[out] grad Output tensor. + * \param[in] grad_per_step THD format gradient of context parallelism. + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[in] first_half One of ("add", "copy", "none") correction op for first half. + * \param[in] second_half One of ("add", "copy", "none") correction op for second half. + Must be different from first_half. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step, const NVTETensor &cu_seqlens, const char *first_half, const char *second_half, cudaStream_t stream); +/*! \brief Generate partitioned indices for inputs in THD format. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[out] output Output tensor. + * \param[in] total_tokens Total number of tokens. + * \param[in] world_size Total number of devices for context parallelism. + * \param[in] rank Device ID for current device. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output, int total_tokens, int world_size, int rank, cudaStream_t stream); +/*! \brief Convert tensor from THD to BSHD format. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] tensor Input tensor. + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[out] new_tensor Output tensor. + * \param[in] b Batch size. + * \param[in] max_seq_len Maximum sequence length. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor, int b, int max_seq_len, cudaStream_t stream); +/*! \brief Convert tensor from BSHD to THD format. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] tensor Input tensor. + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[out] new_tensor Output tensor. + * \param[in] b Batch size. + * \param[in] max_seq_len Maximum sequence length. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor, int t, cudaStream_t stream); +/*! \brief Prepare QKV tensor for Flash Attention forward kernel. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] qkvi Input tensor. + * \param[out] qkv Output tensor. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream); +/*! \brief Prepare QKV tensor for Flash Attention backward kernel. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] q Input query tensor. + * \param[in] k Input key tensor. + * \param[in] v Input value tensor. + * \param[out] qkv Output tensor. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv, cudaStream_t stream); diff --git a/transformer_engine/common/include/transformer_engine/fused_router.h b/transformer_engine/common/include/transformer_engine/fused_router.h new file mode 100644 index 000000000..8cf4b222a --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/fused_router.h @@ -0,0 +1,132 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_H_ +#define TRANSFORMER_ENGINE_FUSED_ROUTER_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Apply topk + softmax/sigmoid to the input tensor. Grouped topk is supported. + * + * \param[in] logits Logits from the gating GEMM. + * \param[in] num_tokens Number of tokens. + * \param[in] num_experts Number of experts. + * \param[in] topk Topk value. + * \param[in] use_pre_softmax Whether to use softmax before topk. + * \param[in] num_groups Number of groups in grouped topk. + * \param[in] group_topk Grouped topk value. + * \param[in] scaling_factor Scaling factor. + * \param[in] score_function Score function, 0: sigmoid, 1: softmax. + * \param[in] expert_bias Expert bias. (Only used at the sigmoid case) + * \param[out] probs Output tensor for probabilities. + * \param[out] routing_map Output tensor for routing map. + * \param[out] intermediate_output Output tensor for intermediate output. (Softmax/sigmoid output) + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_topk_with_score_function_forward( + const NVTETensor logits, int num_tokens, int num_experts, int topk, int use_pre_softmax, + int num_groups, int group_topk, float scaling_factor, int score_function, + const NVTETensor expert_bias, NVTETensor probs, NVTETensor routing_map, + NVTETensor intermediate_output, cudaStream_t stream); + +/*! \brief Backward pass for fused topk + softmax/sigmoid. + * + * \param[in] routing_map Routing map. + * \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output) + * \param[in] grad_probs Gradient of probs. + * \param[in] num_tokens Number of tokens. + * \param[in] num_experts Number of experts. + * \param[in] topk Topk value. + * \param[in] use_pre_softmax Whether to use softmax before topk. + * \param[in] scaling_factor Scaling factor. + * \param[in] score_function Score function, 0: sigmoid, 1: softmax. + * \param[out] grad_logits Gradient of logits. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_topk_with_score_function_backward(const NVTETensor routing_map, + const NVTETensor intermediate_output, + const NVTETensor grad_probs, int num_tokens, + int num_experts, int topk, int use_pre_softmax, + float scaling_factor, int score_function, + NVTETensor grad_logits, cudaStream_t stream); + +/*! \brief Forward pass for computing scores/routing map for auxiliary loss. + * + * \param[in] logits Logits from the gating GEMM. + * \param[in] num_tokens Number of tokens. + * \param[in] num_experts Number of experts. + * \param[in] topk Topk value. + * \param[in] score_function Score function, 0: sigmoid, 1: softmax. + * \param[out] scores Output tensor for scores. + * \param[in] routing_map Routing map. + * \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output) + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_score_for_moe_aux_loss_forward(const NVTETensor logits, int num_tokens, + int num_experts, int topk, int score_function, + NVTETensor scores, const NVTETensor routing_map, + const NVTETensor intermediate_output, + cudaStream_t stream); + +/*! \brief Backward pass for computing scores/routing map for auxiliary loss. + * + * \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output) + * \param[in] grad_scores Gradient of scores. + * \param[in] num_tokens Number of tokens. + * \param[in] num_experts Number of experts. + * \param[in] topk Topk value. + * \param[in] score_function Score function, 0: sigmoid, 1: softmax. + * \param[out] grad_logits Gradient of logits. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_score_for_moe_aux_loss_backward(const NVTETensor intermediate_output, + const NVTETensor grad_scores, int num_tokens, + int num_experts, int topk, int score_function, + NVTETensor grad_logits, cudaStream_t stream); + +/*! \brief Forward pass for auxiliary loss. + * + * \param[in] probs Probabilities from the forward pass. + * \param[in] tokens_per_expert Number of tokens per expert. + * \param[in] total_num_tokens Number of total tokens. Will be used in seq/global aux loss. + * \param[in] num_experts Number of experts. + * \param[in] num_rows Number of rows of probs. + * \param[in] num_cols Number of columns of probs. + * \param[in] topk Topk value. + * \param[in] coeff Coefficient. + * \param[out] aux_loss Output GPU scalar for auxiliary loss. + * \param[out] Const_buf Output GPU scalar for temporary constant buffer for backward pass. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert, + int total_num_tokens, int num_experts, int num_rows, + int num_cols, int topk, float coeff, NVTETensor aux_loss, + NVTETensor Const_buf, cudaStream_t stream); + +/*! \brief Backward pass for auxiliary loss. + * + * \param[in] Const_buf Constant buffer from the forward pass. + * \param[in] tokens_per_expert Number of tokens per expert. + * \param[in] num_rows Number of rows of probs. + * \param[in] num_cols Number of columns of probs. + * \param[in] grad_aux_loss Gradient of auxiliary loss. + * \param[out] grad_probs Gradient of probs. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_moe_aux_loss_backward(const NVTETensor Const_buf, + const NVTETensor tokens_per_expert, int num_rows, + int num_cols, NVTETensor grad_aux_loss, NVTETensor grad_probs, + cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_FUSED_ROPE_H_ diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index c463f6b9d..a68070308 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -119,8 +119,6 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor* A, const NVTETensor* B, NVT */ namespace transformer_engine { -constexpr int num_streams = 4; - /*! \brief TE/JAX cudaGraph requires the cuBLAS initialization to happen outside of the capturing * region. This function is a helper to call cublasCreate() which allocate memory for the handle. * The function will be called in the initialize phase of the related XLA custom calls. diff --git a/transformer_engine/common/include/transformer_engine/multi_stream.h b/transformer_engine/common/include/transformer_engine/multi_stream.h new file mode 100644 index 000000000..e406a0786 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/multi_stream.h @@ -0,0 +1,47 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file multi_stream.h + * \brief Functions for multi streams executions. + */ + +#ifndef TRANSFORMER_ENGINE_MULTI_STREAM_H +#define TRANSFORMER_ENGINE_MULTI_STREAM_H + +#include "cuda_runtime.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Number of CUDA streams to use in multi-stream operations */ +int nvte_get_num_compute_streams(); + +/*! \brief Get a CUDA stream for compute operations. + * + * \param[in] idx Index of the stream to retrieve.Add commentMore actions + * \return A cudaStream_t. + * + * This function returns a CUDA stream that can be used for compute operations. + * The index should be in the range [0, nvte_get_num_compute_streams() - 1]. + */ +cudaStream_t nvte_get_compute_stream(const int idx); + +/*! \brief Get a CUDA event for compute operations. + * + * \param[in] idx Index of the event to retrieve. + * \return A cudaEvent_t. + * + * This function returns a CUDA event that can be used to synchronize compute operations. + * The index should be in the range [0, nvte_get_num_compute_streams() - 1]. + */ +cudaEvent_t nvte_get_compute_stream_event(const int idx); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_MULTI_STREAM_H diff --git a/transformer_engine/common/include/transformer_engine/multi_tensor.h b/transformer_engine/common/include/transformer_engine/multi_tensor.h index e78b31d77..c21fd2627 100644 --- a/transformer_engine/common/include/transformer_engine/multi_tensor.h +++ b/transformer_engine/common/include/transformer_engine/multi_tensor.h @@ -17,6 +17,25 @@ extern "C" { #endif +/*! \brief Computes L2 norm for a list of tensors. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] output Scratch space. Required size grows with number of inputs. + * \param[in] output_per_tensor Fixed size auxilliary scratch space. + * \param[out] ret L2 norm of all inputs. + * \param[out] ret_per_tensor L2 norm for each tensor. + * \param[in] per_tensor Whether to calculate per tensor or cumulative norm. + * \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret, @@ -24,6 +43,28 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen int max_chunks_per_tensor, const int device_id, cudaStream_t stream); +/*! \brief Computes L2 norm for a list of tensors after unscaling. + * + * Unscaling is only done for computing the L2 norm. The tensors themselves are not updated. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] output Scratch space. Required size grows with number of inputs. + * \param[in] output_per_tensor Fixed size auxilliary scratch space. + * \param[out] ret L2 norm of all inputs. + * \param[out] ret_per_tensor L2 norm for each tensor. + * \param[in] inv_scale Scalar for the unscaling operation. + * \param[in] per_tensor Whether to calculate per tensor or cumulative norm. + * \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor output, @@ -32,6 +73,27 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, int per_tensor, int max_chunks_per_tensor, const int device_id, cudaStream_t stream); +/*! \brief Compute and apply gradient update to parameters for Adam optimizer. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] lr Learning rate. + * \param[in] beta1 Coefficient for first moment of gradient. + * \param[in] beta2 Coefficient for second moment of gradient. + * \param[in] epsilon Term added to the denominator for numerical stability. + * \param[in] step Iteration counter. + * \param[in] mode Whether to use AdamW (L2 penalty applied to params). + * \param[in] bias_correction Whether to apply correction factor for moment estimates. + * \param[in] weight_decay L2 penalty for weight decay. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2, @@ -39,12 +101,57 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso const int bias_correction, const float weight_decay, const int device_id, cudaStream_t stream); +/*! \brief Compute and apply gradient update to parameters for Adam optimizer + * where the master parameters only store the remainder bits. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] lr Learning rate. + * \param[in] beta1 Coefficient for first moment of gradient. + * \param[in] beta2 Coefficient for second moment of gradient. + * \param[in] epsilon Term added to the denominator for numerical stability. + * \param[in] step Iteration counter. + * \param[in] mode Whether to use AdamW (L2 penalty applied to params). + * \param[in] bias_correction Whether to apply correction factor for moment estimates. + * \param[in] weight_decay L2 penalty for weight decay. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_adam_param_remainder_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, const int device_id, cudaStream_t stream); +/*! \brief Compute and apply gradient update to parameters for Adam optimizer + * when model parameters are in Float8 precision. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] lr Learning rate. + * \param[in] beta1 Coefficient for first moment of gradient. + * \param[in] beta2 Coefficient for second moment of gradient. + * \param[in] epsilon Term added to the denominator for numerical stability. + * \param[in] step Iteration counter. + * \param[in] mode Whether to use AdamW (L2 penalty applied to params). + * \param[in] bias_correction Whether to apply correction factor for moment estimates. + * \param[in] weight_decay L2 penalty for weight decay. + * \param[in] fp8_dtype FP8 data type for model parameters. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, const float lr, @@ -53,28 +160,125 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, const float weight_decay, const NVTEDType fp8_dtype, const int device_id, cudaStream_t stream); +/*! \brief Compute and apply gradient update to parameters for Adam optimizer + * with CUDA graph support and LR scheduling. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] lr Learning rate. + * \param[in] beta1 Coefficient for first moment of gradient. + * \param[in] beta2 Coefficient for second moment of gradient. + * \param[in] epsilon Term added to the denominator for numerical stability. + * \param[in] step Iteration counter. + * \param[in] mode Whether to use AdamW (L2 penalty applied to params). + * \param[in] bias_correction Whether to apply correction factor for moment estimates. + * \param[in] weight_decay L2 penalty for weight decay. + * \param[in] inv_scale Scalar for the unscaling operation. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_adam_capturable_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const float epsilon, NVTETensor step, const int mode, const int bias_correction, const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream); +/*! \brief Compute and apply gradient update to parameters for Adam optimizer + * with CUDA graph support, LR scheduling, and FP32 master weights. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] lr Learning rate. + * \param[in] beta1 Coefficient for first moment of gradient. + * \param[in] beta2 Coefficient for second moment of gradient. + * \param[in] epsilon Term added to the denominator for numerical stability. + * \param[in] step Iteration counter. + * \param[in] mode Whether to use AdamW (L2 penalty applied to params). + * \param[in] bias_correction Whether to apply correction factor for moment estimates. + * \param[in] weight_decay L2 penalty for weight decay. + * \param[in] inv_scale Scalar for the unscaling operation. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_adam_capturable_master_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const float epsilon, NVTETensor step, const int mode, const int bias_correction, const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream); +/*! \brief Compute and apply gradient update to parameters for SGD optimizer. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] wd Weight decay (L2 penalty). + * \param[in] momentum Momentum factor. + * \param[in] dampening Dampening factor. + * \param[in] lr Learning rate. + * \param[in] nesterov Whether or not to enable nesterov momentum. + * \param[in] first_run Whether momentum buffers have been initialized. + * \param[in] wd_after_momentum Whether to applied weight decay after momentum update. + * \param[in] scale Scalar for the scaling operation. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, float wd, float momentum, float dampening, float lr, int nesterov, int first_run, int wd_after_momentum, float scale, const int device_id, cudaStream_t stream); +/*! \brief Check overflow and scale a list of tensors. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] scale Scalar for the scaling operation. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, float scale, const int device_id, cudaStream_t stream); +/*! \brief Check overflow and scale a list of tensors. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] max_fp8 Maximum representible value in underlying FP8 format. + * \param[in] force_pow_2_scales Ensure scaling factors are a power of 2. + * \param[in] epsilon Term added to the denominator for numerical stability. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon, diff --git a/transformer_engine/common/include/transformer_engine/padding.h b/transformer_engine/common/include/transformer_engine/padding.h index 4258463b1..0783fc2b2 100644 --- a/transformer_engine/common/include/transformer_engine/padding.h +++ b/transformer_engine/common/include/transformer_engine/padding.h @@ -44,6 +44,33 @@ extern "C" { void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, const int* padded_num_rows_list, cudaStream_t stream); +/*! \brief Unpadding multiple tensors (reverse operation of padding). + * + * NOTE: Unpadding mode only removes bottom rows. + * + * For example, 4x3 matrix unpad to 3x3 matrix. + * + * source + * | 1 | 2 | 3 | + * | 4 | 5 | 6 | + * | 7 | 8 | 9 | + * | 0 | 0 | 0 | + * + * destination + * | 1 | 2 | 3 | + * | 4 | 5 | 6 | + * | 7 | 8 | 9 | + * + * \param[in] num_tensors Number of tensors. + * \param[in] input_list List of 2D padded input tensors. + * \param[in,out] output_list List of unpadded tensors. Dimensions + * match original unpadded tensors. + * \param[in] unpadded_num_rows_list List of unpadded num rows corresponding to input tensors. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_multi_unpadding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, + const int* unpadded_num_rows_list, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 25d9a471f..70f90fa76 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -22,17 +24,24 @@ extern "C" { * \brief TE datatype. */ enum NVTEDType { - kNVTEByte = 0, /*!< Byte */ - kNVTEInt16 = 1, /*!< 16-bit integer */ - kNVTEInt32 = 2, /*!< 32-bit integer */ - kNVTEInt64 = 3, /*!< 64-bit integer */ - kNVTEFloat32 = 4, /*!< 32-bit float */ - kNVTEFloat16 = 5, /*!< 16-bit float (E5M10) */ - kNVTEBFloat16 = 6, /*!< 16-bit bfloat (E8M7) */ - kNVTEFloat8E4M3 = 7, /*!< 8-bit float (E4M3) */ - kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */ - kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */ - kNVTENumTypes /*!< Number of supported types */ + kNVTEByte = 0, /*!< Byte */ + kNVTEInt16 = 1, /*!< 16-bit integer */ + kNVTEInt32 = 2, /*!< 32-bit integer */ + kNVTEInt64 = 3, /*!< 64-bit integer */ + kNVTEFloat32 = 4, /*!< 32-bit float */ + kNVTEFloat16 = 5, /*!< 16-bit float (E5M10) */ + kNVTEBFloat16 = 6, /*!< 16-bit bfloat (E8M7) */ + kNVTEFloat8E4M3 = 7, /*!< 8-bit float (E4M3) */ + kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */ + kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */ +#ifndef __HIP_PLATFORM_AMD__ + kNVTEFloat4E2M1 = 10, /*!< 4-bit float (E2M1) */ + kNVTENumTypes /*!< Number of supported types */ +#else + //switch the order since rocm platform does not support e2m1 + kNVTENumTypes = 10, /*!< Number of supported types */ + kNVTEFloat4E2M1 = 11 /*!< 4-bit float (E2M1) */ +#endif // #ifndef __HIP_PLATFORM_AMD__ }; /*! \struct NVTEShape @@ -87,6 +96,10 @@ enum NVTEScalingMode { */ NVTE_BLOCK_SCALING_1D = 2, NVTE_BLOCK_SCALING_2D = 3, + /*! Single NVFP4 scale per block of 16 contiguous elements in forward pass (FWD), + and single MXFP8 scale per block of 32 contiguous elements in backward pass (BWD). + */ + NVTE_FWD_NVFP4_BWD_MXFP8_SCALING = 4, NVTE_INVALID_SCALING = 100 }; @@ -177,6 +190,14 @@ size_t nvte_tensor_ndims(const NVTETensor tensor); */ size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim); +/*! \brief Get the byte size for the tensor. + * + * \param[in] tensor Tensor. + * + * \return Byte size of the tensor. + */ +size_t nvte_tensor_size_bytes(const NVTETensor tensor); + /*! \brief Get a tensor's total number of elements. * * \param[in] tensor Tensor. @@ -193,6 +214,14 @@ size_t nvte_tensor_numel(const NVTETensor tensor); */ size_t nvte_tensor_element_size(const NVTETensor tensor); +/*! \brief Get the bit size for the tensor's data type. + * + * \param[in] tensor Tensor. + * + * \return Bit size of the tensor's data type. + */ +size_t nvte_tensor_element_size_bits(const NVTETensor tensor); + /*! \brief Get a tensor's data type. * * \param[in] tensor Tensor. @@ -302,6 +331,13 @@ enum NVTEQuantizationConfigAttribute { conditional early even when captured in a static CUDA graph. */ kNVTEQuantizationConfigNoopTensor = 2, + /*! Data format for an FP8 block-scaled tensor + * + * This is not the right design since the tensor format is a + * property of the tensor, not the quantization. This enum will + * likely be refactored away in the future. + */ + kNVTEQuantizationConfigFloat8BlockScaleTensorFormat = 3, kNVTEQuantizationConfigNumAttributes }; @@ -383,7 +419,13 @@ enum class DType { kFloat8E4M3 = 7, kFloat8E5M2 = 8, kFloat8E8M0 = 9, +#ifndef __HIP_PLATFORM_AMD__ + kFloat4E2M1 = 10, kNumTypes +#else + kNumTypes = 10, + kFloat4E2M1 +#endif // #ifndef __HIP_PLATFORM_AMD__ }; /*! \brief Check if TE datatype is FP8 @@ -391,7 +433,21 @@ enum class DType { * Return true if TE datatype is FP8 * \param[in] DType TE Datatype of interest */ -bool is_fp8_dtype(const DType t); +inline bool is_fp8_dtype(const DType t) { + return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2; +} + +#ifndef __HIP_PLATFORM_AMD__ +/*! \brief Check if TE datatype is FP4 + * + * Return true if TE datatype is FP4 + * \param[in] DType TE Datatype of interest + */ +inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; } +#else +//TODO: fp4 types not supported on AMD GPUs +inline bool is_fp4_dtype(const DType t) { return false; } +#endif // #ifndef __HIP_PLATFORM_AMD__ /*! \struct TensorWrapper * \brief C++ wrapper for the NVTETensor class. @@ -620,6 +676,15 @@ class TensorWrapper { return nvte_tensor_element_size(tensor_); } + /*! \brief Get the tensor's element size in bits. + * + * \return Element size in bits. + */ + size_t element_size_bits() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_tensor_element_size_bits(tensor_); + } + /*! \brief Get the tensor's allocated size in bytes. This will return 0 for tensors with nullptr * data even if the TensorWrapper has a non-zero shape and valid dtype. * @@ -627,7 +692,7 @@ class TensorWrapper { */ size_t bytes() const noexcept { if (tensor_ == nullptr || this->dptr() == nullptr) return 0; - return nvte_tensor_numel(tensor_) * nvte_tensor_element_size(tensor_); + return nvte_tensor_size_bytes(tensor_); } /*! \brief Get the data type of this TensorWrapper. @@ -721,6 +786,16 @@ class TensorWrapper { NVTETensor tensor_ = nullptr; }; +/*! \enum Float8BlockScaleTensorFormat + * \brief Data format for an FP8 block-scaled tensor + */ +enum class Float8BlockScaleTensorFormat { + /*! FP8 data is transposed if needed and scales are swizzled */ + GEMM_READY = 0, + /*! FP8 data is untransposed and scales are not swizzled or padded */ + COMPACT = 1 +}; + /*! \struct QuantizationConfigWrapper * \brief C++ wrapper for NVTEQuantizationConfigWrapper. */ @@ -774,6 +849,13 @@ class QuantizationConfigWrapper { sizeof(NVTETensor)); } + /*! \brief Set FP8 block-scaled tensor format */ + void set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat format) { + nvte_set_quantization_config_attribute(config_, + kNVTEQuantizationConfigFloat8BlockScaleTensorFormat, + &format, sizeof(Float8BlockScaleTensorFormat)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/multi_tensor/adam.cu b/transformer_engine/common/multi_tensor/adam.cu index c14742bbe..72f150512 100644 --- a/transformer_engine/common/multi_tensor/adam.cu +++ b/transformer_engine/common/multi_tensor/adam.cu @@ -232,8 +232,8 @@ struct AdamFunctorMasterParamRemainder { r_m[ii] = static_cast(m[i]); r_v[ii] = static_cast(v[i]); - local_p[ii] = static_cast(p[i]); - local_p_rem[ii] = static_cast(p_remainder[i]); + local_p[ii] = p[i]; + local_p_rem[ii] = p_remainder[i]; } else { r_g[ii] = MATH_T(0); r_m[ii] = MATH_T(0); @@ -287,8 +287,8 @@ struct AdamFunctorMasterParamRemainder { for (int ii = 0; ii < ILP; ii++) { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { - p_remainder[i] = static_cast(local_p_rem[ii]); - p[i] = static_cast(local_p[ii]); + p_remainder[i] = local_p_rem[ii]; + p[i] = local_p[ii]; m[i] = static_cast(r_m[ii]); v[i] = static_cast(r_v[ii]); @@ -473,8 +473,8 @@ struct AdamCapturableFunctor { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { p[i] = static_cast(r_p[ii]); - m[i] = static_cast(r_m[ii]); - v[i] = static_cast(r_v[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); } } } @@ -584,9 +584,6 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, const int device_id, cudaStream_t stream) { - const size_t num_tensor_lists = tensor_lists.size(); - const size_t num_tensors_per_list = tensor_lists[0].size(); - // Handle bias correction mode float bias_correction1 = 1.0f, bias_correction2 = 1.0f; if (bias_correction == 1) { @@ -594,16 +591,48 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, bias_correction2 = 1 - std::pow(beta2, step); } - size_t max_size = 0; + // Check tensor list sizes + // 4 tensor lists: g, p, m, v + // 5 tensor lists: g, p, m, v, p_master + const size_t num_tensor_lists = tensor_lists.size(); + NVTE_CHECK(num_tensor_lists == 4 || num_tensor_lists == 5, + "Expected 4 or 5 tensor lists, but found ", num_tensor_lists); + const size_t num_tensors_per_list = tensor_lists[0].size(); + for (size_t i = 1; i < num_tensor_lists; i++) { + NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i, + " has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list); + } + + // Check tensor dtypes + const auto g_in_type_te = tensor_lists[0][0]->dtype(); + const auto p_in_type_te = tensor_lists[1][0]->dtype(); + for (size_t j = 0; j < num_tensors_per_list; j++) { + NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j, + " has dtype=", to_string(tensor_lists[0][j]->dtype()), + ", but expected dtype=", to_string(g_in_type_te)); + NVTE_CHECK(tensor_lists[1][j]->dtype() == p_in_type_te, "Param tensor ", j, + " has dtype=", to_string(tensor_lists[1][j]->dtype()), + ", but expected dtype=", to_string(p_in_type_te)); + NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j, + " has dtype=", to_string(tensor_lists[2][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j, + " has dtype=", to_string(tensor_lists[3][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + if (num_tensor_lists == 5) { + NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j, + " has dtype=", to_string(tensor_lists[4][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + } + } + + // Check if 64-bit indices are required bool requires_64bit_indexing = false; for (size_t i = 0; i < num_tensor_lists; i++) { for (size_t j = 0; j < num_tensors_per_list; j++) { - if (tensor_lists[i][j]->numel() > max_size) { - max_size = tensor_lists[i][j]->numel(); - if (max_size >= INT_MAX) { - requires_64bit_indexing = true; - break; - } + if (tensor_lists[i][j]->numel() >= INT_MAX) { + requires_64bit_indexing = true; + break; } } if (requires_64bit_indexing) { @@ -611,16 +640,10 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, } } - const auto g_in_type_te = tensor_lists[0][0]->dtype(); - const auto p_in_type_te = tensor_lists[1][0]->dtype(); - - // case 4: g, p, m, v - // case 5: g, p, m, v, p_master - NVTE_CHECK(num_tensor_lists == 4 || num_tensor_lists == 5, "tensor list must contain 4 or 5"); - + // Launch kernel if (requires_64bit_indexing) { if (num_tensor_lists == 4) { - // Assume single type across p,g,m1,m2 now + // g, p, m, v TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( p_in_type_te, p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( @@ -644,7 +667,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, } } else { if (num_tensor_lists == 4) { - // Assume single type across p,g,m1,m2 now + // g, p, m, v TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( p_in_type_te, p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( @@ -654,6 +677,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)); } else { + // g, p, m, v, p_master TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( p_in_type_te, p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( @@ -674,8 +698,6 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, const int device_id, cudaStream_t stream) { - const size_t num_tensor_lists = tensor_lists.size(); - // Handle bias correction mode float bias_correction1 = 1.0f, bias_correction2 = 1.0f; if (bias_correction == 1) { @@ -683,23 +705,43 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, bias_correction2 = 1 - std::pow(beta2, step); } - const auto g_in_type_te = tensor_lists[0][0]->dtype(); - const auto p_in_type_te = tensor_lists[1][0]->dtype(); - - // case 5: g, p, m, v, p_master - NVTE_CHECK(num_tensor_lists == 5, "tensor list must contain 5"); - NVTE_CHECK(p_in_type_te == DType::kBFloat16, - "Adam with BF16 param remainders requires BF16 params"); + // Check tensor list sizes + // 5 tensor lists: g, p, m, v, p_remainder + const size_t num_tensor_lists = tensor_lists.size(); + NVTE_CHECK(num_tensor_lists == 5, "Expected 5 tensor lists, but found ", num_tensor_lists); + const size_t num_tensors_per_list = tensor_lists[0].size(); + for (size_t i = 1; i < num_tensor_lists; i++) { + NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i, + " has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list); + } - // g, p, m, v, p_master + // Check tensor dtypes + const auto g_in_type_te = tensor_lists[0][0]->dtype(); + for (size_t j = 0; j < num_tensors_per_list; j++) { + NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j, + " has dtype=", to_string(tensor_lists[0][j]->dtype()), + ", but expected dtype=", to_string(g_in_type_te)); + NVTE_CHECK(tensor_lists[1][j]->dtype() == DType::kBFloat16, "Param tensor ", j, + " has dtype=", to_string(tensor_lists[1][j]->dtype()), + ", but expected dtype=", to_string(DType::kBFloat16)); + NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j, + " has dtype=", to_string(tensor_lists[2][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j, + " has dtype=", to_string(tensor_lists[3][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kInt16, "Param remainder tensor ", j, + " has dtype=", to_string(tensor_lists[4][j]->dtype()), + ", but expected dtype=", to_string(DType::kInt16)); + } + // Launch kernel TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, AdamFunctorMasterParamRemainder(), device_id, stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);); - NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -709,9 +751,6 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, const int step, const int mode, const int bias_correction, const float weight_decay, const DType fp8_dtype, const int device_id, cudaStream_t stream) { - const size_t num_tensor_lists = tensor_lists.size(); - const size_t num_tensors_per_list = tensor_lists[0].size(); - // Handle bias correction mode float bias_correction1 = 1.0f, bias_correction2 = 1.0f; if (bias_correction == 1) { @@ -719,16 +758,53 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, bias_correction2 = 1 - std::pow(beta2, step); } - size_t max_size = 0; + // Check tensor list sizes + // 8 tensor lists: g, p_fp8, m, v, p_master, scale, amax, scale_inv + const size_t num_tensor_lists = tensor_lists.size(); + NVTE_CHECK(num_tensor_lists == 8, "Expected 8 tensor lists, but found ", num_tensor_lists); + const size_t num_tensors_per_list = tensor_lists[0].size(); + for (size_t i = 1; i < num_tensor_lists; i++) { + NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i, + " has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list); + } + + // Check tensor dtypes + const auto g_in_type_te = tensor_lists[0][0]->dtype(); + for (size_t j = 0; j < num_tensors_per_list; j++) { + NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j, + " has dtype=", to_string(tensor_lists[0][j]->dtype()), + ", but expected dtype=", to_string(g_in_type_te)); + NVTE_CHECK( + tensor_lists[1][j]->dtype() == fp8_dtype || tensor_lists[1][j]->dtype() == DType::kByte, + "Param tensor ", j, " has dtype=", to_string(tensor_lists[1][j]->dtype()), + ", but expected dtype=", to_string(fp8_dtype)); + NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j, + " has dtype=", to_string(tensor_lists[2][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j, + " has dtype=", to_string(tensor_lists[3][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j, + " has dtype=", to_string(tensor_lists[4][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[5][j]->dtype() == DType::kFloat32, "Scale tensor ", j, + " has dtype=", to_string(tensor_lists[5][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[6][j]->dtype() == DType::kFloat32, "Absmax tensor ", j, + " has dtype=", to_string(tensor_lists[6][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[7][j]->dtype() == DType::kFloat32, "Scale-inverse tensor ", j, + " has dtype=", to_string(tensor_lists[7][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + } + + // Check if 64-bit indices are required bool requires_64bit_indexing = false; for (size_t i = 0; i < num_tensor_lists; i++) { for (size_t j = 0; j < num_tensors_per_list; j++) { - if (tensor_lists[i][j]->numel() > max_size) { - max_size = tensor_lists[i][j]->numel(); - if (max_size >= INT_MAX) { - requires_64bit_indexing = true; - break; - } + if (tensor_lists[i][j]->numel() >= INT_MAX) { + requires_64bit_indexing = true; + break; } } if (requires_64bit_indexing) { @@ -736,11 +812,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, } } - const auto g_in_type_te = tensor_lists[0][0]->dtype(); - - // case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv - NVTE_CHECK(num_tensor_lists == 8, "tensor list must contain 8 tensors"); - + // Launch kernel if (requires_64bit_indexing) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( fp8_dtype, FP8_T, @@ -771,6 +843,34 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag, Tensor step, const int mode, const int bias_correction, const float weight_decay, Tensor inv_scale, const int device_id, cudaStream_t stream) { + // Check tensor list sizes + // 4 tensor lists: g, p, m, v + const size_t num_tensor_lists = tensor_lists.size(); + NVTE_CHECK(num_tensor_lists == 4, "Expected 4 tensor lists, but found ", num_tensor_lists); + const size_t num_tensors_per_list = tensor_lists[0].size(); + for (size_t i = 1; i < num_tensor_lists; i++) { + NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i, + " has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list); + } + + // Check tensor dtypes + const auto g_in_type_te = tensor_lists[0][0]->dtype(); + for (size_t j = 0; j < num_tensors_per_list; j++) { + NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j, + " has dtype=", to_string(tensor_lists[0][j]->dtype()), + ", but expected dtype=", to_string(g_in_type_te)); + NVTE_CHECK(tensor_lists[1][j]->dtype() == g_in_type_te, "Param tensor ", j, + " has dtype=", to_string(tensor_lists[1][j]->dtype()), + ", but expected dtype=", to_string(g_in_type_te)); + NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j, + " has dtype=", to_string(tensor_lists[2][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j, + " has dtype=", to_string(tensor_lists[3][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + } + + // Launch kernel TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), dtype, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, @@ -789,6 +889,37 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag, const int bias_correction, const float weight_decay, Tensor inv_scale, const int device_id, cudaStream_t stream) { + // Check tensor list sizes + // 4 tensor lists: g, p, m, v, p_master + const size_t num_tensor_lists = tensor_lists.size(); + NVTE_CHECK(num_tensor_lists == 5, "Expected 4 tensor lists, but found ", num_tensor_lists); + const size_t num_tensors_per_list = tensor_lists[0].size(); + for (size_t i = 1; i < num_tensor_lists; i++) { + NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i, + " has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list); + } + + // Check tensor dtypes + const auto g_in_type_te = tensor_lists[0][0]->dtype(); + for (size_t j = 0; j < num_tensors_per_list; j++) { + NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j, + " has dtype=", to_string(tensor_lists[0][j]->dtype()), + ", but expected dtype=", to_string(g_in_type_te)); + NVTE_CHECK(tensor_lists[1][j]->dtype() == g_in_type_te, "Param tensor ", j, + " has dtype=", to_string(tensor_lists[1][j]->dtype()), + ", but expected dtype=", to_string(g_in_type_te)); + NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j, + " has dtype=", to_string(tensor_lists[2][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j, + " has dtype=", to_string(tensor_lists[3][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j, + " has dtype=", to_string(tensor_lists[4][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + } + + // Launch kernel TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), dtype, multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, @@ -813,7 +944,7 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso using namespace transformer_engine; multi_tensor_adam::multi_tensor_adam_cuda( - chunk_size, *reinterpret_cast(noop_flag), + chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, epsilon, step, mode, bias_correction, weight_decay, device_id, stream); } @@ -827,7 +958,7 @@ void nvte_multi_tensor_adam_param_remainder_cuda( using namespace transformer_engine; multi_tensor_adam::multi_tensor_adam_param_remainder_cuda( - chunk_size, *reinterpret_cast(noop_flag), + chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, epsilon, step, mode, bias_correction, weight_decay, device_id, stream); } @@ -843,7 +974,7 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, using namespace transformer_engine; multi_tensor_adam::multi_tensor_adam_fp8_cuda( - chunk_size, *reinterpret_cast(noop_flag), + chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, epsilon, step, mode, bias_correction, weight_decay, static_cast(fp8_dtype), device_id, stream); @@ -858,11 +989,10 @@ void nvte_multi_tensor_adam_capturable_cuda( using namespace transformer_engine; multi_tensor_adam::multi_tensor_adam_capturable_cuda( - chunk_size, *reinterpret_cast(noop_flag), + chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), - *reinterpret_cast(lr), beta1, beta2, epsilon, *reinterpret_cast(step), - mode, bias_correction, weight_decay, *reinterpret_cast(inv_scale), device_id, - stream); + *convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode, + bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream); } void nvte_multi_tensor_adam_capturable_master_cuda( @@ -874,9 +1004,8 @@ void nvte_multi_tensor_adam_capturable_master_cuda( using namespace transformer_engine; multi_tensor_adam::multi_tensor_adam_capturable_master_cuda( - chunk_size, *reinterpret_cast(noop_flag), + chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), - *reinterpret_cast(lr), beta1, beta2, epsilon, *reinterpret_cast(step), - mode, bias_correction, weight_decay, *reinterpret_cast(inv_scale), device_id, - stream); + *convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode, + bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream); } diff --git a/transformer_engine/common/multi_tensor/compute_scale.cu b/transformer_engine/common/multi_tensor/compute_scale.cu index b27d5cdd0..ebdcfbb56 100644 --- a/transformer_engine/common/multi_tensor/compute_scale.cu +++ b/transformer_engine/common/multi_tensor/compute_scale.cu @@ -77,7 +77,7 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda( using namespace transformer_engine; multi_tensor_compute_scale::multi_tensor_compute_scale_and_scale_inv_cuda( - chunk_size, *reinterpret_cast(noop_flag), + chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8, force_pow_2_scales, epsilon, device_id, stream); } diff --git a/transformer_engine/common/multi_tensor/l2norm.cu b/transformer_engine/common/multi_tensor/l2norm.cu index 54d28be98..e80b04f97 100644 --- a/transformer_engine/common/multi_tensor/l2norm.cu +++ b/transformer_engine/common/multi_tensor/l2norm.cu @@ -470,10 +470,10 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen using namespace transformer_engine; multi_tensor_l2norm::multi_tensor_l2norm_cuda( - chunk_size, *reinterpret_cast(noop_flag), + chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), - *reinterpret_cast(output), *reinterpret_cast(output_per_tensor), - *reinterpret_cast(ret), *reinterpret_cast(ret_per_tensor), per_tensor, + *convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor), + *convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor), per_tensor, max_chunks_per_tensor, device_id, stream); } @@ -488,9 +488,9 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, using namespace transformer_engine; multi_tensor_l2norm::multi_tensor_unscale_l2norm_cuda( - chunk_size, *reinterpret_cast(noop_flag), + chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), - *reinterpret_cast(output), *reinterpret_cast(output_per_tensor), - *reinterpret_cast(ret), *reinterpret_cast(ret_per_tensor), - *reinterpret_cast(inv_scale), per_tensor, max_chunks_per_tensor, device_id, stream); + *convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor), + *convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor), + *convertNVTETensorCheck(inv_scale), per_tensor, max_chunks_per_tensor, device_id, stream); } diff --git a/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh b/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh index 77e436936..4727f3964 100644 --- a/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh +++ b/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh @@ -52,7 +52,7 @@ class OptionalCUDAGuard { ~OptionalCUDAGuard() { if (device_changed_) { - NVTE_CHECK_CUDA(cudaSetDevice(prev_device_)); + cudaSetDevice(prev_device_); } } diff --git a/transformer_engine/common/multi_tensor/scale.cu b/transformer_engine/common/multi_tensor/scale.cu index 170565831..66a173bdb 100644 --- a/transformer_engine/common/multi_tensor/scale.cu +++ b/transformer_engine/common/multi_tensor/scale.cu @@ -124,7 +124,7 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens using namespace transformer_engine; multi_tensor_scale::multi_tensor_scale_cuda( - chunk_size, *reinterpret_cast(noop_flag), + chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, device_id, stream); } diff --git a/transformer_engine/common/multi_tensor/sgd.cu b/transformer_engine/common/multi_tensor/sgd.cu index 08482e99a..05106e46d 100644 --- a/transformer_engine/common/multi_tensor/sgd.cu +++ b/transformer_engine/common/multi_tensor/sgd.cu @@ -196,7 +196,7 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor using namespace transformer_engine; multi_tensor_sgd::multi_tensor_sgd_cuda( - chunk_size, *reinterpret_cast(noop_flag), + chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), wd, momentum, dampening, lr, nesterov, first_run, wd_after_momentum, scale, device_id, stream); } diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 3be7d5004..6189be7b3 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -53,20 +53,24 @@ cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) { TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma, - bool is_tuned, NVTEScalingMode mode, bool training) { - // TODO: Add scaling_mode to general_key is needed - uint64_t general_key = static_cast(itype) | (static_cast(otype) << 3) | - (static_cast(ctype) << 6) | (static_cast(wtype) << 9) | - (uint32_t(NormType) << 12) | (uint32_t(NormStage)) << 14 | - (uint32_t(NormBackend) << 16) | (uint32_t(zero_centered_gamma) << 18) | - (uint32_t(mode) << 19) | (uint32_t(training) << 22); + bool is_tuned, NVTEScalingMode mode, bool training, + bool gamma_in_weight_dtype) { + static_assert(NVTE_INVALID_SCALING < 1024, + "This function assumes at most 10 bits used in the scaling mode."); + static_assert(kNVTENumTypes < 32, "This function assumes at most 5 bits used in the NVTEDType"); + uint64_t general_key = static_cast(itype) | (static_cast(otype) << 5) | + (static_cast(ctype) << 10) | + (static_cast(wtype) << 15) | (uint64_t(NormType) << 20) | + (uint64_t(NormStage)) << 22 | (uint64_t(NormBackend) << 24) | + (uint64_t(zero_centered_gamma) << 26) | (uint64_t(mode) << 27) | + (uint64_t(training) << 37) | (uint64_t(gamma_in_weight_dtype) << 38); return std::make_tuple(general_key, batch_size, hidden_size, is_tuned); } template TeNormalizationPlan::TeNormalizationPlan( NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, - DType ctype, const size_t batch_size, const size_t hidden_size, const size_t sm_count, + DType ctype, const size_t batch_size, const size_t hidden_size, const size_t sm_count, const bool zero_centered_gamma, const bool is_tuned #ifdef __HIP_PLATFORM_AMD__ , const NVTEScalingMode mode, const bool training @@ -245,8 +249,11 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor } const auto gamma_dtype = use_zero_centered_gamma_in_weight_dtype() ? wtype : ctype; + NVTE_CHECK(gamma_dtype == DType::kFloat32 || gamma_dtype == DType::kFloat16 || + gamma_dtype == DType::kBFloat16, + "Gamma of type FP4 is not supported"); - _scalar_dptr = std::make_unique(typeToSize(gamma_dtype)); + _scalar_dptr = std::make_unique(typeToNumBits(gamma_dtype) / 8); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( gamma_dtype, cpp_dtype, *(reinterpret_cast(_scalar_dptr.get())) = (cpp_dtype)1.0f;); @@ -502,11 +509,12 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan( NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, const size_t batch_size, const size_t hidden_size, const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned, - const NVTEScalingMode mode, const bool training) { + const NVTEScalingMode mode, const bool training, const bool gamma_in_weight_dtype) { const DType ctype = DType::kFloat32; bool is_tuned = is_aligned && (batch_size % 4 == 0); - auto key = get_key(NormBackend, NormType, NormStage, wtype, itype, otype, ctype, batch_size, - hidden_size, zero_centered_gamma, is_tuned, mode, training); + auto key = + get_key(NormBackend, NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, + zero_centered_gamma, is_tuned, mode, training, gamma_in_weight_dtype); auto it = normalizationPlanMap.find(key); if (it != normalizationPlanMap.end()) { @@ -578,6 +586,7 @@ void nvte_enable_cudnn_norm_bwd(bool enable) { transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable; } +// Only for testing, not thread-safe void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable) { NVTE_API_CALL(nvte_enable_zero_centered_gamma_in_weight_dtype); transformer_engine::normalization::_zero_centered_gamma_in_weight_dtype() = enable; diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index fb525c9db..31d2c0b74 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -46,7 +46,6 @@ struct LaunchParams { size_t workspace_bytes = 0; size_t barrier_bytes = 0; size_t dgamma_part_bytes = 0; - int multiprocessorCount; cudaStream_t stream; @@ -196,7 +195,7 @@ TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma, bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, - bool training = true); + bool training = true, bool gamma_in_weight_dtype = false); template class TeNormalizationRegistry { @@ -350,7 +349,8 @@ class NormalizationPlanRegistry { NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, const size_t batch_size, const size_t hidden_size, const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned, - const NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, const bool training = true); + const NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, const bool training = true, + const bool gamma_in_weight_dtype = false); private: NormalizationPlanRegistry() {} @@ -430,6 +430,8 @@ bool is_ptr_aligned(const Args*... ptrs) { #ifndef __HIP_PLATFORM_AMD__ bool use_cudnn_norm_fwd(); bool use_cudnn_norm_bwd(); + +bool& use_zero_centered_gamma_in_weight_dtype(); #endif #ifdef __HIP_PLATFORM_AMD__ diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index f660ca5b7..8c4cdbcc0 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -17,6 +17,7 @@ #include "../../common.h" #include "../common.h" +#include "transformer_engine/transformer_engine.h" namespace transformer_engine { @@ -66,12 +67,16 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size bool is_aligned = true; #ifndef __HIP_PLATFORM_AMD__ bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); +#endif //__HIP_PLATFORM_AMD__ + bool gamma_in_weight_dtype = false; +#ifndef __HIP_PLATFORM_AMD__ if (cudnn_backend) { // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; + gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); } else -#endif +#endif //__HIP_PLATFORM_AMD__ { norm_backend = NVTE_Norm_Backend::Te; is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, beta.data.dptr, @@ -88,7 +93,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size z->data.dtype, // otype x.data.shape[0], // batch_size x.data.shape[1], // hidden_size - multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training); + multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training, + gamma_in_weight_dtype); if (workspace->data.shape.empty()) { workspace->data.shape = plan->getWorkspaceShape(); @@ -106,11 +112,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size // Compute FP8 transpose if required if (z->has_columnwise_data() && is_tensor_scaling(z->scaling_mode)) { - Tensor transpose_data; - transpose_data.data = z->columnwise_data; - transpose_data.scaling_mode = z->scaling_mode; - nvte_transpose(reinterpret_cast(z), reinterpret_cast(&transpose_data), - stream); + NVTETensor transpose_data = nvte_create_tensor(z->scaling_mode); + Tensor& t = *convertNVTETensor(transpose_data); + t.data = z->columnwise_data; + nvte_transpose(static_cast(*z), transpose_data, stream); + nvte_destroy_tensor(transpose_data); } return; @@ -155,10 +161,12 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te NVTE_Norm_Backend norm_backend; bool is_aligned = true; + bool gamma_in_weight_dtype = false; #ifndef __HIP_PLATFORM_AMD__ if (use_cudnn_norm_bwd()) { // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; + gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); } else #endif { @@ -173,7 +181,8 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te gamma.data.dtype, // otype x.data.shape[0], // batch_size x.data.shape[1], // hidden_size - multiprocessorCount, zero_centered_gamma, is_aligned); + multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true, + gamma_in_weight_dtype); if (workspace->data.shape.empty()) { workspace->data.shape = plan->getWorkspaceShape(); @@ -196,11 +205,10 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size const bool zero_centered_gamma, cudaStream_t stream) { NVTE_API_CALL(nvte_layernorm_fwd); using namespace transformer_engine; - layernorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), - *reinterpret_cast(beta), epsilon, reinterpret_cast(z), - reinterpret_cast(mu), reinterpret_cast(rsigma), - reinterpret_cast(workspace), multiprocessorCount, zero_centered_gamma, - stream); + layernorm_fwd(*convertNVTETensorCheck(x), *convertNVTETensorCheck(gamma), + *convertNVTETensorCheck(beta), epsilon, convertNVTETensor(z), convertNVTETensor(mu), + convertNVTETensor(rsigma), convertNVTETensor(workspace), multiprocessorCount, + zero_centered_gamma, stream); } void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size @@ -213,10 +221,9 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size cudaStream_t stream) { NVTE_API_CALL(nvte_layernorm_bwd); using namespace transformer_engine; - layernorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), - *reinterpret_cast(mu), *reinterpret_cast(rsigma), - *reinterpret_cast(gamma), reinterpret_cast(dx), - reinterpret_cast(dgamma), reinterpret_cast(dbeta), - reinterpret_cast(workspace), multiprocessorCount, zero_centered_gamma, - stream); + layernorm_bwd(*convertNVTETensorCheck(dz), *convertNVTETensorCheck(x), + *convertNVTETensorCheck(mu), *convertNVTETensorCheck(rsigma), + *convertNVTETensorCheck(gamma), convertNVTETensor(dx), convertNVTETensor(dgamma), + convertNVTETensor(dbeta), convertNVTETensor(workspace), multiprocessorCount, + zero_centered_gamma, stream); } diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh index a13976e6f..b68e79cd9 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh @@ -130,7 +130,6 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_tuned_kernel( compute_t dy_tmp = dy[it * NUM_ELTS + jt]; compute_t y_tmp = y[it * NUM_ELTS + jt]; compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local)); - // HIP-TODO: dx[it].data.elt[jt] = Converter::convert(dx_tmp); dx[it].data.elt[jt] = dx_tmp; } dx[it].store_to(params.dx, idx); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index eabed2bd5..d084e5c06 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -15,6 +15,7 @@ #include "../../common.h" #include "../common.h" #include "transformer_engine/normalization.h" +#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transpose.h" namespace transformer_engine { @@ -57,10 +58,14 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens bool training = is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; -#ifndef __HIP_PLATFORM_AMD__ +#ifdef __HIP_PLATFORM_AMD__ + constexpr bool gamma_in_weight_dtype = false; +#else + bool gamma_in_weight_dtype = false; if (cudnn_backend) { // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; + gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); } else #endif { @@ -75,7 +80,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens z->data.dtype, // otype x.data.shape[0], // batch_size x.data.shape[1], // hidden_size - multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training); + multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training, + gamma_in_weight_dtype); if (workspace->data.shape.empty()) { workspace->data.shape = plan->getWorkspaceShape(); @@ -93,11 +99,12 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens // Compute FP8 transpose if required if (z->has_columnwise_data() && is_tensor_scaling(z->scaling_mode)) { - Tensor transpose_data; - transpose_data.data = z->columnwise_data; - transpose_data.scaling_mode = z->scaling_mode; - nvte_transpose(reinterpret_cast(z), reinterpret_cast(&transpose_data), - stream); + NVTETensor transpose_data = nvte_create_tensor(z->scaling_mode); + auto *t = convertNVTETensor(transpose_data); + t->data = z->columnwise_data; + + nvte_transpose(static_cast(*z), transpose_data, stream); + nvte_destroy_tensor(transpose_data); } return; @@ -133,10 +140,12 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const NVTE_Norm_Backend norm_backend; bool is_aligned = true; + bool gamma_in_weight_dtype = false; #ifndef __HIP_PLATFORM_AMD__ if (use_cudnn_norm_bwd()) { // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; + gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); } else #endif { @@ -151,7 +160,8 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const gamma.data.dtype, // otype x.data.shape[0], // batch_size x.data.shape[1], // hidden_size - multiprocessorCount, zero_centered_gamma, is_aligned); + multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true, + gamma_in_weight_dtype); if (workspace->data.shape.empty()) { workspace->data.shape = plan->getWorkspaceShape(); @@ -174,10 +184,9 @@ void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size cudaStream_t stream) { NVTE_API_CALL(nvte_rmsnorm_fwd); using namespace transformer_engine; - rmsnorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), - epsilon, reinterpret_cast(z), reinterpret_cast(rsigma), - reinterpret_cast(workspace), multiprocessorCount, zero_centered_gamma, - stream); + rmsnorm_fwd(*convertNVTETensorCheck(x), *convertNVTETensorCheck(gamma), epsilon, + convertNVTETensor(z), convertNVTETensor(rsigma), convertNVTETensor(workspace), + multiprocessorCount, zero_centered_gamma, stream); } void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size @@ -189,9 +198,8 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size cudaStream_t stream) { NVTE_API_CALL(nvte_rmsnorm_bwd); using namespace transformer_engine; - rmsnorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), - *reinterpret_cast(rsigma), *reinterpret_cast(gamma), - reinterpret_cast(dx), reinterpret_cast(dgamma), - reinterpret_cast(workspace), multiprocessorCount, zero_centered_gamma, - stream); + rmsnorm_bwd(*convertNVTETensorCheck(dz), *convertNVTETensorCheck(x), + *convertNVTETensorCheck(rsigma), *convertNVTETensorCheck(gamma), + convertNVTETensor(dx), convertNVTETensor(dgamma), convertNVTETensor(workspace), + multiprocessorCount, zero_centered_gamma, stream); } diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index 8c038b24a..5e4f2d0f1 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -342,22 +342,16 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad, const NVTETensor input_fwd, const int num_rows, const int topK, const int num_cols, const int num_out_tokens, cudaStream_t stream) { + using namespace transformer_engine; NVTE_API_CALL(nvte_permute); - const transformer_engine::Tensor *input_cu = - reinterpret_cast(input); - const transformer_engine::Tensor *output_cu = - reinterpret_cast(output); - const transformer_engine::Tensor *sorted_row_id_cu = - reinterpret_cast(sorted_row_id); - const transformer_engine::Tensor *row_id_map_cu = - reinterpret_cast(row_id_map); - const transformer_engine::Tensor *prob_cu = - reinterpret_cast(prob); - const transformer_engine::Tensor *prob_grad_cu = - reinterpret_cast(prob_grad); - const transformer_engine::Tensor *input_fwd_cu = - reinterpret_cast(input_fwd); + const Tensor *input_cu = convertNVTETensorCheck(input); + const Tensor *output_cu = convertNVTETensorCheck(output); + const Tensor *sorted_row_id_cu = convertNVTETensorCheck(sorted_row_id); + const Tensor *row_id_map_cu = convertNVTETensorCheck(row_id_map); + const Tensor *prob_cu = convertNVTETensorCheck(prob); + const Tensor *prob_grad_cu = convertNVTETensorCheck(prob_grad); + const Tensor *input_fwd_cu = convertNVTETensorCheck(input_fwd); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( input_cu->data.dtype, T, @@ -374,16 +368,13 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, const NVTETensor prob, const int num_rows, const int topK, const int num_cols, cudaStream_t stream) { + using namespace transformer_engine; NVTE_API_CALL(nvte_unpermute); - const transformer_engine::Tensor *input_cu = - reinterpret_cast(input); - const transformer_engine::Tensor *output_cu = - reinterpret_cast(output); - const transformer_engine::Tensor *row_id_map_cu = - reinterpret_cast(row_id_map); - const transformer_engine::Tensor *prob_cu = - reinterpret_cast(prob); + const Tensor *input_cu = convertNVTETensorCheck(input); + const Tensor *output_cu = convertNVTETensorCheck(output); + const Tensor *row_id_map_cu = convertNVTETensorCheck(row_id_map); + const Tensor *prob_cu = convertNVTETensorCheck(prob); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( input_cu->data.dtype, T, diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 9426d1621..466c2e605 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -196,6 +196,7 @@ def __post_init__(self) -> None: def __repr__(self) -> str: return ( + f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " f"amax_history_len={self.amax_history_len}, " @@ -208,42 +209,12 @@ def __repr__(self) -> str: class Float8CurrentScaling(Recipe): """ Use the per-tensor current scaling factor strategy. + Parameters ---------- fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID Controls the FP8 data format used during forward and backward pass. - fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} - used for quantization of input tensor x - fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} - used for quantization of weight tensor w - fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} - used for quantization of gradient tensor dY - fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False - used for calculating output y in forward pass - fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True - use for calculating dgrad in backward pass - fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True - use for calculating dgrad in backward pass - fp8_dpa: bool, default = `False` - Whether to enable FP8 dot product attention (DPA). When the model is placed in an - `fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the - inputs from higher precision to FP8, performs attention in FP8, and casts tensors - back to higher precision as outputs. FP8 DPA currently is only supported in the - `FusedAttention` backend. - fp8_mha: bool, default = `False` - Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting - operations mentioned above at the DPA boundaries. Currently only standard MHA modules - i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When - `fp8_mha = False, fp8_dpa = True`, a typical MHA module works as - `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. - When `fp8_mha = True, fp8_dpa = True`, it becomes - `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. - - Notes - ----- - * `fp8_dpa` and `fp8_mha` are Beta features, and their API and functionality are - subject to change in future Transformer Engine releases. """ fp8_format: Format = Format.HYBRID @@ -258,9 +229,13 @@ class Float8CurrentScaling(Recipe): def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert ( + not self.fp8_dpa and not self.fp8_mha + ), "FP8 attention is not supported for Float8CurrentScaling." def __repr__(self) -> str: return ( + f"recipe_type={self.__class__.__name__}, " f"format={str(self.fp8_format).split('.')[1]}, " f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, " @@ -307,7 +282,11 @@ def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." def __repr__(self) -> str: - return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]}," + return ( + f"recipe_type={self.__class__.__name__}, " + f"margin={self.margin}, " + f"format={str(self.fp8_format).split('.')[1]}" + ) @dataclass() @@ -329,32 +308,12 @@ class Float8BlockScaling(Recipe): NOTE: To relax the default constraint that scales be powers of 2, set env variable NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it for the recipe defaults. - export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 - Or initialize the Recipe with non-default QParams in code for increased control. Parameters ---------- fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. - fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} - used for quantization of input tensor x - fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} - used for quantization of weight tensor w - fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} - used for quantization of gradient tensor dY - x_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) - qblock scaling for x. - w_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) - qblock scaling for w. - grad_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) - qblock scaling for grad. - fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False - used for calculating output y in forward pass - fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True - use for calculating dgrad in backward pass - fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True - use for calculating dgrad in backward pass """ use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" @@ -388,9 +347,13 @@ def __post_init__(self) -> None: assert self.fp8_gemm_fprop.use_split_accumulator, "Split accumulator required for fprop." assert self.fp8_gemm_dgrad.use_split_accumulator, "Split accumulator required for dgrad." assert self.fp8_gemm_wgrad.use_split_accumulator, "Split accumulator required for wgrad." + assert ( + not self.fp8_dpa and not self.fp8_mha + ), "FP8 attention is not supported for Float8BlockScaling." def __repr__(self) -> str: return ( + f"recipe_type={self.__class__.__name__}, " f"format={str(self.fp8_format).split('.')[1]}, " f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, " diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index 709ab200f..2d43d47db 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -120,7 +120,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt // Check input tensor NVTE_CHECK(input_ != nullptr, "Invalid input tensor (got NULL)"); - const auto &input = *reinterpret_cast(input_); + const auto &input = *convertNVTETensorCheck(input_); NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "Input tensor for amax computation must unquantized, " "but got scaling_mode=", @@ -133,7 +133,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt // Check output tensor NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)"); - auto &output = *reinterpret_cast(output_); + auto &output = *convertNVTETensorCheck(output_); NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "Output tensor for amax computation must be FP8 tensor with per-tensor scaling, " "but got scaling_mode=", @@ -178,7 +178,7 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf // Check output tensor NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)"); - auto &output = *reinterpret_cast(output_); + auto &output = *convertNVTETensorCheck(output_); NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "Tensor must be FP8 tensor with per-tensor scaling, " "but got scaling_mode=", diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 4872bbffd..9bdfc0a5c 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -405,9 +405,9 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update( NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update); using namespace transformer_engine; delayed_scaling_recipe::amax_and_scale_update( - *reinterpret_cast(amax_history), *reinterpret_cast(scale), - reinterpret_cast(updated_amax_history), reinterpret_cast(updated_scale), - amax_compute_algo, static_cast(fp8_dtype), margin, stream); + *convertNVTETensorCheck(amax_history), *convertNVTETensorCheck(scale), + convertNVTETensor(updated_amax_history), convertNVTETensor(updated_scale), amax_compute_algo, + static_cast(fp8_dtype), margin, stream); } void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( @@ -419,10 +419,10 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( size_t num_tensors = amax_histories.size(); std::vector t_amax_histories, t_scales; for (size_t i = 0; i < num_tensors; i++) { - t_amax_histories.push_back(reinterpret_cast(amax_histories[i])); - t_scales.push_back(reinterpret_cast(scales[i])); + t_amax_histories.push_back(convertNVTETensor(amax_histories[i])); + t_scales.push_back(convertNVTETensor(scales[i])); } delayed_scaling_recipe::amax_and_scale_update_after_reduction( - *reinterpret_cast(amax_reduction_buffer), t_amax_histories, t_scales, - amax_compute_algo, static_cast(fp8_dtype), margin, stream); + *convertNVTETensorCheck(amax_reduction_buffer), t_amax_histories, t_scales, amax_compute_algo, + static_cast(fp8_dtype), margin, stream); } diff --git a/transformer_engine/common/recipe/fp8_block_scaling.cu b/transformer_engine/common/recipe/fp8_block_scaling.cu index 0eae3d07a..cdb307238 100644 --- a/transformer_engine/common/recipe/fp8_block_scaling.cu +++ b/transformer_engine/common/recipe/fp8_block_scaling.cu @@ -241,8 +241,8 @@ void nvte_fp8_block_scaling_compute_partial_amax(const NVTETensor inp, NVTETenso NVTE_API_CALL(nvte_fp8_block_scaling_compute_partial_amax); using namespace transformer_engine; fp8_block_scaling_recipe::fp8_block_scaling_compute_partial_amax( - *reinterpret_cast(inp), *reinterpret_cast(amax), h, w, - amax_stride_h, amax_stride_w, start_offset, block_len, stream); + *convertNVTETensorCheck(inp), *convertNVTETensorCheck(amax), h, w, amax_stride_h, + amax_stride_w, start_offset, block_len, stream); } void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out, @@ -253,7 +253,7 @@ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out, NVTE_API_CALL(nvte_fp8_block_scaling_partial_cast); using namespace transformer_engine; fp8_block_scaling_recipe::fp8_block_scaling_partial_cast( - *reinterpret_cast(inp), *reinterpret_cast(out), - *reinterpret_cast(scale), h, w, scale_stride_h, scale_stride_w, start_offset, - block_len, static_cast(out_dtype), stream); + *convertNVTETensorCheck(inp), *convertNVTETensorCheck(out), *convertNVTETensorCheck(scale), h, + w, scale_stride_h, scale_stride_w, start_offset, block_len, static_cast(out_dtype), + stream); } diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 11056917f..81a150283 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -357,6 +357,5 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swizzle_scaling_factors); using namespace transformer_engine; - swizzle_scaling_factors(reinterpret_cast(input), reinterpret_cast(output), - stream); + swizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream); } diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index d154e224e..e24e4d33d 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -8,20 +8,27 @@ #include +#include +#include #include #include +#include #include "common.h" #include "common/util/cuda_runtime.h" +#include "common/util/logging.h" namespace transformer_engine { -size_t typeToSize(const DType type) { +size_t typeToNumBits(const DType type) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, return TypeInfo::size;); // NOLINT(*) } -bool is_fp8_dtype(const DType t) { return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2; } +size_t typeToSize(const DType type) { + NVTE_CHECK(type != DType::kFloat4E2M1, "typeToSize() Does not support FP4 data type."); + return typeToNumBits(type) / 8; +} std::string to_string(const DType type) { switch (type) { @@ -39,6 +46,10 @@ std::string to_string(const DType type) { return "Float8E5M2"; case DType::kFloat8E8M0: return "Float8E8M0"; + case DType::kFloat4E2M1: + return "Float4E2M1"; + case DType::kInt16: + return "Int16"; case DType::kInt32: return "Int32"; case DType::kInt64: @@ -54,6 +65,8 @@ std::string to_string(const NVTEScalingMode &mode) { return "NVTE_DELAYED_TENSOR_SCALING"; case NVTE_MXFP8_1D_SCALING: return "NVTE_MXFP8_1D_SCALING"; + case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING: + return "NVTE_FWD_NVFP4_BWD_MXFP8_SCALING"; case NVTE_INVALID_SCALING: return "NVTE_INVALID_SCALING"; } @@ -83,10 +96,13 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { t.columnwise_scale_inv.shape, ")"); } } else { - if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) { + if (t.scaling_mode == NVTE_MXFP8_1D_SCALING || + t.scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) { // Need (4, 128) alignment even for e8 scaling factor auto block_alignment = std::vector{128ul, 4ul}; size_t expected_x, expected_y, alignment; + const size_t block_size_rowwise = (t.scaling_mode == NVTE_MXFP8_1D_SCALING) ? 32 : 16; + const size_t block_size_colwise = 32; if (t.has_data()) { alignment = block_alignment[0]; @@ -94,7 +110,8 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { DIVUP(DIVUP(t.flat_first_dim(), static_cast(1)), alignment) * alignment; alignment = block_alignment[1]; expected_y = - DIVUP(DIVUP(t.flat_last_dim(), static_cast(32)), alignment) * alignment; + DIVUP(DIVUP(t.flat_last_dim(), static_cast(block_size_rowwise)), alignment) * + alignment; const auto &expected = std::vector{expected_x, expected_y}; NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, "\" has invalid scale_inv shape (expected ", expected, ", got ", @@ -103,7 +120,8 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { if (t.has_columnwise_data()) { alignment = block_alignment[1]; expected_x = - DIVUP(DIVUP(t.flat_first_dim(), static_cast(32)), alignment) * alignment; + DIVUP(DIVUP(t.flat_first_dim(), static_cast(block_size_colwise)), alignment) * + alignment; alignment = block_alignment[0]; expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast(1)), alignment) * alignment; const auto &expected = std::vector{expected_x, expected_y}; @@ -194,23 +212,139 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt CheckScaleTensorShape(t, name); } +class TensorAllocator { + public: + static TensorAllocator &instance() { + static TensorAllocator allocator; + return allocator; + } + + ~TensorAllocator() {} + + NVTETensor Allocate(NVTEScalingMode mode) { + std::lock_guard lock(mutex); + if (!free_list.empty()) { + uintptr_t index = free_list.back(); + NVTETensor ret = reinterpret_cast(index); + free_list.pop_back(); + if (debug) { + std::cout << "Allocated " << index + << " from free list. Free list size: " << free_list.size() << " and capacity " + << free_list.capacity() << std::endl; + } + // 1-based indexing + memory[index - 1].scaling_mode = mode; + return ret; + } + if (memory.size() < memory.capacity()) { + memory.emplace_back(); + Tensor &t = memory.back(); + size = memory.size(); + // 1-based indexing + uintptr_t index = memory.size(); + if (debug) { + std::cout << "Allocated " << index << ". Memory size: " << memory.size() << " and capacity " + << memory.capacity() << std::endl; + } + t.scaling_mode = mode; + t.nvte_tensor = reinterpret_cast(index); + return reinterpret_cast(index); + } + NVTE_ERROR("Cannot allocate a new NVTETensor. Maximum number of tensors reached: ", + MAX_TENSOR_NUM, ". There is probably a memory leak in your application."); + } + + void Free(NVTETensor t) { + std::lock_guard lock(mutex); + uintptr_t index = reinterpret_cast(t); + if (index == 0) return; + NVTE_CHECK(index <= memory.size(), "Invalid tensor."); + free_list.push_back(index); + // Clean up + memory[index - 1].clear(); + if (debug) { + std::cout << "Freed " << index << ". Free list size: " << free_list.size() << " and capacity " + << free_list.capacity() << std::endl; + } + } + + void Free(NVTETensor *t, size_t N) { + std::lock_guard lock(mutex); + for (size_t i = 0; i < N; ++i) { + uintptr_t index = reinterpret_cast(t[i]); + if (index == 0) continue; + NVTE_CHECK(index <= memory.size(), "Invalid tensor."); + free_list.push_back(index); + // Clean up + memory[index - 1].clear(); + } + if (debug) { + std::cout << "Freed range of" << N << " tensors. Free list size: " << free_list.size() + << " and capacity " << free_list.capacity() << std::endl; + } + } + + Tensor *convertNVTETensor(NVTETensor t) { + uintptr_t index = reinterpret_cast(t); + // 1-based indexing to enable 0-initialization of NVTETensor + // to be invalid tensor + static_assert(nullptr == 0); + if (index != 0 && index <= size) { + return &(memory[index - 1]); + } + return nullptr; + } + + void setDebug(bool debug) { + std::lock_guard lock(mutex); + this->debug = debug; + } + + private: + TensorAllocator() { + std::lock_guard lock(mutex); + memory.reserve(MAX_TENSOR_NUM); + } + + std::mutex mutex; + std::atomic size; + // Allocate at most 20 MB for tensors + // Should be replaced by virtual memory allocation + const size_t MAX_TENSOR_NUM = 20 * 1024 * 1024 / sizeof(Tensor); + std::vector free_list; + std::vector memory; + bool debug = false; +}; + +Tensor *convertNVTETensor(const NVTETensor t) { + return TensorAllocator::instance().convertNVTETensor(t); +} + +Tensor *convertNVTETensorCheck(const NVTETensor t) { + Tensor *ptr = TensorAllocator::instance().convertNVTETensor(t); + NVTE_CHECK(ptr != nullptr, "Invalid tensor."); + return ptr; +} + } // namespace transformer_engine NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) { - transformer_engine::Tensor *ret = new transformer_engine::Tensor; - ret->scaling_mode = scaling_mode; + NVTETensor ret = transformer_engine::TensorAllocator::instance().Allocate(scaling_mode); return ret; } + void nvte_destroy_tensor(NVTETensor tensor) { - if (tensor == nullptr) return; - auto *t = reinterpret_cast(tensor); - delete t; + transformer_engine::TensorAllocator::instance().Free(tensor); +} + +void nvte_destroy_tensors(NVTETensor *tensors, size_t N) { + transformer_engine::TensorAllocator::instance().Free(tensors, N); } NVTEDType nvte_tensor_type(const NVTETensor tensor) { - if (tensor == nullptr) return kNVTEFloat32; - return static_cast( - reinterpret_cast(tensor)->dtype()); + auto *t = transformer_engine::convertNVTETensor(tensor); + if (t == nullptr) return kNVTEFloat32; + return static_cast(t->dtype()); } NVTEShape nvte_make_shape(const size_t *data, size_t ndim) { @@ -228,23 +362,24 @@ NVTEShape nvte_make_shape(const size_t *data, size_t ndim) { } NVTEShape nvte_tensor_shape(const NVTETensor tensor) { - if (tensor == nullptr) { + auto *t = transformer_engine::convertNVTETensor(tensor); + if (t == nullptr) { NVTE_ERROR("Invalid tensor"); } // Determine tensor shape depending on tensor format - const auto &t = *reinterpret_cast(tensor); - std::vector shape = t.shape(); + const std::vector &shape = t->shape(); return nvte_make_shape(shape.data(), shape.size()); } NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { - if (tensor == nullptr) { + auto *t = transformer_engine::convertNVTETensor(tensor); + if (t == nullptr) { NVTE_ERROR("Invalid tensor"); } - const auto &t = *reinterpret_cast(tensor); - return nvte_make_shape(t.columnwise_data.shape.data(), t.columnwise_data.shape.size()); + const std::vector &shape = t->columnwise_data.shape; + return nvte_make_shape(shape.data(), shape.size()); } size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; } @@ -265,83 +400,97 @@ size_t nvte_tensor_numel(const NVTETensor tensor) { return numel; } +size_t nvte_tensor_element_size_bits(const NVTETensor tensor) { + auto *t = transformer_engine::convertNVTETensor(tensor); + if (t == nullptr) return 8 * sizeof(float); + return transformer_engine::typeToNumBits(t->dtype()); +} + size_t nvte_tensor_element_size(const NVTETensor tensor) { - if (tensor == nullptr) return sizeof(float); - const auto &t = *reinterpret_cast(tensor); - return transformer_engine::typeToSize(t.dtype()); + auto *t = transformer_engine::convertNVTETensor(tensor); + if (t == nullptr) return sizeof(float); + NVTE_CHECK(!is_fp4_dtype(t->dtype()), + "For FP4 type please use the nvte_tensor_element_size_bits."); + return nvte_tensor_element_size_bits(tensor) / 8; +} + +size_t nvte_tensor_size_bytes(const NVTETensor tensor) { + auto *t = transformer_engine::convertNVTETensor(tensor); + if (t == nullptr) return 0; + return (nvte_tensor_numel(tensor) * nvte_tensor_element_size_bits(tensor)) / 8; } void *nvte_tensor_data(const NVTETensor tensor) { - if (tensor == nullptr) return nullptr; - const auto &t = *reinterpret_cast(tensor); - return t.data.dptr; + auto *t = transformer_engine::convertNVTETensor(tensor); + if (t == nullptr) return nullptr; + return t->data.dptr; } void *nvte_tensor_columnwise_data(const NVTETensor tensor) { - if (tensor == nullptr) return nullptr; - const auto &t = *reinterpret_cast(tensor); - return t.columnwise_data.dptr; + auto *t = transformer_engine::convertNVTETensor(tensor); + if (t == nullptr) return nullptr; + return t->columnwise_data.dptr; } float *nvte_tensor_amax(const NVTETensor tensor) { - if (tensor == nullptr) return nullptr; - const auto &t = *reinterpret_cast(tensor); - NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32, + auto *t = transformer_engine::convertNVTETensor(tensor); + if (t == nullptr) return nullptr; + NVTE_CHECK(t->amax.dtype == transformer_engine::DType::kFloat32, "Tensor's amax must have Float32 type!"); - return reinterpret_cast(t.amax.dptr); + return reinterpret_cast(t->amax.dptr); } float *nvte_tensor_scale(const NVTETensor tensor) { - if (tensor == nullptr) return nullptr; - const auto &t = *reinterpret_cast(tensor); - NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32, + auto *t = transformer_engine::convertNVTETensor(tensor); + if (t == nullptr) return nullptr; + NVTE_CHECK(t->scale.dtype == transformer_engine::DType::kFloat32, "Tensor's scale must have Float32 type!"); - return reinterpret_cast(t.scale.dptr); + return reinterpret_cast(t->scale.dptr); } float *nvte_tensor_scale_inv(const NVTETensor tensor) { - if (tensor == nullptr) return nullptr; - const auto &t = *reinterpret_cast(tensor); - return reinterpret_cast(t.scale_inv.dptr); + auto *t = transformer_engine::convertNVTETensor(tensor); + if (t == nullptr) return nullptr; + return reinterpret_cast(t->scale_inv.dptr); } void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) { - if (tensor == nullptr) return nullptr; - const auto &t = *reinterpret_cast(tensor); - return t.columnwise_scale_inv.dptr; + auto *t = transformer_engine::convertNVTETensor(tensor); + if (t == nullptr) return nullptr; + return t->columnwise_scale_inv.dptr; } NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) { - if (tensor == nullptr) { + auto *t = transformer_engine::convertNVTETensor(tensor); + if (t == nullptr) { return nvte_make_shape(nullptr, 0); } - const auto &t = *reinterpret_cast(tensor); - return nvte_make_shape(t.scale_inv.shape.data(), t.scale_inv.shape.size()); + return nvte_make_shape(t->scale_inv.shape.data(), t->scale_inv.shape.size()); } void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, const NVTEBasicTensor *param) { NVTE_CHECK(tensor != nullptr, "Tensor pointer can't be NULL."); - NVTE_CHECK(*tensor != nullptr, "Tensor is not allocated."); - auto &t = *reinterpret_cast(*tensor); + auto *t = transformer_engine::convertNVTETensor(*tensor); + NVTE_CHECK(t != nullptr, "Tensor is not allocated."); switch (param_name) { case kNVTERowwiseData: - t.data = *param; + t->data = *param; break; case kNVTEColumnwiseData: - t.columnwise_data = *param; + t->columnwise_data = *param; break; case kNVTEScale: - t.scale = *param; + t->scale = *param; break; case kNVTEAmax: - t.amax = *param; + t->amax = *param; break; case kNVTERowwiseScaleInv: - t.scale_inv = *param; + t->scale_inv = *param; break; case kNVTEColumnwiseScaleInv: - t.columnwise_scale_inv = *param; + t->columnwise_scale_inv = *param; break; default: NVTE_ERROR("Unknown tensor parameter!"); @@ -352,7 +501,7 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p if (tensor == nullptr) { return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)}; } - const auto &t = *reinterpret_cast(tensor); + const auto &t = *transformer_engine::convertNVTETensorCheck(tensor); switch (param_name) { case kNVTERowwiseData: return t.data; @@ -372,28 +521,30 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p } NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor) { - const auto &t = *reinterpret_cast(tensor); + if (tensor == nullptr) { + return NVTE_DELAYED_TENSOR_SCALING; + } + const auto &t = *transformer_engine::convertNVTETensorCheck(tensor); return t.scaling_mode; } void nvte_tensor_pack_create(NVTETensorPack *pack) { for (int i = 0; i < pack->MAX_SIZE; i++) { - pack->tensors[i] = reinterpret_cast(new transformer_engine::Tensor); + pack->tensors[i] = + transformer_engine::TensorAllocator::instance().Allocate(NVTE_DELAYED_TENSOR_SCALING); } } void nvte_tensor_pack_destroy(NVTETensorPack *pack) { - for (int i = 0; i < pack->MAX_SIZE; i++) { - auto *t = reinterpret_cast(pack->tensors[i]); - delete t; - } + transformer_engine::TensorAllocator::instance().Free(pack->tensors, pack->MAX_SIZE); } void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) { - const auto &t = *reinterpret_cast(tensor); + if (tensor == nullptr) return; + const auto &t = *transformer_engine::convertNVTETensorCheck(tensor); // Zero out tensor data if allocated if (t.data.dptr != nullptr) { - size_t size_in_bytes = nvte_tensor_element_size(tensor) * nvte_tensor_numel(tensor); + const size_t size_in_bytes = nvte_tensor_size_bytes(tensor); (void)cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream); } // Set amax to 0 if allocated @@ -441,6 +592,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigNoopTensor: std::memcpy(buf, &config_.noop_tensor, attr_size); break; + case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat: + std::memcpy(buf, &config_.float8_block_scale_tensor_format, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -473,6 +627,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigNoopTensor: std::memcpy(&config_.noop_tensor, buf, attr_size); break; + case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat: + std::memcpy(&config_.float8_block_scale_tensor_format, buf, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu index 345e8ee2b..fdf92938c 100644 --- a/transformer_engine/common/transpose/cast_transpose.cu +++ b/transformer_engine/common/transpose/cast_transpose.cu @@ -442,15 +442,15 @@ void nvte_cast_transpose(const NVTETensor input, NVTETensor output, cudaStream_t NVTE_API_CALL(nvte_cast_transpose); using namespace transformer_engine; auto noop = Tensor(); - transformer_engine::detail::cast_transpose(*reinterpret_cast(input), noop, - reinterpret_cast(output), stream); + transformer_engine::detail::cast_transpose(*convertNVTETensorCheck(input), noop, + convertNVTETensor(output), stream); } void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_with_noop); using namespace transformer_engine; - transformer_engine::detail::cast_transpose(*reinterpret_cast(input), - *reinterpret_cast(noop), - reinterpret_cast(output), stream); + transformer_engine::detail::cast_transpose(*convertNVTETensorCheck(input), + *convertNVTETensorCheck(noop), + convertNVTETensor(output), stream); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index 3148b4f72..a73723926 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -31,25 +31,27 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor // enum class for rowwise usage enum class FP8BlockwiseRowwiseOption { - // No rowwise data + // No rowwise data, skip rowwise quantization NONE, // Rowwise data, scales in GEMM format - ROWWISE - // TODO: FP8 all gather requires some changes. - // 1. Compact scales are better for gathering than the GEMM format. + ROWWISE_GEMM_READY, + // Rowwise data, scales in compact format, needs extra processing (padding, transposing) before GEMM + ROWWISE_COMPACT }; // enum class for columnwise usage // For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling enum class FP8BlockwiseColumnwiseOption { - // No columnwise data + // No columnwise data, skip columnwise quantization NONE, // Columnwise data transposed from original shape. // Scales in GEMM format corresponding to GEMM ingesting transposed column data. - COLUMNWISE_TRANSPOSE - // TODO: FP8 all gather requires some changes. - // 1. The transpose gets in the way of the all gather. - // 2. Compact scales are better for gathering than the GEMM format. + // On Hopper sm90, GEMM_READY means that columnwise quantization also fuses transpose op + // On higher sm versions with TN,NT,NN fp8 gemm, GEMM_READY doesn't fuse transpose + COLUMNWISE_GEMM_READY, + // Columnwise data in original shape + // Scales in compact format, needs extra processing (padding, transposing) before GEMM + COLUMNWISE_COMPACT }; void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv, diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index 17506e143..8a2c39def 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -19,6 +19,7 @@ #include "../util/string.h" #include "../utils.cuh" #include "cast_transpose.h" +#include "common/common.h" namespace transformer_engine { @@ -197,17 +198,18 @@ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, / workspace->data.dtype = DType::kFloat32; } else { // Check that workspace matches expected size - const size_t workspace_size = + const size_t workspace_size = get_buffer_size_bytes( std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1, - std::multiplies()) * - typeToSize(workspace->data.dtype); - const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32); + std::multiplies()), + workspace->data.dtype); + const size_t required_size = + get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32); NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (", num_rows_partial_dbias, ",", row_length, "), found ())"); NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), "; found dims=", workspace->data.shape, - ", dtype=", typeToSize(workspace->data.dtype), ")"); + ", dtype=", typeToNumBits(workspace->data.dtype), " bits)"); } } @@ -1384,9 +1386,8 @@ void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor output, NVTETe constexpr const NVTETensor activation_input = nullptr; cast_transpose_fused( - *reinterpret_cast(input), reinterpret_cast(activation_input), - reinterpret_cast(output), reinterpret_cast(dbias), - reinterpret_cast(workspace), stream); + *convertNVTETensorCheck(input), convertNVTETensor(activation_input), + convertNVTETensor(output), convertNVTETensor(dbias), convertNVTETensor(workspace), stream); } void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input, @@ -1401,9 +1402,9 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor ac constexpr bool IS_ACT = false; cast_transpose_fused>( - *reinterpret_cast(input), reinterpret_cast(act_input), - reinterpret_cast(output), reinterpret_cast(dbias), - reinterpret_cast(workspace), stream); + *convertNVTETensorCheck(input), convertNVTETensorCheck(act_input), + convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace), + stream); } void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor silu_input, @@ -1418,9 +1419,9 @@ void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor si constexpr bool IS_ACT = false; cast_transpose_fused>( - *reinterpret_cast(input), reinterpret_cast(silu_input), - reinterpret_cast(output), reinterpret_cast(dbias), - reinterpret_cast(workspace), stream); + *convertNVTETensorCheck(input), convertNVTETensorCheck(silu_input), + convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace), + stream); } void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor relu_input, @@ -1435,9 +1436,9 @@ void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor re constexpr bool IS_ACT = false; cast_transpose_fused>( - *reinterpret_cast(input), reinterpret_cast(relu_input), - reinterpret_cast(output), reinterpret_cast(dbias), - reinterpret_cast(workspace), stream); + *convertNVTETensorCheck(input), convertNVTETensorCheck(relu_input), + convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace), + stream); } void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor srelu_input, @@ -1452,9 +1453,9 @@ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor s constexpr bool IS_ACT = false; cast_transpose_fused>( - *reinterpret_cast(input), reinterpret_cast(srelu_input), - reinterpret_cast(output), reinterpret_cast(dbias), - reinterpret_cast(workspace), stream); + *convertNVTETensorCheck(input), convertNVTETensorCheck(srelu_input), + convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace), + stream); } void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor qgelu_input, @@ -1469,9 +1470,9 @@ void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor q constexpr bool IS_ACT = false; cast_transpose_fused>( - *reinterpret_cast(input), reinterpret_cast(qgelu_input), - reinterpret_cast(output), reinterpret_cast(dbias), - reinterpret_cast(workspace), stream); + *convertNVTETensorCheck(input), convertNVTETensorCheck(qgelu_input), + convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace), + stream); } void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, @@ -1481,8 +1482,8 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_a using namespace transformer_engine::detail; dgated_act_cast_transpose, gelu>( - *reinterpret_cast(input), *reinterpret_cast(gated_act_input), - reinterpret_cast(output), stream); + *convertNVTETensorCheck(input), *convertNVTETensorCheck(gated_act_input), + convertNVTETensorCheck(output), stream); } void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu_input, @@ -1492,8 +1493,8 @@ void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu using namespace transformer_engine::detail; dgated_act_cast_transpose, silu>( - *reinterpret_cast(input), *reinterpret_cast(swiglu_input), - reinterpret_cast(output), stream); + *convertNVTETensorCheck(input), *convertNVTETensorCheck(swiglu_input), + convertNVTETensorCheck(output), stream); } void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, @@ -1503,8 +1504,8 @@ void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_a using namespace transformer_engine::detail; dgated_act_cast_transpose, relu>( - *reinterpret_cast(input), *reinterpret_cast(gated_act_input), - reinterpret_cast(output), stream); + *convertNVTETensorCheck(input), *convertNVTETensorCheck(gated_act_input), + convertNVTETensorCheck(output), stream); } void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, @@ -1514,8 +1515,8 @@ void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_ using namespace transformer_engine::detail; dgated_act_cast_transpose, srelu>( - *reinterpret_cast(input), *reinterpret_cast(gated_act_input), - reinterpret_cast(output), stream); + *convertNVTETensorCheck(input), *convertNVTETensorCheck(gated_act_input), + convertNVTETensorCheck(output), stream); } void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, @@ -1525,6 +1526,6 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_ using namespace transformer_engine::detail; dgated_act_cast_transpose, qgelu>( - *reinterpret_cast(input), *reinterpret_cast(gated_act_input), - reinterpret_cast(output), stream); + *convertNVTETensorCheck(input), *convertNVTETensorCheck(gated_act_input), + convertNVTETensorCheck(output), stream); } diff --git a/transformer_engine/common/transpose/multi_cast_transpose.cu b/transformer_engine/common/transpose/multi_cast_transpose.cu index 5cf316f45..2be365465 100644 --- a/transformer_engine/common/transpose/multi_cast_transpose.cu +++ b/transformer_engine/common/transpose/multi_cast_transpose.cu @@ -237,8 +237,8 @@ void multi_cast_transpose(const std::vector input_list, std::vector input_list_, output_list_; for (size_t i = 0; i < num_tensors; ++i) { - input_list_.push_back(reinterpret_cast(const_cast(input_list[i]))); - output_list_.push_back(reinterpret_cast(output_list[i])); + input_list_.push_back(convertNVTETensorCheck(input_list[i])); + output_list_.push_back(convertNVTETensorCheck(output_list[i])); } multi_cast_transpose(input_list_, output_list_, stream); } diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 663c61a1c..79d8d215f 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -460,7 +460,7 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size CUtensorMap tensor_map_output_trans{}; create_2D_tensor_map(tensor_map_output_trans, tensor, global_dim_y, global_dim_x, /*shmemY=*/BLOCK_TILE_DIM, /*shmemX=*/BLOCK_TILE_DIM, - /*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType)); + /*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType) * 8); return tensor_map_output_trans; } diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 91f73dea1..6f5c0f3a6 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -99,14 +99,14 @@ Step 2: Cast and store to output_c | ... | +-------------------------------+-------------------------------+-------------------------------+-------------------------------+ -Step 3: Transpose, cast and store to output_t +Step 3 (if columnwise transpose is True, GEMM_READY): Transpose, cast and store to output_t * shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) * 8 warps * Loop 2 times * What each thread does in each loop: * 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times * Every 8 consecutive threads do reduction and calculate the amax of each column - * 16 elements are quantized and write to output_c at a time, for a total of 2 times + * 16 elements are quantized and write to output_t at a time, for a total of 2 times +------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+ | T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | | | T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | | @@ -118,6 +118,29 @@ Step 3: Transpose, cast and store to output_t | T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | | +-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+ +Step 3 (if columnwise transpose is False, COMPACT format): Skip Transpose, cast and store to output_t +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 1 times +* What each thread does in each loop: + * 16 elements (in a row) are read from the shared memory, for a total of 4 rows, + * it needs 8 reads in smem to get 16 elements in a row, thread tile shape is 16x4 + * Every 32 consecutive threads in a warp do reduction and calculate the amax of each column, + * so each thread will do warp shuffle 16 times to get the amax of each column + * 16 elements are quantized and write to output_t at a time, for a total of 4 times ++------16 elements-------+------16 elements-------+-----80 elements-----+------16 elements------+ +| T0 | | | | +| T1 | | | | +| T2 | | | | +| T3 | | | | +| T4 | | | | +| T5 | | | | +| T6 | | | | +| T7 | | | | +| ... | | | | +| T31 | | | | ++-----------------------+-----------------------+-----------------------+-----------------------+ + */ // clang-format on @@ -140,6 +163,7 @@ constexpr int kNumThreadsLoad = kTileDim / kNVecIn; constexpr int kNumThreadsStore = kTileDim / kNVecOut; static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); +constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; template __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( @@ -149,9 +173,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, const bool pow_2_scaling) { - bool return_rowwise = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE; - bool return_columnwise_transpose = - columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE; + bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE; + bool return_columnwise_gemm_ready = + columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + bool return_columnwise_compact = + columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT; using SMemVec = Vec; using OVec = Vec; @@ -299,8 +325,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } } - // Step 3: Transpose, cast and store to output_t - if (return_columnwise_transpose) { + // Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t + if (return_columnwise_gemm_ready) { constexpr int c_stride = kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); @@ -385,6 +411,103 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } } } + + // Step 4 (return_columnwise_compact): cast in 128x1 style and store to output, skip transpose + if (return_columnwise_compact) { + // thread tile should be 4x16, 16 means 8 smem reads + constexpr int kThreadTileRow = kTileDim / kThreadsPerWarp; + constexpr int kThreadTileCol = kNVecOut; + using RegVec = Vec; + using RegScaleVec = Vec; + constexpr int num_smem_reads = kNVecOut / kNVecSMem; + // c_stride will not be used here because we only have one iteration + // constexpr int c_stride = kThreadTileCol * kNumWarps / kNVecSMem; + constexpr int num_iterations = + kTileDim / (kNumWarps * kThreadTileCol); // should be only one iteration + static_assert(num_iterations == 1, + "num_iterations should be 1 for columnwise non-transpose case"); + const int thr_idx_in_warp = threadIdx.x % kThreadsPerWarp; + const int warp_idx = threadIdx.x / kThreadsPerWarp; + const int r_s = thr_idx_in_warp * kThreadTileRow; // Row in shared memory + int c_s = warp_idx * num_smem_reads; // Column in shared memory + size_t r_g = static_cast(blockIdx.y) * kTileDim + r_s; // Row in global memory + const size_t c_g = + static_cast(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory + const size_t num_ele = c_g < row_length + ? min(static_cast(kThreadTileCol), row_length - c_g) + : 0; // For not aligned case +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + RegVec reg_vec[kThreadTileRow]; + RegScaleVec thr_scale; + + // Step 3.1: Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kThreadTileRow; ++i) { + int r = r_s + i; +#pragma unroll + for (int j = 0; j < num_smem_reads; ++j) { + int c = c_s + j; + SMemVec smem_vec = smem[r * kSMemCol + c]; + // copy smem_vec to reg vec with its elements +#pragma unroll + for (int k = 0; k < kNVecSMem; ++k) { + reg_vec[i].data.elt[j * kNVecSMem + k] = smem_vec.data.elt[k]; + } + } + } +#pragma unroll + for (int reg_idx = 0; reg_idx < kThreadTileCol; ++reg_idx) { + // Step 3.2: Compute local amax + CType amax = 0; +#pragma unroll + for (int i = 0; i < kThreadTileRow; ++i) { + amax = fmaxf(amax, fabsf(reg_vec[i].data.elt[reg_idx])); + } + // Step 3.3: Reduce amax + const bool is_src_lane = thr_idx_in_warp == 0; + amax = warp_reduce_max(amax); + constexpr int lane_zero = 0; + amax = __shfl_sync(0xFFFFFFFF, amax, lane_zero); + // Step 3.4: Compute scale + CType scale; + scale = compute_scale_from_types(amax, epsilon, pow_2_scaling); + thr_scale.data.elt[reg_idx] = scale; + // Step 3.5: Write scale_inv_t + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (c_g + reg_idx < row_length); + } + if (write_scale_inv) { + CType scale_inv = 1.0 / scale; + size_t row_idx = static_cast(blockIdx.y); + size_t col_idx = static_cast(blockIdx.x) * kTileDim + c_s * kNVecSMem + reg_idx; + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + } + // Step 3.6: Quantize + for (int row_idx = 0; row_idx < kThreadTileRow; ++row_idx) { + OType* output_g = + &output_t[(r_g + row_idx) * row_length + c_g]; // Output address in global memory + OVec output_vec; +#pragma unroll + for (int i = 0; i < kThreadTileCol; ++i) { + output_vec.data.elt[i] = static_cast( + static_cast(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i]); + } + // Step 3.7: Store output_t + if constexpr (kAligned) { + output_vec.store_to(output_g); + } else { + if (r_g + row_idx < num_rows) { + output_vec.store_to_elts(output_g, 0, num_ele); + } + } + } + // Step 3.8: Update output address, column index of shared memory + // this section shouldn't matter since we only have one iteration + } + } } } // namespace @@ -400,11 +523,6 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor const bool pow2_scale, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise); - // assert that rowwise_option and columnwise_option are not both NONE - NVTE_CHECK(rowwise_option != FP8BlockwiseRowwiseOption::NONE || - columnwise_option != FP8BlockwiseColumnwiseOption::NONE, - "rowwise_option and columnwise_option cannot both be NONE"); - const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; size_t num_rows = 1; @@ -425,32 +543,43 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor size_t scale_t_stride_y = 0; if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) { - NVTE_CHECK(rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE, + NVTE_CHECK(rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY || + rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE_COMPACT, "Unexpected rowwise enum value"); NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2."); size_t scale_k = scale_inv.shape[1]; - scale_stride_x = scale_k; - scale_stride_y = 1; + bool rowwise_compact = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE_COMPACT; + scale_stride_x = rowwise_compact ? 1 : scale_k; + scale_stride_y = rowwise_compact ? scale_k : 1; } if (columnwise_option != FP8BlockwiseColumnwiseOption::NONE) { - NVTE_CHECK(columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE, - "Unexpected columnwise enum value"); NVTE_CHECK(output_t.shape.size() == input.shape.size(), "output_t must have same number of dimensions as input."); + if (output_t.shape.size() > 0) { - NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t."); - for (size_t i = 1; i < output_t.shape.size(); ++i) { - NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t"); + if (columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY) { + NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t."); + for (size_t i = 1; i < output_t.shape.size(); ++i) { + NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t"); + } + } else { + NVTE_CHECK(columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT, + "Unexpected columnwise option enum value"); + NVTE_CHECK(output_t.shape[0] == input.shape[0], "Wrong dimension 0 of output_t."); + NVTE_CHECK( + input.shape == output_t.shape, + "Input and output_t must have the same shape for columnwise non-transpose case."); } } NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype."); - NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2."); - scale_t_stride_x = scale_inv_t.shape[1]; - scale_t_stride_y = 1; + bool columnwise_compact = columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT; + size_t scale_t_k = scale_inv_t.shape[1]; + scale_t_stride_x = columnwise_compact ? 1 : scale_t_k; + scale_t_stride_y = columnwise_compact ? scale_t_k : 1; } const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); diff --git a/transformer_engine/common/transpose/transpose.cu b/transformer_engine/common/transpose/transpose.cu index 26740a383..103f45cf1 100644 --- a/transformer_engine/common/transpose/transpose.cu +++ b/transformer_engine/common/transpose/transpose.cu @@ -288,14 +288,13 @@ void nvte_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stre NVTE_API_CALL(nvte_transpose); using namespace transformer_engine; auto noop = Tensor(); - transpose(*reinterpret_cast(input), noop, reinterpret_cast(output), - stream); + transpose(*convertNVTETensorCheck(input), noop, convertNVTETensor(output), stream); } void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_transpose_with_noop); using namespace transformer_engine; - transpose(*reinterpret_cast(input), *reinterpret_cast(noop), - reinterpret_cast(output), stream); + transpose(*convertNVTETensorCheck(input), *convertNVTETensorCheck(noop), + convertNVTETensor(output), stream); } diff --git a/transformer_engine/common/transpose/transpose_fusion.cu b/transformer_engine/common/transpose/transpose_fusion.cu index cf326e924..af859471e 100644 --- a/transformer_engine/common/transpose/transpose_fusion.cu +++ b/transformer_engine/common/transpose/transpose_fusion.cu @@ -388,17 +388,18 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/ workspace->data.dtype = DType::kFloat32; } else { // Check that workspace matches expected size - const size_t workspace_size = + const size_t workspace_size = get_buffer_size_bytes( std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1, - std::multiplies()) * - typeToSize(workspace->data.dtype); - const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32); + std::multiplies()), + workspace->data.dtype); + const size_t required_size = + get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32); NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (", num_rows_partial_dbias, ",", row_length, "), found ())"); NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), "; found dims=", workspace->data.shape, - ", dtype=", typeToSize(workspace->data.dtype), ")"); + ", dtype=", typeToNumBits(workspace->data.dtype), " bits)"); } } @@ -505,7 +506,6 @@ void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_outp NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_fp8_transpose_dbias); using namespace transformer_engine; - fp8_transpose_dbias( - *reinterpret_cast(input), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + fp8_transpose_dbias(*convertNVTETensorCheck(input), convertNVTETensor(transposed_output), + convertNVTETensor(dbias), convertNVTETensor(workspace), stream); } diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu index fe11120b3..dce878c8c 100644 --- a/transformer_engine/common/util/cast.cu +++ b/transformer_engine/common/util/cast.cu @@ -12,13 +12,16 @@ #endif //#ifndef __HIP_PLATFORM_AMD__ #include #include +#include #include #include +#include #include #include "../common.h" #include "../transpose/cast_transpose.h" +#include "../util/multi_stream.h" #include "../util/vectorized_pointwise.h" #include "../utils.cuh" #include "cast_kernels.cuh" @@ -158,6 +161,45 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dequantize); using namespace transformer_engine; - detail::dequantize_helper(*reinterpret_cast(input), - reinterpret_cast(output), stream); + detail::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream); +} + +void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, + const NVTEQuantizationConfig quant_configs, + const size_t num_tensors, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_quantize); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = false; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; + constexpr const NVTETensor grad = nullptr; + + const size_t num_streams = nvte_get_num_compute_streams(); + + int num_stream_used = std::min(num_streams, num_tensors); + // wait for current stream to finish + NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream)); + for (int s = 0; s < num_stream_used; s++) { + NVTE_CHECK_CUDA( + cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0))); + } + + for (int i = 0; i < num_tensors; i++) { + detail::quantize_helper( + inputs[i], grad, outputs[i], dbias, workspace, nullptr, + detail::get_compute_stream(i % num_streams)); + } + + // record events on compute streams + for (int s = 0; s < num_stream_used; s++) { + NVTE_CHECK_CUDA( + cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s))); + } + // wait for all compute streams to finish + for (int s = 0; s < num_stream_used; s++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s))); + } } diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 4dfd45b3e..a5a23c1c0 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -762,19 +762,20 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu if constexpr (IS_DGATED) { create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, - cols, 0, sizeof(IType)); + cols, 0, typeToNumBits(gated_input.dtype())); } const uint32_t tensor_stride_elems = output_cols; create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols * 2, 0, sizeof(IType)); + SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype())); create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols * 2, cols, sizeof(IType)); + SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype())); create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, tensor_stride_elems, 0, sizeof(OType)); + SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype())); create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, tensor_stride_elems, cols, sizeof(OType)); + SHMEM_DIM_X, tensor_stride_elems, cols, + typeToNumBits(output->dtype())); const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; const size_t buff_size_aligned_in = @@ -868,31 +869,33 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out if constexpr (IS_DGATED) { create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols, 0, sizeof(IType)); + SHMEM_DIM_X, cols, 0, typeToNumBits(gated_input.dtype())); } const uint32_t tensor_stride_elems = output_cols; create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, sizeof(IType)); + SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, + typeToNumBits(gated_input.dtype())); create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols, sizeof(IType)); + SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols, + typeToNumBits(gated_input.dtype())); if (USE_ROWWISE_SCALING) { create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0, - sizeof(OType)); + typeToNumBits(output->dtype())); create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols, - sizeof(OType)); + typeToNumBits(output->dtype())); } if (USE_COLWISE_SCALING) { create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, - 0, sizeof(OType)); + 0, typeToNumBits(output->dtype())); create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, - cols, sizeof(OType)); + cols, typeToNumBits(output->dtype())); } #endif // #ifdef __HIP_PLATFORM_AMD__ @@ -950,17 +953,20 @@ template void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { CheckInputTensor(input, "gated_act_input"); CheckOutputTensor(*output, "gated_act_output"); - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); - NVTE_CHECK(input.data.shape[0] == output->data.shape[0], - "Input shape[0] must be equal to output shape[0]."); - NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2, - "Input shape[1] must be 2x larger than output shape[1]."); + NVTE_CHECK(output->flat_first_dim() == input.flat_first_dim(), + "Wrong output shape. Expected (after flattening) [", input.flat_first_dim(), + ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(input.flat_last_dim() % 2 == 0, + "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", + input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == input.flat_last_dim() / 2, + "Wrong output shape. Expected (after flattening) [*, ", input.flat_last_dim() / 2, + "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, + input.dtype(), IType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, + output->dtype(), OType, if (!is_fp8_dtype(output->data.dtype) || is_delayed_tensor_scaling(output->scaling_mode)) { @@ -970,8 +976,8 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { reinterpret_cast(output->data.dptr), reinterpret_cast(output->scale.dptr), reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), output->data.shape[0], - output->data.shape[1], {}, stream); + reinterpret_cast(output->scale_inv.dptr), input.flat_first_dim(), + output->flat_last_dim(), {}, stream); } else { NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); }); // NOLINT(*) @@ -1098,10 +1104,9 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, cudaStream_t stream) { using namespace gated_kernels; Tensor grad_empty_tensor; - const Tensor &grad_tensor = - IS_DGATED ? *(reinterpret_cast(grad)) : grad_empty_tensor; - const Tensor gated_input_tensor = *reinterpret_cast(gated_input); - Tensor *output_tensor = reinterpret_cast(output); + const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor; + const Tensor gated_input_tensor = *convertNVTETensorCheck(gated_input); + Tensor *output_tensor = convertNVTETensorCheck(output); #ifdef __HIP_PLATFORM_AMD__ if (1) { diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 0618a7e30..468d31690 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -905,15 +905,15 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T alignas(64) CUtensorMap tensor_map_output{}; create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); + FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); if constexpr (IS_DACT) { create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); + FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); } create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, sizeof(OType)); + FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(output->data.dtype)); cast_fp8_2D_kernel <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, @@ -1018,24 +1018,24 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, alignas(64) CUtensorMap tensor_map_output_colwise{}; create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y, - MXFP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); + MXFP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype())); if constexpr (IS_DACT) { create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, - sizeof(IType)); + typeToNumBits(input.dtype())); } if (use_rowwise_scaling) { create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, - sizeof(OType)); + typeToNumBits(output->dtype())); } if (use_colwise_scaling) { create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, - sizeof(OType)); + typeToNumBits(output->dtype())); } cast_mxfp8_2D_kernelflat_last_dim(); constexpr int TMA_bytes = 16; - const int alignment_requirement = TMA_bytes / typeToSize(t->dtype()); + const int alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); return cols % alignment_requirement == 0; } #endif //#ifndef __HIP_PLATFORM_AMD__ @@ -1283,23 +1283,23 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o const Tensor *activation_input_tensor; if constexpr (IS_DBIAS || IS_DACT) { // backward - input is incoming gradient - input_tensor = reinterpret_cast(grad); - activation_input_tensor = reinterpret_cast(input); + input_tensor = convertNVTETensorCheck(grad); + activation_input_tensor = convertNVTETensor(input); } else { // forward = input is activation input - input_tensor = reinterpret_cast(input); + input_tensor = convertNVTETensorCheck(input); activation_input_tensor = nullptr; } - auto output_tensor = reinterpret_cast(output); - auto dbias_tensor = reinterpret_cast(dbias); - auto workspace_tensor = reinterpret_cast(workspace); + auto output_tensor = convertNVTETensorCheck(output); + auto dbias_tensor = convertNVTETensor(dbias); + auto workspace_tensor = convertNVTETensor(workspace); const QuantizationConfig *quant_config_cpp = reinterpret_cast(quant_config); // extract noop tensor from quant_config_cpp if it's not null const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr; - const auto noop_tensor = noop != nullptr ? *(reinterpret_cast(noop)) : Tensor(); + const auto noop_tensor = noop != nullptr ? *(convertNVTETensorCheck(noop)) : Tensor(); switch (output_tensor->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { @@ -1345,12 +1345,25 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; - FP8BlockwiseRowwiseOption rowwise_option = output_tensor->has_data() - ? FP8BlockwiseRowwiseOption::ROWWISE - : FP8BlockwiseRowwiseOption::NONE; - FP8BlockwiseColumnwiseOption columnwise_option = - output_tensor->has_columnwise_data() ? FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE - : FP8BlockwiseColumnwiseOption::NONE; + FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + if (output_tensor->has_data()) { + bool rowwise_compact = quant_config_cpp + ? quant_config_cpp->float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT + : false; + rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT + : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + } + if (output_tensor->has_columnwise_data()) { + bool columnwise_compact = quant_config_cpp + ? quant_config_cpp->float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT + : false; + columnwise_option = columnwise_compact + ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT + : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + } quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, diff --git a/transformer_engine/common/util/cuda_driver.cpp b/transformer_engine/common/util/cuda_driver.cpp index bacc56c31..d4835b611 100644 --- a/transformer_engine/common/util/cuda_driver.cpp +++ b/transformer_engine/common/util/cuda_driver.cpp @@ -14,22 +14,48 @@ namespace transformer_engine { namespace cuda_driver { - -void *get_symbol(const char *symbol) { - void *entry_point; +//TODO: hipGetDriverEntryPoint is supported in rocm 7.1 #ifdef __HIP_PLATFORM_AMD__ +void *get_symbol(const char *symbol, int cuda_version) { + void *entry_point; hipDriverProcAddressQueryResult driver_result; NVTE_CHECK_CUDA(hipGetProcAddress(symbol, &entry_point, HIP_VERSION_MAJOR*100+HIP_VERSION_MINOR, 0, &driver_result)); NVTE_CHECK(driver_result == HIP_GET_PROC_ADDRESS_SUCCESS, "Could not find CUDA driver entry point for ", symbol); + return entry_point; +} #else +typedef cudaError_t (*VersionedGetEntryPoint)(const char *, void **, unsigned int, + unsigned long long, // NOLINT(*) + cudaDriverEntryPointQueryResult *); +typedef cudaError_t (*GetEntryPoint)(const char *, void **, unsigned long long, // NOLINT(*) + cudaDriverEntryPointQueryResult *); + +void *get_symbol(const char *symbol, int cuda_version) { + constexpr char driver_entrypoint[] = "cudaGetDriverEntryPoint"; + constexpr char driver_entrypoint_versioned[] = "cudaGetDriverEntryPointByVersion"; + // We link to the libcudart.so already, so can search for it in the current context + static GetEntryPoint driver_entrypoint_fun = + reinterpret_cast(dlsym(RTLD_DEFAULT, driver_entrypoint)); + static VersionedGetEntryPoint driver_entrypoint_versioned_fun = + reinterpret_cast(dlsym(RTLD_DEFAULT, driver_entrypoint_versioned)); + cudaDriverEntryPointQueryResult driver_result; - NVTE_CHECK_CUDA(cudaGetDriverEntryPoint(symbol, &entry_point, cudaEnableDefault, &driver_result)); + void *entry_point = nullptr; + if (driver_entrypoint_versioned_fun != nullptr) { + // Found versioned entrypoint function + NVTE_CHECK_CUDA(driver_entrypoint_versioned_fun(symbol, &entry_point, cuda_version, + cudaEnableDefault, &driver_result)); + } else { + NVTE_CHECK(driver_entrypoint_fun != nullptr, "Error finding the CUDA Runtime-Driver interop."); + // Versioned entrypoint function not found + NVTE_CHECK_CUDA(driver_entrypoint_fun(symbol, &entry_point, cudaEnableDefault, &driver_result)); + } NVTE_CHECK(driver_result == cudaDriverEntryPointSuccess, "Could not find CUDA driver entry point for ", symbol); -#endif return entry_point; } +#endif } // namespace cuda_driver diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index 36638995c..f131bab45 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -21,7 +21,7 @@ namespace transformer_engine { namespace cuda_driver { /*! \brief Get pointer corresponding to symbol in CUDA driver library */ -void *get_symbol(const char *symbol); +void *get_symbol(const char *symbol, int cuda_version = 12010); /*! \brief Call function in CUDA driver library * diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 9453c2f86..896f09e50 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -126,8 +126,11 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id) bool supports_multicast(int device_id) { #if CUDART_VERSION >= 12010 - // NOTE: This needs to be guarded at compile time because the + // NOTE: This needs to be guarded at compile-time and run-time because the // CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED enum is not defined in earlier CUDA versions. + if (cudart_version() < 12010) { + return false; + } static std::vector cache(num_devices(), false); static std::vector flags(num_devices()); if (device_id < 0) { @@ -155,7 +158,6 @@ bool supports_multicast(int device_id) { #endif } - const std::string &include_directory(bool required) { static std::string path; @@ -220,6 +222,16 @@ const std::string &include_directory(bool required) { // Return cached path return path; } + +int cudart_version() { + auto get_version = []() -> int { + int version; + NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&version)); + return version; + }; + static int version = get_version(); + return version; +} #endif // __HIP_PLATFORM_AMD__ } // namespace cuda diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index 0f0f730be..6e25459c5 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -81,6 +81,12 @@ bool supports_multicast(int device_id = -1); const std::string &include_directory(bool required = false); #endif +/* \brief CUDA Runtime version number at run-time + * + * Versions may differ between compile-time and run-time. + */ +int cudart_version(); + } // namespace cuda } // namespace transformer_engine diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index 76e05c2d9..8f0a9730b 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -320,9 +320,9 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s alignas(64) CUtensorMap tensor_map_output{}; create_2D_tensor_map(tensor_map_input, input_data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols, 0, sizeof(IType)); + SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype())); create_2D_tensor_map(tensor_map_output, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols, 0, sizeof(OType)); + SHMEM_DIM_X, cols, 0, typeToNumBits(output->dtype())); dequantize_mxfp8_kernel <<>>(tensor_map_input, tensor_map_output, scales_ptr, diff --git a/transformer_engine/common/util/handle_manager.h b/transformer_engine/common/util/handle_manager.h index 1f50af9c3..a63cd61c3 100644 --- a/transformer_engine/common/util/handle_manager.h +++ b/transformer_engine/common/util/handle_manager.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -9,13 +11,13 @@ #include +#ifndef __HIP_PLATFORM_AMD__ #include "cuda_runtime.h" #include "logging.h" - -namespace transformer_engine::cuda { -int current_device(); -int num_devices(); -} // namespace transformer_engine::cuda +#else +#include "util/cuda_runtime.h" +#include "util/logging.h" +#endif namespace transformer_engine::detail { diff --git a/transformer_engine/common/util/multi_stream.cpp b/transformer_engine/common/util/multi_stream.cpp new file mode 100644 index 000000000..70d7376af --- /dev/null +++ b/transformer_engine/common/util/multi_stream.cpp @@ -0,0 +1,69 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_ +#define TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_ + +#include "multi_stream.h" + +#include + +#include +#include + +#include "cuda_runtime.h" +#include "logging.h" + +namespace transformer_engine::detail { + +cudaStream_t get_compute_stream(int idx) { + const size_t num_streams = nvte_get_num_compute_streams(); + NVTE_CHECK(0 <= idx && idx < num_streams, "Invalid compute stream (requested idx ", idx, + ", but there are ", num_streams, " streams)"); + static std::vector streams(num_streams); + static std::once_flag stream_init_flag; + auto init = [&]() { + for (size_t i = 0; i < num_streams; i++) { + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&streams[i], cudaStreamNonBlocking, -1)); + } + }; + std::call_once(stream_init_flag, init); + return streams[idx]; +} + +cudaEvent_t get_compute_stream_event(int idx) { + const size_t num_streams = nvte_get_num_compute_streams(); + NVTE_CHECK(0 <= idx && idx < num_streams, "Invalid compute stream (requested idx ", idx, + ", but there are ", num_streams, " streams)"); + static std::vector events(num_streams); + static std::once_flag event_init_flag; + auto init = [&]() { + for (size_t i = 0; i < num_streams; i++) { + NVTE_CHECK_CUDA(cudaEventCreate(&events[i])); + } + }; + std::call_once(event_init_flag, init); + return events[idx]; +} + +int get_num_compute_streams() { + static constexpr int num_compute_streams = 4; + return num_compute_streams; +} + +} // namespace transformer_engine::detail + +int nvte_get_num_compute_streams() { return transformer_engine::detail::get_num_compute_streams(); } + +cudaStream_t nvte_get_compute_stream(const int idx) { + return transformer_engine::detail::get_compute_stream(idx); +} + +cudaEvent_t nvte_get_compute_stream_event(const int idx) { + return transformer_engine::detail::get_compute_stream_event(idx); +} + +#endif // TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_ diff --git a/transformer_engine/common/util/multi_stream.h b/transformer_engine/common/util/multi_stream.h new file mode 100644 index 000000000..26f2d19df --- /dev/null +++ b/transformer_engine/common/util/multi_stream.h @@ -0,0 +1,20 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_ +#define TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_ + +namespace transformer_engine::detail { + +int get_num_compute_streams(); + +cudaStream_t get_compute_stream(int idx); + +cudaEvent_t get_compute_stream_event(int idx); + +} // namespace transformer_engine::detail + +#endif // TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_ diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index e90d2de55..a1899d5b1 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -126,6 +126,83 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP } } +template +__global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(MultiPaddingArgs args) { + using Vec = Vec; + + // Thread indices + // Note: Block is interpreted as a warp_size x num_warps grid + constexpr int bdimx = THREADS_PER_WARP; + constexpr int bdimy = n_warps_per_tile; + const int tid = threadIdx.x; + const int tidx = tid % bdimx; + const int tidy = tid / bdimx; + const int bid = blockIdx.x; + + // Input tensors are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles + constexpr int tile_dim_m = THREADS_PER_WARP * nvec; + constexpr int tile_dim_n = THREADS_PER_WARP * nvec; + + // Number of nvec x nvec subtiles for each thread to + // load/store + constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; + + // Find tensor corresponding to block + int tensor_id = 0; + while (args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + const Type* input = reinterpret_cast(args.input_list[tensor_id]); + Type* output = reinterpret_cast(args.output_list[tensor_id]); + const int num_rows = args.num_rows_list[tensor_id]; + const int row_length = args.row_length_list[tensor_id]; + + // Find position of tile within tensor + const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; + const int tile_id = bid - args.block_range[tensor_id]; + const int tile_id_m = tile_id / num_tiles_n; + const int tile_id_n = tile_id % num_tiles_n; + const int tile_row = tile_id_m * tile_dim_m; + const int tile_col = tile_id_n * tile_dim_n; + + // Load input and store to registers + // Note: Each thread loads n_iterations subtiles, casts to output + // type, and transposes in registers. + Type local_zero = static_cast(0.f); +#pragma unroll + for (int iter = 0; iter < n_iterations; ++iter) { + const int i1 = tidy + iter * bdimy; + const int j1 = tidx; +#pragma unroll + for (int i2 = 0; i2 < nvec; ++i2) { + const int row = tile_row + i1 * nvec + i2; + const int col = tile_col + j1 * nvec; + Vec local_input; + Vec local_output; + local_input.clear(); + if (row < num_rows) { + for (int j2 = 0; j2 < nvec; ++j2) { + if (col + j2 < row_length) { + local_input.data.elt[j2] = input[row * row_length + col + j2]; + } + } + } +#pragma unroll + for (int j2 = 0; j2 < nvec; ++j2) { + local_output.data.elt[j2] = local_input.data.elt[j2]; + } + if (row < num_rows) { + for (int j2 = 0; j2 < nvec; ++j2) { + if (col + j2 < row_length) { + output[row * row_length + col + j2] = local_output.data.elt[j2]; + } + } + } + } + } +} + } // namespace void multi_padding(const std::vector input_list, std::vector output_list, @@ -155,8 +232,8 @@ void multi_padding(const std::vector input_list, std::vector o // Input matrices are divided into tiles // Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles - const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); - const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); + const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size * 8 / typeToNumBits(type); + const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size * 8 / typeToNumBits(type); // Add tensors to kernel argument struct MultiPaddingArgs kernel_args; @@ -202,6 +279,78 @@ void multi_padding(const std::vector input_list, std::vector o } } +void multi_unpadding(const std::vector input_list, std::vector output_list, + const std::vector unpadded_num_rows_list, cudaStream_t stream) { + // Check that number of tensors is valid + NVTE_CHECK(output_list.size() == input_list.size(), + "Number of input and output tensors must match"); + if (input_list.empty()) { + return; + } + + // Check that tensor properties are valid + DType type = input_list[0]->data.dtype; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + const auto& input = *input_list[tensor_id]; + const auto& output = *output_list[tensor_id]; + CheckInputTensor(input, "multi_unpadding_input_" + std::to_string(tensor_id)); + CheckInputTensor(output, "multi_unpadding_output_" + std::to_string(tensor_id)); + + NVTE_CHECK(input.data.dtype == type, "Input tensor types do not match."); + NVTE_CHECK(output.data.dtype == type, "Output tensor types do not match."); + + NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions."); + NVTE_CHECK(output.data.shape[0] == unpadded_num_rows_list[tensor_id], + "output tensor shape does not match padded input shape."); + } + + // Input matrices are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles + const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); + const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); + + // Add tensors to kernel argument struct + MultiPaddingArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.block_range[0] = 0; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + // Launch kernel if argument struct is full + if (kernel_args.num_tensors == kMaxTensorsPerKernel) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type); + const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; + multi_unpadding_kernel + <<>>(kernel_args);); // NOLINT(*) + kernel_args.num_tensors = 0; + } + + // Calculate number of thread blocks needed for tensor + const int num_rows = unpadded_num_rows_list[tensor_id]; + const int row_length = input_list[tensor_id]->data.shape[1]; + const int num_tiles_m = (num_rows + tile_dim_m - 1) / tile_dim_m; + const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; + const int num_tiles = num_tiles_m * num_tiles_n; + + // Add tensor to kernel argument struct + const int pos = kernel_args.num_tensors; + kernel_args.input_list[pos] = const_cast(input_list[tensor_id]->data.dptr); + kernel_args.output_list[pos] = output_list[tensor_id]->data.dptr; + kernel_args.num_rows_list[pos] = num_rows; + kernel_args.row_length_list[pos] = row_length; + kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles; + kernel_args.num_tensors++; + } + + // Launch kernel + if (kernel_args.num_tensors > 0) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type); + const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; + multi_unpadding_kernel + <<>>(kernel_args);); // NOLINT(*) + } +} + } // namespace transformer_engine void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, @@ -211,9 +360,23 @@ void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETe std::vector input_list_, output_list_; std::vector padded_num_rows_list_; for (size_t i = 0; i < num_tensors; ++i) { - input_list_.push_back(reinterpret_cast(const_cast(input_list[i]))); - output_list_.push_back(reinterpret_cast(output_list[i])); + input_list_.push_back(convertNVTETensorCheck(input_list[i])); + output_list_.push_back(convertNVTETensorCheck(output_list[i])); padded_num_rows_list_.push_back(padded_num_rows_list[i]); } multi_padding(input_list_, output_list_, padded_num_rows_list_, stream); } + +void nvte_multi_unpadding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, + const int* unpadded_num_rows_list, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_unpadding); + using namespace transformer_engine; + std::vector input_list_, output_list_; + std::vector unpadded_num_rows_list_; + for (size_t i = 0; i < num_tensors; ++i) { + input_list_.push_back(convertNVTETensorCheck(input_list[i])); + output_list_.push_back(convertNVTETensorCheck(output_list[i])); + unpadded_num_rows_list_.push_back(unpadded_num_rows_list[i]); + } + multi_unpadding(input_list_, output_list_, unpadded_num_rows_list_, stream); +} diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index fbc6dd1e1..be06e807e 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -37,6 +37,7 @@ // Define comm overlap handles if not using ROCm #ifndef USE_ROCM + #define NVTE_DECLARE_COMM_OVERLAP_HANDLES(m) \ pybind11::enum_(m, "CommOverlapType", \ pybind11::module_local()) \ @@ -154,6 +155,10 @@ .value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \ .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD); \ NVTE_DECLARE_FUSED_ATTENTION_HANDLES(m) \ + pybind11::enum_( \ + m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \ + .value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \ + .value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT); \ NVTE_DECLARE_COMM_OVERLAP_HANDLES(m) #endif diff --git a/transformer_engine/common/util/rtc.cpp b/transformer_engine/common/util/rtc.cpp index 641fde964..054531169 100644 --- a/transformer_engine/common/util/rtc.cpp +++ b/transformer_engine/common/util/rtc.cpp @@ -159,7 +159,7 @@ void KernelManager::compile(const std::string& kernel_label, const std::string& // Choose whether to compile to PTX or cubin const int sm_arch_ = cuda::sm_arch(device_id); const int compile_sm_arch = std::min(sm_arch_, max_supported_sm_arch()); - const bool compile_ptx = (CUDA_VERSION <= 11000) || (sm_arch_ != compile_sm_arch); + const bool compile_ptx = sm_arch_ != compile_sm_arch; #endif // __HIP_PLATFORM_AMD__ // Compilation flags diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index a524fbbd4..0ed5bca6b 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -26,6 +26,10 @@ using namespace __hip_internal; typedef uint16_t hip_bfloat16x2 __attribute__((ext_vector_type(2))); #else +#if CUDA_VERSION >= 12080 +#include +#endif + #if !defined(__CUDACC_RTC__) #include #else diff --git a/transformer_engine/debug/features/api.py b/transformer_engine/debug/features/api.py index 887043c42..13ab6040d 100644 --- a/transformer_engine/debug/features/api.py +++ b/transformer_engine/debug/features/api.py @@ -12,7 +12,7 @@ import torch from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS -from transformer_engine.pytorch.tensor import all_tensor_types +from transformer_engine.pytorch.tensor import get_all_tensor_types from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor @@ -424,7 +424,7 @@ def output_assertions_hook(self, api_name, ret, **kwargs): if api_name in ["inspect_tensor", "inspect_tensor_postquantize"]: assert ret is None if api_name == "modify_tensor": - assert type(ret) in all_tensor_types + assert type(ret) in get_all_tensor_types() if ( type(ret) == torch.Tensor # pylint: disable=unidiomatic-typecheck and "dtype" in kwargs @@ -438,4 +438,4 @@ def step(self): def end_debug(self): """This function is called by the nvidia-dlframework-inspect after every debug_api.end_debug()""" - TEDebugState.reset() + TEDebugState._reset() diff --git a/transformer_engine/debug/features/fake_quant.py b/transformer_engine/debug/features/fake_quant.py index bab4b4dcf..4a5b6c34a 100644 --- a/transformer_engine/debug/features/fake_quant.py +++ b/transformer_engine/debug/features/fake_quant.py @@ -49,7 +49,7 @@ def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None): fp8_dtype = tex.DType.kFloat8E5M2 amax = tensor.abs().max().float() one = torch.ones(1, device=tensor.device) - scale = _default_sf_compute(amax, one, fp8_max) + scale = _default_sf_compute(amax, one, fp8_max, 0) quantizer = Float8Quantizer(scale, amax, fp8_dtype) else: diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index 4ca2a8ed3..e5c84a9bd 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -120,7 +120,6 @@ def inspect_tensor_postquantize( if not rowwise: return # tensor was already seen rowwise in the other gemm - tensor = tensor._data options = ( config.get("start_step", None), config.get("end_step", None), diff --git a/transformer_engine/debug/features/per_tensor_scaling.py b/transformer_engine/debug/features/per_tensor_scaling.py index eabb6304a..7b4de0a18 100644 --- a/transformer_engine/debug/features/per_tensor_scaling.py +++ b/transformer_engine/debug/features/per_tensor_scaling.py @@ -15,6 +15,7 @@ from transformer_engine.pytorch.tensor import Quantizer from transformer_engine.pytorch.tensor.float8_tensor import ( Float8Tensor, + Float8Quantizer, Float8CurrentScalingQuantizer, ) from transformer_engine.debug.features.api import TEConfigAPIMapper @@ -39,7 +40,7 @@ def per_tensor_cast( }, "[NVTORCH INSPECT ERROR] Only 2 FP8 types: E4M3 and E5M2 are supported in TE." tensor = tensor.contiguous() - quantizer = Float8CurrentScalingQuantizer(fp8_dtype) + quantizer = Float8CurrentScalingQuantizer(fp8_dtype, device=tensor.device) if out is not None: quantizer.update_quantized(tensor, out) @@ -81,7 +82,6 @@ class PerTensorScaling(TEConfigAPIMapper): transformer_engine: PerTensorScaling: enabled: True - margin: 1 gemms: [dgrad] tensors: [weight, activation] """ @@ -118,7 +118,7 @@ def modify_tensor( if key not in ["gemm", "tensor"]: raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') - assert isinstance(default_quantizer, Float8CurrentScalingQuantizer), ( + assert isinstance(default_quantizer, Float8Quantizer), ( f"[NVTORCH INSPECT ERROR] Feature={self.__class__.__name__}, API=process_tensor: " "Per-tensor current scaling can be used only within `DelayedScaling` recipe autocast." f" {layer_name}" diff --git a/transformer_engine/debug/features/utils/stats_buffer.py b/transformer_engine/debug/features/utils/stats_buffer.py index 231348405..4be465f8e 100644 --- a/transformer_engine/debug/features/utils/stats_buffer.py +++ b/transformer_engine/debug/features/utils/stats_buffer.py @@ -85,6 +85,13 @@ def feed(self, tensor, iteration): if self.modified[0] and not self.reduce_within_microbatch: return + if ( + tensor.numel() == 0 + if hasattr(tensor, "numel") + else all((t is None or t.numel() == 0) for t in tensor.get_data_tensors()) + ): + return + # save stats for tensor to tmp buffer for stat_name in self.stats_to_compute: fn, _ = STATS[stat_name] diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 84a740161..ed32de1ae 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -17,6 +17,8 @@ def _compute_dynamic_range_top(tensor): """Computes the log2 of the amax of the tensor""" tensor_abs = tensor.abs() tensor_abs = tensor_abs[tensor_abs != 0] + if tensor_abs.numel() == 0: + return torch.inf amax = tensor_abs.max().float() if not amax.all(): amax = torch.tensor(1, device=tensor.device).to(torch.float) @@ -96,7 +98,10 @@ def _get(buffers, stat_name): "max": (torch.max, lambda buffers: max(_get(buffers, "max"))), "sum": (torch.sum, lambda buffers: sum(_get(buffers, "sum"))), "mean": (torch.mean, lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel"))), - "numel": (lambda x: x.numel(), lambda buffers: sum(_get(buffers, "numel"))), + "numel": ( + lambda x: x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel(), + lambda buffers: sum(_get(buffers, "numel")), + ), "l1_norm": (lambda x: torch.norm(x, p=1), lambda buffers: sum(_get(buffers, "l1_norm"))), "l2_norm_square": ( lambda x: torch.sum(x**2), @@ -122,7 +127,7 @@ def _get(buffers, stat_name): lambda buffers: min(_get(buffers, "dynamic_range_bottom")), ), "underflows_num": ( - lambda x: (x._data == 0).sum(), + lambda x: (x.get_data_tensors()[0] == 0).sum(), lambda buffers: sum(_get(buffers, "underflows_num")), ), "std": ( @@ -137,7 +142,7 @@ def _get(buffers, stat_name): - min(_get(buffers, "dynamic_range_bottom")), ), "underflows%": ( - lambda x: (x == 0).sum() / x.numel() * 100, + lambda x: (x.get_data_tensors()[0] == 0).sum() / x.get_data_tensors()[0].numel() * 100, lambda buffers: 100 * sum(_get(buffers, "underflows_num")) / sum(_get(buffers, "numel")), ), } diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 4a7a156a0..2b859800a 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -14,10 +14,11 @@ import transformer_engine_torch as tex - +from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch.tensor.quantized_tensor import ( QuantizedTensor, Quantizer, + QuantizedTensorBase, prepare_for_saving, restore_from_saved, ) @@ -61,6 +62,12 @@ def __init__( self.tp_group = tp_group # used in inspect_tensor calls self.iteration = debug_api.DEBUG_MANAGER._trainer_iteration_count + # .internal = True is slightly faster, but results + # in errors when caching the weights. + # Setting .internal = False is safer. + if parent_quantizer is not None: + parent_quantizer.internal = False + self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name] # The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled, @@ -299,8 +306,9 @@ def quantize( iteration=self.iteration, dtype=dtype, ) - if columnwise_gemm_tensor.dtype != dtype: - raise ValueError("Dtype does not match the output of the modify_tensor call") + if dtype is not None: + if columnwise_gemm_tensor.dtype != dtype: + raise ValueError("Dtype does not match the output of the modify_tensor call") if self.rowwise_tensor_plan == API_CALL_MODIFY: rowwise_gemm_tensor = debug_api.transformer_engine.modify_tensor( layer_name=self.layer_name, @@ -311,8 +319,9 @@ def quantize( iteration=self.iteration, dtype=dtype, ) - if rowwise_gemm_tensor.dtype != dtype: - raise ValueError("Dtype does not match the output of the modify_tensor call") + if dtype is not None: + if rowwise_gemm_tensor.dtype != dtype: + raise ValueError("Dtype does not match the output of the modify_tensor call") # 3. If some tensors still are not defined we use high precision tensor. if self.rowwise_tensor_plan == HIGH_PRECISION: @@ -332,6 +341,7 @@ def quantize( quantizer=self, layer_name=self.layer_name, tensor_name=self.tensor_name, + original_tensor=tensor, ) def process_gemm_output(self, tensor: torch.Tensor): @@ -455,8 +465,12 @@ def any_feature_enabled(self) -> bool: return True return False + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + """Probably not needed for debug quantizer""" + return None + -class DebugQuantizedTensor: +class DebugQuantizedTensor(QuantizedTensorBase): """ Class containing quantized tensors after debug. Depending on configuration it can contain one or two different objects. These objects can be accessed by the method @@ -470,6 +484,7 @@ def __init__( quantizer, layer_name=None, tensor_name=None, + original_tensor=None, ): self.rowwise_gemm_tensor = rowwise_gemm_tensor @@ -477,6 +492,7 @@ def __init__( self.quantizer = quantizer self._layer_name = layer_name self._tensor_name = tensor_name + self._original_tensor = original_tensor def prepare_for_saving(self): """ " Prepare for saving method override""" @@ -524,5 +540,5 @@ def size(self): """Size of the tensor.""" return self.rowwise_gemm_tensor.size() - def update_usage(self, rowwise_usage: bool, columnwise_usage: bool): + def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None): """Update usage of the tensor.""" diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index b25c46429..fe4109cee 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -277,6 +277,7 @@ def canonicalize_attn_mask_type(attn_mask_type: str): def is_fused_attn_kernel_available( + is_training, q_dtype, kv_dtype, qkv_layout, @@ -297,6 +298,7 @@ def is_fused_attn_kernel_available( def make_helper(attn_mask_type): return tex.FusedAttnHelper( + is_training, q_dtype, kv_dtype, qkv_layout, diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 1ad96a1eb..0cd8f5a36 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -420,37 +420,35 @@ def shardy_sharding_rule( if version.parse(jax.__version__) < version.parse("0.5.0"): raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types - + prefix = "ActLuPrimitive_" x_rank = len(value_types[0].shape) scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - x_rank - 1, unique_var="ActLuPrimitive_i", flatten_axis=-2 + x_rank - 1, unique_var=prefix + "x", flatten_axis=-2 ) - x_axes = scale_rules.input_spec + (f"x{x_rank-1}",) + x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}",) out = (*x_axes[:-2], x_axes[-1]) scale_inv = scale_rules.rowwise_rule - colwise_scale_inv = scale_rules.colwise_rule + colwise_out = (prefix + "out_colwise",) + colwise_scale_inv = (prefix + "scale_inv_colwise",) if is_2x: + colwise_scale_inv = scale_rules.colwise_rule if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out = tuple( multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2) ) else: colwise_out = out - else: - colwise_out = ("j",) - colwise_scale_inv = ("k",) # amax is always a unit tensor. - amax = ("l",) + amax = (prefix + "amax",) return SdyShardingRule( ( x_axes, - "…1", + ("…1",), ), (out, colwise_out, scale_inv, colwise_scale_inv, amax), - **scale_rules.factor_sizes, ) @@ -458,7 +456,7 @@ def shardy_sharding_rule( # TODO(Jeremy): replace is_2x with q_layout -class DActLuDBiasQuantizePrimitive(BasePrimitive): +class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): """ DActLu DBias Cast Transpose Primitive """ @@ -566,7 +564,7 @@ def outer_abstract(*args, **kwargs): te_dact_dbias_quantize_p outer abstract """ (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = ( - DActLuDBiasQuantizePrimitive.abstract(*args, **kwargs) + BaseDActLuDBiasQuantizePrimitive.abstract(*args, **kwargs) ) return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias @@ -594,7 +592,7 @@ def lowering( assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype == dz_aval.dtype assert scale_aval.dtype == jnp.float32 - return ffi.ffi_lowering(DActLuDBiasQuantizePrimitive.name)( + return ffi.ffi_lowering(BaseDActLuDBiasQuantizePrimitive.name)( ctx, dz, x, @@ -623,9 +621,9 @@ def impl( te_dact_dbias_quantize_p impl """ del is_outer - assert DActLuDBiasQuantizePrimitive.inner_primitive is not None + assert BaseDActLuDBiasQuantizePrimitive.inner_primitive is not None (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = ( - DActLuDBiasQuantizePrimitive.inner_primitive.bind( + BaseDActLuDBiasQuantizePrimitive.inner_primitive.bind( dz, x, scale, @@ -671,7 +669,7 @@ def batcher( """ del is_outer check_valid_batch_dims(batch_dims) - assert DActLuDBiasQuantizePrimitive.outer_primitive is not None + assert BaseDActLuDBiasQuantizePrimitive.outer_primitive is not None dz, x, scale = batched_args _, x_bdim, scale_bdim = batch_dims @@ -684,7 +682,7 @@ def batcher( x_bdim, # dbias ) return ( - DActLuDBiasQuantizePrimitive.outer_primitive.bind( + BaseDActLuDBiasQuantizePrimitive.outer_primitive.bind( dz, x, scale, @@ -723,7 +721,7 @@ def infer_sharding_from_operands( ), "Partitioned current tensor scaling is not yet supported." out_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out" + mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out" ) if is_2x: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: @@ -733,14 +731,16 @@ def infer_sharding_from_operands( else: colwise_x_spec = (None,) colwise_out_sharding = NamedSharding( - mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out" + mesh, + PartitionSpec(*colwise_x_spec), + desc="BaseDActLuDBiasQuantizePrimitive.colwise_out", ) dbias_spec = x_spec[-2:] if is_dbias else (None,) dbias_sharding = NamedSharding( mesh, PartitionSpec(*dbias_spec), - desc="DActLuDBiasQuantizePrimitive.dbias", + desc="BaseDActLuDBiasQuantizePrimitive.dbias", ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) @@ -753,15 +753,15 @@ def infer_sharding_from_operands( colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*scale_inv_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv" + mesh, PartitionSpec(*scale_inv_spec), desc="BaseDActLuDBiasQuantizePrimitive.scale_inv" ) amax_sharding = NamedSharding( - mesh, PartitionSpec(*amax_spec), desc="DActLuDBiasQuantizePrimitive.amax" + mesh, PartitionSpec(*amax_spec), desc="BaseDActLuDBiasQuantizePrimitive.amax" ) colwise_scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*colwise_scale_inv_spec), - desc="DActLuDBiasQuantizePrimitive.colwise_scale_inv", + desc="BaseDActLuDBiasQuantizePrimitive.colwise_scale_inv", ) return ( out_sharding, @@ -791,7 +791,7 @@ def partition( scale_spec = get_padded_spec(arg_infos[2]) out_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out" + mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out" ) if is_2x: @@ -802,14 +802,16 @@ def partition( else: colwise_x_spec = (None,) colwise_out_sharding = NamedSharding( - mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out" + mesh, + PartitionSpec(*colwise_x_spec), + desc="BaseDActLuDBiasQuantizePrimitive.colwise_out", ) dbias_spec = x_spec[-2:] if is_dbias else (None,) dbias_sharding = NamedSharding( mesh, PartitionSpec(*dbias_spec), - desc="DActLuDBiasQuantizePrimitive.dbias", + desc="BaseDActLuDBiasQuantizePrimitive.dbias", ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) @@ -832,7 +834,9 @@ def partition( arg_shardings = list(arg_i.sharding for arg_i in arg_infos) # Ensure dz and x are partitioned the same way. arg_shardings[0] = NamedSharding( - mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]), desc="DActLuDBiasQuantizePrimitive.dz" + mesh, + PartitionSpec(*x_spec[:-2], x_spec[-1]), + desc="BaseDActLuDBiasQuantizePrimitive.dz", ) arg_shardings = tuple(arg_shardings) out_shardings = ( @@ -846,7 +850,7 @@ def partition( def sharded_impl(dz, x, scale): (out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = ( - DActLuDBiasQuantizePrimitive.impl( + BaseDActLuDBiasQuantizePrimitive.impl( dz, x, scale, @@ -891,32 +895,38 @@ def shardy_sharding_rule( if version.parse(jax.__version__) < version.parse("0.5.0"): raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types - - x_rank = len(value_types[1].shape) + prefix = "BaseDActLuDBiasQuantizePrimitive_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - x_rank, unique_var="DActLuDbiasQuantizePrimitive_i", flatten_axis=-2 + len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2 ) x_axes = scale_rules.input_spec + dz_axes = (*x_axes[:-2], x_axes[-1]) out = x_axes + colwise_out = (prefix + "out_colwise",) if is_2x: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) else: - colwise_out = tuple(x_axes) - else: - colwise_out = ("j",) + colwise_out = out - dbias = x_axes[-2:] if is_dbias else ("k",) - amax = ("…4",) + dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",) + amax = (prefix + "amax",) return SdyShardingRule( - (("…0",), tuple(x_axes), ("…2",)), + (dz_axes, x_axes, ("…2",)), (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), - **scale_rules.factor_sizes, ) -register_primitive(DActLuDBiasQuantizePrimitive) +register_primitive(BaseDActLuDBiasQuantizePrimitive) + + +class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): + """Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" + + +class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): + """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]: @@ -978,6 +988,7 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + noop_scaled_tensor: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -986,6 +997,7 @@ def act_lu( Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations activation_type: Type of activation function to apply. quantizer: Optional quantizer for FP8 quantization of the output. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: If quantizer is None: @@ -1030,6 +1042,10 @@ def act_lu( is_outer=True, ) out = out.reshape(output_shape) + if noop_scaled_tensor: + return ScaledTensorFactory.create_2x( + out, None, out, None, ScalingMode.NO_SCALING, dq_dtype=out.dtype + ) return out if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: @@ -1083,6 +1099,7 @@ def quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]] = ("gelu",), is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1093,6 +1110,7 @@ def quantize_dact_dbias( activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",). is_dbias: If True, compute bias gradient. Defaults to True. quantizer: Optional quantizer for FP8 quantization of the output. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: Tuple[ScaledTensor, jnp.ndarray]: A tuple containing: @@ -1106,12 +1124,49 @@ def quantize_dact_dbias( f" {x.shape} and act_len {act_len}" ) - if not DActLuDBiasQuantizePrimitive.enabled(): + scale = jnp.empty((), jnp.float32) + act_type_id = ActivationEnum[activation_type] + PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive + if not PrimitiveClass.enabled() or ( + quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE + ): return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) - # TE/common does not support colwise-only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: - return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) + if quantizer is None: + output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( + dz, + x, + scale, + # outputs float32 for dbias accumulation + out_dtype=(jnp.float32 if is_dbias else x.dtype), + # default value for no scaling, TE/common ignore this value when scale is unset + scaling_mode=ScalingMode.NO_SCALING.value, + is_2x=False, # unused + scale_dtype=jnp.float32, # unused + is_dbias=False, + act_enum=act_type_id, + act_len=act_len, + is_outer=True, + ) + output = output.astype(x.dtype) + dbias = None + if is_dbias: + dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) + + if noop_scaled_tensor: + return ( + ScaledTensorFactory.create_2x( + output, + None, + output, + None, + ScalingMode.NO_SCALING, + dq_dtype=output.dtype, + ), + dbias, + ) + + return output, dbias # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): @@ -1137,31 +1192,6 @@ def quantize_dact_dbias( if war_output is not None: return war_output - scale = jnp.empty((), jnp.float32) - - act_type_id = ActivationEnum[activation_type] - - if quantizer is None: - output, _, _, _, _, _ = DActLuDBiasQuantizePrimitive.outer_primitive.bind( - dz, - x, - scale, - # outputs float32 for dbias accumulation - out_dtype=(jnp.float32 if is_dbias else x.dtype), - # default value for no scaling, TE/common ignore this value when scale is unset - scaling_mode=ScalingMode.NO_SCALING.value, - is_2x=False, # unused - scale_dtype=jnp.float32, # unused - is_dbias=False, - act_enum=act_type_id, - act_len=act_len, - is_outer=True, - ) - dbias = None - if is_dbias: - dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) - return output.astype(x.dtype), dbias - if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = dact_lu( @@ -1175,7 +1205,7 @@ def quantize_dact_dbias( ) return out, dbias - if isinstance(quantizer, DelayedScaleQuantizer): + if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale # TE/common dact_dbias_quantize does not support gated act yet @@ -1195,7 +1225,7 @@ def quantize_dact_dbias( colwise_scale_inv, updated_amax, dbias, - ) = DActLuDBiasQuantizePrimitive.outer_primitive.bind( + ) = PrimitiveClass.outer_primitive.bind( dz, x, scale, @@ -1235,6 +1265,7 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + noop_scale_tensor: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1244,6 +1275,7 @@ def dact_lu( x: Input tensor that was used in forward pass. activation_type: Type of activation function that was applied. quantizer: Optional quantizer for FP8 quantization of the output gradient. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: The gradient of the activation with respect to the input. @@ -1254,5 +1286,6 @@ def dact_lu( activation_type=activation_type, is_dbias=False, quantizer=quantizer, + noop_scaled_tensor=noop_scale_tensor, ) return output diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 495284e11..04fcf1a8d 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -30,7 +30,7 @@ SequenceDescriptor, ) -from .misc import is_hip_extension +from ..util import is_hip_extension from .base import BasePrimitive, register_primitive from .misc import ( check_valid_batch_dims, @@ -45,6 +45,7 @@ all_reduce_sum_along_dp_fsdp, get_mesh_axis_size, get_mesh_axis_rank, + get_mesh_axis_rank_host, get_all_mesh_axes, num_of_devices, with_sharding_constraint, @@ -78,6 +79,7 @@ "window_size", "context_parallel_load_balanced", "cp_axis", + "cp_striped_window_size", ], ) @dataclass(frozen=True) @@ -96,6 +98,7 @@ class _FusedAttnConfig: window_size: Tuple[int, int] context_parallel_load_balanced: bool cp_axis: str + cp_striped_window_size: Tuple[int, int] # Only for CP + Ring + THD + SWA @dataclass(frozen=True) @@ -104,6 +107,7 @@ class FusedAttnHelper: Helper for the fused attention backend """ + is_training: bool q_dtype: jnp.dtype kv_dtype: jnp.dtype qkv_layout: QKVLayout @@ -125,6 +129,7 @@ def is_fused_attn_kernel_available(self): def get_fused_attn_backend(self): """Get the fused attention kernel backend""" return transformer_engine_jax.get_fused_attn_backend( + self.is_training, jax_dtype_to_te_dtype(self.q_dtype), jax_dtype_to_te_dtype(self.kv_dtype), self.qkv_layout.value, @@ -200,6 +205,7 @@ def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout): v_head_dim, ) + @dataclass(frozen=True) class _FusedAttnRNGStateChecker: """ @@ -316,6 +322,7 @@ def abstract( # backend determines the softmax buffer shape/dtype backend = FusedAttnHelper( + config.is_training, q_dtype, k_dtype, config.qkv_layout, @@ -462,6 +469,13 @@ def lowering( *bias_batch_shape, bias_heads, _, _ = bias_aval.shape bias_batch = reduce(operator.mul, bias_batch_shape) + if config.cp_striped_window_size is not None: + window_size_left = config.cp_striped_window_size[0] + window_size_right = config.cp_striped_window_size[1] + else: + window_size_left = config.window_size[0] + window_size_right = config.window_size[1] + return ffi.ffi_lowering(FusedAttnFwdPrimitive.name)( ctx, q, @@ -494,8 +508,8 @@ def lowering( qkv_layout=int(config.qkv_layout.value), is_training=config.is_training, deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), - window_size_left=config.window_size[0], - window_size_right=config.window_size[1], + window_size_left=window_size_left, + window_size_right=window_size_right, ) @staticmethod @@ -774,7 +788,6 @@ def abstract( qk_head_dim, v_head_dim, ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) - if config.attn_bias_type == AttnBiasType.NO_BIAS: bias_batch = bias_heads = 0 @@ -871,6 +884,13 @@ def lowering( *bias_batch_shape, bias_heads, _, _ = bias_aval.shape bias_batch = reduce(operator.mul, bias_batch_shape) + if config.cp_striped_window_size is not None: + window_size_left = config.cp_striped_window_size[0] + window_size_right = config.cp_striped_window_size[1] + else: + window_size_left = config.window_size[0] + window_size_right = config.window_size[1] + return ffi.ffi_lowering(FusedAttnBwdPrimitive.name)( ctx, q, @@ -906,8 +926,8 @@ def lowering( qkv_layout=int(config.qkv_layout.value), is_training=config.is_training, deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), - window_size_left=config.window_size[0], - window_size_right=config.window_size[1], + window_size_left=window_size_left, + window_size_right=window_size_right, ) @staticmethod @@ -1261,6 +1281,7 @@ def get_step_config(self) -> _FusedAttnConfig: window_size=self.config.window_size, context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, + cp_striped_window_size=None, ) def all_gather_kv(self, k, v): @@ -1700,6 +1721,16 @@ def check_supported(self): " NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment" ) + # If using scanloop, idx in scan_kv_block() will be a traced device value, but + # _normalize_window_size_for_cp_striped() requires all parameters to be host values + is_context_parallel = get_mesh_axis_size(self.config.cp_axis, self.mesh) > 1 + is_thd_layout = self.config.qkv_layout.is_thd() + is_sliding_window = self.config.window_size[0] != -1 + if is_context_parallel and is_thd_layout and is_sliding_window and self.use_scanloop(): + raise ValueError( + f"{header} with THD format and sliding window does not support using scan loop" + ) + def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: """Returns a _FusedAttnConfig for single CP step call to fused attention.""" return _FusedAttnConfig( @@ -1713,6 +1744,7 @@ def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: window_size=self.config.window_size, context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, + cp_striped_window_size=None, ) def stack_kv(self, k, v): @@ -2184,6 +2216,67 @@ def jax_cond_wrap(): register_primitive(FusedRingAttnBwdPrimitive) +def adjust_cp_striped_window_size(q_pos0, kv_pos0, cp_size, window_size): + """ + Adjust window size with cp_size for striped sharding, where both q_pos and + kv_pos are arithmetic sequences like [x, x+cp_size, x+2*cp_size, ...]. + Example 1: + q_pos = kv_pos = [0, 8, 16, 24, 32], cp_size = 8, window_size = (15, 0). + q_pos = 32 can look at kv_pos at [24, 32]. The effective mask is: + 0 8 16 24 32 + ---------------- + 0 | 1 0 0 0 0 + 8 | 1 1 0 0 0 + 16 | 0 1 1 0 0 + 24 | 0 0 1 1 0 + 32 | 0 0 0 1 1 + SequenceDescriptor outputs: {q,kv}_seqlen = [5, ...], {q,kv}_seq_offsets = [0, ...]. + Adjusted window size = (1, 0). + Example 2: + q_pos = [0, 8, 16, 24, 32], kv_pos = [1, 9, 17, 25, 33], cp_size = 8, + window_size = (15, 0). The effective mask is: + 1 9 17 25 33 + ---------------- + 0 | 0 0 0 0 0 + 8 | 1 0 0 0 0 + 16 | 1 1 0 0 0 + 24 | 0 1 1 0 0 + 32 | 0 0 1 1 0 + SequenceDescriptor outputs: + q_seqlen = [4, ...], q_seq_offsets = [1, ...], + kv_seqlen = [4, ...], kv_seq_offsets = [0, ...]. + If diagonal are all 1, left window size = 2. Now since diagonal are all 0, + we need to use left window size = 2 - 1 = 1 to make cuDNN work. + Example 3: + q_pos = [7, 15, 23, 31, 39], kv_pos = [0, 8, 16, 24, 32], cp_size = 8, + window_size = (22, 0). The effective mask is: + 0 8 16 24 32 + ---------------- + 7 | 1 0 0 0 0 + 15 | 1 1 0 0 0 + 23 | 0 1 1 0 0 + 31 | 0 0 1 1 0 + 39 | 0 0 0 1 1 + SequenceDescriptor outputs: {q,kv}_seqlen = [5, ...], {q,kv}_seq_offsets = [0, ...]. + Adjust window size = (1, 0). + """ + + left_limit = q_pos0 - window_size[0] + right_limit = q_pos0 + window_size[1] + + # Count how many left/right steps of size cp_size we can take from kv_pos0 -/+ cp_size + left_steps = (kv_pos0 - cp_size - left_limit) // cp_size + 1 + right_steps = (right_limit - kv_pos0 - cp_size) // cp_size + 1 + left_steps = max(left_steps, 0) + right_steps = max(right_steps, 0) + + # If kv_pos0 > q_pos0, we must reduce left window size by 1 + shift = 1 if kv_pos0 > q_pos0 else 0 + left_steps = left_steps - shift + + return left_steps, right_steps + + class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): """ Fused Striped Ring Attention Forward Primitive @@ -2192,9 +2285,6 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): @staticmethod def partition(config, mesh, arg_infos, result_infos): is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 - assert ( - not is_context_parallel or config.window_size[0] == -1 - ), "Sliding window attention is not supported when context parallelism is enabled" if not is_context_parallel: return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) @@ -2240,6 +2330,7 @@ def fwd_impl( subblock_config = config cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_rank = get_mesh_axis_rank_host(config.cp_axis, mesh) cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)] batch, q_max_seqlen, head, _ = q.shape @@ -2260,22 +2351,36 @@ def scan_kv_block(idx, carry): kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm) kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm) - output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl( - q, - kv, - _not_used, - bias, - seed, - q_seqlen, - kv_seqlen, - q_seq_offsets, - k_seq_offsets, - q_segment_ids, - kv_segment_ids, - q_segment_pos, - kv_segment_pos, - subblock_config, - ) + def compute(config): + return FusedAttnFwdPrimitive.impl( + q, + kv, + _not_used, + bias, + seed, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, + config, + ) + + if config.window_size != (-1, -1): + kv_src_rank = (cp_size + cp_rank - idx) % cp_size + # Note: all inputs of adjust_cp_striped_window_size should be host values + cp_striped_window_size = adjust_cp_striped_window_size( + cp_rank, kv_src_rank, cp_size, config.window_size + ) + current_config = replace( + subblock_config, cp_striped_window_size=cp_striped_window_size + ) + else: + current_config = subblock_config + output_per_step, softmax_aux_per_step, _ = compute(current_config) softmax_aux_per_step = softmax_aux_per_step.reshape((batch, q_max_seqlen, head, 1)) @@ -2328,9 +2433,6 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): @staticmethod def partition(config, mesh, arg_infos, result_infos): is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 - assert ( - not is_context_parallel or config.window_size[0] == -1 - ), "Sliding window attention is not supported when context parallelism is enabled" if not is_context_parallel: return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) @@ -2374,13 +2476,15 @@ def bwd_impl( subblock_config = config cp_size = get_mesh_axis_size(config.cp_axis, mesh) + # We need cp_rank to be a host value for adjust_cp_striped_window_size() + cp_rank = get_mesh_axis_rank_host(config.cp_axis, mesh) cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)] dq = jnp.zeros_like(q) dkv = jnp.zeros_like(kv) dbias = jnp.zeros_like(bias) - def scan_kv_block(_idx, carry): + def scan_kv_block(idx, carry): kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias = carry # Start communication that feeds the next iteration. @@ -2390,7 +2494,7 @@ def scan_kv_block(_idx, carry): kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm) kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm) - def compute(): + def compute(config): dq_per_step, dkv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( q, kv, @@ -2408,11 +2512,22 @@ def compute(): kv_segment_ids, q_segment_pos, kv_segment_pos, - config=subblock_config, + config=config, ) return dq_per_step, dkv_per_step, dbias_per_step - dq_per_step, dkv_per_step, dbias_per_step = compute() + if config.window_size != (-1, -1): + kv_src_rank = (cp_size + cp_rank - idx) % cp_size + # Note: all inputs of adjust_cp_striped_window_size should be host values + cp_striped_window_size = adjust_cp_striped_window_size( + cp_rank, kv_src_rank, cp_size, config.window_size + ) + current_config = replace( + subblock_config, cp_striped_window_size=cp_striped_window_size + ) + else: + current_config = subblock_config + dq_per_step, dkv_per_step, dbias_per_step = compute(current_config) kv_next, dkv = jnp.unstack(kv_dkv) dq += dq_per_step @@ -2546,6 +2661,7 @@ def fused_attn_fwd( window_size=(-1, -1) if window_size is None else window_size, context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), + cp_striped_window_size=None, ) primitive = None @@ -2667,6 +2783,7 @@ def fused_attn_bwd( window_size=(-1, -1) if window_size is None else window_size, context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), + cp_striped_window_size=None, ) primitive = None diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 5f7e8f35c..bf3b3b7fd 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -36,14 +36,15 @@ class BasePrimitive(metaclass=ABCMeta): @classmethod def enabled(cls): """ - A custom call is marked as disabled if the `cls.name` does not fully match the + A custom call is marked as disabled if the `cls.__name__` does not fully match the `NVTE_JAX_CUSTOM_CALLS_RE` pattern. + This uses the Python class name of the primitive definitions that inherit from BasePrimitive. By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names. - For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!te_act_lu$).+$'` to disable `te_act_lu`. + For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!DBiasQuantizePrimitive$).+$'` to disable `DBiasQuantizePrimitive`. """ pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*") pattern = re.compile(pattern) - is_enabled = pattern.fullmatch(cls.name) is not None + is_enabled = pattern.fullmatch(cls.__name__) is not None return is_enabled @staticmethod diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index e85df943b..39d8c89b1 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -3,45 +3,985 @@ # See LICENSE for license information. """JAX te modules""" -from typing import Tuple, Sequence, Union, Dict -from functools import partial, reduce +import math import operator +from collections.abc import Iterable +from typing import Tuple, Sequence, Union +from functools import partial, reduce + import jax import jax.numpy as jnp -from transformer_engine_jax import get_device_compute_capability +from jax import dtypes +from jax.sharding import NamedSharding, PartitionSpec +from jax.experimental.custom_partitioning import SdyShardingRule + +import transformer_engine_jax as tex +from transformer_engine_jax import get_num_compute_streams from .base import BasePrimitive, register_primitive +from .quantization import grouped_quantize -from ..util import is_hip_extension +from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type from ..quantize import ( ScaledTensor, + ScaledTensor2x, + GroupedScaledTensor1x, ScalingMode, Quantizer, + GroupedQuantizer, QuantizeConfig, + QuantizerSet, + QuantizeLayout, noop_quantizer_set, + is_fp8_gemm_with_all_layouts_supported, + apply_padding_to_scale_inv, ) +from .misc import get_padded_spec +__all__ = [ + "gemm", + "grouped_gemm", + "gemm_uses_jax_dot", + "sanitize_dims", + "get_non_contracting_dims", + "transpose_dims", +] -__all__ = ["gemm"] +jnp_float8_e4m3_type = get_jnp_float8_e4m3_type() +jnp_float8_e5m2_type = get_jnp_float8_e5m2_type() - -num_cublas_streams = 4 +num_cublas_streams = get_num_compute_streams() def get_cublas_workspace_size_bytes() -> None: """Return workspace size needed for current architecture""" if is_hip_extension(): """Return 64 MiB for gfx50x, 32 MiB for all other architectures.""" - if get_device_compute_capability() == (9, 5): + if tex.get_device_compute_capability(0) == 95: return 67_108_864 return 33_554_432 """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" - if get_device_compute_capability(0) >= 90: + if tex.get_device_compute_capability(0) >= 90: return 33_554_432 return 4_194_304 +def sanitize_dims(ndim: int, dims: Union[int, Sequence[int]]) -> Sequence[int]: + """Convert relative (negative) indexes to absolute dimension numbers.""" + dims_ = dims if isinstance(dims, Iterable) else (dims,) + if len(dims_) == 0: + return dims_ + return tuple(ndim + dim if dim < 0 else dim for dim in dims_ if dim is not None) + + +def get_non_contracting_dims(ndim, contracting_dims): + """Return a tuple of dimensions not included in the contracting dimensions.""" + contracting_dims = sanitize_dims(ndim, contracting_dims) + return tuple(dim for dim in range(ndim) if dim not in contracting_dims) + + +def transpose_dims(ndim, dims_to_transpose, flatten_axis=-1): + """Compute the new dimension numbers after transpose.""" + if len(dims_to_transpose) == 0: + return dims_to_transpose + flatten_axis = ndim - flatten_axis if flatten_axis > 0 else flatten_axis + transposed_dims = (*range(flatten_axis, ndim), *range(flatten_axis)) + return tuple(transposed_dims.index(dim) for dim in dims_to_transpose) + + +def _compatible_fp8_gemm_dtypes(lhs_dtype, rhs_dtype) -> bool: + lhs, rhs, e4m3, e5m2 = map( + dtypes.canonicalize_dtype, + ( + lhs_dtype, + rhs_dtype, + jnp_float8_e4m3_type, + jnp_float8_e5m2_type, + ), + ) + + # FP8 GEMM supports (e4m3 x e4m3), (e4m3 x e5m2) and (e5m2 x e4m3) + if (lhs is e4m3 and rhs in (e4m3, e5m2)) or (lhs in (e4m3, e5m2) and rhs is e4m3): + return True + + # Any other combination of data types is not supported + return False + + +def _get_gemm_layout( + operand_ndims: Tuple[int, int], contracting_dims: Tuple[Sequence[int], Sequence[int]] +) -> Tuple[bool, bool]: + lhs_contracting, rhs_contracting = map(sanitize_dims, operand_ndims, contracting_dims) + lhs_is_transposed = operand_ndims[0] - 1 not in lhs_contracting + rhs_is_transposed = operand_ndims[1] - 1 in rhs_contracting + return lhs_is_transposed, rhs_is_transposed + + +def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims): + lhs_q = lhs + rhs_q = rhs + + if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None: + lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims[0]) + lhs_is_transposed = lhs.ndim - 1 not in lhs_cdims + need_lhs_colwise = lhs_is_transposed and ( + lhs_quantizer.scaling_mode.is_1d_block_scaling() + or not is_fp8_gemm_with_all_layouts_supported() + ) + flatten_axis = max(lhs_cdims) + 1 if lhs_is_transposed else min(lhs_cdims) + lhs_q = lhs_quantizer.quantize( + lhs, + is_rowwise=not need_lhs_colwise, + is_colwise=need_lhs_colwise, + flatten_axis=flatten_axis, + ) + + if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: + rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1]) + rhs_is_transposed = rhs.ndim - 1 in rhs_cdims + need_rhs_colwise = not rhs_is_transposed and ( + rhs_quantizer.scaling_mode.is_1d_block_scaling() + or not is_fp8_gemm_with_all_layouts_supported() + ) + flatten_axis = min(rhs_cdims) if rhs_is_transposed else max(rhs_cdims) + 1 + rhs_q = rhs_quantizer.quantize( + rhs, + is_rowwise=not need_rhs_colwise, + is_colwise=need_rhs_colwise, + flatten_axis=flatten_axis, + ) + + assert not isinstance(lhs_q, ScaledTensor2x) + assert not isinstance(rhs_q, ScaledTensor2x) + + return lhs_q, rhs_q + + +class GemmPrimitive(BasePrimitive): + """ + Primitive for cuBLAS GEMM + """ + + name = "te_gemm_ffi" + multiple_results = True + impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + batched_dims, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + ): + del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator + + def _dims_are_consecutive(dims): + if len(dims) <= 1: + return True + return sorted(dims) == list(range(min(dims), max(dims) + 1)) + + # Sanity-check operand layouts and types + operand_ndims = (lhs.ndim, rhs.ndim) + + ( + lhs_contracting_dims, + rhs_contracting_dims, + ) = map(sanitize_dims, operand_ndims, contracting_dims) + assert _dims_are_consecutive(lhs_contracting_dims), ( + "cuBLAS GEMM expected consecutive contracting dimensions for LHS operand, but got " + f"{lhs_contracting_dims}." + ) + assert _dims_are_consecutive(rhs_contracting_dims), ( + "cuBLAS GEMM expected consecutive contracting dimensions for RHS operand, but got " + f"{rhs_contracting_dims}." + ) + + ( + lhs_batch_dims, + rhs_batch_dims, + ) = map(sanitize_dims, operand_ndims, batched_dims) + assert _dims_are_consecutive(lhs_batch_dims), ( + "cuBLAS GEMM expected consecutive batch dimensions for LHS operand, but got " + f"{lhs_batch_dims}." + ) + assert _dims_are_consecutive(rhs_batch_dims), ( + "cuBLAS GEMM expected consecutive batch dimensions for RHS operand, but got " + f"{rhs_batch_dims}." + ) + if len(lhs_batch_dims) == 0: + assert ( + len(rhs_batch_dims) == 0 + ), "cuBLAS GEMM RHS operand cannot be batched if LHS operand is not batched." + elif len(rhs_batch_dims) != 0: + assert all(bdim in lhs_contracting_dims for bdim in lhs_batch_dims) and all( + bdim in rhs_contracting_dims for bdim in rhs_batch_dims + ), "cuBLAS GEMM batched dimensions must be contracting when both operands are batched." + + lhs_contracting_size, rhs_contracting_size = map( + lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]), + (lhs.shape, rhs.shape), + (lhs_contracting_dims, rhs_contracting_dims), + ) + assert lhs_contracting_size == rhs_contracting_size, ( + "cuBLAS GEMM operands have incompatible contracting dimensions: " + f"{lhs.shape} @ idx {lhs_contracting_dims} X {rhs.shape} @ idx {rhs_contracting_dims}." + ) + + lhs_is_transposed, rhs_is_transposed = _get_gemm_layout(operand_ndims, contracting_dims) + if scaling_mode != ScalingMode.NO_SCALING: + assert _compatible_fp8_gemm_dtypes(lhs.dtype, rhs.dtype), ( + "cuBLAS GEMM quantized operands have incompatible data types: " + f"{lhs.dtype} x {rhs.dtype}." + ) + assert ( + lhs_scale_inv.size > 0 and rhs_scale_inv.size > 0 + ), "Quantized cuBLAS GEMM requires inverse scaling factors for both operands." + if ( + scaling_mode != ScalingMode.MXFP8_1D_SCALING + and not tex.is_non_nt_fp8_gemm_supported() + ): + assert not lhs_is_transposed and rhs_is_transposed, ( + "cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) " + "require non-transposed LHS and transposed RHS operands " + "(`contracting_dims=((-1, ), (-1, ))`)." + ) + + # Determine output shape and dtype + assert ( + dtypes.canonicalize_dtype(out_dtype).itemsize > 1 + ), "cuBLAS GEMM custom op does not support 8-bit quantized output types." + lhs_non_contracting_shape, rhs_non_contracting_shape = map( + lambda shape, dims: [shape[dim] for dim in range(len(shape)) if dim not in dims], + (lhs.shape, rhs.shape), + (lhs_contracting_dims, rhs_contracting_dims), + ) + out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape) + output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + + # Validate bias + bias_shape = (0,) + bias_dtype = out_dtype + if fuse_bias: + expected_bias_size = reduce(operator.mul, rhs_non_contracting_shape) + if not grad: + assert bias.size == expected_bias_size, ( + "cuBLAS GEMM bias tensor has incorrect shape, " + f"expected ({expected_bias_size}, ) but found {bias.shape}." + ) + assert bias.dtype == out_dtype, ( + "cuBLAS GEMM bias tensor has incorrect data type, " + f"expected {bias_dtype} but found {bias.dtype}." + ) + bias_shape = bias.shape + else: + bias_shape = rhs_non_contracting_shape + bias_grad = jax.core.ShapedArray(shape=bias_shape, dtype=bias_dtype) + + # Validate pre-GeLU + pre_gelu_shape = (0,) + pre_gelu_dtype = out_dtype + if fuse_gelu: + pre_gelu_shape = out_shape + if grad: + pre_gelu_ndim = len(pre_gelu_shape) + assert gelu_input.ndim == pre_gelu_shape and all( + gelu_input.shape[i] == pre_gelu_shape[i] for i in range(pre_gelu_ndim) + ), ( + "cuBLAS GEMM pre-GeLU tensor has incorrect shape, " + f"expected {pre_gelu_shape} but found {gelu_input.shape}." + ) + assert gelu_input.dtype == out_dtype, ( + "cuBLAS GEMM pre-GeLU tensor has incorrect data type, " + f"expected {pre_gelu_dtype} but found {gelu_input.dtype}." + ) + pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) + + # Need extra workspace for swizzled scale factors + lhs_swizzle_size = 0 + rhs_swizzle_size = 0 + swizzle_dtype = jnp.uint8 + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + lhs_swizzle_size = lhs_scale_inv.size + rhs_swizzle_size = rhs_scale_inv.size + lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size,), dtype=swizzle_dtype) + rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype) + + # Declare cuBLAS workspace + # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not + # necessarily 256 bytes aligned, we add some padding to ensure alignment. + workspace_size = get_cublas_workspace_size_bytes() + 256 + workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) + + return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace + + @staticmethod + def outer_abstract(*args, **kwargs): + outputs = GemmPrimitive.abstract(*args, **kwargs) + return outputs[:-3] # discard workspace arrays + + @staticmethod + def lowering( + ctx, + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + batched_dims, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + ): + del batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, out_dtype + lhs_aval, _, rhs_aval, *_ = ctx.avals_in + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) + lhs_transposed, rhs_transposed = _get_gemm_layout( + (lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims) + ) + + args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input) + kwargs = { + "scaling_mode": int(scaling_mode.value), + "lhs_axis_boundary": max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), + "rhs_axis_boundary": min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, + "lhs_transposed": lhs_transposed, + "rhs_transposed": rhs_transposed, + "fuse_bias": fuse_bias, + "fuse_gelu": fuse_gelu, + "grad": grad, + "use_split_accumulator": use_split_accumulator, + } + + operand_output_aliases = {} + if fuse_bias and not grad: + operand_output_aliases.update({4: 1}) # bias <-> bias_grad + if fuse_gelu and grad: + operand_output_aliases.update({5: 2}) # gelu_input <-> pre_gelu_out + + return jax.ffi.ffi_lowering( + GemmPrimitive.name, + operand_output_aliases=operand_output_aliases, + )(ctx, *args, **kwargs) + + @staticmethod + def impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + batched_dims, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + ): + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) + lhs_transposed, rhs_transposed = _get_gemm_layout( + (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims) + ) + lhs_scale_inv = apply_padding_to_scale_inv( + lhs_scale_inv, + scaling_mode, + lhs.shape, + is_colwise=lhs_quantized_colwise, + flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), + ) + rhs_scale_inv = apply_padding_to_scale_inv( + rhs_scale_inv, + scaling_mode, + rhs.shape, + is_colwise=rhs_quantized_colwise, + flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, + ) + + outputs = GemmPrimitive.inner_primitive.bind( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + batched_dims=batched_dims, + lhs_quantized_colwise=lhs_quantized_colwise, + rhs_quantized_colwise=rhs_quantized_colwise, + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + use_split_accumulator=use_split_accumulator, + ) + return outputs[:-3] # discard workspace arrays + + @staticmethod + def batcher( + batched_args, + jax_batch_dims, + out_dtype, + contracting_dims, + batched_dims, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + ): + assert GemmPrimitive.outer_primitive is not None + lhs, _, rhs, *_ = batched_args + lhs_bdims, _, rhs_bdims, *_ = jax_batch_dims + arg_lhs_bdims, arg_rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims) + arg_lhs_bdims = (None,) if len(arg_lhs_bdims) == 0 else arg_lhs_bdims + assert all(bdim == arg_bdim for bdim, arg_bdim in zip(lhs_bdims, arg_lhs_bdims)), ( + "User-specified batch dimension(s) for cuBLAS GEMM LHS operand does not match batch " + f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}." + ) + arg_rhs_bdims = (None,) if len(arg_rhs_bdims) == 0 else arg_rhs_bdims + assert all(bdim == arg_bdim for bdim, arg_bdim in zip(rhs_bdims, arg_rhs_bdims)), ( + "User-specified batch dimension(s) for cuBLAS GEMM RHS operand does not match batch " + f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}." + ) + + # Output is batched like the non-contracting batch dimensions of the LHS operand + lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims) + lhs_non_contracting_bdims = tuple(dim for dim in lhs_bdims if dim not in lhs_cdims) + out_bdims = (None,) if len(lhs_non_contracting_bdims) == 0 else lhs_non_contracting_bdims + + # Bias gradient is never batched + bias_bdims = (None,) + + # Pre-GeLU output, if exists, is batched like GEMM output + pre_gelu_bdims = (None,) + if fuse_gelu and not grad: + pre_gelu_bdims = out_bdims + + return ( + GemmPrimitive.outer_primitive.bind( + *batched_args, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + batched_dims=batched_dims, + lhs_quantized_colwise=lhs_quantized_colwise, + rhs_quantized_colwise=rhs_quantized_colwise, + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + use_split_accumulator=use_split_accumulator, + ), + (out_bdims, bias_bdims, pre_gelu_bdims), + ) + + @staticmethod + def _decompose_operand_specs(specs, contracting_dims, batch_dims): + ndims = len(specs) + cdims, bdims = map(sanitize_dims, (ndims, ndims), (contracting_dims, batch_dims)) + + # Batch specs + bspecs = tuple(specs[i] for i in bdims) + + # Non-batch leading dimension specs + lspecs = tuple(specs[i] for i in range(ndims) if i not in cdims + bdims) + + # Non-batch contracting dimension specs + cspecs = tuple(specs[i] for i in range(ndims) if i in cdims and i not in bdims) + + return bspecs, lspecs, cspecs + + @staticmethod + def _parse_operand_output_specs(arg_infos, contracting_dims, batched_dims): + lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) + lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) + lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map( + sanitize_dims, 2 * [lhs_ndim, rhs_ndim], contracting_dims + batched_dims + ) + ( + (lhs_bspecs, lhs_lspecs, lhs_cspecs), + (rhs_bspecs, rhs_lspecs, rhs_cspecs), + ) = map( + GemmPrimitive._decompose_operand_specs, + (lhs_specs, rhs_specs), + (lhs_cdims, rhs_cdims), + (lhs_bdims, rhs_bdims), + ) + + # Batched dimensions must have the same sharding + if len(lhs_bdims) > 0 and len(rhs_bdims) > 0: + assert all( + lhs_bspec == rhs_bspec for lhs_bspec, rhs_bspec in zip(lhs_bspecs, rhs_bspecs) + ), ( + "cuBLAS GEMM operand batch dimensions must have the same sharding: " + f"{lhs_specs} @ idx {lhs_bdims} x {rhs_specs} @ idx {rhs_bdims}." + ) + + # Only one each of the non-batched leading dimensions and non-batched contracting + # dimensions can be sharded + lhs_ldims, rhs_ldims = map( + lambda ndim, exclude: tuple(dim for dim in range(ndim) if dim not in exclude), + (lhs_ndim, rhs_ndim), + (lhs_bdims + lhs_cdims, rhs_bdims + rhs_cdims), + ) + (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none) = map( + lambda specs: tuple(spec for spec in specs if spec is not None), + (lhs_lspecs, rhs_lspecs, lhs_cspecs, rhs_cspecs), + ) + assert len(lhs_lspec_not_none) <= 1 and len(rhs_lspec_not_none) <= 1, ( + "cuBLAS GEMM operands can have only one sharded non-batched leading dimension: " + f"{lhs_specs} @ idx {lhs_ldims} x {rhs_specs} @ idx {rhs_ldims}." + ) + assert len(lhs_cspec_not_none) <= 1 and len(rhs_cspec_not_none) <= 1, ( + "cuBLAS GEMM operands can have only one sharded non-batched contracting dimension: " + f"{lhs_specs} @ idx {lhs_cdims} x {rhs_specs} @ idx {rhs_cdims}." + ) + + # Extract single leading and contracting dimension specs + (lhs_lspec, rhs_lspec, lhs_cspec, rhs_cspec) = map( + lambda specs: None if len(specs) == 0 else specs[0], + (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none), + ) + + # Reproducing jax.nn.scaled_matmul() custom partitioning for arbitrary GEMM layouts + # with row-wise LHS:(B, M, K1) and row-wise RHS:(B, N, K2) operands. + # 1. K1 == K2 != None and N == None + # LHS: (B, M, K) + # RHS: (B, None, K) + # OUT: (B, M, None) --(AR)-> (B, M, None) + # 2. K1 == K2 != None and M == N != None + # LHS: (B, M, K) + # RHS: (B, N, K)--(AG)->(B, None, K) + # OUT: (B, M, None) --(RS)--> (B, M, N) + # 3. M == N + # LHS: (B, M, K)--(AG)->(B, M, None) + # RHS: (B, M, K)--(AG)->(B, None, None) + # OUT: (B, M, None) + # 4. M != N + # LHS: (B, M, K)--(AG)->(B, M, None) + # RHS: (B, N, K)--(AG)->(B, N, None) + # OUT: (B, M, N) + reduce_flag = lhs_cspec is not None and lhs_cspec == rhs_cspec + all_reduce_output = reduce_flag and rhs_lspec is None + reduce_scatter_output = reduce_flag and lhs_lspec is not None and lhs_lspec == rhs_lspec + all_reduce_spec = reduce_scatter_spec = scatter_dim = None + + lhs_non_contracting_specs, rhs_non_contracting_specs = map( + lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims), + (lhs_specs, rhs_specs), + (lhs_cdims, rhs_cdims), + ) + out_specs = (*lhs_non_contracting_specs, *rhs_non_contracting_specs) + if reduce_scatter_output: + # All-gather (if necessary) the non-batch non-contracting dimension of RHS + # (B, N, K) --(AG)-> (B, None, K) + # (B, M, K) x (B, None, K)^T = (B, M, None) --(RS)-> (B, M, N) + rhs_spec = tuple( + rhs_spec[i] if i in set(rhs_bdims + rhs_cdims) else None for i in range(rhs_ndim) + ) + reduce_scatter_spec = lhs_cspec + scatter_dim = out_specs.index(rhs_lspec) + + elif all_reduce_output: + # Set all output trailing dimensions to zero + out_specs = ( + *lhs_non_contracting_specs, + *[None for _ in range(len(rhs_non_contracting_specs))], + ) + all_reduce_spec = lhs_cspec + else: + # All-gather (if necessary) the non-batch contracting dimensions + # (B, M, K) --(AG)-> (B, M, None) + # (B, N, K) --(AG)-> (B, N, None) + # (B, M, None) x (B, N, None)^T = (B, M, N) + lhs_specs = tuple( + None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i] + for i in range(lhs_ndim) + ) + rhs_specs = tuple( + None if i in rhs_cdims and i not in rhs_bdims else rhs_specs[i] + for i in range(rhs_ndim) + ) + # Check if RHS non-contracting spec also appears in the LHS non-contracting specs + if rhs_lspec is not None and rhs_lspec in tuple( + lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_cdims + ): + # All-gather (if necessary) the non-batch non-contracting dimensions of RHS + # (B, N, None) --(AG)-> (B, None, None) + # (B, M, None) x (B, None, None)^T = (B, M, None) + rhs_specs = tuple( + None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i] + for i in range(rhs_ndim) + ) + # Set all output trailing dimensions to zero + out_specs = ( + *lhs_non_contracting_specs, + *[None for _ in range(len(rhs_non_contracting_specs))], + ) + + # Bias and Pre-GeLU sharding is based on GEMM output + bias_specs = out_specs[len(lhs_non_contracting_specs) :] + gelu_specs = out_specs + + return ( + (lhs_specs, rhs_specs, bias_specs, gelu_specs), + (out_specs, bias_specs, gelu_specs), + all_reduce_spec, + reduce_scatter_spec, + scatter_dim, + ) + + @staticmethod + def infer_sharding_from_operands( + out_dtype, + contracting_dims, + batched_dims, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + mesh, + arg_infos, + result_infos, + ): + del ( + out_dtype, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + grad, + ) + del use_split_accumulator, result_infos + + (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( + GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims) + ) + out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) + + # Discard bias gradient spec if there is no bias fusion + if not fuse_bias: + dbias_specs = (None,) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_specs)) + + # Discard pre-GeLU output spec if there is no GeLU fusion + if not fuse_gelu: + pre_gelu_specs = (None,) + pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_specs)) + + return [out_sharding, dbias_sharding, pre_gelu_sharding] + + @staticmethod + def partition( + out_dtype, + contracting_dims, + batched_dims, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + mesh, + arg_infos, + result_infos, + ): + del result_infos + + ( + (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), + (out_specs, dbias_specs, pre_gelu_specs), + all_reduce_spec, + reduce_scatter_spec, + scatter_dim, + ) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims) + + # Assemble argument shardings + # NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded. + none_sharding = NamedSharding(mesh, PartitionSpec(None)) + lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs)) + rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs)) + arg_shardings = ( + lhs_sharding, + lhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding, + rhs_sharding, + rhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding, + ) + + # Discard bias input spec if there is no bias fusion + if not fuse_bias: + bias_input_specs = (None,) + arg_shardings += (NamedSharding(mesh, PartitionSpec(*bias_input_specs)),) + + # Discard pre-GeLU input spec if there is no GeLU fusion + if not fuse_gelu: + gelu_input_specs = (None,) + arg_shardings += (NamedSharding(mesh, PartitionSpec(*gelu_input_specs)),) + + # Assemble output shardings + out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))] + + # Discard bias gradient spec if there is no bias fusion + if not fuse_bias: + dbias_specs = (None,) + out_shardings.append(NamedSharding(mesh, PartitionSpec(*dbias_specs))) + + # Discard pre-GeLU output spec if there is no GeLU fusion + if not fuse_gelu: + pre_gelu_specs = (None,) + out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs))) + + def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): + outputs = GemmPrimitive.impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + batched_dims=batched_dims, + lhs_quantized_colwise=lhs_quantized_colwise, + rhs_quantized_colwise=rhs_quantized_colwise, + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + use_split_accumulator=use_split_accumulator, + ) + + # All-Reduce/Reduce-Scatter GEMM output + if all_reduce_spec is not None: + outputs[0] = jax.lax.psum(outputs[0], all_reduce_spec) + if fuse_gelu and not grad: + outputs[2] = jax.lax.psum(outputs[2], all_reduce_spec) + elif reduce_scatter_spec is not None: + outputs[0] = jax.lax.psum_scatter( + outputs[0], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True + ) + if fuse_gelu and not grad: + outputs[2] = jax.lax.psum_scatter( + outputs[2], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True + ) + + return outputs + + return mesh, _sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule( + out_dtype, + contracting_dims, + batched_dims, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + mesh, + operand_types, + result_types, + ): + del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype, grad, use_split_accumulator + del mesh, result_types + + prefix = "GemmPrimitive_" + + def _generate_operand_rules(name, ndim, cdims, bdims): + specs = [] + ldims = tuple(i for i in range(ndim) if i not in bdims + cdims) + for i in range(ndim): + dim_name = None + if i in bdims: + dim_idx = bdims.index(i) if len(bdims) > 1 else "" + dim_name = f"b{dim_idx}" + elif i in cdims: + dim_idx = cdims.index(i) if len(cdims) > 1 else "" + dim_name = f"k{dim_idx}" + else: + dim_idx = ldims.index(i) if len(ldims) > 1 else "" + dim_name = f"{name}_l{dim_idx}" + specs.append(prefix + dim_name) + return specs + + lhs, _, rhs, *_ = operand_types + operand_ndims = (len(lhs.shape), len(rhs.shape)) + (lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims) = map( + lambda dims: map(sanitize_dims, operand_ndims, dims), + (contracting_dims, batched_dims), + ) + lhs_specs, rhs_specs = map( + _generate_operand_rules, + ("lhs", "rhs"), + operand_ndims, + (lhs_cdims, rhs_cdims), + (lhs_bdims, rhs_bdims), + ) + lhs_scale_specs = ("…1",) + rhs_scale_specs = ("…2",) + if scaling_mode.is_1d_block_scaling(): + # Shardy rules for MXFP8 scales cannot be related to the operands because of the + # global-unpadding and local-padding workflow. This can potentially insert expensive + # re-shards in the partition call later if the scales are not already sharded correctly. + lhs_scale_specs, rhs_scale_specs = map( + lambda specs: tuple(spec.replace(prefix, prefix + "scale_inv_") for spec in specs), + (lhs_specs, rhs_specs), + ) + + lhs_non_cspec = tuple(lhs_specs[i] for i in range(operand_ndims[0]) if i not in lhs_cdims) + rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims) + out_spec = (*lhs_non_cspec, *rhs_non_cspec) + bias_spec = rhs_non_cspec if fuse_bias else ("…4",) + gelu_spec = out_spec if fuse_gelu else ("…5",) + + return SdyShardingRule( + operand_mappings=( + lhs_specs, + lhs_scale_specs, + rhs_specs, + rhs_scale_specs, + bias_spec, + gelu_spec, + ), + result_mappings=( + out_spec, + bias_spec, + gelu_spec, + ), + ) + + +register_primitive(GemmPrimitive) + + +def gemm_uses_jax_dot() -> bool: + """Check if the GEMM call directs to the TE custom cuBLAS call or native JAX dot.""" + return not GemmPrimitive.enabled() + + +def _te_gemm( + lhs: Union[jax.Array, ScaledTensor], + rhs: Union[jax.Array, ScaledTensor], + bias: jax.Array = None, + gelu_input: jax.Array = None, + lhs_quantizer: Quantizer = None, + rhs_quantizer: Quantizer = None, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), + batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()), + fuse_bias: bool = False, + fuse_gelu: bool = False, + grad: bool = False, + use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP, +) -> Tuple[jax.Array, ...]: + + # Prepare non-quantized GEMM operands + lhs_data = lhs + rhs_data = rhs + lhs_scale_inv = jnp.empty(0, dtype=jnp.float32) + rhs_scale_inv = jnp.empty(0, dtype=jnp.float32) + scaling_mode = ScalingMode.NO_SCALING + lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) + lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims) + + # Quantize operands (if necessary) + lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) + + # Extract GEMM custom op inputs from quantized operands + if isinstance(lhs_q, ScaledTensor): + assert isinstance(rhs_q, ScaledTensor) or rhs_quantizer is not None, ( + "cuBLAS GEMM with quantized LHS and non-quantized RHS operands requires a valid " + "`Quantizer` object to quantize the RHS operand." + ) + if isinstance(lhs_q, ScaledTensor2x): + # Choose the quantization of the contracting dimension(s) + lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() + scaling_mode = lhs_q.scaling_mode + lhs_data = lhs_q.data + lhs_scale_inv = lhs_q.scale_inv + if lhs_q.data_layout == "T": + lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis) + lhs_bdims = transpose_dims(lhs_q.ndim, lhs_bdims, flatten_axis=lhs_q.flatten_axis) + + if isinstance(rhs_q, ScaledTensor): + assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( + "cuBLAS GEMM with non-quantized LHS and quantized RHS operands requires a valid " + "`Quantizer` object to quantize the LHS operand." + ) + if isinstance(rhs_q, ScaledTensor2x): + # Choose the quantization of the contracting dimension(s) + rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() + assert rhs_q.scaling_mode == lhs_q.scaling_mode, ( + "cuBLAS GEMM quantized operands have mismatched scaling types, " + f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}." + ) + rhs_data = rhs_q.data + rhs_scale_inv = rhs_q.scale_inv + if rhs_q.data_layout == "T": + rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis) + rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis) + + # Dummy empties for bias and gelu + out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype + if bias is None or not (fuse_bias and not grad): + bias = jnp.empty(0, dtype=out_dtype) + if gelu_input is None or not (fuse_gelu and grad): + gelu_input = jnp.empty(0, dtype=out_dtype) + + return GemmPrimitive.outer_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + gelu_input, + out_dtype=out_dtype, + contracting_dims=(lhs_cdims, rhs_cdims), + batched_dims=(lhs_bdims, rhs_bdims), + lhs_quantized_colwise=lhs_q.is_colwise if isinstance(lhs_q, ScaledTensor) else False, + rhs_quantized_colwise=rhs_q.is_colwise if isinstance(rhs_q, ScaledTensor) else False, + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + use_split_accumulator=use_split_accumulator, + ) + + class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM @@ -49,73 +989,157 @@ class GroupedGemmPrimitive(BasePrimitive): name = "te_grouped_gemm_ffi" multiple_results = True - impl_static_args = () + impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15) inner_primitive = None outer_primitive = None @staticmethod - def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias): + def abstract( + lhs_data_aval, + lhs_scale_inv_aval, + rhs_data_aval, + rhs_scale_inv_aval, + bias_aval, + group_sizes_aval, + group_offset_aval, + *, + M, + N, + K, + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + is_grouped_dense_wgrad, + ): """ + Grouped GEMM operation. + Args: - *args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias: - args[ 0 : num_gemms] are the lhs tensors, - args[ num_gemms : 2*num_gemms] are the rhs tensors, - args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors, - args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors, - args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True. - num_gemms: Number of GEMM operations to perform. - scaling_mode: Scaling mode for the GEMM operations. - out_dtype: Data type of the output tensors. - has_bias: Boolean indicating if bias tensors are provided. + lhs_data: Left-hand side input matrix data, 1D flattened array + lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array + rhs_data: Right-hand side input matrix data, 1D flattened array + rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array + bias: Bias matrix of shape (G, N) + group_sizes: 1D array containing the sizes of each group + group_offset: 1D array containing offsets for each group (not yet implemented) + M: Number of rows in the output matrix + N: Number of columns in the output matrix + K: Number of columns in the left-hand side matrix + lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed + rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed + scaling_mode: Scaling mode for the GEMM operations + out_dtype: Data type of the output tensors + has_bias: Boolean indicating if bias tensors are provided + is_grouped_dense_wgrad: Boolean indicating if this is a grouped dense wgrad operation + where both lhs and rhs are 2D matrices and output is (G, M, N) Returns: - A tuple of ShapedArray objects of size num_gemms+1: - ret[0 : num_gemms]: GEMM output tensors, - ret[num_gemms]:workspace tensor. + A jnp.ndarray containing the result of the grouped GEMM operation """ - del scaling_mode - expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms - assert ( - len(args) == expected_num_args - ), f"Expected {expected_num_args} input arguments, but got {len(args)}" - A_list = args[0:num_gemms] - B_list = args[num_gemms : 2 * num_gemms] - # A and B have shapes [1, m, k] and [1, n, k] - out_list_aval = tuple( - jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype) - for A, B in zip(A_list, B_list) - ) + del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval + del K, lhs_is_trans, rhs_is_trans, has_bias + # TODO(Phuong): move some shape checks from Cpp to here workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams + workspace_alignment_padding = 256 + tensor_scaling_sinv_aligment = 16 + mxfp8_scaling_sinv_alignment_padding = 256 + # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not + # necessarily 256 bytes aligned, we add some padding to ensure alignment. + workspace_size += workspace_alignment_padding + if scaling_mode in ( + ScalingMode.DELAYED_TENSOR_SCALING.value, + ScalingMode.CURRENT_TENSOR_SCALING.value, + ): + # For tensor scaling, each matrix has a single scale value, but it + # needs to be aligned to 16 bytes for CUDA 12.9.1 and later. + workspace_size += lhs_scale_inv_aval.size * tensor_scaling_sinv_aligment + workspace_size += rhs_scale_inv_aval.size * tensor_scaling_sinv_aligment + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + # We also pad scale_inv swizzle buffers size for 256 bytes alignment. + workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding + workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) - return (*out_list_aval, workspace_aval) + + out_shape = (M, N) + if is_grouped_dense_wgrad: + out_shape = (group_sizes_aval.size, M, N) + out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + return (out_aval, workspace_aval) @staticmethod def outer_abstract(*args, **kwargs): (out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs) - return out_aval + return (out_aval,) @staticmethod - def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias): + def lowering( + ctx, + *args, + M, + N, + K, + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + is_grouped_dense_wgrad, + ): del out_dtype return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( ctx, *args, - num_gemms=num_gemms, - scaling_mode=int(scaling_mode), + M=M, + N=N, + K=K, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode.value, has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, ) @staticmethod - def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias): + def impl( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + group_offset, + M, + N, + K, + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + is_grouped_dense_wgrad, + ): assert GroupedGemmPrimitive.inner_primitive is not None - out = GroupedGemmPrimitive.inner_primitive.bind( - *args, - num_gemms=num_gemms, - scaling_mode=scaling_mode.value, + (out, _) = GroupedGemmPrimitive.inner_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + group_offset, + M=M, + N=N, + K=K, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode, out_dtype=out_dtype, has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, ) - return out[:-1] # out is [out_list, wkspace], only return out_list + return (out,) register_primitive(GroupedGemmPrimitive) @@ -147,68 +1171,33 @@ def _shape_normalization(x, dimension_numbers, already_transposed: bool = False) def _calculate_remaining_shape(shape, contracting_dims): - return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims) - - -def _dequantize(x, scale_inv, dq_dtype): - return x.astype(dq_dtype) * scale_inv.astype(dq_dtype) + contracting_dims_ = sanitize_dims(len(shape), contracting_dims) + return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims_) # Apply jit to guarantee correctness of FP8 GEMM. -@partial( - jax.jit, - static_argnums=( - 2, - 3, - 4, - ), -) -def __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision): - # Need to hard-code the dequantize here instead of calling lhs.dequantize() for pattern matching - lhs_dq = _dequantize(lhs.data, lhs.scale_inv, lhs.dq_dtype) - rhs_dq = _dequantize(rhs.data, rhs.scale_inv, rhs.dq_dtype) - - # Reshape + Transpose - # [..., M, K] -> [B, M, K] - # [..., K, M] -> [B, M, K] - lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.data_layout == "N") - rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.data_layout == "T") - - dim_nums = (((2,), (2,)), ((0,), (0,))) - out_3d = jax.lax.dot_general( - lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype - ) - return out_3d - - -def _jax_gemm_tensor_scaling_fp8( - lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]] -): - """FP8 GEMM for XLA pattern match""" - assert rhs.scaling_mode.is_tensor_scaling(), "rhs does not have tensor scaling mode" - +@partial(jax.jit, static_argnums=(2, 3)) +def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision): (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums if lhs.data_layout == "T": - lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract) + lhs_contract = transpose_dims(lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis) + lhs_batch = transpose_dims(lhs.data.ndim, lhs_batch, flatten_axis=lhs.flatten_axis) if rhs.data_layout == "T": - rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract) - - lhs_dn = (lhs_contract, lhs_batch) - rhs_dn = (rhs_contract, rhs_batch) + rhs_contract = transpose_dims(rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis) + rhs_batch = transpose_dims(rhs.data.ndim, rhs_batch, flatten_axis=rhs.flatten_axis) - lhs_remain_shape = _calculate_remaining_shape(lhs.data.shape, lhs_contract) - rhs_remain_shape = _calculate_remaining_shape(rhs.data.shape, rhs_contract) + dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) - precision = ( - jax.lax.Precision.HIGHEST if QuantizeConfig.FP8_2X_ACC_FPROP else jax.lax.Precision.DEFAULT + out_fp8 = jax.lax.dot_general( + lhs.data, rhs.data, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype ) - out_3d = __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision) + scale_inv = lhs.scale_inv * rhs.scale_inv + out = (out_fp8 * scale_inv).astype(lhs.dq_dtype) - # Reshape [B, M, N] -> [..., M, N] - out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape) return out +@partial(jax.jit, static_argnums=(2,)) def _jax_gemm_mxfp8_1d( lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]] ): @@ -218,7 +1207,6 @@ def _jax_gemm_mxfp8_1d( assert ( rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING ), "rhs does not have MXFP8 1D scaling mode" - from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums @@ -241,15 +1229,11 @@ def _jax_gemm_mxfp8_1d( lhs_scale_3d = _shape_normalization(lhs.scale_inv, (lhs_contract, lhs_batch)) rhs_scale_3d = _shape_normalization(rhs.scale_inv, (rhs_contract, rhs_batch)) - # Slice out the padding as scaled_matmul does not support padded scales yet - lhs_scale_3d = jnp.asarray(lhs_scale_3d[:, : lhs_3d.shape[1], : int(lhs_3d.shape[2] / 32)]) - rhs_scale_3d = jnp.asarray(rhs_scale_3d[:, : rhs_3d.shape[1], : int(rhs_3d.shape[2] / 32)]) - # JAX scaled_matmul only supports NT now (TN-gemm) # * Expected shape: # * lhs_data (B, M, K) * rhs_data (B, N, K) # * lhs_scale (B, M, K_block) * rhs_scale (B, N, K_block) - out_3d = scaled_matmul_wrapper( + out_3d = jax.nn.scaled_matmul( lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=lhs.dq_dtype ) # Reshape [1, reduce(..., M), N] -> [..., M, N] @@ -267,50 +1251,41 @@ def _jax_gemm( lhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor], contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), - quantizer_set: Dict["str", Quantizer] = noop_quantizer_set, + lhs_quantizer: Quantizer = None, + rhs_quantizer: Quantizer = None, ) -> jnp.ndarray: """ FP8 GEMM via JAX """ - dim_nums = (contracting_dims, ((), ())) def _jax_gemm_fp8_impl(lhs, rhs): - if lhs.scaling_mode.is_tensor_scaling(): - return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums) + assert ( + rhs.scaling_mode == lhs.scaling_mode + ), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}" + precision = ( + jax.lax.Precision.HIGHEST + if QuantizeConfig.FP8_2X_ACC_FPROP + else jax.lax.Precision.DEFAULT + ) + return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision) if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums) raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") - if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): - return _jax_gemm_fp8_impl(lhs, rhs) - - if not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor): - if quantizer_set != noop_quantizer_set: - assert type(quantizer_set.x) is type(quantizer_set.kernel) - (((lhs_contract_dim,), (rhs_contract_dim,)), _) = dim_nums - lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1 - rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1 - # Call JAX quantization so that XLA can do pattern matching (QDQ --> FP8 gemm) - lhs_q = quantizer_set.x.quantize( - lhs, - is_rowwise=lhs_is_rowwise, - is_colwise=not lhs_is_rowwise, - ) - rhs_q = quantizer_set.kernel.quantize( - rhs, - is_rowwise=rhs_is_rowwise, - is_colwise=not rhs_is_rowwise, - ) - return _jax_gemm_fp8_impl(lhs_q, rhs_q) + lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) + + if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor): + return _jax_gemm_fp8_impl(lhs_q, rhs_q) if ( isinstance(lhs, jnp.ndarray) and isinstance(rhs, jnp.ndarray) - and quantizer_set == noop_quantizer_set + and lhs_quantizer is None + and rhs_quantizer is None ): return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype) @@ -320,156 +1295,303 @@ def _jax_gemm_fp8_impl(lhs, rhs): def gemm( lhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor], - contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), - quantizer_set: Dict["str", Quantizer] = noop_quantizer_set, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), + batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()), + lhs_quantizer: Quantizer = None, + rhs_quantizer: Quantizer = None, + **kwargs, +) -> Tuple[jnp.ndarray, ...]: + r"""General matrix multiplication with optional quantization. + + Parameters + ---------- + lhs: Union[jax.Array, ScaledTensor] + Left-hand side operand in the matrix multiplication. + rhs: Union[jax.Array, ScaledTensor] + Right-hand side operand in the matrix multiplication. + lhs_quantizer: Quantizer, default = None + Object for down-casting the LHS operand for quantized GEMM. + rhs_quantizer: Quantizer, default = None + Object for down-casting the RHS operand for quantized GEMM. + contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (0, )) + Tuple of sequences representing the contracting dimensions of the operands. + batched_dims: Tuple[Sequence[int], Sequence[int]], default = ((), ()), + Tuple of sequences representing the batched dimensions of the operands. This is *not* used + to perform a batched matrix multiplication, but it is required to avoid a potentially + undesirable reduction in any batched contracting dimensions when invoked with sharded + operands (e.g. when computing weight gradients in a Flax module). + bias: jax.Array, default = None + Optional additive bias term, required for forward GEMM with bias fusion. Only supported + with TE's custom call to cuBLAS GEMM. + gelu_input: jax.Array, default = None + Pre-GeLU output from forward GEMM, required for backward/grad GEMM with dGeLU fusion. Only + supported with TE's custom call to cuBLAS GEMM. + fuse_bias: bool, default = False + Enable bias addition in forward GEMM or bias gradient in backward GEMM. Only supported with + TE's custom call to cuBLAS GEMM. + fuse_gelu: bool, default = False + Enable GeLU activation in forward GEMM or GeLU gradient in backward GEMM. Only supported + with TE's custom call to cuBLAS GEMM. + grad: bool, default = False + Flag for switching bias and GeLU fusions from forward to backward mode. Only supported with + TE's custom call to cuBLAS GEMM. + use_split_accumulator: bool, default = True + Enable promoting some intermediate sums to higher precision when accumulating the result in + the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. + + Returns + ------- + jax.Array: + Result of the operation. For TE's custom call to cuBLAS GEMM, this result can include the + GeLU application when `fuse_gelu=True` and `grad=False`, the GeLU gradient contribution + when `fuse_gelu=True` and `grad=True`, and the additive bias when `fuse_bias=True` and + `grad=False`. + Optional[jax.Array]: + Bias gradient when `fuse_bias=True` and `grad=True`. Only supported with TE's custom call + to cuBLAS GEMM. + Optional[jax.Array]: + Pre-GeLU GEMM output when `fuse_gelu=True` and `grad=False`. This is required as an input + to `_te_gemm()` with `fuse_gelu=True` and `grad=True` in the backward pass in order to + compute the GeLU contribution to the gradient. Only supported with TE's custom call to + cuBLAS GEMM. + """ + # Try to get LHS and RHS quantizers from a quantizer set for backward compatibility + if lhs_quantizer is None or rhs_quantizer is None: + quantizer_set = kwargs.get("quantizer_set", None) + if quantizer_set is not None: + lhs_quantizer = quantizer_set.x + rhs_quantizer = quantizer_set.kernel + + # Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled + fuse_bias = kwargs.get("fuse_bias", False) + fuse_gelu = kwargs.get("fuse_gelu", False) + if not GemmPrimitive.enabled(): + assert kwargs.get("bias", None) is None and not fuse_gelu, ( + "TE GEMM was invoked with bias fusion options that are not supported by the " + "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " + "GEMM primitive is disabled." + ) + assert kwargs.get("gelu_input", None) is None and not fuse_bias, ( + "TE GEMM was invoked with GeLU fusion options that are not supported by the " + "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " + "GEMM primitive is disabled." + ) + return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) + + outputs = _te_gemm( + lhs, + rhs, + lhs_quantizer=lhs_quantizer, + rhs_quantizer=rhs_quantizer, + contracting_dims=contracting_dims, + batched_dims=batched_dims, + **kwargs, + ) + + # Discard empty outputs + grad = kwargs.get("grad", False) + clean_outputs = outputs[0] # first output is the final result and is never empty + if (fuse_bias and grad) or (fuse_gelu and not grad): + clean_outputs = (outputs[0],) + if fuse_bias and grad: # only return bias gradient if it exists + clean_outputs += (outputs[1],) + if fuse_gelu and not grad: # only return pre-GeLU output if it exists + clean_outputs += (outputs[2],) + return clean_outputs + + +def grouped_gemm( + lhs: Union[jnp.ndarray, GroupedScaledTensor1x], + rhs: Union[jnp.ndarray, GroupedScaledTensor1x], + group_sizes: jnp.ndarray, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)), + bias: jnp.ndarray = None, + precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, + preferred_element_type: jnp.dtype = None, + group_offset: jnp.array = None, + quantizer_set: QuantizerSet = noop_quantizer_set, ) -> jnp.ndarray: - """General matrix multiplication with optional quantization. + """ + Grouped GEMM operation. Args: - lhs: First input matrix. - rhs: Second input matrix. - contracting_dims: Tuple of two sequences representing the contracting dimensions. - The first sequence represents the contracting dimensions of the first matrix, - and the second sequence represents the contracting dimensions of the second matrix. - quantizer_set: Set of quantizers for FP8 quantization of the output. - If None, no quantization is applied and the output has the same dtype as the inputs. + lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x + rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x + group_sizes: 1D array containing the sizes of each group + contracting_dims: Tuple of two sequences representing the contracting dimensions + bias: Bias tensor of shape (G, N) + precision: JAX precision for the GEMM operation + preferred_element_type: Preferred data type for the output tensor + group_offset: 1D array containing offsets for each group (not yet implemented) + quantizer_set: Set of quantizers for FP8 quantization of the input and output Returns: - If quantizer_set is None: - The matrix multiplication result. - Shape: (M, N) - Dtype: Same as input dtype - If quantizer_set is provided: - A ScaledTensor containing the quantized matrix multiplication result. + A jnp.ndarray containing the result of the grouped GEMM operation + + Note: + Tested shapes: + lhs: [M, K] or [K, N] + rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K] """ + # TODO(Phuong): implement the group_offset + group_offset = group_offset or jnp.zeros((1,), jnp.int32) + + # TODO(Phuong): implement the precision + del precision + + if isinstance(lhs, jnp.ndarray): + assert isinstance(rhs, jnp.ndarray) + out_dtype = lhs.dtype + lhs_shape = lhs.shape + rhs_shape = rhs.shape + lhs_data = lhs + rhs_data = rhs + lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32) + scaling_mode = ScalingMode.NO_SCALING + elif isinstance(lhs, GroupedScaledTensor1x): + assert isinstance(rhs, GroupedScaledTensor1x) + out_dtype = lhs.dq_dtype + lhs_shape = lhs.original_shape + rhs_shape = rhs.original_shape + lhs_data = lhs.data + rhs_data = rhs.data + lhs_scale_inv = lhs.scale_inv + rhs_scale_inv = rhs.scale_inv + assert lhs.scaling_mode == rhs.scaling_mode + scaling_mode = lhs.scaling_mode + else: + raise TypeError("Unsupported lhs type object!") - return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set) + out_dtype = preferred_element_type or out_dtype + lhs_contract_dim, rhs_contract_dim = contracting_dims -""" -def swizzled_scale(scales): - # Swizzle the scale tensor for FP8 GEMM - assert scales.ndim == 2 - rows, cols = scales.shape - scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4) - scales = jnp.transpose(scales, (0, 3, 2, 1, 4)) - scales = scales.reshape(rows, cols) - return scales + lhs_is_trans = lhs_contract_dim[-1] != len(lhs_shape) - 1 + lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1) + # rhs_shape [G, K, N] + rhs_is_trans = rhs_contract_dim[0] != 1 + rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim) -def grouped_gemm( - lhs_list: List[Union[jnp.ndarray, ScaledTensor]], - rhs_list: List[Union[jnp.ndarray, ScaledTensor]], - contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]], - bias_list: List[jnp.ndarray] = None, -) -> List[jnp.ndarray]: - # Grouped GEMM for multiple pairs of tensors. - assert ( - len(lhs_list) == len(rhs_list) == len(contracting_dims_list) - ), "lhs_list, rhs_list, contracting_dims_list must have the same length" - - num_gemms = len(lhs_list) - lhs_list_ = [] - rhs_list_ = [] - lhs_sinv_list_ = [] - rhs_sinv_list_ = [] - bias_list_ = [] - for i in range(num_gemms): - lhs = lhs_list[i] - rhs = rhs_list[i] - contracting_dims = contracting_dims_list[i] - dim_nums = (contracting_dims, ((), ())) - if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): - scaling_mode = lhs.scaling_mode - lhs_shape = lhs.data.shape - rhs_shape = rhs.data.shape - out_dtype = lhs.dq_dtype - # For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout - if lhs.scaling_mode.is_tensor_scaling(): - assert not ( - lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 - ), "FP8 GEMM does not support E5M2 * E5M2" - ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims - if lhs.data_layout == "T": - lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim - if rhs.data_layout == "T": - rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim - dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) + is_grouped_dense_wgrad = False + if len(rhs_shape) == 2: + rhs_is_trans = rhs_contract_dim[0] != 0 + is_grouped_dense_wgrad = True + + # TODO(Hua): thses are for fp16 dense wgrad, any better way to handle this? + if ( + is_grouped_dense_wgrad + and not isinstance(lhs, ScaledTensor) + and not isinstance(rhs, ScaledTensor) + ): + lhs_is_trans = True + rhs_is_trans = False + lhs_flatten_axis = 1 + rhs_flatten_axis = 1 + + if ( + not isinstance(lhs, ScaledTensor) + and not isinstance(rhs, ScaledTensor) + and quantizer_set != noop_quantizer_set + ): + assert isinstance(quantizer_set.x, GroupedQuantizer) + assert type(quantizer_set.x) is type(quantizer_set.kernel) + scaling_mode = quantizer_set.x.scaling_mode + if ( + quantizer_set.x.scaling_mode.is_tensor_scaling() + and is_fp8_gemm_with_all_layouts_supported() + ): + lhs_is_rowwise = rhs_is_rowwise = True else: - # For jnp.ndarray, only consider contracting_dims, data_layout is always NN - scaling_mode = ScalingMode.NO_SCALING - lhs_shape = lhs.shape - rhs_shape = rhs.shape - out_dtype = lhs.dtype - - (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums - lhs_dn = (lhs_contract, lhs_batch) - rhs_dn = (rhs_contract, rhs_batch) - - lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) - rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) - - # Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy - if scaling_mode == ScalingMode.NO_SCALING: - lhs_3d = _shape_normalization(lhs, lhs_dn) - rhs_3d = _shape_normalization(rhs, rhs_dn) - elif scaling_mode.is_tensor_scaling(): - lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") - rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") - elif scaling_mode == ScalingMode.MXFP8_1D_SCALING: - lhs_3d = _shape_normalization(lhs.data, lhs_dn) - rhs_3d = _shape_normalization(rhs.data, rhs_dn) - lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) - rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn) - # swizzled_scale requires a matrix - lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze()) - rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze()) + lhs_is_rowwise = not lhs_is_trans + rhs_is_rowwise = rhs_is_trans + quantizer_set.x.q_layout = ( + QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE + ) + quantizer_set.kernel.q_layout = ( + QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE + ) + lhs_q = grouped_quantize(lhs, quantizer_set.x, group_sizes, lhs_flatten_axis) + rhs_q = grouped_quantize( + rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis + ) + lhs_data = lhs_q.data + rhs_data = rhs_q.data + lhs_scale_inv = lhs_q.scale_inv + rhs_scale_inv = rhs_q.scale_inv + lhs_shape = lhs_q.original_shape + rhs_shape = rhs_q.original_shape + + assert not ( + lhs_data.dtype == jnp_float8_e5m2_type and rhs_data.dtype == jnp_float8_e5m2_type + ), "FP8 GEMM does not support E5M2 * E5M2" + + # Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs + # thus additional transpose is required + if scaling_mode.is_tensor_scaling() and not is_fp8_gemm_with_all_layouts_supported(): + if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): + lhs_layout_is_T = lhs.data_layout == "T" + rhs_layout_is_T = rhs.data_layout == "T" else: - raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}") - - # Note: already_transposed doesn't matter for the output shape - # x.shape = [B, D1, D2] - # contracting_dims = (2, ) --> output.shape = [1, B * D1, D2] - # contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1] - # x.shape = [D1, D2] - # contracting_dims = (1, ) --> output.shape = [1, D1, D2] - # contracting_dims = (0, ) --> output.shape = [1, D2, D1] - bm = lhs_remain_shape[0] - bn = rhs_remain_shape[0] - kl = lhs_3d.shape[-1] - kr = rhs_3d.shape[-1] - assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}" - if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0): - print("grouped_gemm input pair {i} has invalid problem shape for lowering: ") - print(f"m = {bm}, n = {bn}, k = {kl}; ") - print("cuBLAS requires the problem shapes being multiples of 16") - assert (bm % 16 == 0) and (bn % 16 == 0) and (kl % 16 == 0) - - lhs_list_.append(lhs_3d) - rhs_list_.append(rhs_3d) - if scaling_mode == ScalingMode.NO_SCALING: - lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) - rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) - if scaling_mode.is_tensor_scaling(): - lhs_sinv_list_.append(lhs.scale_inv) - rhs_sinv_list_.append(rhs.scale_inv) - if scaling_mode == ScalingMode.MXFP8_1D_SCALING: - lhs_sinv_list_.append(lhs_scale_inv) - rhs_sinv_list_.append(rhs_scale_inv) - if bias_list is not None: - bias_list_.append(bias_list[i]) - - out_list = GroupedGemmPrimitive.outer_primitive.bind( - *lhs_list_, - *rhs_list_, - *lhs_sinv_list_, - *rhs_sinv_list_, - *bias_list_, - num_gemms=num_gemms, - scaling_mode=scaling_mode, + lhs_layout_is_T = lhs_q.data_layout == "T" + rhs_layout_is_T = rhs_q.data_layout == "T" + # we can't apply _shape_normalization on the grouped input + # thus we need to ensure that lhs is in N and rhs is in T + assert ( + lhs_is_trans == lhs_layout_is_T + ), "lhs input must be transposed before calling grouped_gemm" + assert ( + not rhs_is_trans == rhs_layout_is_T + ), "rhs input must be transposed before calling grouped_gemm" + lhs_is_trans = False + rhs_is_trans = True + lhs_ndim = len(lhs_shape) + rhs_ndim = len(rhs_shape) + if lhs_layout_is_T: + lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) + if rhs_layout_is_T: + # For rhs [G, K, N], need to exclude the G dim from contract_dim + if group_sizes.size == rhs_shape[0]: + rhs_contract_dim = tuple( + (rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim + ) + else: + rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) + + # Calling GroupedGEMM Custom Call + K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim) + K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim) + assert K_lhs == K_rhs + M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim)) + N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:]) # Exclude G + + if is_grouped_dense_wgrad: + N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)) + else: + assert group_sizes.size == rhs_shape[0] + + assert group_offset.size == 1 + + has_bias = bias is not None + assert not has_bias or bias.shape == (group_sizes.size, N) + bias = jnp.empty((), jnp.float32) if bias is None else bias + + (out,) = GroupedGemmPrimitive.outer_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + group_offset, + M=M, + N=N, + K=K_lhs, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode.value, out_dtype=out_dtype, - has_bias=1 if bias_list is not None else 0, + has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, ) - - return out_list -""" + return out diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index afb3caabf..1971ccfa3 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -190,19 +190,34 @@ def get_xla_flag(flag: str, default=None, cast=str): return default +def get_min_device_compute_capability(): + """ + Returns the minimum compute capability of all local devices. + """ + return min( + transformer_engine_jax.get_device_compute_capability(local_gpu_id) + for local_gpu_id in range(len(jax.local_devices())) + ) + + def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quantizer=None): """ Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to calculate dbias separately. This function checks if the workaround should be applied. """ + if quantizer is None: + return False + arch_l_100 = False for local_gpu_id in range(len(jax.local_devices())): if transformer_engine_jax.get_device_compute_capability(local_gpu_id) < 100: arch_l_100 = True break + # _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE, + # but this fails when bias fusion is turned on with arch < 100. + force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() return ( - quantizer is not None - and quantizer.q_layout == QuantizeLayout.ROWWISE + (force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE) and arch_l_100 and is_dbias ) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 79d951452..8885ae2ea 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -593,16 +593,17 @@ def shardy_sharding_rule( result_types, ) + prefix = "NormFwdPrimitive_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[0].shape), unique_var="NormFwdPrimitive_i", flatten_axis=-1 + len(value_types[0].shape), unique_var=prefix + "x", flatten_axis=-1 ) x_axes = scale_rules.input_spec - out = x_axes[:-1] + ("k",) - colwise_out = out if is_2x else ("…4",) + out = x_axes + colwise_out = out if is_2x else (prefix + "out_colwise",) rsigma = x_axes[:-1] - mu = ("…5",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma - amax = ("…6",) + mu = (prefix + "mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma + amax = (prefix + "amax",) return SdyShardingRule( (x_axes, ("…1",), ("…2",), ("…3",)), @@ -615,7 +616,6 @@ def shardy_sharding_rule( mu, rsigma, ), - **scale_rules.factor_sizes, ) @@ -1282,6 +1282,7 @@ def normalization_fwd( epsilon: float, norm_type: str, quantizer: Optional[Quantizer], + noop_scaled_tensor: bool = False, ): """Common wrapper for normalization forward pass. @@ -1298,6 +1299,7 @@ def normalization_fwd( - 'layernorm': Layer normalization - 'rmsnorm': Root mean square normalization quantizer: Optional quantizer for FP8 quantization of the output. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: A tuple containing: @@ -1325,6 +1327,15 @@ def normalization_fwd( else: raise ValueError(f"{norm_type=} is not supported.") + if quantizer is None and noop_scaled_tensor: + return ( + ScaledTensorFactory.create_2x( + output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype + ), + mu, + rsigma, + ) + return output, mu, rsigma diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index a516b8d2b..7a5b31ad7 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -7,6 +7,7 @@ import operator from functools import reduce from typing import Tuple, Optional +import math from packaging import version import jax @@ -26,14 +27,18 @@ jax_dtype_to_te_dtype, multidim_transpose, should_apply_1x_fused_dbias_war_for_arch_l_100, + get_min_device_compute_capability, NamedSharding, ) from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp -from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory from ..quantize import ( + ScaledTensor2x, + ScaledTensor, + ScaledTensorFactory, + GroupedScaledTensor1x, Quantizer, + GroupedQuantizer, QuantizeLayout, - DelayedScaleQuantizer, ScalingMode, compute_scale_from_amax, ) @@ -44,10 +49,10 @@ from jax.extend import ffi # pylint: disable=ungrouped-imports -__all__ = ["quantize", "quantize_dbias"] +__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] -class DBiasQuantizePrimitive(BasePrimitive): +class BaseDBiasQuantizePrimitive(BasePrimitive): """ Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias """ @@ -158,7 +163,7 @@ def outer_abstract(*args, **kwargs): updated_amax, dbias, _, - ) = DBiasQuantizePrimitive.abstract(*args, **kwargs) + ) = BaseDBiasQuantizePrimitive.abstract(*args, **kwargs) return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias @staticmethod @@ -182,7 +187,7 @@ def lowering( x_aval, scale_aval = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval.dtype == jnp.float32 - return ffi.ffi_lowering(DBiasQuantizePrimitive.name)( + return ffi.ffi_lowering(BaseDBiasQuantizePrimitive.name)( ctx, x, scale, @@ -208,7 +213,7 @@ def impl( te_dbias_quantize_p implementation """ del is_outer - assert DBiasQuantizePrimitive.inner_primitive is not None + assert BaseDBiasQuantizePrimitive.inner_primitive is not None ( out, colwise_out, @@ -217,7 +222,7 @@ def impl( updated_amax, dbias, _, - ) = DBiasQuantizePrimitive.inner_primitive.bind( + ) = BaseDBiasQuantizePrimitive.inner_primitive.bind( x, scale, out_dtype=out_dtype, @@ -265,14 +270,14 @@ def batcher( """ del is_outer check_valid_batch_dims(batch_dims) - assert DBiasQuantizePrimitive.outer_primitive is not None + assert BaseDBiasQuantizePrimitive.outer_primitive is not None x, scale = batched_args x_bdim, scale_bdim = batch_dims amax_bdim = scale_bdim out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim return ( - DBiasQuantizePrimitive.outer_primitive.bind( + BaseDBiasQuantizePrimitive.outer_primitive.bind( x, scale, out_dtype=out_dtype, @@ -305,7 +310,7 @@ def infer_sharding_from_operands( out_sharding = NamedSharding( mesh, PartitionSpec(*x_spec), - desc="DBiasQuantizePrimitive.out_sharding", + desc="BaseDBiasQuantizePrimitive.out_sharding", ) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if ScalingMode(scaling_mode).is_tensor_scaling(): @@ -317,14 +322,14 @@ def infer_sharding_from_operands( colwise_out_sharding = NamedSharding( mesh, PartitionSpec(*colwise_out_spec), - desc="DBiasQuantizePrimitive.colwise_out_sharding", + desc="BaseDBiasQuantizePrimitive.colwise_out_sharding", ) dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,) dbias_sharding = NamedSharding( mesh, PartitionSpec(*dbias_spec), - desc="DBiasQuantizePrimitive.dbias_sharding", + desc="BaseDBiasQuantizePrimitive.dbias_sharding", ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) @@ -337,15 +342,15 @@ def infer_sharding_from_operands( colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv" + mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv" ) amax_sharding = NamedSharding( - mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax" + mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax" ) colwise_scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*colwise_scale_inv_spec), - desc="DBiasQuantizePrimitive.colwise_scale_inv", + desc="BaseDBiasQuantizePrimitive.colwise_scale_inv", ) return ( @@ -377,7 +382,7 @@ def partition( out_sharding = NamedSharding( mesh, PartitionSpec(*x_spec), - desc="DBiasQuantizePrimitive.out_sharding", + desc="BaseDBiasQuantizePrimitive.out_sharding", ) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if ScalingMode(scaling_mode).is_tensor_scaling(): @@ -389,14 +394,14 @@ def partition( colwise_out_sharding = NamedSharding( mesh, PartitionSpec(*colwise_out_spec), - desc="DBiasQuantizePrimitive.colwise_out_sharding", + desc="BaseDBiasQuantizePrimitive.colwise_out_sharding", ) dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,) dbias_sharding = NamedSharding( mesh, PartitionSpec(*dbias_spec), - desc="DBiasQuantizePrimitive.dbias_sharding", + desc="BaseDBiasQuantizePrimitive.dbias_sharding", ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) @@ -409,15 +414,15 @@ def partition( colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv" + mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv" ) amax_sharding = NamedSharding( - mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax" + mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax" ) colwise_scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*colwise_scale_inv_spec), - desc="DBiasQuantizePrimitive.colwise_scale_inv", + desc="BaseDBiasQuantizePrimitive.colwise_scale_inv", ) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) @@ -438,7 +443,7 @@ def sharded_impl(x, scale): local_colwise_scale_inv, local_amax, local_dbias, - ) = DBiasQuantizePrimitive.impl( + ) = BaseDBiasQuantizePrimitive.impl( x, scale, out_dtype=out_dtype, @@ -488,9 +493,10 @@ def shardy_sharding_rule( raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") del out_dtype, scale_dtype, is_outer, mesh, result_types + prefix = "BaseDBiasQuantizePrimitive_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( len(value_types[0].shape), - unique_var="DBiasQuantizePrimitive_i", + unique_var=prefix + "x", flatten_axis=flatten_axis, ) @@ -498,26 +504,31 @@ def shardy_sharding_rule( colwise_scale_inv = scale_rules.colwise_rule out = x_axes + colwise_out = (prefix + "out_colwise",) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if ScalingMode(scaling_mode).is_tensor_scaling(): colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis)) else: colwise_out = x_axes - else: - colwise_out = ("j",) - colwise_scale_inv = ("k",) - dbias = x_axes[flatten_axis:] if is_dbias else ("l",) - amax = ("m",) + dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",) + amax = (prefix + "amax",) return SdyShardingRule( (x_axes, ("…1",)), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), - **scale_rules.factor_sizes, ) -register_primitive(DBiasQuantizePrimitive) +register_primitive(BaseDBiasQuantizePrimitive) + + +class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive): + """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" + + +class QuantizePrimitive(BaseDBiasQuantizePrimitive): + """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" def _jax_quantize( @@ -529,11 +540,12 @@ def _jax_quantize( def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1): - assert flatten_axis < 0 + sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis + assert sum_axis < dx.ndim, "Flatten axis out of bounds!" dtype = dtype or dx.dtype dbias = jnp.sum( dx.astype(jnp.float32), - axis=tuple(range(dx.ndim + flatten_axis)), + axis=tuple(range(sum_axis)), keepdims=False, ) return dbias.astype(dtype) @@ -559,6 +571,7 @@ def _quantize_dbias_impl( is_dbias: bool = False, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -568,23 +581,34 @@ def _quantize_dbias_impl( quantizer is not None ), "quantizer must be provided if dq_dtype is provided" + # Early-exit for non-quantized call dq_dtype = dq_dtype or x.dtype - - if not DBiasQuantizePrimitive.enabled(): + if quantizer is None: + dbias = None if is_dbias: - return _jax_quantize_dbias( - x, - quantizer=quantizer, - dq_dtype=dq_dtype, - flatten_axis=flatten_axis, + dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) + if noop_scaled_tensor: + # Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor() + # always works. + return ( + ScaledTensorFactory.create_2x( + x, + None, + x, + None, + ScalingMode.NO_SCALING, + dq_dtype=x.dtype, + data_layout="NN", + flatten_axis=flatten_axis, + ), + dbias, ) - return ( - _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), - None, - ) + return x, dbias - # TE/common doesn't support colwise only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: + # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, + # fall back on the native-JAX quantize implementation + PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive + if quantizer.q_layout == QuantizeLayout.COLWISE or not PrimitiveClass.enabled(): if is_dbias: return _jax_quantize_dbias( x, @@ -596,9 +620,8 @@ def _quantize_dbias_impl( _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), None, ) - scale = jnp.empty((), jnp.float32) - # TE/common dbias_quantize does not support 1x on arch < 100 + # TE/common custom quantize op does not support dbias fusion with 1x quantization on arch < 100 if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): out, _ = _quantize_dbias_impl( x=x, @@ -610,21 +633,27 @@ def _quantize_dbias_impl( dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) return out, dbias - if quantizer is None: - if is_dbias: - return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) - return x, None - + scale = jnp.empty((), jnp.float32) if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Globally reduce amax across all devices for current scaling so we have a single global scale. # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this # until the tensor is dequantized (e.g. in the GEMM). amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32) scale = compute_scale_from_amax(amax, quantizer.q_dtype) - - if isinstance(quantizer, DelayedScaleQuantizer): + elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale + # It is faster to use 1x quantization for tensor scaling + is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) + force_1x_quantization = ( + quantizer.scaling_mode.is_tensor_scaling() + and quantizer.is_2x2x() + and is_1x_kernel_supported + ) + q_layout = quantizer.q_layout + if force_1x_quantization: + q_layout = QuantizeLayout.ROWWISE + ( rowwise_casted_output, colwise_casted_output, @@ -632,12 +661,12 @@ def _quantize_dbias_impl( colwise_scale_inv, updated_amax, dbias, - ) = DBiasQuantizePrimitive.outer_primitive.bind( + ) = PrimitiveClass.outer_primitive.bind( x, scale, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, - q_layout=quantizer.q_layout.value, + q_layout=q_layout.value, flatten_axis=flatten_axis, scale_dtype=quantizer.get_scale_dtype(), is_dbias=is_dbias, @@ -647,6 +676,15 @@ def _quantize_dbias_impl( if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x(): colwise_scale_inv = rowwise_scale_inv + if q_layout == QuantizeLayout.ROWWISE: + # Quantizer requires 2x quantization, but we are using 1x quantization + # for performance reasons, so we need to generate the colwise data in JAX + if flatten_axis < 0: + flatten_axis += x.ndim + colwise_casted_output = jnp.transpose( + rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis)) + ) + quantizer.update(updated_amax) out = ScaledTensorFactory.create( @@ -667,6 +705,7 @@ def quantize( x: jnp.ndarray, quantizer: Quantizer, flatten_axis: int = -1, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -676,6 +715,8 @@ def quantize( quantizer: Quantizer for FP8 quantization of the output. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. + noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer + is None. Returns: A ScaledTensor containing the quantized input tensor. @@ -684,6 +725,7 @@ def quantize( x, quantizer=quantizer, flatten_axis=flatten_axis, + noop_scaled_tensor=noop_scaled_tensor, ) return out @@ -693,6 +735,7 @@ def quantize_dbias( quantizer: Quantizer, is_dbias: bool = True, flatten_axis: int = -1, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -703,6 +746,8 @@ def quantize_dbias( is_dbias: If True, compute bias gradient. Defaults to True. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. + noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when + quantizer is None. Returns: A tuple containing: @@ -712,5 +757,319 @@ def quantize_dbias( Shape: (K,) or empty if is_dbias is False. """ return _quantize_dbias_impl( - dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis + dz, + quantizer=quantizer, + is_dbias=is_dbias, + flatten_axis=flatten_axis, + noop_scaled_tensor=noop_scaled_tensor, + ) + + +class GroupedQuantizePrimitive(BasePrimitive): + """ + Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias + """ + + name = "te_grouped_quantize_ffi" + multiple_results = True + impl_static_args = ( + 3, + 4, + 5, + 6, + 7, + 8, + ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + x_aval, + scale_aval, + group_sizes_aval, + *, + out_dtype, + scaling_mode, + q_layout, + flatten_axis, + group_axis, + scale_dtype, + ): + """ + te_dbias_quantize_p abstract + """ + dtype = dtypes.canonicalize_dtype(x_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + out_shape = math.prod(x_aval.shape) + # TODO(Phuong): can scale_aval be None? + assert scale_aval is None or scale_aval.dtype == jnp.float32 + + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( + scaling_mode + ).get_grouped_scale_shape_2x( + x_aval.shape, + group_sizes_aval.size, + group_axis, + is_padded=True, + flatten_axis=flatten_axis, + ) + + if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + rowwise_out_shape = out_shape + else: + rowwise_out_shape = (1,) + rowwise_scale_inv_shape = (1,) + rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype) + + amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) + + if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + colwise_out_shape = out_shape + else: + colwise_out_shape = (1,) + colwise_scale_inv_shape = (1,) + colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype) + rowwise_scale_inv_aval = jax.core.ShapedArray( + shape=rowwise_scale_inv_shape, dtype=scale_dtype + ) + colwise_scale_inv_aval = jax.core.ShapedArray( + shape=colwise_scale_inv_shape, dtype=scale_dtype + ) + + return ( + rowwise_out_aval, + colwise_out_aval, + rowwise_scale_inv_aval, + colwise_scale_inv_aval, + amax_aval, + ) + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + te_dbias_quantize_p outer primitive abstract + """ + # Phuong: keeping outer abstract so that we can add fuse dbias later + ( + rowwise_out, + colwise_out, + scale_inv, + colwise_scale_inv, + updated_amax, + ) = GroupedQuantizePrimitive.abstract(*args, **kwargs) + return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax + + @staticmethod + def lowering( + ctx, + x, + scale, + group_sizes, + *, + out_dtype, + scaling_mode, + q_layout, + flatten_axis, + group_axis, + scale_dtype, + ): + """ + te_dbias_quantize_p lowering rules + """ + del out_dtype, scale_dtype + x_aval, scale_aval, group_sizes_aval = ctx.avals_in + assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert scale_aval.dtype == jnp.float32 + assert group_sizes_aval.dtype == jnp.int32 + assert group_axis == 0 + return ffi.ffi_lowering(GroupedQuantizePrimitive.name)( + ctx, + x, + scale, + group_sizes, + scaling_mode=scaling_mode.value, + q_layout=q_layout, + flatten_axis=flatten_axis, + ) + + @staticmethod + def impl( + x, + scale, + group_sizes, + out_dtype, + scaling_mode, + q_layout, + flatten_axis, + group_axis, + scale_dtype, + ): + """ + te_dbias_quantize_p implementation + """ + assert GroupedQuantizePrimitive.inner_primitive is not None + ( + rowwise_out, + colwise_out, + rowwise_scale_inv, + colwise_scale_inv, + updated_amax, + ) = GroupedQuantizePrimitive.inner_primitive.bind( + x, + scale, + group_sizes, + out_dtype=out_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + flatten_axis=flatten_axis, + group_axis=group_axis, + scale_dtype=scale_dtype, + ) + return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax) + + +register_primitive(GroupedQuantizePrimitive) + + +def grouped_quantize( + x: jnp.ndarray, + quantizer: GroupedQuantizer, + group_sizes: jnp.ndarray = None, + flatten_axis: int = -1, +) -> GroupedScaledTensor1x: + """Quantize a tensor in grouped manner. + + This function quantizes a tensor by splitting it into groups along a specified axis + and applying quantization to each group separately. The groups can be either specified + explicitly through group_sizes or automatically split along the group_axis. + + Args: + x: Input tensor to quantize + quantizer: The quantizer to use for quantization + group_sizes: Array of ints containing the size of each group (default: None) + flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) + + Returns: + A GroupedScaledTensor1x containing the quantized data + + Note: + - If group_sizes is not provided, the tensor will be split into equal-sized groups + along the group_axis + - The group_axis is currently fixed to 0 + - The quantizer's q_layout determines whether row-wise, column-wise, or both + quantization is applied + """ + + if quantizer is None: + return x + + # TODO(Phuong): add support for flatten_axis = -2 + assert flatten_axis in ( + -1, + x.ndim - 1, + ), f"Only flatten_axis = -1 is supported for now, got {flatten_axis}" + group_axis = 0 + + if group_sizes is None: + group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32) + + if not GroupedQuantizePrimitive.enabled(): + return quantizer.quantize( + x, flatten_axis=flatten_axis, group_sizes=group_sizes, group_axis=group_axis + ) + n_groups = group_sizes.size + original_shape = x.shape + assert n_groups == len( + quantizer.quantizers + ), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}" + scale = jnp.empty((n_groups,), jnp.float32) + + if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: + for i, quantizer_i in enumerate(quantizer.quantizers): + scale = scale.at[i].set(quantizer_i.scale[0]) + + if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) + segment_ids = jnp.repeat( + jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis] + ) + grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) + for i in range(n_groups): + tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype) + scale = scale.at[i].set(tmp_scale[0]) + + is_tensor_scaling = quantizer.scaling_mode in ( + ScalingMode.DELAYED_TENSOR_SCALING, + ScalingMode.CURRENT_TENSOR_SCALING, + ) + # WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet + # So we performance ROWWISE_COLWISE and use the colwise_tensor_output + apply_colwise_war = is_tensor_scaling and quantizer.q_layout == QuantizeLayout.COLWISE + q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout + ( + rowwise_casted_output, + colwise_casted_output, + rowwise_scale_inv, + colwise_scale_inv, + updated_amax, + ) = GroupedQuantizePrimitive.outer_primitive.bind( + x, + scale, + group_sizes, + out_dtype=quantizer.q_dtype, + scaling_mode=quantizer.scaling_mode.value, + q_layout=q_layout.value, + flatten_axis=flatten_axis, + group_axis=group_axis, + scale_dtype=quantizer.get_scale_dtype(), + ) + + # For DelayedScaling2x and CurrentScaling2x, the scale buffer + # is shared between rowwise and colwise + if is_tensor_scaling and quantizer.is_2x2x() or apply_colwise_war: + colwise_scale_inv = rowwise_scale_inv + + # TODO(Phuong): store the whole updated_amax in the grouped_quantize instead? + if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: + for i, quantizer_i in enumerate(quantizer.quantizers): + quantizer_i.update(updated_amax[i].reshape((1,))) + + out = ScaledTensorFactory.create( + data=rowwise_casted_output, + scale_inv=rowwise_scale_inv, + colwise_data=colwise_casted_output, + colwise_scale_inv=colwise_scale_inv, + scaling_mode=quantizer.scaling_mode, + dq_dtype=x.dtype, + q_layout=quantizer.q_layout, + data_layout=quantizer.get_data_layout(), + flatten_axis=flatten_axis, + group_sizes=group_sizes, + original_shape=original_shape, + group_axis=group_axis, + ) + return out + + +def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray: + """ + Compute the grouped bias gradient. + + Args: + grad: jnp.ndarray of shape (M, N) + group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M + + Returns: + dbias: jnp.ndarray of shape (num_groups, N) + """ + assert grad.ndim == 2, "Input grad must be a 2D tensor." + assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor." + + segment_ids = jnp.repeat( + jnp.arange(group_sizes.size), group_sizes, total_repeat_length=grad.shape[0] ) + grad_fp32 = grad.astype(jnp.float32) + dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0]) + dbias = dbias_fp32.astype(grad.dtype) + return dbias diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index c78bf3f1b..eee61d126 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -812,13 +812,7 @@ def jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_fact """ JAX based implementation of scaled and masked softmax """ - if mask is not None: - logits += jax.lax.select( - mask > 0, - jnp.full(mask.shape, -1e10).astype(logits.dtype), - jnp.full(mask.shape, 0.0).astype(logits.dtype), - ) - return jax.nn.softmax(logits * scale_factor) + return jax.nn.softmax(logits * scale_factor, where=mask != 1) def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float): @@ -826,12 +820,7 @@ def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: fl JAX based implementation of scaled and upper triangle masked softmax """ mask = 1 - jnp.tril(jnp.ones_like(logits)) - logits += jax.lax.select( - mask > 0, - jnp.full(mask.shape, -1e10).astype(logits.dtype), - jnp.full(mask.shape, 0.0).astype(logits.dtype), - ) - return jax.nn.softmax(logits * scale_factor) + return jax_scaled_masked_softmax(logits, mask, scale_factor) def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 678fd2e01..453a4202b 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -34,6 +34,7 @@ #include "extensions/misc.h" #include "extensions/utils.h" #include "transformer_engine/activation.h" +#include "transformer_engine/multi_stream.h" // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); @@ -72,6 +73,8 @@ pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_s // Quantization XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeHandler); + XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, @@ -97,7 +100,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); -NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, +NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float dropout_probability, size_t q_num_heads, size_t kv_num_heads, @@ -107,19 +110,22 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right); pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right); +// GEMM +XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); + // Grouped GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 063710aa7..cf75c850b 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -7,7 +7,7 @@ #include -#include "extensions.h" +#include "../extensions.h" #include "transformer_engine/cast.h" #include "xla/ffi/api/c_api.h" diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index af1fcb493..bc2ac164c 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -6,27 +6,27 @@ * See LICENSE for license information. ************************************************************************/ -#include "extensions.h" +#include "../extensions.h" #include "transformer_engine/fused_attn.h" +#include "transformer_engine/transformer_engine.h" namespace transformer_engine { namespace jax { -NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, +NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t qk_head_dim, size_t v_head_dim, - int64_t window_size_left, int64_t window_size_right){ + size_t qk_head_dim, size_t v_head_dim, + int64_t window_size_left, int64_t window_size_right) { auto backend = nvte_get_fused_attn_backend( - static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, - mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, - qk_head_dim, v_head_dim, window_size_left, window_size_right); + is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, + bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, + kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); return backend; } -#ifndef USE_ROCM /* NOTE: PrepareFusedAttnForwardAuxTensors unifies the auxiliary tensor pack logic from the fused attention forward kernels in: @@ -43,69 +43,54 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t // all backends need softmax but expect different shapes/dtypes // start with the max512 sequence length softmax shape/dtype and correct later tensor_pack->size = 1; - Tensor *softmax_aux = reinterpret_cast(tensor_pack->tensors[0]); - softmax_aux->data.dptr = softmax_buf; - softmax_aux->data.shape = - std::vector{input_batch, attn_heads, q_max_seqlen, kv_max_seqlen}; - softmax_aux->data.dtype = dtype; + NVTETensor &softmax_aux = tensor_pack->tensors[0]; + NVTEBasicTensor softmax_aux_data; + softmax_aux_data.data_ptr = softmax_buf; + softmax_aux_data.shape.ndim = 4; + softmax_aux_data.shape.data[0] = input_batch; + softmax_aux_data.shape.data[1] = attn_heads; + softmax_aux_data.shape.data[2] = q_max_seqlen; + softmax_aux_data.shape.data[3] = kv_max_seqlen; + softmax_aux_data.dtype = static_cast(dtype); // arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax +#ifndef USE_ROCM if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { +#endif +// ROCm fused attn has two backends: aotriton and ck +// They both have the same shape and stride for softmax and rng aux tensors +// CK now support bias features tensor_pack->size = 2; - Tensor *rng_state_aux = reinterpret_cast(tensor_pack->tensors[1]); - rng_state_aux->data.dptr = rng_state_buf; - rng_state_aux->data.shape = std::vector{2}; - rng_state_aux->data.dtype = DType::kInt64; + NVTETensor &rng_state_aux = tensor_pack->tensors[1]; + NVTEBasicTensor rng_state_aux_data; + rng_state_aux_data.data_ptr = rng_state_buf; + rng_state_aux_data.shape = {}; + rng_state_aux_data.shape.ndim = 2; + rng_state_aux_data.dtype = static_cast(DType::kInt64); + nvte_set_tensor_param(&rng_state_aux, kNVTERowwiseData, &rng_state_aux_data); // correct softmax shape/dtype - softmax_aux->data.shape.at(3) = 1; // {B,H,Qs,Ks} -> {B,H,Qs,1} - softmax_aux->data.dtype = DType::kFloat32; + softmax_aux_data.shape.data[3] = 1; // {B,H,Qs,Ks} -> {B,H,Qs,1} + softmax_aux_data.dtype = static_cast(DType::kFloat32); // include bias if enabled if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) { tensor_pack->size = 3; - Tensor *bias_aux = reinterpret_cast(tensor_pack->tensors[2]); - bias_aux->data.dptr = bias_buf; - bias_aux->data.shape = - std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - bias_aux->data.dtype = dtype; + NVTETensor &bias_aux = tensor_pack->tensors[2]; + NVTEBasicTensor bias_aux_data; + bias_aux_data.data_ptr = bias_buf; + bias_aux_data.shape.ndim = 4; + bias_aux_data.shape.data[0] = bias_batch; + bias_aux_data.shape.data[1] = bias_heads; + bias_aux_data.shape.data[2] = q_max_seqlen; + bias_aux_data.shape.data[3] = kv_max_seqlen; + bias_aux_data.dtype = static_cast(dtype); + nvte_set_tensor_param(&bias_aux, kNVTERowwiseData, &bias_aux_data); } +#ifndef USE_ROCM } -} -#else -// ROCm fused attn has two backends: aotriton and ck -// They both have the same shape and stride for softmax and rng aux tensors -// CK now support bias features -void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch, - const size_t bias_batch, const size_t attn_heads, - const size_t bias_heads, const size_t q_max_seqlen, - const size_t kv_max_seqlen, DType dtype, - NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend, - void *softmax_buf, void *rng_state_buf = nullptr, - void *bias_buf = nullptr) { - tensor_pack->size = 2; - Tensor *softmax_aux = reinterpret_cast(tensor_pack->tensors[0]); - softmax_aux->data.dptr = softmax_buf; - softmax_aux->data.shape = - std::vector{input_batch, attn_heads, q_max_seqlen, 1}; - softmax_aux->data.dtype = DType::kFloat32; - - Tensor *rng_state_aux = reinterpret_cast(tensor_pack->tensors[1]); - rng_state_aux->data.dptr = rng_state_buf; - rng_state_aux->data.shape = std::vector{2}; - rng_state_aux->data.dtype = DType::kInt64; - - // include bias if enabled - if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) { - tensor_pack->size = 3; - Tensor *bias_aux = reinterpret_cast(tensor_pack->tensors[2]); - bias_aux->data.dptr = bias_buf; - bias_aux->data.shape = - std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - bias_aux->data.dtype = dtype; - } -} #endif - + nvte_set_tensor_param(&softmax_aux, kNVTERowwiseData, &softmax_aux_data); +} /* NOTE: Backward fused attention kernels accept auxiliary tensors as explicit function arguments instead of an NVTETensorPack and nvte_fused_attn_bwd() API does all the logic for pulling the @@ -135,16 +120,18 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_ #ifndef USE_ROCM // correct softmax shape for max512 sequence length kernel if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { - Tensor *softmax_aux = reinterpret_cast(tensor_pack->tensors[0]); - softmax_aux->data.shape.at(3) = kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks} - softmax_aux->data.dtype = dtype; + NVTEBasicTensor softmax_aux_data = + nvte_get_tensor_param(tensor_pack->tensors[0], kNVTERowwiseData); + softmax_aux_data.shape.data[3] = kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks} + softmax_aux_data.dtype = static_cast(dtype); + nvte_set_tensor_param(&(tensor_pack->tensors[0]), kNVTERowwiseData, &softmax_aux_data); } #endif } pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { @@ -226,6 +213,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( } } + nvte_tensor_pack_destroy(&aux_output_tensors); + auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape()); return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype()); } @@ -288,14 +277,10 @@ static void FusedAttnForwardImpl( /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - qk_head_dim, v_head_dim, window_size_left, window_size_right); -#ifndef USE_ROCM + is_training, static_cast(dtype), static_cast(dtype), qkv_layout, + bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, + kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); -#else - nvte_populate_rng_state_async(rng_state, seed, input_batch, attn_heads, q_max_seqlen, kv_max_seqlen, stream); -#endif /* Auxiliary tensors (to be propagated to the backward pass later) */ NVTETensorPack aux_output_tensors; @@ -422,7 +407,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI, pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, @@ -518,6 +503,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( } } + nvte_tensor_pack_destroy(&aux_input_tensors); + auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); } @@ -527,7 +514,7 @@ static void FusedAttnBackwardImpl( void *output, void *doutput, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *dq, void *dk, void *dv, void *dbias, void *workspace, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, @@ -547,9 +534,9 @@ static void FusedAttnBackwardImpl( NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - qk_head_dim, v_head_dim, window_size_left, window_size_right); + is_training, static_cast(dtype), static_cast(dtype), qkv_layout, + bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, + kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias); diff --git a/transformer_engine/jax/csrc/extensions/cublas.cpp b/transformer_engine/jax/csrc/extensions/cublas.cpp index 2d2c4499f..995fd4e67 100644 --- a/transformer_engine/jax/csrc/extensions/cublas.cpp +++ b/transformer_engine/jax/csrc/extensions/cublas.cpp @@ -6,7 +6,7 @@ * See LICENSE for license information. ************************************************************************/ #ifndef USE_ROCM -#include "extensions.h" +#include "../extensions.h" #include "transformer_engine/gemm.h" #include "xla/ffi/api/c_api.h" diff --git a/transformer_engine/jax/csrc/extensions/cudnn.cpp b/transformer_engine/jax/csrc/extensions/cudnn.cpp index b4af7e26d..d710b6c47 100644 --- a/transformer_engine/jax/csrc/extensions/cudnn.cpp +++ b/transformer_engine/jax/csrc/extensions/cudnn.cpp @@ -7,7 +7,7 @@ ************************************************************************/ #ifndef USE_ROCM #include "transformer_engine/cudnn.h" -#include "extensions.h" +#include "../extensions.h" #include "xla/ffi/api/c_api.h" namespace transformer_engine { diff --git a/transformer_engine/jax/csrc/extensions/ffi.cpp b/transformer_engine/jax/csrc/extensions/ffi.cpp index 7058a55ea..26764e8af 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.cpp +++ b/transformer_engine/jax/csrc/extensions/ffi.cpp @@ -46,12 +46,11 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { #endif return DType::kFloat8E4M3; break; - // case xla::ffi::DataType::F8E8M0FNU: - // return DType::kFloat8E8M0; - // break; + case xla::ffi::DataType::F8E8M0FNU: + return DType::kFloat8E8M0; + break; default: auto type_num = static_cast(type); - if (type_num == 33) return DType::kFloat8E8M0; NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d", static_cast(type_num)); break; diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index e1f2a1bdc..ba2d65e3e 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -6,53 +6,361 @@ #include "transformer_engine/gemm.h" #include +#include +#include +#include "../extensions.h" #include "common/util/cuda_runtime.h" +#include "common/util/string.h" #include "common/util/system.h" -#include "extensions.h" +#include "transformer_engine/swizzle.h" #include "xla/ffi/api/c_api.h" +#define MXFP8_BLOCK_SIZE 32 + namespace transformer_engine { namespace jax { -Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, - Variadic_Result_Type output_list, int64_t num_gemms, - JAXX_Scaling_Mode scaling_mode, int64_t has_bias) { +static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { + // Move the pointer to the next 256B aligned address + return reinterpret_cast((reinterpret_cast(ptr) + 255) & + ~static_cast(255)); +} + +std::tuple> xla_buffer_to_nvte_gemm_operand( + cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv, + JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) { + // Set tensor data with collapsed 2D shape + auto buffer_dims = buffer.dimensions(); + std::vector input_shape = {product(buffer_dims, 0, axis_boundary), + product(buffer_dims, axis_boundary, buffer_dims.size())}; + auto input_dtype = convert_ffi_datatype_to_te_dtype(buffer.element_type()); + TensorWrapper input(get_nvte_scaling_mode(scaling_mode)); + + if (rowwise) { + input.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); + } else { + input.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); + } + + // Set scaling factor for quantized tensors + if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { + NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands."); + NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM."); + + std::vector scale_shape = {1}; + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { + // Block scaling also needs to be collapsed to match 2D data + scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary), + product(scale_inv.dimensions(), axis_boundary, scale_inv.dimensions().size())}; + } + + auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + if (rowwise) { + input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); + } else { + input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); + } + + // Swizzle scaling factors for MXFP8 + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { + // Get the swizzle buffer + NVTE_CHECK(swizzled_scale_inv->element_count() > 0, + "Missing swizzled inverse scale buffer in the JAX primitive."); + auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + auto swizzled_scale_inv_dtype = + convert_ffi_datatype_to_te_dtype(swizzled_scale_inv->element_type()); + NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1, + "Inverse scale factors need to have an 8-bit data type."); + + // Create tensor to hold swizzled scale factor + TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); + if (rowwise) { + output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); + output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); + } else { + output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); + output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, + scale_shape); + } + + // Launch swizzle kernel + nvte_swizzle_scaling_factors(input.data(), output.data(), stream); + + // Set swizzled scales into the input tensor + if (rowwise) { + input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); + } else { + input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, + scale_shape); + } + } + } + + return std::make_tuple(std::move(input), input_shape); +} + +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, + Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace, + JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, + int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, + bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { + // Operands (this includes swizzling MXFP8 scaling factors) + // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when + // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) + bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || + (is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); + bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed; + bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed; + auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand( + stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise); + auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand( + stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise); + + // Output tensor + std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], + (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; + auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); + NVTE_CHECK(out_.numel() == output->element_count(), + "cuBLAS GEMM output buffer size is incorrect, " + "expected ", + out_.numel(), " elements ", to_string_like(out_shape), " but got ", + output->element_count(), " elements ", to_string_like(output->dimensions())); + + // Bias input to forward pass or bias gradient output from backward pass + void *bias_ptr = nullptr; + std::vector bias_shape = {0}; + DType bias_dtype = out_dtype; + if (fuse_bias) { + if (!grad) { + NVTE_CHECK(bias_grad->untyped_data() == bias.untyped_data(), + "Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad"); + } + bias_ptr = bias_grad->untyped_data(); + bias_shape.at(0) = bias_grad->dimensions().front(); + bias_dtype = convert_ffi_datatype_to_te_dtype(bias_grad->element_type()); + } + auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + + // Pre-GeLU output from forward pass or input to backward pass + void *pre_gelu_ptr = nullptr; + std::vector pre_gelu_shape = {0}; + DType pre_gelu_dtype = out_dtype; + if (gelu_input.element_count() > 0) { + if (grad) { + NVTE_CHECK(pre_gelu_out->untyped_data() == gelu_input.untyped_data(), + "Missing operand-output aliasing in GemmPrimitive: gelu_input <-> pre_gelu_out"); + } + pre_gelu_ptr = pre_gelu_out->untyped_data(); + pre_gelu_shape = {product(pre_gelu_out->dimensions(), 0, pre_gelu_out->dimensions().size() - 1), + static_cast(pre_gelu_out->dimensions().back())}; + pre_gelu_dtype = convert_ffi_datatype_to_te_dtype(pre_gelu_out->element_type()); + } + auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype); + + // cuBLAS workspace + 256 alignment enforcement + auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); + workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); + std::vector workspace_shape = {static_cast(workspace->element_count()) - 256}; + auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte); + + // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); + nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), + rhs_transposed, lhs_transposed, grad, workspace_.data(), false, + use_split_accumulator, num_math_sm, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Ret() // output + .Ret() // bias_grad + .Ret() // pre_gelu_out + .Ret() // lhs_swizzled + .Ret() // rhs_swizzled + .Ret() // workspace + .Attr("scaling_mode") + .Attr("lhs_axis_boundary") + .Attr("rhs_axis_boundary") + .Attr("lhs_transposed") + .Attr("rhs_transposed") + .Attr("fuse_bias") + .Attr("fuse_gelu") + .Attr("grad") + .Attr("use_split_accumulator"), + FFI_CudaGraph_Traits); + +Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, + Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, + Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, + Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, + bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, + bool is_grouped_dense_wgrad) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: - // A: row-major with size [m, k], - // B: row-major with size [n, k], needs transpose, + // A: row-major [m, k] for N - [k, m] for T + // B: row-major [k, n] for N - [n, k] for T // on exiting this function, JAX expect: // C: row-major with size [m, n]. // cuBLAS uses column-major data_layout, in this view, each input matrix pair: - // A: column-major with size [k, m], needs transpose, - // B: column-major with size [k, n]. + // A: column-major with size [k, m] for T - [m, k] for N + // B: column-major with size [n, k] for T - [k, n] for N + // // If we call cuBLAS GEMM for A * B, the output will be: // C: column-major with size [m, n] --> row-major with size [n, m]. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. - if (num_gemms <= 0) { - return ffi_with_cuda_error_check(); + int num_streams = nvte_get_num_compute_streams(); + + // Inputs + auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); + auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); + auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv.untyped_data()); + auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv.untyped_data()); + auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type()); + auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type()); + auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type()); + auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type()); + auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; + auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); + + NVTE_CHECK(group_sizes.dimensions().size() == 1); + size_t num_gemms = group_sizes.dimensions()[0]; + + // It is weird that TE/Common GEMM only use colwise for MXFP8 + const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); + const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || + scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; + const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; + const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans; + const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans; + + // Outputs + auto out_ptr = reinterpret_cast(output->untyped_data()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + // Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned + auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); + workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); + auto workspace_total_size = product(workspace->dimensions()); + + auto lhs_sinv_size = product(lhs_sinv.dimensions()); + auto rhs_sinv_size = product(rhs_sinv.dimensions()); + const size_t workspace_alignment_padding = 256; + const size_t tensor_scaling_sinv_aligment = 16; + const size_t mxfp8_scaling_sinv_alignment_padding = 256; + auto workspace_size = workspace_total_size - workspace_alignment_padding; + if (is_mxfp8_scaling) { + // For MXFP8 swizzled scale_inv buffers, only the first pointer needs to be with 256B alignment padding. Later pointers are guaranteed to be 256-aligned as the scale_inv shapes are padded by 128x4. + workspace_size -= (lhs_sinv_size + rhs_sinv_size + 2 * mxfp8_scaling_sinv_alignment_padding); + } else if (is_tensor_scaling) { + // For tensor scaling, each matrix has a single scale value, and all scales need to be aligned + // by 16 bytes to meet the requirement of CUDA 12.9.1 and later. + workspace_size -= tensor_scaling_sinv_aligment * (lhs_sinv_size + rhs_sinv_size); + } + workspace_size = workspace_size / num_streams; + auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams; + swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr); + auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; + swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr); + auto lhs_scatter_aligned_ptr = swizzled_lhs_sinv_ptr; // Already 256B aligned + auto rhs_scatter_aligned_ptr = lhs_scatter_aligned_ptr + num_gemms * tensor_scaling_sinv_aligment; + + size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); + size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); + size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); + size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); + size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); + size_t out_dtype_bytes = te_dtype_bytes(out_dtype); + + if (is_tensor_scaling) { + cudaStream_t stream_0 = nvte_get_compute_stream(0); + size_t dpitch = tensor_scaling_sinv_aligment; + size_t spitch = lhs_sinv_dtype_bytes; + size_t width = lhs_sinv_dtype_bytes; + size_t height = lhs_sinv_size; + cudaMemcpy2DAsync(lhs_scatter_aligned_ptr, dpitch, lhs_sinv_ptr, spitch, width, height, + cudaMemcpyDeviceToDevice, stream_0); + spitch = rhs_sinv_dtype_bytes; + width = rhs_sinv_dtype_bytes; + height = rhs_sinv_size; + cudaMemcpy2DAsync(rhs_scatter_aligned_ptr, dpitch, rhs_sinv_ptr, spitch, width, height, + cudaMemcpyDeviceToDevice, stream_0); + lhs_sinv_ptr = lhs_scatter_aligned_ptr; + rhs_sinv_ptr = rhs_scatter_aligned_ptr; + } + + NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); + NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, + "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); + + size_t expected_lhs_size = m * k; + size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); + size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); + size_t actual_lhs_size = product(lhs_data.dimensions()); + size_t actual_rhs_size = product(rhs_data.dimensions()); + size_t actual_out_size = product(output->dimensions()); + NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", + expected_lhs_size, ", got ", actual_lhs_size); + if (!is_grouped_dense_wgrad) { + NVTE_CHECK(expected_rhs_size == actual_rhs_size, + "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, + " = ", expected_rhs_size, ", got ", actual_rhs_size); + NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m, + " * ", n, " = ", expected_out_size, ", got ", actual_out_size); + } else { + NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k, + " * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size); + NVTE_CHECK(expected_out_size == actual_out_size, + "Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n, + " = ", expected_out_size, ", got ", actual_out_size); } - size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms; - size_t expected_output_size = num_gemms + 1; - size_t actual_input_size = input_list.size(); - size_t actual_output_size = output_list.size(); - NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu", - expected_input_size, actual_input_size); - NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu", - expected_output_size, actual_output_size); - - bool trans_lhs = true; - bool trans_rhs = false; + + size_t dim_list_bytes = sizeof(int32_t) * num_gemms; + std::vector dim_list_host(num_gemms); + auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + stream); + // Note: This may break cudaGraph. + cudaStreamSynchronize(stream); + size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + if (!is_grouped_dense_wgrad) { + NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, + ", got sum(group_sizes)=", sum_group_sizes); + } else { + NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, + ", got sum(group_sizes)=", sum_group_sizes); + } + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); bool grad = false; bool accumulate = false; bool use_split_accumulator = false; + auto bias_shape = std::vector{has_bias ? n : 0}; + const int arch = cuda::sm_arch(); + + if (arch < 100 && is_fp8_gemm) { + NVTE_CHECK(!lhs_is_trans && rhs_is_trans, + "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", + "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); + } // These lists are to keep the TensorWrapper objects alive std::vector lhs_wrapper_list; std::vector rhs_wrapper_list; + std::vector lhs_swizzle_wrapper_list; // For MXFP8 scale_inv swizzling + std::vector rhs_swizzle_wrapper_list; std::vector bias_wrapper_list; std::vector pre_gelu_wrapper_list; std::vector out_wrapper_list; @@ -61,101 +369,148 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM std::vector lhs_list; std::vector rhs_list; + std::vector lhs_swizzle_list; + std::vector rhs_swizzle_list; std::vector bias_list; std::vector pre_gelu_list; std::vector out_list; std::vector workspace_list; - int lhs_list_offset = 0; - int rhs_list_offset = num_gemms; - int lhs_sinv_list_offset = 2 * num_gemms; - int rhs_sinv_list_offset = 3 * num_gemms; - int bias_list_offset = 4 * num_gemms; - int out_list_offset = 0; - for (int i = 0; i < num_gemms; i++) { - Buffer_Type lhs_i = input_list.get(lhs_list_offset + i).value(); - Buffer_Type rhs_i = input_list.get(rhs_list_offset + i).value(); - Buffer_Type lhs_sinv_i = input_list.get(lhs_sinv_list_offset + i).value(); - Buffer_Type rhs_sinv_i = input_list.get(rhs_sinv_list_offset + i).value(); - Result_Type out_i = output_list.get(out_list_offset + i).value(); - - DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type()); - DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type()); - DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type()); - - void *lhs_ptr = lhs_i.untyped_data(); - void *rhs_ptr = rhs_i.untyped_data(); - void *lhs_sinv_ptr = lhs_sinv_i.untyped_data(); - void *rhs_sinv_ptr = rhs_sinv_i.untyped_data(); - void *out_ptr = out_i->untyped_data(); - - // Placeholder for bias since it can be empty - DType bias_dtype = DType::kFloat32; - void *bias_ptr = nullptr; - - auto lhs_shape_ = lhs_i.dimensions(); - auto rhs_shape_ = rhs_i.dimensions(); - - // lhs and rhs has shape [1, m, k] and [1, n, k] - size_t m = lhs_shape_[1]; - size_t n = rhs_shape_[1]; - size_t k = lhs_shape_[2]; - - auto lhs_shape = std::vector{m, k}; - auto rhs_shape = std::vector{n, k}; - auto out_shape = std::vector{n, m}; - auto lhs_sinv_shape = std::vector{1, 1}; - auto rhs_sinv_shape = std::vector{1, 1}; - - if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || - scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || - scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) { - float *amax_dptr = nullptr; - float *scale_dptr = nullptr; - auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr, - reinterpret_cast(lhs_sinv_ptr)); - auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr, - reinterpret_cast(rhs_sinv_ptr)); - lhs_wrapper_list.push_back(std::move(lhs_i_)); - rhs_wrapper_list.push_back(std::move(rhs_i_)); - } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { - // Note: the scale_inv array should have been swizzled in Python before lowering - auto lhs_sinv_shape_ = lhs_sinv_i.dimensions(); - auto rhs_sinv_shape_ = rhs_sinv_i.dimensions(); - for (int i = 0; i < 2; i++) { - lhs_sinv_shape[i] = lhs_sinv_shape_[i]; - rhs_sinv_shape[i] = rhs_sinv_shape_[i]; + size_t lhs_sinv_total_size = 0; + size_t rhs_sinv_total_size = 0; + + std::vector zero_out_dptr_list; + std::vector zero_out_size_list; + + for (size_t i = 0; i < num_gemms; i++) { + // Matrix data shapes + size_t m_i = dim_list_host[i]; + auto lhs_shape_i = std::vector{m_i, k}; + auto rhs_shape_i = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; + auto out_shape_i = std::vector{m_i, n}; + if (is_grouped_dense_wgrad) { + size_t k_i = dim_list_host[i]; + lhs_shape_i[0] = lhs_is_trans ? k_i : m; + lhs_shape_i[1] = lhs_is_trans ? m : k_i; + rhs_shape_i[0] = rhs_is_trans ? n : k_i; + rhs_shape_i[1] = rhs_is_trans ? k_i : n; + out_shape_i[0] = m; + out_shape_i[1] = n; + } + + size_t lhs_size = lhs_shape_i[0] * lhs_shape_i[1]; + size_t rhs_size = rhs_shape_i[0] * rhs_shape_i[1]; + size_t out_size = out_shape_i[0] * out_shape_i[1]; + bool is_empty_gemm = lhs_size == 0 || rhs_size == 0; + if (is_empty_gemm && out_size > 0) { + zero_out_dptr_list.push_back(out_ptr); + zero_out_size_list.push_back(out_size * out_dtype_bytes); + } + + // Set matrix data pointers + auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto out_i = TensorWrapper(static_cast(out_ptr), out_shape_i, out_dtype); + void *lhs_vptr = static_cast(lhs_ptr); + void *rhs_vptr = static_cast(rhs_ptr); + if (rhs_use_colwise) // MatA to enter cuBLAS + rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); + else + rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); + if (lhs_use_colwise) // MatB to enter cuBLAS + lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); + else + lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); + + // Set scale_inv shapes and pointers + void *rhs_sinv_vptr = static_cast(rhs_sinv_ptr); + void *lhs_sinv_vptr = static_cast(lhs_sinv_ptr); + size_t lhs_sinv_size_i = 0; + size_t rhs_sinv_size_i = 0; + if (is_tensor_scaling) { + auto tensor_scaling_sinv_shape = std::vector{1}; + // If is_empty_gemm, scale_inv does not have the corresponding value, do not move the pointers + if (!is_empty_gemm) { + lhs_sinv_size_i = tensor_scaling_sinv_aligment / lhs_sinv_dtype_bytes; + rhs_sinv_size_i = tensor_scaling_sinv_aligment / rhs_sinv_dtype_bytes; } + if (rhs_use_colwise) // MatA to enter cuBLAS + rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape); + else + rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape); + if (lhs_use_colwise) // MatB to enter cuBLAS + lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); + else + lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); + } else if (is_mxfp8_scaling) { + auto lhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto rhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + void *swizzled_lhs_sinv_vptr = static_cast(swizzled_lhs_sinv_ptr); + void *swizzled_rhs_sinv_vptr = static_cast(swizzled_rhs_sinv_ptr); - NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode); - TensorWrapper lhs_i_(nvte_scaling_mode); - TensorWrapper rhs_i_(nvte_scaling_mode); - lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape); - rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape); - lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape); - rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape); + // {lhs, rhs}_swizzle_i point to unswizzled scale_inv data as input, while {lhs, rhs}_i + // point to swizzled scale_inv data (store on workspace, only used for GEMM). + // Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers + auto lhs_sinv_shape_i = + get_mxfp8_scale_shape(lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); + auto rhs_sinv_shape_i = + get_mxfp8_scale_shape(rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise); + lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1]; + rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1]; + if (lhs_use_colwise) { + lhs_swizzle_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); + lhs_swizzle_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); + lhs_i.set_columnwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); + } else { + lhs_swizzle_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); + lhs_swizzle_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); + lhs_i.set_rowwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); + } + if (rhs_use_colwise) { + rhs_swizzle_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); + rhs_swizzle_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); + rhs_i.set_columnwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); + } else { + rhs_swizzle_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); + rhs_swizzle_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); + rhs_i.set_rowwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); + } - lhs_wrapper_list.push_back(std::move(lhs_i_)); - rhs_wrapper_list.push_back(std::move(rhs_i_)); + if (!is_empty_gemm) { + lhs_swizzle_wrapper_list.push_back(std::move(lhs_swizzle_i)); + rhs_swizzle_wrapper_list.push_back(std::move(rhs_swizzle_i)); + lhs_swizzle_list.push_back(lhs_swizzle_wrapper_list.back().data()); + rhs_swizzle_list.push_back(rhs_swizzle_wrapper_list.back().data()); + } } else { - NVTE_ERROR("Unsupported scaling mode: ", static_cast(scaling_mode)); + NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, + "Unsupported scaling mode: ", static_cast(scaling_mode)); } - auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype); - void *pre_gelu_ptr = nullptr; - auto bias_shape = std::vector{0}; - auto pre_gelu_shape = std::vector{0}; - if (has_bias) { - auto bias_i_get = input_list.get(bias_list_offset + i); - Buffer_Type bias_i = bias_i_get.value(); - bias_ptr = bias_i.untyped_data(); - bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type()); - bias_shape[0] = n; - } auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); - auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype); + auto pre_gelu_i = TensorWrapper(nullptr, std::vector{0}, out_dtype); + + // Update pointer for the next GEMM pair + lhs_ptr += lhs_size * lhs_dtype_bytes; + rhs_ptr += rhs_size * rhs_dtype_bytes; + out_ptr += out_size * out_dtype_bytes; + if (is_fp8_gemm) { + lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; + rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; + lhs_sinv_total_size += lhs_sinv_size_i; + rhs_sinv_total_size += rhs_sinv_size_i; + if (is_mxfp8_scaling) { + swizzled_lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; + swizzled_rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; + } + } + if (has_bias) bias_ptr += n * bias_dtype_bytes; - out_wrapper_list.push_back(std::move(out_i_)); + // Move objects to the lists to keep them alive + if (is_empty_gemm) continue; + lhs_wrapper_list.push_back(std::move(lhs_i)); + rhs_wrapper_list.push_back(std::move(rhs_i)); + out_wrapper_list.push_back(std::move(out_i)); bias_wrapper_list.push_back(std::move(bias_i)); pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); @@ -166,10 +521,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, out_list.push_back(out_wrapper_list.back().data()); } - auto workspace_get = output_list.get(num_gemms); - Result_Type workspace = workspace_get.value(); - uint8_t *workspace_ptr = reinterpret_cast(workspace->untyped_data()); - size_t workspace_size = workspace->dimensions()[0] / num_streams; auto workspace_shape = std::vector{workspace_size}; for (int i = 0; i < num_streams; i++) { auto workspace_i = @@ -179,10 +530,45 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, workspace_ptr += workspace_size; } + if (is_fp8_gemm) { + if (is_tensor_scaling) { + lhs_sinv_size *= tensor_scaling_sinv_aligment; + rhs_sinv_size *= tensor_scaling_sinv_aligment; + } + NVTE_CHECK(lhs_sinv_total_size <= lhs_sinv_size, "Actual total lhs_sinv size ", + lhs_sinv_total_size, " exceeds estimated upper bound ", lhs_sinv_size); + NVTE_CHECK(rhs_sinv_total_size <= rhs_sinv_size, "Actual total rhs_sinv size ", + rhs_sinv_total_size, " exceeds estimated upper bound ", rhs_sinv_size); + } + + size_t num_non_empty_gemms = lhs_list.size(); + + if (is_mxfp8_scaling) { + for (int i = 0; i < num_non_empty_gemms; i++) { + // The i-th GEMM will use the (i % num_streams)-th stream to compute, + // use the same stream to swizzle the scaling factors to make sure that + // the swizzling is done before the GEMM computation starts. + int stream_id = i % num_streams; + cudaStream_t stream_i = nvte_get_compute_stream(stream_id); + nvte_swizzle_scaling_factors(lhs_swizzle_list[i], lhs_list[i], stream_i); + nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); + } + } + + // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM + size_t num_zero_outs = zero_out_dptr_list.size(); + for (int i = 0; i < num_zero_outs; i++) { + int stream_id = i % num_streams; + cudaStream_t stream_i = nvte_get_compute_stream(stream_id); + void *dptr = zero_out_dptr_list[i]; + size_t count = zero_out_size_list[i]; + NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i)); + } + nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), - pre_gelu_list.data(), num_gemms, trans_lhs, trans_rhs, grad, - workspace_list.data(), accumulate, use_split_accumulator, - num_math_sm, stream); + pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, + lhs_is_trans, grad, workspace_list.data(), accumulate, + use_split_accumulator, num_math_sm, stream); return ffi_with_cuda_error_check(); } @@ -190,11 +576,23 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, FFI::Bind() .Ctx() // stream - .RemainingArgs() // input list - .RemainingRets() // output list - .Attr("num_gemms") + .Arg() // lhs_data + .Arg() // lhs_sinv + .Arg() // rhs_data + .Arg() // rhs_sinv + .Arg() // bias + .Arg() // group_sizes + .Arg() // group_offset + .Ret() // output + .Ret() // workspace + .Attr("M") + .Attr("N") + .Attr("K") + .Attr("lhs_is_trans") + .Attr("rhs_is_trans") .Attr("scaling_mode") - .Attr("has_bias"), + .Attr("has_bias") + .Attr("is_grouped_dense_wgrad"), FFI_CudaGraph_Traits); } // namespace jax diff --git a/transformer_engine/jax/csrc/extensions/misc.cpp b/transformer_engine/jax/csrc/extensions/misc.cpp index b1445e5be..ee81b5ad7 100644 --- a/transformer_engine/jax/csrc/extensions/misc.cpp +++ b/transformer_engine/jax/csrc/extensions/misc.cpp @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "extensions.h" +#include "../extensions.h" namespace transformer_engine { namespace jax { @@ -26,5 +26,19 @@ std::vector Shape::to_vector() const { return shape; } +std::vector get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise) { + auto block_x = is_colwise ? MXFP8_BLOCK_SIZE.y : MXFP8_BLOCK_SIZE.x; + auto block_y = is_colwise ? MXFP8_BLOCK_SIZE.x : MXFP8_BLOCK_SIZE.y; + auto alignment_x = is_colwise ? MXFP8_ALIGNMENT.y : MXFP8_ALIGNMENT.x; + auto alignment_y = is_colwise ? MXFP8_ALIGNMENT.x : MXFP8_ALIGNMENT.y; + + NVTE_CHECK(M % block_x == 0, "M must be divisble by %zu (got %zu)", block_x, M); + NVTE_CHECK(N % block_y == 0, "N must be divisble by %zu (got %zu)", block_y, N); + size_t scale_x = DIVUP((M / block_x), alignment_x) * alignment_x; + size_t scale_y = DIVUP((N / block_y), alignment_y) * alignment_y; + + return {scale_x, scale_y}; +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index 4c3d29ef0..af7f54feb 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -47,6 +47,15 @@ enum class JAXX_Scaling_Mode : int64_t { CURRENT_TENSOR_SCALING = 3, }; +inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) { + return (mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING || + mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING); +} + +inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) { + return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING); +} + static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { switch (mode) { case JAXX_Scaling_Mode::NO_SCALING: @@ -67,5 +76,16 @@ static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { } } +constexpr struct BlockSize { + size_t x; + size_t y; +} MXFP8_BLOCK_SIZE{1, 32}; +constexpr struct Alignment { + size_t x; + size_t y; +} MXFP8_ALIGNMENT{128, 4}; + +std::vector get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index d16bc51f8..b07404eb7 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -7,7 +7,7 @@ #include -#include "extensions.h" +#include "../extensions.h" namespace transformer_engine { namespace jax { diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index b48b2be57..563675988 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -6,7 +6,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "extensions.h" +#include "../extensions.h" namespace transformer_engine { namespace jax { @@ -27,6 +27,7 @@ pybind11::dict Registrations() { // Quantization dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler); + dict["te_grouped_quantize_ffi"] = EncapsulateFFI(GroupedQuantizeHandler); dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler); // Softmax @@ -57,6 +58,12 @@ pybind11::dict Registrations() { pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler)); + // GEMM + dict["te_gemm_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(GemmHandler)); + + // Grouped GEMM dict["te_grouped_gemm_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); @@ -69,6 +76,7 @@ pybind11::dict Registrations() { dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler); dict["te_fused_attn_backward_ffi"] = EncapsulateFFI(FusedAttnBackwardHandler); + dict["te_gemm_ffi"] = EncapsulateFFI(GemmHandler); dict["te_grouped_gemm_ffi"] = EncapsulateFFI(GroupedGemmHandler); #endif return dict; @@ -82,6 +90,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { #endif m.def("get_cudnn_version", &GetCudnnRuntimeVersion); m.def("get_device_compute_capability", &GetDeviceComputeCapability); + m.def("get_num_compute_streams", &nvte_get_num_compute_streams); #ifndef USE_ROCM m.def("get_cublasLt_version", &cublasLtGetVersion); #endif @@ -92,6 +101,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); + m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index e59600994..a92934193 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -5,9 +5,10 @@ ************************************************************************/ #include -#include "extensions.h" +#include "../extensions.h" #include "transformer_engine/cast.h" #include "transformer_engine/recipe.h" +#include "transformer_engine/transformer_engine.h" #include "xla/ffi/api/c_api.h" namespace transformer_engine { @@ -226,5 +227,191 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DequantizeHandler, DequantizeFFI, .Ret(), // output FFI_CudaGraph_Traits); +Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Type scales, + Buffer_Type group_sizes, Result_Type outputs, + Result_Type colwise_outputs, Result_Type scale_invs, + Result_Type colwise_scale_invs, Result_Type amaxs, + JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, + int64_t flatten_axis) { + NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NO_SCALING, + "Unsupported scaling mode: ", static_cast(scaling_mode)); + + auto in_dtype = convert_ffi_datatype_to_te_dtype(inputs.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(outputs->element_type()); + NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for quantization."); + + auto scale_dtype = convert_ffi_datatype_to_te_dtype(scales.element_type()); + auto group_size_dtype = convert_ffi_datatype_to_te_dtype(group_sizes.element_type()); + auto sinv_dtype = convert_ffi_datatype_to_te_dtype(scale_invs->element_type()); + auto amax_dtype = convert_ffi_datatype_to_te_dtype(amaxs->element_type()); + auto const quantize_layout = static_cast(quantize_layout_enum); + + auto *input_ptr = reinterpret_cast(inputs.untyped_data()); + auto *scale_ptr = reinterpret_cast(scales.untyped_data()); + auto *output_ptr = reinterpret_cast(outputs->untyped_data()); + auto *colwise_output_ptr = reinterpret_cast(colwise_outputs->untyped_data()); + auto *sinv_ptr = reinterpret_cast(scale_invs->untyped_data()); + auto *colwise_sinv_ptr = reinterpret_cast(colwise_scale_invs->untyped_data()); + auto *amax_ptr = reinterpret_cast(amaxs->untyped_data()); + + bool has_rowwise = quantize_layout == QuantizeLayout::ROWWISE || + quantize_layout == QuantizeLayout::ROWWISE_COLWISE; + bool has_colwise = quantize_layout == QuantizeLayout::COLWISE || + quantize_layout == QuantizeLayout::ROWWISE_COLWISE; + bool is_delayed_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING; + bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || + scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; + bool const is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; + + size_t input_dtype_bytes = te_dtype_bytes(in_dtype); + size_t output_dtype_bytes = te_dtype_bytes(out_dtype); + size_t sinv_dtype_bytes = te_dtype_bytes(sinv_dtype); + size_t group_size_dtype_bytes = te_dtype_bytes(group_size_dtype); + size_t colwise_output_dtype_bytes = has_colwise ? output_dtype_bytes : 0; + size_t colwise_sinv_dtype_bytes = has_colwise ? sinv_dtype_bytes : 0; + size_t scale_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(scale_dtype) : 0; + size_t amax_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(amax_dtype) : 0; + + auto input_dims = inputs.dimensions(); + int64_t input_ndim = input_dims.size(); + if (flatten_axis < 0) flatten_axis += input_ndim; + NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!"); + + auto m = product(input_dims, 0, flatten_axis); + auto n = product(input_dims, flatten_axis, input_ndim); + auto input_shape = std::vector{m, n}; + auto output_shape = std::vector{m * n}; + + // These lists are to keep the TensorWrapper objects alive + std::vector input_holders; + std::vector output_holders; + + // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM + std::vector input_list; + std::vector output_list; + + size_t num_groups = group_sizes.dimensions()[0]; + size_t dim_list_bytes = group_size_dtype_bytes * num_groups; + std::vector dim_list_host(num_groups); + auto *group_size_ptr = reinterpret_cast(group_sizes.untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), group_size_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + stream); + // Note: This may break cudaGraph. + cudaStreamSynchronize(stream); + + size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + NVTE_CHECK(m == sum_group_sizes || input_dims[0] == sum_group_sizes, + "Unexpected group_sizes! Got %zu (M=%zu, input_dims[0] = %zu)", sum_group_sizes, m, + input_dims[0]); + + if (is_delayed_scaling) { + NVTE_CHECK(amaxs->dimensions()[0] == num_groups, "Unexpected amax size, Expected ", num_groups, + ", got ", amaxs->dimensions()[0]); + NVTE_CHECK(amax_dtype == DType::kFloat32 && scale_dtype == DType::kFloat32); + cudaMemsetAsync(amax_ptr, 0, sizeof(float) * num_groups, stream); + } + + size_t sinv_size = 0; + size_t colwise_sinv_size = 0; + size_t non_group_m = flatten_axis > 1 ? product(input_dims, 1, flatten_axis) : 1; + size_t num_non_empty_groups = 0; + size_t total_rowwise_sinv_size = 0; + size_t total_colwise_sinv_size = 0; + for (size_t i = 0; i < num_groups; i++) { + size_t m_i = dim_list_host[i] * non_group_m; + // Skip for zero-size input + shiff the scale ptr + if (m_i == 0) { + if (is_tensor_scaling) scale_ptr += scale_dtype_bytes; + continue; + } + num_non_empty_groups++; + auto shape_i = std::vector{m_i, n}; + auto shape_trans_i = std::vector{n, m_i}; + + auto inp_i = TensorWrapper(static_cast(input_ptr), shape_i, in_dtype); + auto out_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + + if (has_rowwise) { + out_i.set_rowwise_data(static_cast(output_ptr), out_dtype, shape_i); + + if (is_fp8_dtype(out_dtype)) { + if (is_tensor_scaling) { + out_i.set_scale(static_cast(scale_ptr), DType::kFloat32, std::vector{1}); + out_i.set_amax(static_cast(amax_ptr), DType::kFloat32, std::vector{1}); + out_i.set_rowwise_scale_inv(static_cast(sinv_ptr), sinv_dtype, + std::vector{1}); + sinv_size = 1; + } else { + const bool is_colwise = false; + auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise); + out_i.set_rowwise_scale_inv(static_cast(sinv_ptr), sinv_dtype, sinv_shape_i); + sinv_size = product(sinv_shape_i); + } + } + } + + if (has_colwise) { + auto &tmp_shape = is_tensor_scaling ? shape_trans_i : shape_i; + out_i.set_columnwise_data(static_cast(colwise_output_ptr), out_dtype, tmp_shape); + // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling + auto &tmp_sinv_ptr = is_tensor_scaling ? sinv_ptr : colwise_sinv_ptr; + + if (is_tensor_scaling) { + out_i.set_columnwise_scale_inv(static_cast(tmp_sinv_ptr), sinv_dtype, + std::vector{1}); + colwise_sinv_size = 1; + } else { + const bool is_colwise = true; + auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise); + out_i.set_columnwise_scale_inv(static_cast(colwise_sinv_ptr), sinv_dtype, + sinv_shape_i); + colwise_sinv_size = product(sinv_shape_i); + } + } + + input_holders.push_back(std::move(inp_i)); + output_holders.push_back(std::move(out_i)); + + input_list.push_back(input_holders.back().data()); + output_list.push_back(output_holders.back().data()); + + input_ptr += m_i * n * input_dtype_bytes; + scale_ptr += scale_dtype_bytes; + output_ptr += m_i * n * output_dtype_bytes; + colwise_output_ptr += m_i * n * colwise_output_dtype_bytes; + sinv_ptr += sinv_size * sinv_dtype_bytes; + colwise_sinv_ptr += colwise_sinv_size * colwise_sinv_dtype_bytes; + amax_ptr += amax_dtype_bytes; + total_rowwise_sinv_size += sinv_size; + total_colwise_sinv_size += colwise_sinv_size; + } + if (is_mxfp8_scaling) { + nvte_memset(scale_invs->untyped_data(), 0, total_rowwise_sinv_size, stream); + nvte_memset(colwise_scale_invs->untyped_data(), 0, total_colwise_sinv_size, stream); + } + + QuantizationConfigWrapper quant_config; + nvte_multi_tensor_quantize(input_list.data(), output_list.data(), quant_config, + num_non_empty_groups, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // scale + .Arg() // group_sizes + .Ret() // output + .Ret() // colwise output + .Ret() // scale_inv + .Ret() // scale_inv colwise + .Ret() // amax + .Attr("scaling_mode") + .Attr("q_layout") + .Attr("flatten_axis"), + FFI_CudaGraph_Traits); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/softmax.cpp b/transformer_engine/jax/csrc/extensions/softmax.cpp index 8691bf35a..ee3e5b35e 100644 --- a/transformer_engine/jax/csrc/extensions/softmax.cpp +++ b/transformer_engine/jax/csrc/extensions/softmax.cpp @@ -6,7 +6,7 @@ #include "transformer_engine/softmax.h" -#include "extensions.h" +#include "../extensions.h" #include "xla/ffi/api/c_api.h" namespace transformer_engine { diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 55d60e418..a0fc7b7af 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -8,7 +8,7 @@ It implements matrix multiplication with optional bias addition and supports customizable contracting dimensions for flexible tensor operations. """ - +import warnings from typing import Tuple, Sequence from functools import partial import jax @@ -19,9 +19,20 @@ QuantizerSet, noop_quantizer_set, with_sharding_constraint_by_logical_axes, + TensorUsage, ) +DENSE_BATCH_FIRST_WARNING_ISSUED = False + + +def _issue_batch_first_warning(msg): + global DENSE_BATCH_FIRST_WARNING_ISSUED + if not DENSE_BATCH_FIRST_WARNING_ISSUED: + warnings.warn(msg, UserWarning) + DENSE_BATCH_FIRST_WARNING_ISSUED = True + + def dense( x: jnp.ndarray, kernel: jnp.ndarray, @@ -29,6 +40,7 @@ def dense( contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, + batch_first: bool = True, quantizer_set: QuantizerSet = noop_quantizer_set, ): """Perform dense layer transformation with optional quantization. @@ -42,25 +54,28 @@ def dense( kernel: Weight matrix for the dense layer transformation bias: Optional bias tensor to add after the transformation contracting_dims: Tuple of sequences specifying which dimensions to contract + batch_first: Assume that X is batched in the first dimension. quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: Transformed output tensor """ # Remove when tex.quantize() can handle quantizer=None - if quantizer_set == noop_quantizer_set: + if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot(): x = with_sharding_constraint_by_logical_axes(x, input_axes) - output = tex.gemm(x, kernel, contracting_dims) + output = tex.gemm(x, kernel, contracting_dims=contracting_dims) if bias is not None: bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) else: - output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set) + output = _dense( + x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set + ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) -def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) +def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set): """Internal implementation of dense layer transformation with custom VJP. This function implements the core dense layer transformation logic with support @@ -74,107 +89,156 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer input_axes: Logical axes for sharding the activation input kernel_axes: Logical axes for sharding the weight matrix quantizer_set: QuantizerSet which contains quantizers for different tensor types + batch_first: Assume that X is batched in the first dimension if it has more than 2 dims. Returns: Transformed output tensor """ output, _ = _dense_fwd_rule( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set + x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set ) return output -def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): +def _dense_fwd_rule( + x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set +): """Forward pass rule for dense layer transformation. Returns: Tuple of (output, context) for backward pass """ - x_contracting_dims, k_contracting_dims = contracting_dims + x_contracting_dims, k_contracting_dims = map( + tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims + ) + + # Check supported input layout + x_is_transposed = x.ndim - 1 not in x_contracting_dims + k_is_transposed = kernel.ndim - 1 in k_contracting_dims + assert ( + not x_is_transposed and not k_is_transposed + ), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel." + + # Determine X batch dimension + # - If `batch_first=True` -> (batch, leading..., contracting...) + # - Otherwise -> (leading..., batch, contracting...) + # NOTE: Always assume a single batch dimension + x_bdim = None + num_cdims = len(x_contracting_dims) + if x.ndim >= num_cdims + 2: + # Assume X is batched if it has at least +2 dimensions more than the number of contracting + # dimensions. + if not batch_first: + _issue_batch_first_warning( + "TE/JAX `dense()` layer implementation does not officially support sequence-first " + "inputs and may produce incorrect results when `batch_first=False`. Use " + "sequence-first inputs at your own discretion.", + ) + x_bdim = 0 if batch_first else x.ndim - num_cdims - 1 flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) - casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x) + casted_x = tex.quantize( + x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, noop_scaled_tensor=True + ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) casted_kernel = tex.quantize( - kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel + kernel, + flatten_axis=flatten_axis_k, + quantizer=quantizer_set.kernel, + noop_scaled_tensor=True, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) # GEMM NN + use_bias = bias is not None output = tex.gemm( - casted_x.get_rowwise_tensor(), - casted_kernel.get_colwise_tensor(), - (x_contracting_dims, k_contracting_dims), + casted_x.get_tensor(usage=TensorUsage.LHS), + casted_kernel.get_tensor(usage=TensorUsage.RHS), + contracting_dims=(x_contracting_dims, k_contracting_dims), + batched_dims=((x_bdim,), ()), + bias=bias if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, ) - use_bias = bias is not None - if use_bias: + if use_bias and tex.gemm_uses_jax_dot(): bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) ctx = ( - casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None, - casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None, + casted_x.get_tensor(usage=TensorUsage.LHS_TRANS), + casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS), x.shape, kernel.shape, use_bias, quantizer_set, flatten_axis_k, + x_bdim, ) return output, ctx def _dense_bwd_rule( - contracting_dims, input_axes, kernel_axes, ctx, grad + contracting_dims, input_axes, kernel_axes, batch_first, ctx, grad ): # pylint: disable=unused-argument """Backward pass rule for dense layer transformation. Returns: Tuple of gradients with respect to inputs """ - fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims - ( - colwise_casted_x, - rowwise_casted_kernel, + casted_x_lhs, + casted_kernel_rhs, x_shape, kernel_shape, use_bias, quantizer_set, flatten_axis_k, + x_bdim, ) = ctx + fwd_x_contracting_dims, fwd_k_contracting_dims = map( + tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims + ) + casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad + grad, + is_dbias=use_bias, + flatten_axis=flatten_axis_k, + quantizer=quantizer_set.dgrad, + noop_scaled_tensor=True, ) # GEMM NT # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim - g_constracting_dim = tuple( + g_contracting_dim = tuple( range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) ) # k_non_contracting_dims - k_constracting_dim = tuple( + k_contracting_dim = tuple( dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims ) dgrad = tex.gemm( - casted_grad.get_rowwise_tensor(), - rowwise_casted_kernel, - (g_constracting_dim, k_constracting_dim), + casted_grad.get_tensor(usage=TensorUsage.LHS), + casted_kernel_rhs, + contracting_dims=(g_contracting_dim, k_contracting_dim), + batched_dims=((x_bdim,), ()), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) # GEMM TN # x_non_contracting_dims - g_constracting_dim = x_constracting_dim = tuple( + g_contracting_dim = x_contracting_dim = tuple( range(0, len(x_shape) - len(fwd_x_contracting_dims)) ) wgrad = tex.gemm( - colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim) + casted_x_lhs, + casted_grad.get_tensor(usage=TensorUsage.RHS), + contracting_dims=(x_contracting_dim, g_contracting_dim), + batched_dims=((x_bdim,), (x_bdim,)), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) @@ -184,135 +248,238 @@ def _dense_bwd_rule( _dense.defvjp(_dense_fwd_rule, _dense_bwd_rule) -""" def grouped_dense( - x_list, - kernel_list, - bias_list, - contracting_dims_list, - quantizer_set_list=None, + x: jnp.ndarray, + kernel: jnp.ndarray, + group_sizes: jnp.ndarray, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)), + bias: jnp.ndarray = None, + precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, + preferred_element_type: jnp.dtype = None, + group_offset: jnp.array = None, + quantizer_set: QuantizerSet = noop_quantizer_set, ): - # Perform grouped_dense layer transformation with optional quantization. + """ + Perform grouped dense (linear) layer transformation with optional quantization. - output_list = _grouped_dense( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list + Args: + x: Input tensor of shape (M, K) + kernel: Weight matrix of shape (G, K, N) + group_sizes: 1D array of shape (G,) specifying the size of each group + contracting_dims: Tuple of sequences specifying which dimensions to contract + (currently only supports ((1,), (1,))) + bias: Bias tensor of shape (G, N) + precision: JAX precision for the GEMM operation + preferred_element_type: Preferred data type for the output tensor + group_offset: 1D array containing offsets for each group (not yet implemented) + quantizer_set: Set of quantizers for FP8 quantization of the input and output + + Returns: + A jnp.ndarray containing the result of the grouped linear operation + """ + output = _grouped_dense( + x, + kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, + quantizer_set, ) - return output_list + return output -@partial(jax.custom_vjp, nondiff_argnums=(3,)) -def _grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list): - output_list, _ = _grouped_dense_fwd_rule( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list +@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7)) +def _grouped_dense( + x, + kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, + quantizer_set, +): + output, _ = _grouped_dense_fwd_rule( + x, + kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, + quantizer_set, ) - return output_list + return output def _grouped_dense_fwd_rule( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list + x, + kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, + quantizer_set, ): - use_bias = bias_list is not None - output_list = [] - x_rowwise_list = [] - x_colwise_list = [] - kernel_colwise_list = [] - kernel_rowwise_list = [] - x_shape_list = [] - kernel_shape_list = [] - if quantizer_set_list is None: - x_rowwise_list = x_list - x_colwise_list = x_list - kernel_colwise_list = kernel_list - kernel_rowwise_list = kernel_list - x_shape_list = [x.shape for x in x_list] - kernel_shape_list = [kernel.shape for kernel in kernel_list] + use_bias = bias is not None + is_noop_quantizer_set = quantizer_set == noop_quantizer_set + + if is_noop_quantizer_set: + grouped_gemm_x = x + grouped_gemm_kernel = kernel + ctx_x = x + ctx_kernel = kernel + flatten_axis_k = None else: - for i in range(len(x_list)): # pylint: disable=consider-using-enumerate - q_x = tex.quantize(x_list[i], quantizer_set_list[i].x) - q_kernel = tex.quantize(kernel_list[i], quantizer_set_list[i].kernel) - x_rowwise_list.append(q_x.get_rowwise_tensor()) - x_colwise_list.append(q_x.get_colwise_tensor()) - kernel_colwise_list.append(q_kernel.get_colwise_tensor()) - kernel_rowwise_list.append(q_kernel.get_rowwise_tensor()) - x_shape_list.append(x_rowwise_list[-1].data.shape) - kernel_shape_list.append(kernel_rowwise_list[-1].data.shape) - - output_list = tex.grouped_gemm( - x_rowwise_list, kernel_colwise_list, contracting_dims_list, bias_list + x_contracting_dims, k_contracting_dims = contracting_dims + flatten_axis_x = -len(x_contracting_dims) + flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis + + assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)" + assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)" + # Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose + # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? + assert x_contracting_dims == (1,) and k_contracting_dims == (1,), ( + "grouped_dense for FP8 can only handle x_contracting_dims=(1,) " + "and k_contracting_dims=(1,) for now, " + f"got {x_contracting_dims=} and {k_contracting_dims=}" + ) + + casted_x = tex.grouped_quantize( + x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x + ) + casted_kernel = tex.grouped_quantize( + kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k + ) + contracting_dims = (x_contracting_dims, k_contracting_dims) + + # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have + # rowwise_casted_x.original_shape == (M, K) + # colwise_casted_kernel.original_shape == (G, N, K) + grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS) + grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS) + ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS) + ctx_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS) + + output = tex.grouped_gemm( + grouped_gemm_x, + grouped_gemm_kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, ) ctx = ( - x_colwise_list, - kernel_rowwise_list, - x_shape_list, - kernel_shape_list, + group_sizes, + ctx_x, + ctx_kernel, + x.shape, + kernel.shape, use_bias, - quantizer_set_list, + is_noop_quantizer_set, + quantizer_set, + flatten_axis_k, ) - return output_list, ctx + return output, ctx + +def _grouped_dense_bwd_rule( + contracting_dims, precision, preferred_element_type, group_offset, ctx, grad +): + fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims -def _grouped_dense_bwd_rule(contracting_dims_list, ctx, grad_list): ( - colwise_x_list, - rowwise_kernel_list, - x_shape_list, - kernel_shape_list, + group_sizes, + ctx_x, + ctx_kernel, + x_shape, + kernel_shape, use_bias, - quantizer_set_list, + is_noop_quantizer_set, + quantizer_set, + flatten_axis_k, ) = ctx - group_size = len(grad_list) - dbias_list = [] - grad_rowwise_list = [] - grad_colwise_list = [] - dgrad_contracting_dims_list = [] - wgrad_contracting_dims_list = [] - for i in range(group_size): - grad = grad_list[i] - x_shape = x_shape_list[i] - kernel_shape = kernel_shape_list[i] - fwd_contracting_dims = contracting_dims_list[i] - - if quantizer_set_list is None: - casted_grad = grad - dbias = tex.quantization._jax_dbias(grad) - grad_rowwise_list.append(grad) - grad_colwise_list.append(grad) - else: - quantizer_set = quantizer_set_list[i] - casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad - ) - grad_rowwise_list.append(casted_grad.get_rowwise_tensor()) - grad_colwise_list.append(casted_grad.get_colwise_tensor()) - dbias_list.append(dbias) - - # GEMM NT - fwd_x_contracting_dims, fwd_k_contracting_dims = fwd_contracting_dims + if is_noop_quantizer_set: + # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?) + # g_contracting_dim = (1, ) + # k_contracting_dim = (2, ) g_contracting_dim = tuple( - range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) + range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) ) k_contracting_dim = tuple( - dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims + dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims ) dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) - dgrad_contracting_dims_list.append(dgrad_contracting_dims) + dgrad_grad = grad + dgrad_kernel_T = ctx_kernel - # GEMM TN + # g_contracting_dim = (0, ) + # x_contracting_dim = (0, ) g_contracting_dim = x_contracting_dim = tuple( range(0, len(x_shape) - len(fwd_x_contracting_dims)) ) wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) - wgrad_contracting_dims_list.append(wgrad_contracting_dims) + wgrad_x_T = ctx_x + wgrad_grad = grad + else: + casted_grad = tex.grouped_quantize( + grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k + ) - dgrad_list = tex.grouped_gemm( - grad_rowwise_list, rowwise_kernel_list, dgrad_contracting_dims_list + # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use + # g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the + # extra transpose for FP8 in grouped_gemm + # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? + g_contracting_dim = (1,) + k_contracting_dim = (2,) + dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) + dgrad_grad = casted_grad.get_tensor(usage=TensorUsage.LHS) + dgrad_kernel_T = ctx_kernel + + # We need to use g_contracting_dim = (0,) and x_contracting_dim = (0,) to make it work + # after the extra transpose for FP8 in grouped_gemm + # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? + g_contracting_dim = (0,) + x_contracting_dim = (0,) + wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) + wgrad_x_T = ctx_x + wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS) + + dgrad = tex.grouped_gemm( + dgrad_grad, + dgrad_kernel_T, + group_sizes, + dgrad_contracting_dims, + precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset, ) - wgrad_list = tex.grouped_gemm(colwise_x_list, grad_colwise_list, wgrad_contracting_dims_list) - return dgrad_list, wgrad_list, dbias_list, quantizer_set_list + wgrad = tex.grouped_gemm( + wgrad_x_T, + wgrad_grad, + group_sizes, + wgrad_contracting_dims, + precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset, + ) + + group_sizes_grad = None + dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None + + return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set _grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule) -""" diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index bd311472f..5992d3607 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -6,7 +6,7 @@ """ from functools import reduce import operator -from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union +from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType import numpy as np import jax.numpy as jnp @@ -15,12 +15,12 @@ from jax import random as jax_random from jax.ad_checkpoint import checkpoint_name -from ..dense import dense +from ..dense import dense, _issue_batch_first_warning as _dense_warning from ..layernorm import canonicalize_norm_type from ..layernorm import layernorm -from ..layernorm_dense import layernorm_dense -from ..layernorm_mlp import layernorm_mlp +from ..layernorm_dense import layernorm_dense, _issue_batch_first_warning as _ln_dense_warning +from ..layernorm_mlp import layernorm_mlp, _issue_batch_first_warning as _ln_mlp_warning from ..activation import activation from ..softmax import softmax, SoftmaxType from ..sharding import with_sharding_constraint_by_logical_axes @@ -35,8 +35,8 @@ PRNGKey = Any Shape = Tuple[int, ...] -DType = jnp.dtype -Array = jnp.ndarray +DType = NewType("DType", jnp.dtype) +Array = NewType("Array", jnp.ndarray) PrecisionLike = Union[ None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] ] @@ -441,6 +441,12 @@ class DenseGeneral(TransformerEngineBase): input_axes: Tuple[str, ...] = () def __post_init__(self): + if self.transpose_batch_sequence: + _dense_warning( + "TE/JAX DenseGeneral() module does not officially support sequence-first inputs " + "and may produce incorrect results when `transpose_batch_sequence=True`. Use " + "sequence-first inputs at your own discretion." + ) if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( 1.0, "fan_in", "truncated_normal", dtype=self.dtype @@ -657,6 +663,12 @@ class LayerNormDenseGeneral(TransformerEngineBase): depth_scaling: float = None def __post_init__(self): + if self.transpose_batch_sequence: + _ln_dense_warning( + "TE/JAX LayerNormDenseGeneral() module does not officially support sequence-first " + "inputs and may produce incorrect results when `transpose_batch_sequence=True`. " + "Use sequence-first inputs at your own discretion." + ) if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( 1.0, @@ -967,6 +979,12 @@ class LayerNormMLP(TransformerEngineBase): dot_2_input_axes: Tuple[str, ...] = None def __post_init__(self): + if self.transpose_batch_sequence: + _ln_mlp_warning( + "TE/JAX LayerNormMLP() module does not officially support sequence-first inputs " + "and may produce incorrect results when `transpose_batch_sequence=True`. Use " + "sequence-first inputs at your own discretion." + ) if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( 1.0, "fan_in", "truncated_normal", dtype=self.dtype diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index e77974187..6161c59e9 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -182,8 +182,9 @@ def __call__( attn_weights_without_groups_shape = (b, h * g, q, k) attn_weights = attn_weights.reshape(attn_weights_without_groups_shape) + # (b, h, q, k): Last two axes are always replicated attn_weights = with_sharding_constraint_by_logical_axes( - attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES) + attn_weights, (BATCH_AXES, HEAD_AXES, None, None) ) # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias) @@ -596,6 +597,12 @@ def __call__( seqlen_kv = seqlen_q else: seqlen_kv = key.shape[sequence_dim] + if qkv_layout.is_separate(): + head_dim_qk = query.shape[-1] + head_dim_v = value.shape[-1] + else: + head_dim_qk = self.head_dim + head_dim_v = self.head_dim if qkv_layout.is_separate(): head_dim_qk = query.shape[-1] @@ -605,6 +612,8 @@ def __call__( head_dim_v = self.head_dim has_fused_attn_kernel = is_fused_attn_kernel_available( + # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode. + not deterministic, self.dtype, self.dtype, qkv_layout, diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 727ff78c2..5ccfc71c2 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -9,6 +9,7 @@ distributed training through sharding constraints. """ +import warnings from functools import partial from typing import Tuple @@ -21,9 +22,20 @@ QuantizerSet, noop_quantizer_set, with_sharding_constraint_by_logical_axes, + TensorUsage, ) +LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False + + +def _issue_batch_first_warning(msg): + global LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED + if not LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED: + warnings.warn(msg, UserWarning) + LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = True + + def layernorm_dense( x: jnp.ndarray, kernel: jnp.ndarray, @@ -36,6 +48,7 @@ def layernorm_dense( layernorm_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, + batch_first: bool = True, quantizer_set: QuantizerSet = noop_quantizer_set, ) -> jnp.ndarray: """Apply layer normalization followed by dense layer transformation. @@ -56,6 +69,7 @@ def layernorm_dense( layernorm_input_axes: Logical axes for sharding the layernorm input dot_input_axes: Logical axes for sharding the matrix multiplication input kernel_axes: Logical axes for sharding the weight matrix + batch_first: Assume that X is batched in the first dimension if it has more than 2 dims. quantizer_set: Set of quantizers for different tensor types Returns: @@ -79,6 +93,7 @@ def layernorm_dense( layernorm_input_axes, dot_input_axes, kernel_axes, + batch_first, quantizer_set, ) return output @@ -93,6 +108,7 @@ def layernorm_dense( 8, 9, 10, + 11, ), ) def _layernorm_dense( @@ -107,6 +123,7 @@ def _layernorm_dense( layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...], + batch_first: bool, quantizer_set, ): """Internal implementation of layernorm_dense with custom VJP. @@ -126,6 +143,7 @@ def _layernorm_dense( epsilon: Small constant for numerical stability layernorm_input_axes: Logical axes for layernorm sharding dot_input_axes: Logical axes for matrix multiplication sharding + batch_first: Assume that X is batched in the first dimension. quantizer_set: Set of quantizers Returns: @@ -143,6 +161,7 @@ def _layernorm_dense( layernorm_input_axes, dot_input_axes, kernel_axes, + batch_first, quantizer_set, ) return output @@ -160,6 +179,7 @@ def _layernorm_dense_fwd_rule( layernorm_input_axes, dot_input_axes, kernel_axes, + batch_first, quantizer_set, ): """Forward pass rule for layernorm_dense. @@ -177,6 +197,17 @@ def _layernorm_dense_fwd_rule( k_contracting_dims = (0,) assert x.shape[-1] == kernel.shape[0] + x_bdim = None + if x.ndim > 2: + if not batch_first: + _issue_batch_first_warning( + "TE/JAX `layernorm_dense()` fused-layer implementation does not officially " + "support sequence-first inputs and may produce incorrect results when " + "`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first " + "inputs at your own discretion." + ) + x_bdim = 0 if batch_first else x.ndim - 2 + x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) casted_ln_out, mu, rsigma = tex.normalization_fwd( @@ -186,31 +217,37 @@ def _layernorm_dense_fwd_rule( zero_centered_gamma, epsilon, norm_type, - quantizer_set.x, + quantizer=quantizer_set.x, + noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) # Kernel in (hidden_in, hidden_out...) flatten_axis = 1 - len(kernel.shape) - casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel) + casted_kernel = tex.quantize( + kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True + ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) # NN GEMM # (batch..., hidden_in) x (hidden_in, hidden_out...) + use_bias = bias is not None output = tex.gemm( - casted_ln_out.get_rowwise_tensor(), - casted_kernel.get_colwise_tensor(), - (x_contracting_dims, k_contracting_dims), + casted_ln_out.get_tensor(TensorUsage.LHS), + casted_kernel.get_tensor(TensorUsage.RHS), + contracting_dims=(x_contracting_dims, k_contracting_dims), + batched_dims=((x_bdim,), ()), + bias=bias if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, ) - use_bias = bias is not None - if use_bias: + if use_bias and tex.gemm_uses_jax_dot(): bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) ctx = ( - casted_ln_out.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None, - casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None, + casted_ln_out.get_tensor(TensorUsage.LHS_TRANS), + casted_kernel.get_tensor(TensorUsage.RHS_TRANS), x.shape, kernel.shape, mu, @@ -223,6 +260,7 @@ def _layernorm_dense_fwd_rule( use_bias, quantizer_set, flatten_axis, + x_bdim, ) return output, ctx @@ -235,6 +273,7 @@ def _layernorm_dense_bwd_rule( layernorm_input_axes, dot_input_axes, # pylint: disable=unused-argument kernel_axes, + batch_first, # pylint: disable=unused-argument ctx, grad, ): @@ -250,8 +289,8 @@ def _layernorm_dense_bwd_rule( Tuple of gradients for all input parameters """ ( - colwise_casted_ln_out, - rowwise_casted_kernel, + casted_ln_out, + casted_kernel, x_shape, kernel_shape, mu, @@ -264,10 +303,15 @@ def _layernorm_dense_bwd_rule( use_bias, quantizer_set, flatten_axis, + x_bdim, ) = ctx casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad + grad, + is_dbias=use_bias, + flatten_axis=flatten_axis, + quantizer=quantizer_set.dgrad, + noop_scaled_tensor=True, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim @@ -281,9 +325,10 @@ def _layernorm_dense_bwd_rule( # NT GEMM dgrad = tex.gemm( - casted_grad.get_rowwise_tensor(), - rowwise_casted_kernel, - (g_constracting_dim, k_constracting_dim), + casted_grad.get_tensor(TensorUsage.LHS), + casted_kernel, + contracting_dims=(g_constracting_dim, k_constracting_dim), + batched_dims=((x_bdim,), ()), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) @@ -294,9 +339,10 @@ def _layernorm_dense_bwd_rule( # TN GEMM wgrad = tex.gemm( - colwise_casted_ln_out, - casted_grad.get_colwise_tensor(), - (x_constracting_dim, g_constracting_dim), + casted_ln_out, + casted_grad.get_tensor(TensorUsage.RHS), + contracting_dims=(x_constracting_dim, g_constracting_dim), + batched_dims=((x_bdim,), (x_bdim,)), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index e04b93023..507c49c7e 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -13,6 +13,7 @@ quantization, and distributed training through sharding constraints. """ +import warnings from typing import List, Tuple, Sequence, Union, Callable from functools import partial @@ -22,10 +23,25 @@ from . import cpp_extensions as tex from .layernorm import canonicalize_norm_type -from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set +from .quantize import ( + with_sharding_constraint_by_logical_axes, + QuantizerSet, + noop_quantizer_set, + TensorUsage, +) from .sharding import get_non_contracting_logical_axes +LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False + + +def _issue_batch_first_warning(msg): + global LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED + if not LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED: + warnings.warn(msg, UserWarning) + LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = True + + def layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -43,6 +59,7 @@ def layernorm_mlp( ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), + batch_first: bool = True, quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), ) -> jnp.ndarray: """Apply layer normalization followed by MLP block. @@ -74,6 +91,7 @@ def layernorm_mlp( ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network activation_type: Activation function(s) to apply after the first dense layer transformation + batch_first: Assume that X is batched in the first dimension if it has more than 2 dims. quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations Returns: @@ -119,12 +137,13 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, quantizer_sets, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)) def _layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -144,6 +163,7 @@ def _layernorm_mlp( ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], + batch_first: bool, quantizer_sets, ): """Internal implementation of layernorm_mlp with custom VJP. @@ -169,6 +189,7 @@ def _layernorm_mlp( ffn1_ckpt_name: Name for first feed-forward network checkpointing ffn2_ckpt_name: Name for second feed-forward network checkpointing activation_type: Activation function(s) + batch_first: Assume that X is batched in the first dimension. quantizer_sets: Tuple of quantizer sets Returns: @@ -193,6 +214,7 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, quantizer_sets, ) return output @@ -217,6 +239,7 @@ def _layernorm_mlp_fwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, quantizer_sets, ): """Forward pass rule for layernorm_mlp. @@ -249,6 +272,17 @@ def _layernorm_mlp_fwd_rule( assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] + x_bdim = None + if x.ndim > 2: + if not batch_first: + _issue_batch_first_warning( + "TE/JAX `layernorm_mlp()` fused-layer implementation does not officially " + "support sequence-first inputs and may produce incorrect results when " + "`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first " + "inputs at your own discretion." + ) + x_bdim = 0 if batch_first else x.ndim - 2 + use_bias_1 = bias_1 is not None use_bias_2 = bias_1 is not None @@ -262,17 +296,23 @@ def _layernorm_mlp_fwd_rule( epsilon, norm_type, quantizer=ffn1_quantizer_set.x, + noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) - casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel) + casted_kernel_1 = tex.quantize( + kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True + ) # NN GEMM # (batch..., hidden_in) x (hidden_in, hidden_out) dot_1_output = tex.gemm( - casted_ln_out.get_rowwise_tensor(), - casted_kernel_1.get_colwise_tensor(), - (x_contracting_dims, k_contracting_dims), + casted_ln_out.get_tensor(TensorUsage.LHS), + casted_kernel_1.get_tensor(TensorUsage.RHS), + contracting_dims=(x_contracting_dims, k_contracting_dims), + batched_dims=((x_bdim,), ()), + bias=bias_1 if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, ) if dot_1_input_axes is not None and kernel_1_axes is not None: @@ -282,7 +322,7 @@ def _layernorm_mlp_fwd_rule( ) dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes) - if use_bias_1: + if use_bias_1 and tex.gemm_uses_jax_dot(): bias_1_shape = bias_1.shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape dot_1_output += jnp.reshape(bias_1, bias_1_new_shape) @@ -290,21 +330,28 @@ def _layernorm_mlp_fwd_rule( dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) # (batch..., hidden_in) -> (batch..., hidden) - casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x) + casted_act_out = tex.act_lu( + dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True + ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) - casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel) + casted_kernel_2 = tex.quantize( + kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True + ) # NN GEMM # (batch..., hidden_in) x (hidden_out, hidden_in) dot_2_output = tex.gemm( - casted_act_out.get_rowwise_tensor(), - casted_kernel_2.get_colwise_tensor(), - (x_contracting_dims, k_contracting_dims), + casted_act_out.get_tensor(TensorUsage.LHS), + casted_kernel_2.get_tensor(TensorUsage.RHS), + contracting_dims=(x_contracting_dims, k_contracting_dims), + batched_dims=((x_bdim,), ()), + bias=bias_2 if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, ) - if use_bias_2: + if use_bias_2 and tex.gemm_uses_jax_dot(): bias_2_shape = bias_2.shape bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape dot_2_output += jnp.reshape(bias_2, bias_2_new_shape) @@ -317,11 +364,11 @@ def _layernorm_mlp_fwd_rule( rsigma, gamma, beta, - casted_ln_out.get_colwise_tensor(), - casted_kernel_1.get_rowwise_tensor(), + casted_ln_out.get_tensor(TensorUsage.LHS_TRANS), + casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS), dot_1_output, - casted_act_out.get_colwise_tensor(), - casted_kernel_2.get_rowwise_tensor(), + casted_act_out.get_tensor(TensorUsage.LHS_TRANS), + casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS), x_contracting_dims, k_contracting_dims, kernel_1.shape, @@ -329,6 +376,7 @@ def _layernorm_mlp_fwd_rule( use_bias_1, use_bias_2, quantizer_sets, + x_bdim, ) return dot_2_output, ctx @@ -346,6 +394,7 @@ def _layernorm_mlp_bwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, ctx, grad, ): @@ -362,18 +411,18 @@ def _layernorm_mlp_bwd_rule( Returns: Tuple of gradients for all input parameters """ - del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name + del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, batch_first ( x, mu, rsigma, gamma, beta, - colwise_casted_ln_out, - rowwise_casted_kernel_1, + casted_ln_out, + casted_kernel_1, dot_1_output, - colwise_casted_act_out, - rowwise_casted_kernel_2, + casted_act_out, + casted_kernel_2, x_contracting_dims_in_fwd, k_contracting_dims_in_fwd, kernel_1_shape, @@ -381,6 +430,7 @@ def _layernorm_mlp_bwd_rule( use_bias_1, use_bias_2, quantizer_sets, + x_bdim, ) = ctx ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets @@ -389,7 +439,7 @@ def _layernorm_mlp_bwd_rule( grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) casted_grad, dbias_2 = tex.quantize_dbias( - grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad + grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -404,9 +454,10 @@ def _layernorm_mlp_bwd_rule( # NT GEMM # (batch..., hidden_out) x (hidden_in, hidden_out) dgrad_2 = tex.gemm( - casted_grad.get_rowwise_tensor(), - rowwise_casted_kernel_2, - (g_contracting_dims_2, k_contracting_dims_2), + casted_grad.get_tensor(TensorUsage.LHS), + casted_kernel_2, + contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), + batched_dims=((x_bdim,), ()), ) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) @@ -418,9 +469,10 @@ def _layernorm_mlp_bwd_rule( # TN GEMM # (hidden, batch...,) x (hidden, batch...) wgrad_2 = tex.gemm( - colwise_casted_act_out, - casted_grad.get_colwise_tensor(), - (x_contracting_dims, g_contracting_dims), + casted_act_out, + casted_grad.get_tensor(TensorUsage.RHS), + contracting_dims=(x_contracting_dims, g_contracting_dims), + batched_dims=((x_bdim,), (x_bdim,)), ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) @@ -430,10 +482,11 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, + noop_scaled_tensor=True, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim - dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim + dact_out_ndim = casted_dact_out.get_tensor(TensorUsage.LHS).data.ndim g_contracting_dims_1 = tuple( range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim) ) @@ -444,9 +497,10 @@ def _layernorm_mlp_bwd_rule( # NT GEMM dgrad_1 = tex.gemm( - casted_dact_out.get_rowwise_tensor(), - rowwise_casted_kernel_1, - (g_contracting_dims_1, k_contracting_dims_1), + casted_dact_out.get_tensor(TensorUsage.LHS), + casted_kernel_1, + contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), + batched_dims=((x_bdim,), ()), ) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) @@ -454,9 +508,10 @@ def _layernorm_mlp_bwd_rule( # TN GEMM # (hidden, batch...) x (hidden, batch...) wgrad_1 = tex.gemm( - colwise_casted_ln_out, - casted_dact_out.get_colwise_tensor(), - (x_contracting_dims, g_contracting_dims), + casted_ln_out, + casted_dact_out.get_tensor(TensorUsage.RHS), + contracting_dims=(x_contracting_dims, g_contracting_dims), + batched_dims=((x_bdim,), (x_bdim,)), ) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) diff --git a/transformer_engine/jax/pyproject.toml b/transformer_engine/jax/pyproject.toml new file mode 100755 index 000000000..d664eb9b1 --- /dev/null +++ b/transformer_engine/jax/pyproject.toml @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +[build-system] +requires = ["setuptools>=61.0", "pybind11[global]", "pip", "jax", "flax>=0.7.1"] + +# Use legacy backend to import local packages in setup.py +build-backend = "setuptools.build_meta:__legacy__" + diff --git a/transformer_engine/jax/quantize/__init__.py b/transformer_engine/jax/quantize/__init__.py index aa36df7a2..11f692917 100644 --- a/transformer_engine/jax/quantize/__init__.py +++ b/transformer_engine/jax/quantize/__init__.py @@ -15,3 +15,4 @@ from .scaling_modes import * from .metadata import * from .helper import * +from .device_utils import * diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index d43c61c9f..9d46c3c30 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -7,24 +7,70 @@ This module provides utilities for dequantizing tensors that have been quantized using various scaling modes, including delayed scaling and block scaling. """ +import math +from dataclasses import dataclass +from abc import ABC, abstractmethod + import jax import jax.numpy as jnp from .scaling_modes import ScalingMode -__all__ = ["Dequantizer"] +__all__ = ["ScalingModeToDequantizerMap"] + + +@dataclass +class Dequantizer(ABC): + """ + Base Dequantizer Class + """ + @staticmethod + @abstractmethod + def _dequantize_func(data, scale_inv, dq_dtype, **kwargs): + pass -class Dequantizer: - """Encapsulation class for dequantization helpers. + @staticmethod + @abstractmethod + def dequantize(scaled_tensor): + """Dequantizing given tensor to higher precision.""" + + +@dataclass +class NoopDequantizer(Dequantizer): + """No-op Dequantizer Class""" + + @staticmethod + def _dequantize_func(data, *args, **kwargs): + """A no-op dequantize function that returns the data without any changes.""" + del args, kwargs + return data + + @staticmethod + def dequantize(scaled_tensor): + """A no-op dequantize function that simply returns the data array in the ScaledTensor.""" + return scaled_tensor.data + + +class TensorScaleDequantizer(Dequantizer): + """ + TensorScaling Dequantizer Class This class provides static methods for dequantizing tensors that have been - quantized using different scaling modes. It supports both delayed scaling - and block scaling modes. + quantized using different tensor scaling modes. It supports both delayed scaling + and current scaling modes. """ @staticmethod - def _dq_func_tensor_scaling(scaled_tensor): + def _dequantize_func(data, scale_inv, dq_dtype, **kwargs): + del kwargs + return jnp.asarray( + data.astype(jnp.float32) * scale_inv.astype(jnp.float32), + dq_dtype, + ) + + @staticmethod + def dequantize(scaled_tensor): """Dequantize a tensor using delayed scaling. This function dequantizes a tensor that was quantized using delayed scaling @@ -36,36 +82,45 @@ def _dq_func_tensor_scaling(scaled_tensor): Returns: The dequantized tensor in the specified data type """ - return jnp.asarray( - scaled_tensor.data.astype(jnp.float32) * scaled_tensor.scale_inv.astype(jnp.float32), - scaled_tensor.dq_dtype, + return TensorScaleDequantizer._dequantize_func( + scaled_tensor.data, scaled_tensor.scale_inv, scaled_tensor.dq_dtype ) + +class BlockScaleDequantizer(Dequantizer): + """BlockScaling Dequantizer Class. + + This class provides static methods for dequantizing tensors that have been + quantized using block scaling modes. + """ + @staticmethod - def _dq_func_block_scaling(scaled_tensor): + def _dequantize_func(data, scale_inv, dq_dtype, scaling_mode, is_colwise, flatten_axis): """Dequantize a tensor using block scaling. - This function dequantizes a tensor that was quantized using block scaling - by applying the inverse scaling factor to each block of data. - Args: - scaled_tensor: The quantized tensor to dequantize + data: The quantized tensor data + scale_inv: The inverse scaling factors + dq_dtype: The data type for dequantized values + scaling_mode: The scaling mode used for quantization + is_colwise: Whether the scaling is column-wise + flatten_axis: The axis along which the tensor could be flattened to 2D Returns: - The dequantized tensor in the specified data type + The dequantized tensor """ - data = scaled_tensor.data.astype(jnp.float32) + + data = data.astype(jnp.float32) + scale_inv = scale_inv.view(jnp.uint8).astype(jnp.float32) + data_shape = data.shape - scale = scaled_tensor.scale_inv.view(jnp.uint8).astype(jnp.float32) - flatten_axis = scaled_tensor.flatten_axis flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis assert ( 0 < flatten_axis < len(data_shape) ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" - scale_shape = scaled_tensor.scaling_mode.get_scale_shape( - data_shape, scaled_tensor.is_colwise, is_padded=False, flatten_axis=flatten_axis + scale_shape = scaling_mode.get_scale_shape( + data_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis ) - scale = jax.lax.slice(scale, [0] * len(scale_shape), scale_shape) # slice out the padding data = data.reshape( *data_shape[: flatten_axis - 1], @@ -76,31 +131,117 @@ def _dq_func_block_scaling(scaled_tensor): int(data_shape[-1] / scale_shape[-1]), ) - # E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers. - scale = jnp.expand_dims(scale, axis=(flatten_axis + 2 - 2, -1)) - # E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers. - return jnp.asarray(data * jnp.power(2, scale - 127), scaled_tensor.dq_dtype).reshape( - data_shape - ) + scale_inv = jnp.expand_dims(scale_inv, axis=(flatten_axis + 2 - 2, -1)) - funcs = { - ScalingMode.DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling, - ScalingMode.CURRENT_TENSOR_SCALING: _dq_func_tensor_scaling, - ScalingMode.MXFP8_1D_SCALING: _dq_func_block_scaling, - } + # E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers. + return jnp.asarray(data * jnp.power(2, scale_inv - 127), dq_dtype).reshape(data_shape) @staticmethod def dequantize(scaled_tensor): - """Dequantize a scaled tensor using the appropriate scaling mode. - - This method selects the appropriate dequantization function based on the - scaling mode used for quantization and applies it to the tensor. + """Dequantize a tensor using block scaling. Args: - scaled_tensor: The quantized tensor to dequantize + data: The quantized tensor data + scale_inv: The inverse scaling factors + dq_dtype: The data type for dequantized values + scaling_mode: The scaling mode used for quantization + is_colwise: Whether the scaling is column-wise + flatten_axis: The axis along which the tensor could be flattened to 2D Returns: - The dequantized tensor in the specified data type + The dequantized tensor """ - dq_func = Dequantizer.funcs[scaled_tensor.scaling_mode] - return dq_func(scaled_tensor) + return BlockScaleDequantizer._dequantize_func( + scaled_tensor.data, + scaled_tensor.scale_inv, + scaled_tensor.dq_dtype, + scaled_tensor.scaling_mode, + scaled_tensor.is_colwise, + scaled_tensor.flatten_axis, + ) + + +ScalingModeToDequantizerMap = { + ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer, + ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer, + ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer, + ScalingMode.NO_SCALING: NoopDequantizer, +} + + +@staticmethod +def _grouped_dequantize(grouped_scaled_tensor): + """Dequantize a grouped tensor. + + Args: + grouped_scaled_tensor: The grouped scaled tensor to dequantize + + Returns: + List of dequantized tensors for each group + """ + data = grouped_scaled_tensor.data + scale_inv = grouped_scaled_tensor.scale_inv + group_sizes = grouped_scaled_tensor.group_sizes + flatten_axis = grouped_scaled_tensor.flatten_axis + scaling_mode = grouped_scaled_tensor.scaling_mode + original_shape = grouped_scaled_tensor.original_shape + group_axis = grouped_scaled_tensor.group_axis + + flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis + + output = [] + non_group_shape = tuple( + original_shape[i] for i in range(len(original_shape)) if i != group_axis + ) + matrix_sizes = group_sizes * math.prod(non_group_shape) + + data = jnp.split(data, jnp.cumulative_sum(matrix_sizes)[:-1]) + + scale_inv_ptr = 0 + for i, data_i in enumerate(data): + data_shape_i = ( + *original_shape[:group_axis], + group_sizes[i], + *original_shape[group_axis + 1 :], + ) + assert math.prod(data_shape_i) == data_i.size, ( + f"math.prod({data_shape_i}) = {math.prod(data_shape_i)} which is not equal to" + f" {data_i.size}" + ) + padded_scale_shape_i = scaling_mode.get_scale_shape( + data_shape_i, + grouped_scaled_tensor.is_colwise, + is_padded=True, + flatten_axis=flatten_axis, + ) + unpadded_scale_shape_i = scaling_mode.get_scale_shape( + data_shape_i, + grouped_scaled_tensor.is_colwise, + is_padded=False, + flatten_axis=flatten_axis, + ) + scale_inv_i = scale_inv[ + scale_inv_ptr : scale_inv_ptr + math.prod(padded_scale_shape_i) + ].reshape(padded_scale_shape_i) + scale_inv_i = jax.lax.slice( + scale_inv_i, [0] * len(unpadded_scale_shape_i), unpadded_scale_shape_i + ) + dequantizer_type = ScalingModeToDequantizerMap.get(grouped_scaled_tensor.scaling_mode) + if len(data_i) == 0: + out_i = [] + else: + out_i = dequantizer_type._dequantize_func( + data_i.reshape(data_shape_i), + scale_inv_i, + grouped_scaled_tensor.dq_dtype, + scaling_mode=grouped_scaled_tensor.scaling_mode, + is_colwise=grouped_scaled_tensor.is_colwise, + flatten_axis=grouped_scaled_tensor.flatten_axis, + ) + output.append(out_i) + scale_inv_ptr += math.prod(padded_scale_shape_i) + + return output + + +Dequantizer.grouped_dequantize = _grouped_dequantize diff --git a/transformer_engine/jax/quantize/device_utils.py b/transformer_engine/jax/quantize/device_utils.py new file mode 100644 index 000000000..ca90ba9fb --- /dev/null +++ b/transformer_engine/jax/quantize/device_utils.py @@ -0,0 +1,41 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Device utility functions for JAX quantization. + +This module provides utility functions for checking device capabilities and compatibility +for quantization operations in JAX. +""" + +import functools + +import transformer_engine_jax + +from ..util import is_hip_extension + +__all__ = [ + "get_device_compute_capability", + "is_fp8_gemm_with_all_layouts_supported", +] + + +@functools.lru_cache(maxsize=None) +def get_device_compute_capability(gpu_id: int = 0) -> int: + """ + Get the compute capability of the device. + """ + return transformer_engine_jax.get_device_compute_capability(gpu_id) + + +@functools.lru_cache(maxsize=None) +def is_fp8_gemm_with_all_layouts_supported() -> bool: + """Return True if using Blackwell architecture, False otherwise.""" + compute_capability = get_device_compute_capability() + if is_hip_extension(): + # gfx950 --> NV blackwell + return compute_capability == 95 + return 100 <= compute_capability < 120 diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index ec5a31c6b..0b9659a46 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -11,7 +11,9 @@ """ from contextlib import contextmanager from enum import Enum -from typing import Optional, Tuple, Dict, Union +from typing import Optional, Tuple, Dict, Union, Sequence +from functools import reduce +import operator import jax import jax.numpy as jnp @@ -20,19 +22,17 @@ from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type from transformer_engine_jax import DType -if is_hip_extension(): - from transformer_engine_jax import get_device_compute_capability -else: +if not is_hip_extension(): from transformer_engine_jax import ( get_cublasLt_version, get_cuda_version, - get_device_compute_capability, ) from transformer_engine.common import recipe from transformer_engine.jax.sharding import global_shard_guard, MeshResource from .scaling_modes import ScalingMode from .. import cpp_extensions as tex +from .device_utils import get_device_compute_capability __all__ = [ "QuantizeConfig", @@ -40,6 +40,8 @@ "is_fp8_available", "update_collections", "get_delayed_scaling", + "apply_padding_to_scale_inv", + "remove_padding_from_scale_inv", "NVTE_FP8_COLLECTION_NAME", ] @@ -217,7 +219,7 @@ class QuantizeConfig: FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients - IF_QUANTIZE_2X: Whether 2x quantization is enabled + INFERENCE_MODE: Whether to enable optimization for inference SCALING_MODE: Scaling mode AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling AMAX_COMPUTE_ALGO: Algorithm for AMAX computation @@ -232,7 +234,7 @@ class QuantizeConfig: FP8_2X_ACC_FPROP: bool = False FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False - IF_QUANTIZE_2X: bool = False + INFERENCE_MODE: bool = False SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING # DelayedScaling @@ -260,7 +262,6 @@ def initialize(cls, fp8_recipe: recipe.Recipe) -> None: cls.FP8_FORMAT = fp8_recipe.fp8_format cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) cls.SCALING_MODE = _get_scaling_mode(fp8_recipe) - cls.IF_QUANTIZE_2X = True @classmethod def finalize(cls) -> None: @@ -274,7 +275,7 @@ def finalize(cls) -> None: cls.FP8_2X_ACC_DGRAD = False cls.FP8_2X_ACC_WGRAD = False cls.SCALING_MODE = ScalingMode.NO_SCALING - cls.IF_QUANTIZE_2X = False + cls.INFERENCE_MODE = False # DelayedScaling cls.AMAX_HISTORY_LEN = 1024 cls.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX @@ -490,4 +491,115 @@ def update_collections(new: Collection, original: Collection) -> Collection: return new_coll +def remove_padding_from_scale_inv( + scale_inv: jax.Array, + scaling_mode: ScalingMode, + data_shape: Sequence[int], + is_colwise: bool = False, + flatten_axis: int = -1, +): + """ + Slice padding out of padded inverse scale factors. + + Args: + scale_inv: Inverse scale factor. + data_shape: Shape of the quantized data the inverse scale belongs to. + scaling_mode: ScalingMode representing the quantization method. + is_colwise: Whether the data was quantized column-wise. + flatten_axis: The axis along with the data could be flattened to 2D. + + Returns: + Inverse scale factor without padding. + """ + # Get expected unpadded scale shape and check if inverse scale already matches + unpadded_scale_shape = scaling_mode.get_scale_shape( + data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis + ) + if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == unpadded_scale_shape: + return scale_inv + + # Get the padded scale shape and make sure inverse scale matches + padded_scale_shape = scaling_mode.get_scale_shape( + data_shape, + is_colwise=is_colwise, + is_padded=True, + flatten_axis=flatten_axis, + ) + assert scale_inv.shape == padded_scale_shape, ( + f"Padded inverse scale factor has wrong shape, expected {padded_scale_shape} but got " + f"{scale_inv.shape} instead." + ) + + # Reshape scale inverse to 2D in two stages to preserve the flatten axis + padded_scale_shape_2d = ( + reduce(operator.mul, padded_scale_shape[:flatten_axis]), + reduce(operator.mul, padded_scale_shape[flatten_axis:]), + ) + scale_inv_2d = jnp.reshape( + jnp.reshape(scale_inv, (padded_scale_shape_2d[0], *scale_inv.shape[flatten_axis:])), + padded_scale_shape_2d, + ) + + # Slice reshaped 2D scale inverse using collapsed 2D unpadded_scale_shape + unpadded_scale_shape_2d = ( + reduce(operator.mul, unpadded_scale_shape[:flatten_axis]), + reduce(operator.mul, unpadded_scale_shape[flatten_axis:]), + ) + scale_inv_2d_unpadded = jnp.asarray( + scale_inv_2d[: unpadded_scale_shape_2d[0], : unpadded_scale_shape_2d[1]] + ) + + # Reshape 2D scale inverse back in two stages in order to preserve the flatten axis + scale_inv_unpadded = jnp.reshape( + jnp.reshape( + scale_inv_2d_unpadded, + (*unpadded_scale_shape[:flatten_axis], scale_inv_2d_unpadded.shape[1]), + ), + unpadded_scale_shape, + ) + return scale_inv_unpadded + + +def apply_padding_to_scale_inv( + scale_inv: jax.Array, + scaling_mode: ScalingMode, + data_shape: Sequence[int], + is_colwise: bool = False, + flatten_axis: int = -1, +): + """ + Pad the scale inverse with zeros to match the necessary padded shape for this scaling + mode. + + Args: + scale_inv: Inverse scale factor. + data_shape: Shape of the quantized data the inverse scale belongs to. + scaling_mode: ScalingMode representing the quantization method. + is_colwise: Whether the data was quantized column-wise. + flatten_axis: The axis along with the data could be flattened to 2D. + + Returns: + Padded inverse scale factor. + """ + # Get the expected padded scale shape and check if inverse scale already matches + padded_scale_shape = scaling_mode.get_scale_shape( + data_shape, is_colwise=is_colwise, is_padded=True, flatten_axis=flatten_axis + ) + if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == padded_scale_shape: + return scale_inv + + # Get the expected unpadded scale shape and make sure inverse scales match + unpadded_scale_shape = scaling_mode.get_scale_shape( + data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis + ) + assert scale_inv.shape == unpadded_scale_shape, ( + f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got " + f"{scale_inv.shape}." + ) + + # Pad the scales with the lowest representable value (2^-127) and return + pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape)) + return jnp.pad(scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127) + + NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index a764f0710..881f3a74b 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -9,7 +9,8 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from functools import partial -from typing import Union, Optional +from typing import Union, Optional, Tuple +import warnings import jax import jax.numpy as jnp @@ -17,11 +18,12 @@ from transformer_engine_jax import QuantizeLayout from .scaling_modes import ScalingMode -from .tensor import ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory +from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory from .helper import ( QuantizeConfig, AmaxComputeAlgo, ) +from .device_utils import is_fp8_gemm_with_all_layouts_supported __all__ = [ "QuantizeLayout", @@ -30,6 +32,7 @@ "CurrentScaleQuantizer", "DelayedScaleQuantizer", "BlockScaleQuantizer", + "GroupedQuantizer", "QuantizerFactory", "noop_quantizer_set", "compute_scale_from_amax", @@ -74,6 +77,7 @@ class Quantizer(ABC): q_dtype: jnp.dtype scaling_mode: ScalingMode q_layout: QuantizeLayout + data_layout: str def tree_flatten(self): """Flatten the quantizer for JAX tree operations. @@ -82,7 +86,7 @@ def tree_flatten(self): Tuple of (children, aux_data) for tree operations """ children = () - aux_data = (self.q_dtype, self.scaling_mode, self.q_layout) + aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout) return (children, aux_data) @classmethod @@ -110,13 +114,22 @@ def is_2x2x(self) -> bool: """ return self.q_layout == QuantizeLayout.ROWWISE_COLWISE - @abstractmethod def get_data_layout(self) -> str: - """Get the data data_layout. + """Get the data data_layout string. Returns: Data data_layout in string format + + Raises: + ValueError: If quantization axis is invalid """ + if self.q_layout == QuantizeLayout.ROWWISE_COLWISE: + return self.data_layout + if self.q_layout == QuantizeLayout.ROWWISE: + return self.data_layout[0] + if self.q_layout == QuantizeLayout.COLWISE: + return self.data_layout[1] + raise ValueError(f"Invalid q_layout: {self.q_layout}") @abstractmethod def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x: @@ -132,7 +145,9 @@ def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> A ScaledTensor1x containing the quantized data """ - def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1): + def quantize( + self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1, **kwargs + ) -> ScaledTensor: """Quantize a tensor using the internal _quantize_func(). Args: @@ -145,6 +160,7 @@ def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten Returns: A ScaledTensor1x or ScaledTensor2x containing the quantized data """ + del kwargs if (is_rowwise and is_colwise) or self.is_2x2x(): rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) colwise_tensor = self._quantize_func( @@ -159,7 +175,7 @@ def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten return self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) - def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1): + def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1, **kwargs): """Get shapes for scale tensors. Args: @@ -169,6 +185,7 @@ def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1): Returns: Tuple of (rowwise_scale_shape, colwise_scale_shape) """ + del kwargs return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded, flatten_axis) def get_scale_dtype(self): @@ -194,24 +211,7 @@ class CurrentScaleQuantizer(Quantizer): scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE - - def get_data_layout(self) -> str: - """Get the data data_layout string. - - Returns: - Data data_layout in string format - - Raises: - ValueError: If quantization axis is invalid - """ - data_layout = "NT" - if self.q_layout == QuantizeLayout.ROWWISE_COLWISE: - return data_layout - if self.q_layout == QuantizeLayout.ROWWISE: - return data_layout[0] - if self.q_layout == QuantizeLayout.COLWISE: - return data_layout[1] - raise ValueError(f"Invalid q_layout: {self.q_layout}") + data_layout: str = "NT" def _quantize_func( self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1 @@ -230,16 +230,11 @@ def _quantize_func( compute_dtype = jnp.float32 dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) - amax = jnp.max(jnp.abs(x)).reshape((1,)).astype(compute_dtype) + amax = jnp.max(jnp.abs(x)).reshape((1,)) fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32) scale = (fp8_max / amax) / (2**QuantizeConfig.MARGIN) scaled_x = x.astype(compute_dtype) * scale - # quantize() in the old dot.py do this way, leave this code block here for future debugging - # compute_dtype = x.dtype - # dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) - # scaled_x = x * self.scale.astype(compute_dtype) - clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) scale_inv = 1.0 / scale return ScaledTensorFactory.create_1x( @@ -295,6 +290,7 @@ def quantize( data_layout="T", flatten_axis=flatten_axis, ) + if is_colwise and is_rowwise: return ScaledTensor2x(rowwise_tensor, colwise_tensor) if is_colwise: @@ -332,7 +328,7 @@ def tree_flatten(self): Tuple of (children, aux_data) for tree operations """ children = (self.scale, self.amax_history) - aux_data = (self.q_dtype, self.scaling_mode, self.q_layout) + aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout) return (children, aux_data) def _quantize_func( @@ -447,16 +443,7 @@ class BlockScaleQuantizer(Quantizer): scaling_mode: ScalingMode = ScalingMode.MXFP8_1D_SCALING q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE - - def get_data_layout(self) -> str: - """Get the data data_layout string. - - Returns: - Data data_layout in string format - """ - if self.is_2x2x(): - return "NN" - return "N" + data_layout: str = "NN" def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x: """Quantize function helper for block scaling FP8. @@ -591,6 +578,190 @@ def tree_unflatten(cls, aux_data, children): return cls(*aux_data, *children) +@register_pytree_node_class +@dataclass +class GroupedQuantizer(Quantizer): + """Quantizer for grouped arrays. + + This class extends Quantizer to support quantization of arrays in grouped manner, + where elements are grouped along a specified axis then quantized separately. + + Attributes: + data_layout: The data layout specification + n_groups: Number of groups for quantization + quantizers: Tuple of quantizers for each group + """ + + data_layout: str = None + n_groups: int = 1 + quantizers: Tuple[Quantizer] = field(default_factory=lambda: (None,)) + + def tree_flatten(self): + """Flatten the quantizer for JAX tree operations. + + Returns: + Tuple of (children, aux_data) for tree operations + """ + children = (self.quantizers,) + aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout, self.n_groups) + return (children, aux_data) + + def __post_init__(self): + if self.quantizers[0] is None: + quantizers = QuantizerFactory.create( + self.n_groups, self.scaling_mode, self.q_dtype, self.q_layout + ) + self.quantizers = (quantizers,) if not isinstance(quantizers, tuple) else quantizers + self.data_layout = self.quantizers[0].data_layout + + def _create_grouped_tensor_from_tensor_list( + self, tensor_list, group_sizes, original_shape, group_axis, mode + ): + # mode 0 = concate, mode 1 = add + # TODO(Ming Huang): Consider to apply Enum for mode. + assert mode in [0, 1] + grouped_data = ( + [] if mode == 0 else jnp.zeros(tensor_list[0].data.shape, tensor_list[0].data.dtype) + ) + grouped_scale_inv = [] + + for tensor in tensor_list: + if mode == 0: + grouped_data.append(tensor.data.flatten()) + else: + grouped_data += tensor.data + grouped_scale_inv.append(tensor.scale_inv.flatten()) + + grouped_data = jnp.concatenate(grouped_data) if mode == 0 else grouped_data.flatten() + grouped_scale_inv = jnp.concatenate(grouped_scale_inv) + + return ScaledTensorFactory.create_1x( + grouped_data, + grouped_scale_inv, + self.scaling_mode, + tensor_list[0].dq_dtype, + tensor_list[0].is_colwise, + tensor_list[0].data_layout, + tensor_list[0].flatten_axis, + group_sizes=group_sizes, + original_shape=original_shape, + group_axis=group_axis, + ) + + def _quantize_func(self, *args, **kwargs): + pass + + def quantize( + self, + x, + is_rowwise: bool = None, + is_colwise: bool = None, + dq_dtype=None, + flatten_axis=-1, + group_sizes=None, + group_axis=0, + ): + """Quantize a tensor in grouped manner. + + Expected input shape: [M, K] or [G, K, N] + Split to x.shape[group_axis] number of groups if group_sizes is not given + + Args: + x: Input tensor to quantize + is_rowwise: Whether to use row-wise quantization + is_colwise: Whether to use column-wise quantization + dq_dtype: Data type for dequantized values + flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) + group_sizes: Array of ints containing the size of each group (default: None) + group_axis: The axis along which grouping is performed (default: 0) + + Returns: + A ScaledTensor1x or ScaledTensor2x containing the quantized data + """ + assert group_axis == 0, "Only group_axis == 0 is supported now!" + + dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + if flatten_axis < 0: + flatten_axis += x.ndim + assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!" + + is_rowwise = ( + is_rowwise + if is_rowwise is not None + else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x()) + ) + is_colwise = ( + is_colwise + if is_colwise is not None + else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x()) + ) + assert is_rowwise or is_colwise, "No quantization layout is specified" + + original_shape = x.shape + + if group_sizes is not None: + assert not is_colwise, "Not yet implememted!" + assert group_sizes.ndim == 1, ( + "GroupedQuantizer only support 1D group_sizes, got group_sizes.ndim =" + f" {group_sizes.ndim}" + ) + + _zeros = partial(jax.lax.full_like, fill_value=0) + + x_iota = jax.lax.broadcasted_iota(group_sizes.dtype, x.shape, 0) + group_ends = jnp.cumulative_sum(group_sizes) + group_starts = jax.lax.concatenate( + [_zeros(group_sizes)[:1], group_ends[:-1]], + dimension=0, + ) + x_zero = _zeros(x) + + tensor_list = [] + for i in range(len(group_sizes)): + mask = jax.lax.bitwise_and(group_starts[i] <= x_iota, x_iota < group_ends[i]) + x_selected = jax.lax.select(mask, x, x_zero) + tensor = self.quantizers[i].quantize( + x_selected, is_rowwise, is_colwise, dq_dtype, flatten_axis + ) + tensor_list.append(tensor) + combine_mode = 1 # Add + else: + group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32) + x = jnp.split(x, x.shape[group_axis], axis=group_axis) + + tensor_list = [] + for i in range(len(group_sizes)): + tensor = self.quantizers[i].quantize( + x[i], is_rowwise, is_colwise, dq_dtype, flatten_axis + ) + tensor_list.append(tensor) + combine_mode = 0 # Concate + + grouped_rowwise_tensor = grouped_colwise_tensor = None + if is_rowwise: + rowwise_tensor_list = [tensor.get_rowwise_tensor() for tensor in tensor_list] + grouped_rowwise_tensor = self._create_grouped_tensor_from_tensor_list( + rowwise_tensor_list, group_sizes, original_shape, group_axis, combine_mode + ) + if is_colwise: + colwise_tensor_list = [tensor.get_colwise_tensor() for tensor in tensor_list] + grouped_colwise_tensor = self._create_grouped_tensor_from_tensor_list( + colwise_tensor_list, group_sizes, original_shape, group_axis, combine_mode + ) + + if is_colwise and is_rowwise: + return ScaledTensor2x(grouped_rowwise_tensor, grouped_colwise_tensor) + if is_colwise: + return grouped_colwise_tensor + return grouped_rowwise_tensor + + def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1, group_sizes=None): + assert group_sizes, "Empty group_sizes was given!" + return self.scaling_mode.get_grouped_scale_shape_2x( + data_shape, group_sizes, is_padded, flatten_axis + ) + + @dataclass class QuantizerFactory: """Factory class for creating quantizers. @@ -611,6 +782,7 @@ def create( scaling_mode: ScalingMode = None, q_dtype: jnp.dtype = None, q_layout: QuantizeLayout = None, + n_groups: int = None, **kwargs, ) -> Quantizer: """Create one or more quantizers with specified parameters. @@ -621,6 +793,7 @@ def create( q_dtype: Quantization data type q_layout: Quantization axis flatten_axis: The quantization axis for the tensor + n_groups: Number of quantizers if GroupedQuantizer **kwargs: Additional arguments for quantizer initialization Returns: @@ -628,13 +801,21 @@ def create( """ # (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type" - # import pdb; pdb.set_trace() + if n_groups: + if n_quantizers != 1: + warnings.warn( + "Using more than one GroupedQuantizer for a grouped input is not recommended" + ) + quantizer_type = GroupedQuantizer + kwargs["n_groups"] = n_groups + else: + quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode) + if scaling_mode == ScalingMode.NO_SCALING: quantizers = [None] * n_quantizers else: quantizers = [] for _ in range(n_quantizers): - quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode) quantizers.append( quantizer_type( q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, **kwargs @@ -643,7 +824,9 @@ def create( return quantizers[0] if len(quantizers) == 1 else tuple(quantizers) @staticmethod - def _create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs) -> QuantizerSet: + def _create_set( + scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs + ) -> QuantizerSet: """Create a set of quantizers for forward and backward passes. Args: @@ -651,6 +834,7 @@ def _create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs) -> Quanti fwd_dtype: Data type for forward pass bwd_dtype: Data type for backward pass is_2x2x: Whether to use 2x2x quantization + n_groups **kwargs: Additional arguments for quantizer initialization Returns: @@ -659,9 +843,11 @@ def _create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs) -> Quanti if is_2x2x: q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE else: - q_layout_x = QuantizeLayout.ROWWISE - q_layout_kernel = QuantizeLayout.COLWISE - q_layout_dgrad = None + q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE + if scaling_mode.is_1d_block_scaling(): + q_layout_kernel = QuantizeLayout.COLWISE + if QuantizeConfig.INFERENCE_MODE: + q_layout_dgrad = None if "quantize_meta_set" in kwargs: quantize_meta_set = kwargs.get("quantize_meta_set") @@ -680,11 +866,13 @@ def _create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs) -> Quanti else: args_x = args_kernel = args_grad = {} - q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, **args_x) + q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x) q_kernel = QuantizerFactory.create( - 1, scaling_mode, fwd_dtype, q_layout_kernel, **args_kernel + 1, scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel + ) + q_dgrad = QuantizerFactory.create( + 1, scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad ) - q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_layout_dgrad, **args_grad) return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad) @staticmethod @@ -694,6 +882,7 @@ def create_set( fwd_dtype: jnp.dtype = None, bwd_dtype: jnp.dtype = None, is_2x2x: bool = None, + n_groups: int = None, **kwargs, ) -> tuple[Union[tuple[Quantizer], None]]: """Create one or more sets of quantizers. @@ -704,6 +893,7 @@ def create_set( fwd_dtype: Data type for forward pass, default is QuantizeConfig.FWD_DTYPE bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X + n_groups: **kwargs: Additional arguments for quantizer initialization Returns: @@ -712,15 +902,25 @@ def create_set( scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE - is_2x2x = is_2x2x or QuantizeConfig.IF_QUANTIZE_2X + if is_2x2x is None: + if scaling_mode.is_1d_block_scaling(): + is_2x2x = True + elif scaling_mode.is_tensor_scaling(): + is_2x2x = not is_fp8_gemm_with_all_layouts_supported() + else: # NO_SCALING ignores is_2x2x for now + is_2x2x = False + is_inference_mode = QuantizeConfig.INFERENCE_MODE + assert not is_inference_mode, "Inference mode is not supported yet!" q_set = [] for _ in range(n_quantizer_sets): q_set.append( - QuantizerFactory._create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs) + QuantizerFactory._create_set( + scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs + ) ) return q_set[0] if len(q_set) == 1 else tuple(q_set) -noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING) +noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING, is_2x2x=False) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 303f5ffbb..fc4fd1353 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -13,20 +13,57 @@ from dataclasses import dataclass from enum import Enum from typing import Tuple, Dict -from functools import reduce +from functools import reduce, lru_cache import operator +import numpy as np -from packaging import version -import jax -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax.experimental.custom_partitioning import CompoundFactor +from jax.experimental.custom_partitioning import BATCHING from jax.tree_util import register_pytree_node_class import jax.numpy as jnp -from transformer_engine_jax import JAXX_Scaling_Mode +from transformer_engine_jax import JAXX_Scaling_Mode, QuantizeLayout +from .device_utils import is_fp8_gemm_with_all_layouts_supported -__all__ = ["QuantizeShardyRules", "ScalingMode"] +__all__ = [ + "QuantizeShardyRules", + "ScalingMode", + "TensorUsage", +] + + +class TensorUsage(Enum): + """Enum indicating tensor usage in GEMM operations. + + Given a GEMM operation: C = A * B in which A and B can be in the normal or transposed form. + The tensor usage can be: + - LHS: A is in the normal form + - LHS_TRANS: A is in the transposed form + - RHS: B is in the normal form + - RHS_TRANS: B is in the transposed form + + The tensor usage is used in the ScaledTensor.get_tensor() method. + """ + + # LHS: Left-hand side, RHS: Right-hand side + # LHS_TRANS: Left-hand side transposed, RHS_TRANS: Right-hand side transposed + LHS = 0 + LHS_TRANS = 1 + RHS = 2 + RHS_TRANS = 3 + + def __eq__(self, other): + if not isinstance(other, TensorUsage): + return False + return self.value == other.value + + def __hash__(self): + return hash(self.value) + + +def DIVUP(a, b): + "Divide a by b and then round up" + return -(a // -b) @dataclass @@ -77,11 +114,42 @@ def get_scale_shape( data_shape: The shape of the tensor being quantized is_colwise: Whether the scaling is column-wise is_padded: Whether to return padded shape - flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) + Returns: The shape for scale tensors """ + @abstractmethod + def get_grouped_scale_shape( + self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 + ) -> Tuple[int]: + """Get the shape for scale tensors in this mode. + + Args: + data_shape: Original shape of the data tensor + n_groups: Number of groups in grouped quantization + group_axis: The axis along which grouping is performed + is_colwise: Whether to use column-wise scaling + is_padded: Whether to use padded shapes + flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) + + Returns: + The shape for scale tensors + """ + + @lru_cache(maxsize=4) + @abstractmethod + def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: + """Get the quantize layout for the tensor usage. + + Args: + usage: The usage of the tensor + + Returns: + The quantize layout for the tensor usage + """ + @abstractmethod def get_shardy_sharding_rules( self, input_rank, unique_var, flatten_axis @@ -130,9 +198,46 @@ def get_scale_shape( Returns: The shape for scale tensors - (1,) """ - del data_shape, is_colwise + del is_colwise + if np.prod(data_shape) == 0: + return (0,) return (1,) + @lru_cache(maxsize=4) + def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: + """Get the quantize layout for the tensor usage. + + Args: + usage: The usage of the tensor + + Returns: + The quantize layout for the tensor usage + """ + if is_fp8_gemm_with_all_layouts_supported(): + return QuantizeLayout.ROWWISE + + if usage in (TensorUsage.LHS, TensorUsage.RHS_TRANS): + return QuantizeLayout.ROWWISE + return QuantizeLayout.COLWISE + + def get_grouped_scale_shape( + self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 + ) -> Tuple[int]: + """Get the shape for scale tensors in this mode. + + Args: + data_shape: Original shape of the data tensor + is_colwise: Whether to use column-wise scaling + is_padded: Whether to use padded shapes + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + + Returns: + The shape for scale tensors + """ + del data_shape, group_axis, is_colwise + assert isinstance(n_groups, int) + return (n_groups,) + def get_shardy_sharding_rules( self, input_rank, unique_var, flatten_axis ) -> QuantizeShardyRules: @@ -147,8 +252,9 @@ def get_shardy_sharding_rules( The Shardy rules for the scaling mode """ del flatten_axis - input_spec = tuple(f"x{i}" for i in range(input_rank)) - return QuantizeShardyRules(input_spec, (unique_var,), (unique_var,), {}) + input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) + scale_var = BATCHING + unique_var + "_scale_inv" + return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl): @@ -279,6 +385,98 @@ def get_scale_shape( return (*first_dim_scale_shape, *last_dim_scale_shape) + @lru_cache(maxsize=4) + def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: + """Get the quantize layout for the tensor usage. + + Args: + usage: The usage of the tensor + + Returns: + The quantize layout for the tensor usage + """ + # If we need to support 1x1x for inference in the future + # if QuantizeConfig.INFERENCE_MODE: + # assert usage not in (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS), (f"Invalid usage {usage} as we are in MXFP8_1D_SCALING 1x1x (FWD only) mode so no transposed usage is needed!") + # if usage == TensorUsage.LHS: + # return QuantizeLayout.ROWWISE + # return QuantizeLayout.COLWISE + + if usage in (TensorUsage.LHS, TensorUsage.RHS_TRANS): + return QuantizeLayout.ROWWISE + return QuantizeLayout.COLWISE + + def get_grouped_scale_shape( + self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 + ) -> Tuple[int]: + """Get the shape for grouped scale tensors in this mode. + If padded: The estimiated maximal possible shape for grouped scale tensor is return instead. + + Args: + data_shape: Original shape of the data tensor + is_colwise: Whether to use column-wise scaling + is_padded: Whether to use padded shapes + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + + Returns: + The shape for scale tensors + """ + assert isinstance(n_groups, int) + block_alignment = self._block_alignment if is_padded else (1, 1) + + if is_colwise: + block_y, block_x = self._block_dims + alignment_y, alignment_x = block_alignment + else: + block_x, block_y = self._block_dims + alignment_x, alignment_y = block_alignment + + if flatten_axis < 0: + flatten_axis = len(data_shape) + flatten_axis + assert ( + 0 < flatten_axis < len(data_shape) + ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" + + assert data_shape[flatten_axis - 1] % block_x == 0, ( + f"Data shape {data_shape} should be divisible by block_x {block_x} in axis" + f" {flatten_axis - 1}" + ) + assert ( + data_shape[-1] % block_y == 0 + ), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1" + + flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1) + flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1) + + assert flattened_first_dim % block_x == 0, ( + f"Flattened first dim - mutiplication of axes={tuple(range(0, flatten_axis))} of shape" + f" {data_shape} - should be divisible by block_x {block_x}" + ) + assert flattened_last_dim % block_y == 0, ( + "Flattened last dim - mutiplication of" + f" axes={tuple(range(flatten_axis, len(data_shape)))} of shape {data_shape} - should be" + f" divisible by block_y {block_y}" + ) + + n_block_x = int(flattened_first_dim // block_x) + n_block_y = int(flattened_last_dim // block_y) + + """ + Given the scale shape of [M, N], and G groups, and padding alignment (128, 4), + The worst scenario is when we have (G-1) groups with 1 rows and 1 group with (M-G+1) rows. + Then: + max_padded_rows = (G-1) * 128 + DIVUP(M-G+1, 128) * 128 + max_padded_cols = DIVUP(N, 4) * 4 + max_scale_size = max_padded_rows * max_padded_cols + """ + if is_padded: + n_block_x = (n_groups - 1) * alignment_x + DIVUP( + n_block_x - n_groups + 1, alignment_x + ) * alignment_x + n_block_y = DIVUP(n_block_y, alignment_y) * alignment_y + + return (n_block_x * n_block_y,) + def get_shardy_sharding_rules( self, input_rank, unique_var, flatten_axis ) -> QuantizeShardyRules: @@ -291,33 +489,41 @@ def get_shardy_sharding_rules( Returns: The Shardy rules for the scaling mode """ - if version.parse(jax.__version__) < version.parse("0.5.0"): - raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") - input_spec = [f"x{i}" for i in range(input_rank)] - - # We have to use two different factors in the two CompoundFactors because of Shardy - # verifier requirements, even though they are the same. - rowwise_var = unique_var - colwise_var = f"{unique_var}_" - input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise") - input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise") - - # The rowwise and colwise scale tensors should be sharded the same way as the input. - # However, we need to adjust the dimensions where the block scaling factor applies. - rowwise = input_spec.copy() - rowwise[-1] = rowwise_var - - colwise = input_spec.copy() - colwise[flatten_axis - 1] = colwise_var - - # This implementation needs to be updated for different block dims. - assert self._block_dims == (1, 32) + del flatten_axis + input_spec = [f"{unique_var}{i}" for i in range(input_rank)] + rowwise = [f"{unique_var}scale_inv_rowwise{i}" for i in range(input_rank)] + colwise = [f"{unique_var}scale_inv_colwise{i}" for i in range(input_rank)] + + # NOTE (Alp): Padding the scales breaks the size relationship in CompoundFactors. + # Unfortunately, because Shardy rules are applied to the inner primitive, the + # only way to preserve the relationship is to lower unpadded scales to the + # underlying custom call and pad them in C++. Until that's implemented, the + # Shardy rules for block scales have to be completely disconnected from the + # Shardy rules for the tensor they belong to. + + # # We have to use two different factors in the two CompoundFactors because of Shardy + # # verifier requirements, even though they are the same. + # rowwise_var = unique_var + # colwise_var = f"{unique_var}_" + # input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise") + # input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise") + + # # The rowwise and colwise scale tensors should be sharded the same way as the input. + # # However, we need to adjust the dimensions where the block scaling factor applies. + # rowwise = input_spec.copy() + # rowwise[-1] = rowwise_var + + # colwise = input_spec.copy() + # colwise[flatten_axis - 1] = colwise_var + + # # This implementation needs to be updated for different block dims. + # assert self._block_dims == (1, 32) return QuantizeShardyRules( tuple(input_spec), tuple(rowwise), tuple(colwise), - {"block_size_rowwise": 32, "block_size_colwise": 32}, + {}, # {"block_size_rowwise": 32, "block_size_colwise": 32}, ) @@ -395,6 +601,17 @@ def get_scale_shape( """ return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis) + def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: + """Get the quantize layout for the tensor usage. + + Args: + usage: The usage of the tensor + + Returns: + The quantize layout for the tensor usage + """ + return self._get_impl().get_quantize_layout(usage) + def get_shardy_sharding_rules( self, input_rank, unique_var, flatten_axis=-1 ) -> Tuple[Tuple[str]]: @@ -409,6 +626,61 @@ def get_shardy_sharding_rules( """ return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis) + def get_grouped_scale_shape_2x( + self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1 + ) -> Tuple[Tuple[int]]: + """Get shapes for both row-wise and column-wise scaling. + + Args: + data_shape: Shape of the data tensor + n_groups: Number of groups for grouped quantization + group_axis: The axis along which grouping is performed + is_padded: Whether to use padded shapes + flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) + + Returns: + Tuple of (rowwise_scale_shape, colwise_scale_shape) + """ + rowwise_scale_shape = self.get_grouped_scale_shape( + data_shape, + n_groups, + group_axis, + is_colwise=False, + is_padded=is_padded, + flatten_axis=flatten_axis, + ) + colwise_scale_shape = self.get_grouped_scale_shape( + data_shape, + n_groups, + group_axis, + is_colwise=True, + is_padded=is_padded, + flatten_axis=flatten_axis, + ) + return (rowwise_scale_shape, colwise_scale_shape) + + def get_grouped_scale_shape( + self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 + ) -> Tuple[Tuple[int]]: + """Get shapes for both row-wise and column-wise scaling. + + Args: + data_shape: Shape of the data tensor + is_padded: Whether to use padded shapes + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + + Returns: + Tuple of (rowwise_scale_shape, colwise_scale_shape) + """ + return self._get_impl().get_grouped_scale_shape( + data_shape, + n_groups, + group_axis, + is_colwise=is_colwise, + is_padded=is_padded, + flatten_axis=flatten_axis, + ) + def is_tensor_scaling(self) -> bool: """Check if this scaling mode is per-tensor scaling. diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 0ef30f472..97e127269 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -17,16 +17,18 @@ from transformer_engine_jax import QuantizeLayout -from .scaling_modes import ScalingMode -from .dequantizer import Dequantizer +from .scaling_modes import ScalingMode, TensorUsage +from .dequantizer import ScalingModeToDequantizerMap from ..sharding import ( with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes, ) __all__ = [ + "TensorUsage", "ScaledTensor", "ScaledTensor1x", "ScaledTensor2x", + "GroupedScaledTensor1x", "ScaledTensorFactory", "with_sharding_constraint_by_logical_axes", ] @@ -54,6 +56,11 @@ def tree_unflatten(cls, aux_data, children): """ return cls(*children, *aux_data) + @property + @abstractmethod + def ndim(self): + """Number of dimensions of the underlying quantized array.""" + @abstractmethod def dequantize(self): """Dequantizes the tensor back to its original precision. @@ -63,25 +70,15 @@ def dequantize(self): """ @abstractmethod - def get_rowwise_tensor(self): - """Returns the row-wise component of the tensor. - - Returns: - The row-wise tensor component + def get_tensor(self, usage: TensorUsage): + """Returns the appropriate tensor based on the tensor usage and the scaling mode. + If the tensor usage is not valid for the scaling mode, an error is raised. - Raises: - ValueError: If called on a tensor that doesn't support row-wise access - """ - - @abstractmethod - def get_colwise_tensor(self): - """Returns the column-wise component of the tensor. + Args: + usage: The usage of the tensor Returns: - The column-wise tensor component - - Raises: - ValueError: If called on a tensor that doesn't support column-wise access + The tensor based on the usage """ @abstractmethod @@ -122,7 +119,7 @@ class ScaledTensor1x(ScaledTensor): _dq_func: Callable is_colwise: bool data_layout: str - flatten_axis: int = -1 + flatten_axis: int def __post_init__(self): """Validates and adjusts the scale_inv shape after initialization. @@ -130,35 +127,23 @@ def __post_init__(self): Ensures the scale_inv shape matches the expected shape based on the scaling mode and quantization direction. Pads the scale_inv if necessary. """ - flatten_axis = ( - len(self.data.shape) + self.flatten_axis if self.flatten_axis < 0 else self.flatten_axis - ) + assert self.flatten_axis > 0 assert ( - 0 < flatten_axis < len(self.data.shape) - ), f"flatten_axis {flatten_axis} is out of bounds for shape {self.data.shape}" - - if self.data_layout == "T": - flatten_axis = self.data.ndim - flatten_axis - self.flatten_axis = flatten_axis + 0 < self.flatten_axis < len(self.data.shape) + ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}" - expected_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=True, flatten_axis=flatten_axis - ) - expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=False, flatten_axis=flatten_axis - ) - if self.scale_inv.shape != expected_scale_shape: - assert self.scale_inv.shape == expected_unpadded_scale_shape, ( - f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded" - f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got" - f" {self.scale_inv.shape}" - ) - pad_width = tuple( - (0, a - b) for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) + if self.scaling_mode == ScalingMode.NO_SCALING: + self.scale_inv = jnp.empty((0,), dtype=jnp.float32) + else: + unpadded_scale_shape = self.scaling_mode.get_scale_shape( + self.data.shape, + is_colwise=self.is_colwise, + is_padded=False, + flatten_axis=self.flatten_axis, ) - # This actually pad scale_inv with nan, should we pad it with 127 directly instead? - self.scale_inv = jnp.pad( - self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0 + assert self.scale_inv.shape == unpadded_scale_shape, ( + "Unpadded inverse scale factor has wrong shape, expected" + f" {unpadded_scale_shape} but got {self.scale_inv.shape}." ) def tree_flatten(self): @@ -178,6 +163,10 @@ def tree_flatten(self): ) return (children, aux_data) + @property + def ndim(self): + return self.data.ndim + def dequantize(self): """Dequantizes the tensor using the stored dequantization function. @@ -186,33 +175,19 @@ def dequantize(self): """ return self._dq_func(self) - def get_rowwise_tensor(self): - """Returns the tensor if it's row-wise quantized. - - Returns: - The row-wise tensor - - Raises: - ValueError: If called on a column-wise quantized tensor - """ - if not self.is_colwise: - return self - - raise ValueError("Calling get_rowwise_tensor() from a colwise ScaledTensor1x!") - - def get_colwise_tensor(self): - """Returns the tensor if it's column-wise quantized. + def get_tensor(self, usage: TensorUsage): + """Returns the tensor based on the tensor usage.""" + q_layout = self.scaling_mode.get_quantize_layout(usage) + colwise_usage_valid = q_layout == QuantizeLayout.COLWISE and self.is_colwise + rowwise_usage_valid = q_layout == QuantizeLayout.ROWWISE and not self.is_colwise - Returns: - The column-wise tensor - - Raises: - ValueError: If called on a row-wise quantized tensor - """ - if self.is_colwise: + if colwise_usage_valid or rowwise_usage_valid: return self - raise ValueError("Calling get_colwise_tensor() from a rowwise ScaledTensor1x!") + raise ValueError( + f"Calling get_tensor() with usage {usage} is not valid for this tensor as" + f" self.is_colwise={self.is_colwise}!" + ) def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): """Applies sharding constraints to a tensor based on logical axis names. @@ -229,8 +204,12 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st # axis_names were given for N layout, so needs to be transpose for T layout if self.data_layout == "T": assert self.flatten_axis > 0 - flatten_axis = -self.flatten_axis - axis_names = (*logical_axis_names[flatten_axis:], *logical_axis_names[:flatten_axis]) + assert len(logical_axis_names) == self.data.ndim + flatten_axis = self.data.ndim - self.flatten_axis + axis_names = ( + *logical_axis_names[flatten_axis:], + *logical_axis_names[:flatten_axis], + ) else: axis_names = logical_axis_names @@ -254,6 +233,98 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st ) +@register_pytree_node_class +@dataclass +class GroupedScaledTensor1x(ScaledTensor1x): + """Grouped Quantizer for an array. + + This class extends ScaledTensor1x to support quantization of an array in grouped manner, + where elements are grouped along a specified axis. + + Attributes: + group_sizes: Array containing the size of each group + original_shape: The original shape of the tensor before grouping + group_axis: The axis along which grouping is performed (default: 0) + """ + + group_sizes: jnp.ndarray + original_shape: Tuple + group_axis: int + + def __init__( + self, + data, + scale_inv, + group_sizes, + scaling_mode, + dq_dtype, + _dq_func, + is_colwise, + data_layout, + flatten_axis, + original_shape, + group_axis=0, + ): + self.flatten_axis = flatten_axis + self.group_sizes = group_sizes + self.original_shape = original_shape + self.group_axis = group_axis + super().__init__( + data, scale_inv, scaling_mode, dq_dtype, _dq_func, is_colwise, data_layout, flatten_axis + ) + + def __post_init__(self): + assert self.scale_inv.ndim == 1, "Only support flattened scale_inv" + assert self.data.ndim == 1, "Only support flattened data" + assert self.group_axis >= 0 + assert self.flatten_axis > 0 + + data_ndim = len(self.original_shape) + assert ( + 0 < self.flatten_axis < data_ndim + ), f"flatten_axis {self.flatten_axis} is out of bounds for data.ndim = {data_ndim}" + + assert ( + 0 <= self.group_axis < data_ndim + ), f"group_axis {self.group_axis} is out of bounds for shape {self.original_shape}" + + expected_scale_shape = self.scaling_mode.get_grouped_scale_shape( + self.original_shape, + self.group_sizes.size, + self.group_axis, + self.is_colwise, + is_padded=True, + flatten_axis=self.flatten_axis, + ) + + assert self.scale_inv.shape == expected_scale_shape, ( + f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded" + f" scale_inv, got {self.scale_inv.shape}" + ) + + def tree_flatten(self): + """Flattens the tensor for JAX tree operations. + + Returns: + A tuple containing (children, aux_data) for tree operations + """ + children = (self.data, self.scale_inv, self.group_sizes) + aux_data = ( + self.scaling_mode, + self.dq_dtype, + self._dq_func, + self.is_colwise, + self.data_layout, + self.flatten_axis, + self.original_shape, + self.group_axis, + ) + return (children, aux_data) + + def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): + raise NotImplementedError + + @register_pytree_node_class @dataclass class ScaledTensor2x(ScaledTensor): @@ -279,6 +350,11 @@ def tree_flatten(self): aux_data = () return (children, aux_data) + @property + def ndim(self): + """Number of dimensions of the underlying row-wise tensor.""" + return self.rowwise_tensor.ndim + def dequantize(self): """Dequantizes the tensor using the row-wise component's dequantization. @@ -287,21 +363,21 @@ def dequantize(self): """ return self.rowwise_tensor.dequantize() - def get_rowwise_tensor(self): - """Returns the row-wise quantized component. + def get_tensor(self, usage: TensorUsage): + """Returns the tensor based on the tensor usage.""" + q_layout_rowwise = self.rowwise_tensor.scaling_mode.get_quantize_layout(usage) + q_layout_colwise = self.colwise_tensor.scaling_mode.get_quantize_layout(usage) - Returns: - The row-wise tensor component - """ - return self.rowwise_tensor + if q_layout_rowwise == QuantizeLayout.ROWWISE: + return self.rowwise_tensor - def get_colwise_tensor(self): - """Returns the column-wise quantized component. + if q_layout_colwise == QuantizeLayout.COLWISE: + return self.colwise_tensor - Returns: - The column-wise tensor component - """ - return self.colwise_tensor + raise ValueError( + f"Calling get_tensor() with usage {usage} is not valid for this tensor as" + f" q_layout_rowwise={q_layout_rowwise} and q_layout_colwise={q_layout_colwise}!" + ) def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): """Applies sharding constraints to a tensor based on logical axis names. @@ -342,6 +418,9 @@ def create_1x( is_colwise=False, data_layout="N", flatten_axis=-1, + group_sizes=None, + original_shape=None, + group_axis=0, ): """Creates a single-scale quantized tensor. @@ -353,13 +432,67 @@ def create_1x( is_colwise: Whether to use column-wise quantization (default: False) data_layout: The data_layout specification (default: "N") flatten_axis: The quantization axis for the tensor + group_sizes: Arra of ints containing the size of each group (default: None) + original_shape: The original shape of the tensor before grouping (default: None) + group_axis: The axis along which grouping is performed (default: 0) Returns: - A ScaledTensor1x instance + A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided """ - dq_func = Dequantizer.funcs.get(scaling_mode) + dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) + + if group_sizes is not None: + flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis + assert ( + original_shape is not None + ), "original_shape is not given for GroupedScaledTensor1x" + + # Handling attrs of transposed tensors + group_axis = len(original_shape) + group_axis if group_axis < 0 else group_axis + if data_layout == "T": + if original_shape[0] == group_sizes.size: + original_shape = ( + original_shape[0], + *original_shape[flatten_axis:], + *original_shape[1:flatten_axis], + ) + flatten_axis = len(original_shape) - flatten_axis + 1 + else: + original_shape = ( + *original_shape[flatten_axis:], + *original_shape[:flatten_axis], + ) + group_axis = flatten_axis + flatten_axis = len(original_shape) - flatten_axis + + return GroupedScaledTensor1x( + data=data, + scale_inv=scale_inv, + scaling_mode=scaling_mode, + dq_dtype=dq_dtype, + _dq_func=dequantizer.grouped_dequantize, + is_colwise=is_colwise, + data_layout=data_layout, + flatten_axis=flatten_axis, + group_sizes=group_sizes, + original_shape=original_shape, + group_axis=group_axis, + ) + + # Handling attrs of transposed tensors + flatten_axis = data.ndim + flatten_axis if flatten_axis < 0 else flatten_axis + if data_layout == "T": + flatten_axis = data.ndim - flatten_axis + return ScaledTensor1x( - data, scale_inv, scaling_mode, dq_dtype, dq_func, is_colwise, data_layout, flatten_axis + data, + scale_inv, + scaling_mode, + dq_dtype, + dequantizer.dequantize, + is_colwise, + data_layout, + flatten_axis, ) @staticmethod @@ -372,6 +505,9 @@ def create_2x( dq_dtype=jnp.bfloat16, data_layout="NN", flatten_axis=-1, + group_sizes=None, + original_shape=None, + group_axis=0, ): """Creates a double-scale quantized tensor. @@ -384,30 +520,37 @@ def create_2x( dq_dtype: The data type for dequantized values (default: bfloat16) data_layout: The data_layout specification (default: "NN") flatten_axis: The quantization axis for the tensor + group_sizes: Array containing the size of each group (default: None) + original_shape: The original shape of the tensor before grouping (default: None) + group_axis: The axis along which grouping is performed (default: 0) Returns: A ScaledTensor2x instance """ - dq_func = Dequantizer.funcs.get(scaling_mode) - rowwise_tensor = ScaledTensor1x( + assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}" + rowwise_tensor = ScaledTensorFactory.create_1x( data, scale_inv, scaling_mode, dq_dtype, - dq_func, is_colwise=False, data_layout=data_layout[0], flatten_axis=flatten_axis, + group_sizes=group_sizes, + original_shape=original_shape, + group_axis=group_axis, ) - colwise_tensor = ScaledTensor1x( + colwise_tensor = ScaledTensorFactory.create_1x( colwise_data, colwise_scale_inv, scaling_mode, dq_dtype, - dq_func, is_colwise=True, data_layout=data_layout[1], flatten_axis=flatten_axis, + group_sizes=group_sizes, + original_shape=original_shape, + group_axis=group_axis, ) return ScaledTensor2x(rowwise_tensor, colwise_tensor) @@ -422,6 +565,9 @@ def create( data_layout: str = "NN", q_layout: QuantizeLayout = QuantizeLayout.ROWWISE, flatten_axis: int = -1, + group_sizes: jnp.ndarray = None, + original_shape: Tuple[int] = None, + group_axis: int = 0, ): """Creates a scaled tensor based on the quantization axis. @@ -434,6 +580,10 @@ def create( dq_dtype: The data type for dequantized values (default: bfloat16) data_layout: The data_layout specification (default: "NN") q_layout: The quantization axis (default: ROWWISE) + flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) + group_sizes: Array containing the size of each group (default: None) + original_shape: The original shape of the tensor before grouping (default: None) + group_axis: The axis along which grouping is performed (default: 0) Returns: Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout @@ -448,9 +598,26 @@ def create( dq_dtype, data_layout=data_layout, flatten_axis=flatten_axis, + group_sizes=group_sizes, + original_shape=original_shape, + group_axis=group_axis, ) is_colwise = q_layout == QuantizeLayout.COLWISE + if is_colwise: + return ScaledTensorFactory.create_1x( + colwise_data, + colwise_scale_inv, + scaling_mode, + dq_dtype, + is_colwise=is_colwise, + data_layout=data_layout[0], + flatten_axis=flatten_axis, + group_sizes=group_sizes, + original_shape=original_shape, + group_axis=group_axis, + ) + return ScaledTensorFactory.create_1x( data, scale_inv, @@ -459,6 +626,9 @@ def create( is_colwise=is_colwise, data_layout=data_layout[0], flatten_axis=flatten_axis, + group_sizes=group_sizes, + original_shape=original_shape, + group_axis=group_axis, ) @@ -472,6 +642,9 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, . Returns: The tensor with applied sharding constraints """ + if isinstance(x, GroupedScaledTensor1x): + raise NotImplementedError + if isinstance(x, ScaledTensor): return x.apply_sharding_constraint_by_logical_axes(logical_axis_names) diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index 8234c6aa6..b58d2df7f 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -47,11 +47,10 @@ from build_tools.build_ext import get_build_ext from build_tools.utils import ( rocm_build, copy_common_headers, copy_hipify_tools, - clear_hipify_tools_copy, install_and_import ) + clear_hipify_tools_copy) from build_tools.te_version import te_version -from build_tools.jax import setup_jax_extension, jax_install_requires +from build_tools.jax import setup_jax_extension, install_requirements, test_requirements -install_and_import("pybind11") from pybind11.setup_helpers import build_ext as BuildExtension os.environ["NVTE_PROJECT_BUILDING"] = "1" @@ -105,10 +104,8 @@ description="Transformer acceleration library - Jax Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, - install_requires=( - jax_install_requires(["flax>=0.7.1"]) if rocm_build() else ["jax", "flax>=0.7.1"] - ), - tests_require=[] if rocm_build() else ["numpy"], + install_requires=install_requirements(), + tests_require=test_requirements(), ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): shutil.rmtree(common_headers_dir) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index f89fe60f9..e59c9de12 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -14,10 +14,12 @@ from dataclasses import dataclass from enum import Enum from typing import Callable, Optional +import warnings from jax.interpreters import pxla import jax import jax.numpy as jnp from jax.sharding import PartitionSpec +import numpy as np _PXLA_THREAD_RESOURCES = pxla.thread_resources @@ -116,7 +118,9 @@ def with_sharding_constraint_by_logical_axes( x: jnp.array, logical_axis_names: Optional[tuple | list] ): """ - A wrapper function to jax.lax.with_sharding_constraint to accept logical axes. + A wrapper function to flax.linen.with_logical_constraint. + + DEPRECATED USE CASE: If no Flax logical axis rules are available, this function falls back to jax.lax.with_sharding_constraint using a hardcoded logical axis rule table from TE rules, such as BATCH_AXES. This functionality will be removed in the future. If logical_axis_names = None, this means no sharding constraint is applied. @@ -132,6 +136,28 @@ def with_sharding_constraint_by_logical_axes( if not logical_axis_names: return x + try: + # Check if Flax logical axis rules are available, if so use them + import flax + + flax_rules = flax.linen.get_logical_axis_rules() + if len(flax_rules) > 0: + return flax.linen.with_logical_constraint( + x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.NO_CONSTRAINT + ) + except ImportError: + pass + + warnings.warn( + "TransformerEngine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated and" + " will be removed in a future version. Please use Flax logical axes with a" + " flax.linen.logical_axis_rules context and optionally use" + " transformer_engine.jax.flax.extend_logical_axis_rules to add BATCH_AXES, etc. to your" + " rules.", + DeprecationWarning, + ) + + # If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table assert len(x.shape) == len(logical_axis_names) pspec = generate_pspec(logical_axis_names) return with_sharding_constraint(x, pspec) @@ -201,6 +227,31 @@ def get_mesh_axis_rank(axis: str, mesh=None): return jax.lax.axis_index(axis_name) +def get_mesh_axis_rank_host(axis, mesh) -> int: + """ + Same as get_mesh_axis_rank(), but return a host value instead of a + traced device value. + """ + if axis not in mesh.axis_names: + raise ValueError(f"Axis {axis} not found in mesh axis names: {mesh.axis_names}") + + axis_index = mesh.axis_names.index(axis) + + # Convert mesh.devices (ndarray of Device objects) to flat list + devices = mesh.devices + local_device = jax.devices()[jax.process_index()] # Pick one device on this host + + # Find index of local_device in mesh.devices + coords = np.argwhere(devices == local_device) + if coords.size == 0: + raise ValueError(f"Local device {local_device} not found in mesh.devices.") + coords = tuple(coords[0]) # Coordinates in the mesh array + + # Get the mesh rank along the specified axis + rank = coords[axis_index] + return int(rank) + + @dataclass class MeshResource: """A data container for managing mesh resources in distributed training. diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index c9eb57021..c1c21fdc5 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -60,6 +60,7 @@ def torch_version() -> tuple[int, ...]: from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context from transformer_engine.pytorch import ops from transformer_engine.pytorch import optimizers +from transformer_engine.pytorch.export import onnx_export from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy try: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index d3cc4a9f2..a823379f1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -60,6 +60,8 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import ( AttentionLogging as attn_log, ) +from transformer_engine.pytorch import export +from transformer_engine.pytorch.export import is_in_onnx_export_mode # Global vars for flash attn v2 and v3 imports flash_attn_cuda_bwd = None @@ -155,7 +157,14 @@ def __init__( self.attention_dropout_ctx = attention_dropout_ctx self.layer_number = layer_number - self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func) + def mask_func(x, y): + return ( + export.onnx_attention_mask_func(x, y) + if is_in_onnx_export_mode() + else attention_mask_func(x, y) + ) + + self.scale_mask_softmax = FusedScaleMaskSoftmax(mask_func) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but @@ -222,7 +231,12 @@ def forward( if "padding" in attn_mask_type and attention_mask is None: attention_mask = dpa_utils.get_padding_mask( - batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv + batch_size, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + self.attention_type, ) attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = ( dpa_utils.get_full_mask( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index b52d1003f..9fc2342e7 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -464,6 +464,7 @@ def forward( ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward") + enable_mla = k.shape[-1] != v.shape[-1] if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -501,7 +502,10 @@ def forward( cu_seqlens_q_half, cu_seqlens_kv_half = None, None if qkv_format in ["bshd", "sbhd"]: seq_dim = qkv_format.index("s") - qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] + if enable_mla: + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + else: + qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None if use_fused_attention: batch_dim = qkv_format.index("b") @@ -679,9 +683,16 @@ def forward( fwd_results_correction_done = torch.cuda.Event() p2p_comm_buffers = [None for _ in range(cp_size)] - if qkv_format in ["bshd", "sbhd"]: + if enable_mla: + # If MLA, the shape of k and v does not match, so we flatten them + # and split them after receiving them. + k_shape = k.shape + k_numel = k.numel() + v_shape = v.shape + p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1) + elif qkv_format in ["bshd", "sbhd"]: p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) - else: + else: # qkv_format == "thd" p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) send_recv_reqs = [[], []] @@ -710,6 +721,10 @@ def forward( else: # KV exchange is in BF16/FP16, cast received KV in each step kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data + if enable_mla: + # If MLA, k and v are flattened, so split them after receiving. + k_part = kv_inputs[i % 2][:k_numel].view(*k_shape) + v_part = kv_inputs[i % 2][k_numel:].view(*v_shape) if causal: if i == 0: if pad_between_seqs: @@ -728,17 +743,27 @@ def forward( if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - k.shape[0], -1, 2, *k.shape[-2:] - ) + if enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) + v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + k.shape[0], -1, 2, *k.shape[-2:] + ) elif qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - -1, k.shape[2], 2, *k.shape[-2:] - ) + if enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part.view(-1, *k_part.shape[2:]) + v_part = v_part.view(-1, *v_part.shape[2:]) + else: + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + -1, k.shape[2], 2, *k.shape[-2:] + ) elif qkv_format == "thd": q_inputs[i % 2] = q if use_fused_attention: @@ -753,16 +778,19 @@ def forward( ).contiguous() q_part = q_inputs[i % 2] - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) + if not enable_mla: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fp8_meta_kwargs = {} if fp8: q_part = QKV_quantizer.create_tensor_from_data( @@ -813,6 +841,7 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, ) + # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q_inputs[i % 2], ( @@ -861,36 +890,60 @@ def forward( if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...] + if enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk//2, np, hn] + k_part = k_part[:, 0, ...] + v_part = v_part[:, 0, ...] + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...] elif qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][0] + if enable_mla: + # [2, sk//2, b, np, hn] -> [sk//2, b, np, hn] + k_part = k_part[0] + v_part = v_part[0] + else: + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2][0] elif qkv_format == "thd": q_inputs[i % 2] = q - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_inputs[i % 2] = tex.thd_read_half_tensor( - kv_inputs[i % 2], cu_seqlens_kv_padded, 0 - ) + if enable_mla: + # [t, np, hn] -> [t/2, np, hn] + k_part = tex.thd_read_half_tensor( + k_part, cu_seqlens_kv_padded, 0 + ) + v_part = tex.thd_read_half_tensor( + v_part, cu_seqlens_kv_padded, 0 + ) + else: + # [2, t, np, hn] -> [2, t/2, np, hn] + kv_inputs[i % 2] = tex.thd_read_half_tensor( + kv_inputs[i % 2], cu_seqlens_kv_padded, 0 + ) if use_fused_attention: - kv_inputs[i % 2] = kv_inputs[i % 2].contiguous() + if enable_mla: + k_part = k_part.contiguous() + v_part = v_part.contiguous() + else: + kv_inputs[i % 2] = kv_inputs[i % 2].contiguous() if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() q_part = q_inputs[i % 2] - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) + if not enable_mla: + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fp8_meta_kwargs = {} if fp8: q_part = QKV_quantizer.create_tensor_from_data( @@ -951,6 +1004,7 @@ def forward( elif fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 + # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q_inputs[i % 2], ( @@ -999,17 +1053,27 @@ def forward( if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] q_inputs[i % 2] = q[:, 1, ...] - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - k.shape[0], -1, 2, *k.shape[-2:] - ) + if enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) + v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + k.shape[0], -1, 2, *k.shape[-2:] + ) elif qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] q_inputs[i % 2] = q[1] - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - -1, k.shape[2], 2, *k.shape[-2:] - ) + if enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part.view(-1, *k_part.shape[2:]) + v_part = v_part.view(-1, *v_part.shape[2:]) + else: + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + -1, k.shape[2], 2, *k.shape[-2:] + ) elif qkv_format == "thd": # [t, np, hn] -> [t/2, np, hn] q_inputs[i % 2] = tex.thd_read_half_tensor( @@ -1028,16 +1092,17 @@ def forward( ).contiguous() q_part = q_inputs[i % 2] - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) + if not enable_mla: + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fp8_meta_kwargs = {} if fp8: q_part = QKV_quantizer.create_tensor_from_data( @@ -1098,6 +1163,7 @@ def forward( elif fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 + # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q_inputs[i % 2], ( @@ -1155,16 +1221,17 @@ def forward( ).contiguous() q_part = q - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) + if not enable_mla: + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fp8_meta_kwargs = {} if fp8: q_part = QKV_quantizer.create_tensor_from_data( @@ -1214,6 +1281,7 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, ) + # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q, ( @@ -1260,7 +1328,15 @@ def forward( if i == 1: softmax_lse = torch.clone(softmax_lse_per_step[0]) if qkv_format == "thd": - out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) + if enable_mla: + out = torch.zeros_like(v if not fp8 else out_per_step[0]).view( + v_shape + ) + else: + # MHA or GQA + out = torch.zeros_like(q if not fp8 else out_per_step[0]).view( + q.shape + ) elif (i - 1) <= rank or not causal: flash_attn_fwd_softmax_lse_correction( softmax_lse, softmax_lse_per_step[i - 1] @@ -1298,7 +1374,10 @@ def forward( softmax_lse_per_step[0], seq_dim, ) - out = out.view(q.shape) + if enable_mla: + out = out.view(v_shape) + else: + out = out.view(q.shape) else: flash_attn_fwd_out_correction( out.view(*out_per_step[i].shape), @@ -1420,6 +1499,12 @@ def forward( ctx.is_output_fp8 = is_output_fp8 ctx.use_flash_attn_3 = use_flash_attn_3 + ctx.enable_mla = enable_mla + if enable_mla: + ctx.k_numel = k_numel + ctx.k_shape = k_shape + ctx.v_shape = v_shape + ctx.qkv_dtype = qkv_dtype ctx.dQKV_quantizer = dQKV_quantizer ctx.dQKV_CP_quantizer = dQKV_CP_quantizer @@ -1469,7 +1554,10 @@ def backward(ctx, dout): seq_dim = None if ctx.qkv_format in ["bshd", "sbhd"]: seq_dim = ctx.qkv_format.index("s") - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:] + if ctx.enable_mla: + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + else: + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:] else: qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format @@ -1598,8 +1686,13 @@ def backward(ctx, dout): ) dout = dout.dequantize(dtype=dout_dtype) - out = out.view(*q.shape) - dout = dout.view(*q.shape) + if ctx.enable_mla: + out = out.view(*ctx.v_shape) + dout = dout.view(*ctx.v_shape) + else: + # MHA or GQA + out = out.view(*q.shape) + dout = dout.view(*q.shape) send_recv_reqs = [] flash_attn_bwd = None @@ -1675,6 +1768,9 @@ def backward(ctx, dout): kv = p2p_comm_buffers[i % 2][0] q_, kv_, out_, dout_ = None, None, None, None dq_, dk_, dv_ = None, None, None + if ctx.enable_mla: + k_part = kv[: ctx.k_numel].view(*ctx.k_shape) + v_part = kv[ctx.k_numel :].view(*ctx.v_shape) # In reversed order of fwd if causal: if i == (cp_size - 1): @@ -1683,13 +1779,23 @@ def backward(ctx, dout): q_, out_, dout_ = [ x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] ] - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) + v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) elif ctx.qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_ = kv.view(-1, *kv.shape[-4:]) + if ctx.enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part.view(-1, *k_part.shape[-3:]) + v_part = v_part.view(-1, *v_part.shape[-3:]) + else: + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_ = kv.view(-1, *kv.shape[-4:]) elif ctx.qkv_format == "thd": q_, kv_, out_, dout_ = q, kv, out, dout if ctx.use_fused_attention: @@ -1704,8 +1810,13 @@ def backward(ctx, dout): if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] q_part = q_ - k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + if not ctx.enable_mla: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) out_part = out_ dout_part = dout_ @@ -1787,6 +1898,7 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = 0 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout_, q_, @@ -1804,19 +1916,38 @@ def backward(ctx, dout): q_, out_, dout_ = [ x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] ] - # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] - kv_ = kv[:, 0] + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part[:, 0] + v_part = v_part[:, 0] + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] + kv_ = kv[:, 0] elif ctx.qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] - # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] - kv_ = kv[0] + if ctx.enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part[0] + v_part = v_part[0] + else: + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv_ = kv[0] elif ctx.qkv_format == "thd": q_, out_, dout_ = q, out, dout - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) + if ctx.enable_mla: + # [t, np, hn] -> [t/2, np, hn] + k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0) + v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0) + else: + # [2, t, np, hn] -> [2, t/2, np, hn] + kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) if ctx.use_fused_attention: - kv_ = kv_.contiguous() + if ctx.enable_mla: + k_part = k_part.contiguous() + v_part = v_part.contiguous() + else: + kv_ = kv_.contiguous() if ctx.fp8: aux_ctx_tensors = [ softmax_lse, @@ -1828,8 +1959,13 @@ def backward(ctx, dout): if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] q_part = q_ - k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + if not ctx.enable_mla: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) out_part = out_ dout_part = dout_ @@ -1913,6 +2049,7 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = -1 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout_, q_, @@ -1928,13 +2065,23 @@ def backward(ctx, dout): if ctx.qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1] - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) + v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) elif ctx.qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] q_, out_, dout_ = q[1], out[1], dout[1] - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_ = kv.view(-1, *kv.shape[-4:]) + if ctx.enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part.view(-1, *k_part.shape[-3:]) + v_part = v_part.view(-1, *v_part.shape[-3:]) + else: + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_ = kv.view(-1, *kv.shape[-4:]) elif ctx.qkv_format == "thd": # [t, np, hn] -> [t/2, np, hn] q_, out_, dout_ = [ @@ -1956,8 +2103,13 @@ def backward(ctx, dout): aux_ctx_tensors += [attn_biases[cp_size - i - 1]] q_part = q_ - k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + if not ctx.enable_mla: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) out_part = out_ dout_part = dout_ @@ -2041,6 +2193,7 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = -1 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout_, q_, @@ -2061,8 +2214,9 @@ def backward(ctx, dout): if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] q_part = q - k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] - v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] + if not ctx.enable_mla: + k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] + v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] out_part = out dout_part = dout @@ -2136,6 +2290,7 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = -1 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout, q, @@ -2228,15 +2383,18 @@ def backward(ctx, dout): else: dkv = p2p_comm_buffers[(i + 1) % 2][1] if ctx.use_fused_attention: - if ctx.qkv_format in ["bshd", "sbhd"]: + if ctx.enable_mla: + dkv_ = None + elif ctx.qkv_format in ["bshd", "sbhd"]: dkv_ = combine_tensors([dk_, dv_], -2) elif ctx.qkv_format == "thd": dkv_ = torch.cat( (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0 ) # pylint: disable=used-before-assignment - if ctx.qkv_format in ["bshd", "sbhd"]: + if not ctx.enable_mla and ctx.qkv_format in ["bshd", "sbhd"]: # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] + # dkv is a buffer, so we do not need to transpose it, but only need to reshape it. dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:]) dkv_ = dkv_.movedim(-3, 0) if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): @@ -2244,91 +2402,225 @@ def backward(ctx, dout): # [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn] dkv_ = dkv_.view(*dkv.shape) - if ctx.fp8: - if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].copy_(dkv_) - dkv[:, :, 1, ...].fill_(0) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].copy_(dkv_) - dkv[:, 1, ...].fill_(0) - else: - dkv.copy_(dkv_) - elif causal: - if i == (cp_size - 1): - if rank == 0: + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] or + # [2, sk//2, b, np, hn] + dk = dkv[: ctx.k_numel].view(*ctx.k_shape) + dv = dkv[ctx.k_numel :].view(*ctx.v_shape) + if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): + dk_ = dk_.view(*ctx.k_shape) + dv_ = dv_.view(*ctx.v_shape) + + if ctx.fp8: + # enable_mla and fp8 + if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...]) - dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...]) + dk[:, 0, ...].copy_(dk_) + dk[:, 1, ...].fill_(0) + dv[:, 0, ...].copy_(dv_) + dv[:, 1, ...].fill_(0) elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].add_(dkv_[:, 0, ...]) - dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "copy") - else: - dkv.add_(dkv_) - elif i >= (cp_size - rank - 1): - if i == 0 and rank == (cp_size - 1): + dk[0].copy_(dk_) + dk[1].fill_(0) + dv[0].copy_(dv_) + dv[1].fill_(0) + else: + dk.copy_(dk_) + dv.copy_(dv_) + elif causal: + # enable_mla and not fp8 and causal + if i == (cp_size - 1): + if rank == 0: + if ctx.qkv_format == "bshd": + dk[:, 0, ...].add_(dk_[:, 0, ...]) + dk[:, 1, ...].copy_(dk_[:, 1, ...]) + dv[:, 0, ...].add_(dv_[:, 0, ...]) + dv[:, 1, ...].copy_(dv_[:, 1, ...]) + elif ctx.qkv_format == "sbhd": + dk[0, ...].add_(dk_[0, ...]) + dk[1, ...].copy_(dk_[1, ...]) + dv[0, ...].add_(dv_[0, ...]) + dv[1, ...].copy_(dv_[1, ...]) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dk, dk_, cu_seqlens_kv_padded, "add", "copy" + ) + tex.thd_grad_correction( + dv, dv_, cu_seqlens_kv_padded, "add", "copy" + ) + else: + dk.add_(dk_) + dv.add_(dv_) + elif i >= (cp_size - rank - 1): + if i == 0 and rank == (cp_size - 1): + if ctx.qkv_format == "bshd": + dk[:, 0, ...].copy_(dk_) + dv[:, 0, ...].copy_(dv_) + elif ctx.qkv_format == "sbhd": + dk[0, ...].copy_(dk_) + dv[0, ...].copy_(dv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dk, dk_, cu_seqlens_kv_padded, "copy", "none" + ) + tex.thd_grad_correction( + dv, dv_, cu_seqlens_kv_padded, "copy", "none" + ) + else: + if ctx.qkv_format == "bshd": + dk[:, 0, ...].add_(dk_) + dv[:, 0, ...].add_(dv_) + elif ctx.qkv_format == "sbhd": + dk[0, ...].add_(dk_) + dv[0, ...].add_(dv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dk, dk_, cu_seqlens_kv_padded, "add", "none" + ) + tex.thd_grad_correction( + dv, dv_, cu_seqlens_kv_padded, "add", "none" + ) + elif i > 0: + dk.add_(dk_) + dv.add_(dv_) + else: # i == 0 + dk.copy_(dk_) + dv.copy_(dv_) + else: + # enable_mla and not fp8 and not causal + if i == 0: + dk.copy_(dk_) + dv.copy_(dv_) + else: # i > 0 + dk.add_(dk_) + dv.add_(dv_) + else: + if ctx.fp8: + # fp8 + if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): if ctx.qkv_format == "bshd": dkv[:, :, 0, ...].copy_(dkv_) + dkv[:, :, 1, ...].fill_(0) elif ctx.qkv_format == "sbhd": dkv[:, 0, ...].copy_(dkv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "copy", "none") + dkv[:, 1, ...].fill_(0) else: - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].add_(dkv_) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].add_(dkv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none") - elif i > 0: - dkv.add_(dkv_) - else: - dkv.copy_(dkv_) - else: - if i == 0: - dkv.copy_(dkv_) + dkv.copy_(dkv_) + elif causal: + # not fp8 and causal + if i == (cp_size - 1): + if rank == 0: + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...]) + dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...]) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].add_(dkv_[:, 0, ...]) + dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dkv, dkv_, cu_seqlens_kv_padded, "add", "copy" + ) + else: + dkv.add_(dkv_) + elif i >= (cp_size - rank - 1): + if i == 0 and rank == (cp_size - 1): + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].copy_(dkv_) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].copy_(dkv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dkv, dkv_, cu_seqlens_kv_padded, "copy", "none" + ) + else: + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].add_(dkv_) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].add_(dkv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dkv, dkv_, cu_seqlens_kv_padded, "add", "none" + ) + elif i > 0: + dkv.add_(dkv_) + else: # i == 0 + dkv.copy_(dkv_) else: - dkv.add_(dkv_) + # not fp8 and not causal + if i == 0: + dkv.copy_(dkv_) + else: # i > 0 + dkv.add_(dkv_) if ctx.fp8 and ctx.use_fused_attention: amax_cp_bwd = amax_per_step.amax(dim=1) ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) ctx.dQKV_CP_quantizer.amax.copy_(amax_cp_bwd[1]) - if ctx.qkv_format in ["bshd", "sbhd"]: - # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or - # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] - dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) dq = ctx.dQKV_CP_quantizer.create_tensor_from_data( dq_fp8, fake_dtype=torch.float32, internal=True ) - dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data( - dkv_fp8, fake_dtype=torch.float32, internal=True - ) - dq, dkv = [x.dequantize(dtype=torch.float32) for x in [dq, dkv]] - dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] + + if ctx.enable_mla: + # [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn] + dk_fp8 = dkv_fp8[:, : ctx.k_numel].view(cp_size, *ctx.k_shape) + dv_fp8 = dkv_fp8[:, ctx.k_numel :].view(cp_size, *ctx.v_shape) + dk = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dk_fp8, fake_dtype=torch.float32, internal=True + ) + dv = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dv_fp8, fake_dtype=torch.float32, internal=True + ) + dq, dk, dv = [x.dequantize(dtype=torch.float32) for x in [dq, dk, dv]] + dq, dk, dv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dk, dv]] + else: + if ctx.qkv_format in ["bshd", "sbhd"]: + # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or + # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] + dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) + dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dkv_fp8, fake_dtype=torch.float32, internal=True + ) + dq, dkv = [x.dequantize(dtype=torch.float32) for x in [dq, dkv]] + dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] if causal: if ctx.qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:]) + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + dk = dk.view(dk.shape[0], -1, *dk.shape[-2:]) + dv = dv.view(dv.shape[0], -1, *dv.shape[-2:]) + else: + # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] + dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:]) elif ctx.qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] dq = dq.view(-1, *dq.shape[-3:]) - # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] - dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:]) + if ctx.enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + dk = dk.view(-1, *dk.shape[-3:]) + dv = dv.view(-1, *dv.shape[-3:]) + else: + # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] + dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:]) if ctx.qkv_format == "thd" and not ctx.use_fused_attention: dq[cu_seqlens_q_padded[-1] :].fill_(0) - dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0) + if ctx.enable_mla: + dk[cu_seqlens_kv_padded[-1] :].fill_(0) + dv[cu_seqlens_kv_padded[-1] :].fill_(0) + else: + dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: assert torch.uint8 not in [dq.dtype, dkv.dtype] - dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]] - dk, dv = dkv[0], dkv[1] + if ctx.enable_mla: + dq, dk, dv = [ctx.dQKV_quantizer(x)._data for x in [dq, dk, dv]] + else: + dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]] + if not ctx.enable_mla: + dk, dv = dkv[0], dkv[1] if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device) @@ -3487,7 +3779,64 @@ def attn_forward_func_with_cp( use_flash_attn_3=False, ) -> torch.Tensor: """ - Attention implementation with context parallelism. + Attention implementation with context parallelism (CP). CP partitions tensors along the sequence + dimension, and by reducing the memory and computational pressure on each GPU, it enables long-context + LLMs in a distributed fashion. Transformer Engine's PyTorch CP implementation currently utilizes + the DualChunkSwap strategy to ensure load balancing across CP ranks. It is applied to all `attn_mask_type`s + and all `qkv_format`s, and it requires sequence lengths to be, or are padded to be, divisible by + (cp_size * 2). It also requires tokens to be re-ordered before entering this function. + + For qkv_format = {'bshd', 'sbhd'}, the token re-ordering is illustrated as below, for an example + use case of s = 12, attn_mask_type = 'causal', and cp_size = 2. seq_pos indicates each token's position + in their corresponding sequence. + + GPU0 | GPU1 GPU0 | GPU1 + seq_pos | 0 1 2 3 4 5 | 6 7 8 9 10 11 seq_pos | 0 1 2 9 10 11 | 3 4 5 6 7 8 + ---------------------------|----------------- ---------------------------|------------------ + 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + U 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 9 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 1, 1, + 0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 -> 0 10 | 1, 1, 1, 1, 1, 0,| 1, 1, 1, 1, 1, 1, + 5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1, + ---------------------------|----------------- ---------------------------|------------------ + 6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 3 | 1, 1, 1, 0, 0, 0,| 1, 0, 0, 0, 0, 0, + G 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 4 | 1, 1, 1, 0, 0, 0,| 1, 1, 0, 0, 0, 0, + P 8 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 0, 0, 0, P 5 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 0, 0, 0, + U 9 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 0, 0, U 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0, + 1 10 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 0, 1 7 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 0, + 11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1, 8 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 1, + + For qkv_format = 'thd', multiple sequences may be packed into the batch, and they may be of different + lengths. DualChunkSwap divides each sequence into (cp_size * 2) chunks and distributes 2 chunks of + every sequence onto a CP rank. The token matrix transformation is shown as follows, for an example of + batch_size = 2, seq_ids = [0, 1], seq_lens = [8, 4], t = 12, attn_mask_type = 'padding_causal', and + cp_size = 2. + + GPU0 | GPU1 GPU0 | GPU1 + seq_id | 0 0 0 0 0 0 | 0 0 1 1 1 1 seq_id | 0 0 0 0 1 1 | 0 0 0 0 1 1 + seq_pos | 0 1 2 3 4 5 | 6 7 0 1 2 3 seq_pos | 0 1 6 7 0 3 | 2 3 4 5 1 2 + ---------------------------|----------------- ---------------------------|------------------ + 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + P 0 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 0 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0, + U 0 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 0 7 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 0, 0, + 0 0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 -> 0 1 0 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 0, 0, + 0 5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 1 3 | 0, 0, 0, 0, 2, 2,| 0, 0, 0, 0, 2, 2, + ---------------------------|----------------- ---------------------------|------------------ + 0 6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 0 2 | 1, 1, 0, 0, 0, 0,| 1, 0, 0, 0, 0, 0, + G 0 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 0 3 | 1, 1, 0, 0, 0, 0,| 1, 1, 0, 0, 0, 0, + P 1 0 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 0, 0, 0 P 0 4 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 0, 0, 0, + U 1 1 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 0, 0 U 0 5 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 1, 0, 0, + 1 1 2 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 0 1 1 1 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 0, + 1 3 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 2 1 2 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 2, + + When all transformer layers in a model share the same CP configuration, i.e. cp_group, cp_global_ranks, + cp_comm_type and cp_stream, token re-ordering can take place in the dataloader, i.e. only once for + all the layers. An example of the re-ordering code is `get_batch_on_this_cp_rank + `_ + in Megatron-LM. + """ if cp_comm_type == "a2a+p2p": @@ -3530,6 +3879,12 @@ def attn_forward_func_with_cp( "all_gather", ], "The context parallel running configs cannot support sliding window attetnion!" + enable_mla = k.shape[-1] != v.shape[-1] + assert not enable_mla or cp_comm_type in [ + "p2p", + "a2a+p2p", + ], "The context parallel running configs cannot support MLA!" + args = [ is_training, q, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 7d50b9fa5..893e2d228 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -17,6 +17,7 @@ from transformer_engine.pytorch.fp8 import get_fp8_te_dtype from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule +from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.constants import ( AttnMaskTypes, AttnTypes, @@ -963,47 +964,54 @@ def forward( inference_params=inference_params, ) global _attention_backends - if ( - _attention_backends["attention_params"] is None - or attention_params != _attention_backends["attention_params"] - ): - _attention_backends["attention_params"] = attention_params - _attention_backends["backend_selection_requires_update"] = True - if _attention_backends["backend_selection_requires_update"]: - ( - use_flash_attention, - flash_attention_backend, - use_fused_attention, - fused_attention_backend, - use_unfused_attention, - _, - ) = dpa_utils.get_attention_backend(attention_params) - # Set global _attention_backends var using return value - # from get_attention_backend() - _attention_backends["use_flash_attention"] = use_flash_attention - _attention_backends["flash_attention_backend"] = flash_attention_backend - _attention_backends["use_fused_attention"] = use_fused_attention - _attention_backends["fused_attention_backend"] = fused_attention_backend - _attention_backends["use_unfused_attention"] = use_unfused_attention - _attention_backends["backend_selection_requires_update"] = False - if use_flash_attention: - self.logger.info( - "Running with FlashAttention backend (version %s)", - flash_attention_backend, - ) - elif use_fused_attention: - self.logger.info( - "Running with FusedAttention backend (sub-backend %s)", - int(fused_attention_backend), - ) - elif use_unfused_attention: - self.logger.info("Running with UnfusedDotProductAttention backend") + if is_in_onnx_export_mode(): + # We do not want to call get_attention_backend() in ONNX mode + # and we want to avoid using any global variables like _attention_backends. + use_flash_attention = False + use_fused_attention = False + use_unfused_attention = True else: - use_flash_attention = _attention_backends["use_flash_attention"] - flash_attention_backend = _attention_backends["flash_attention_backend"] - use_fused_attention = _attention_backends["use_fused_attention"] - fused_attention_backend = _attention_backends["fused_attention_backend"] - use_unfused_attention = _attention_backends["use_unfused_attention"] + if ( + _attention_backends["attention_params"] is None + or attention_params != _attention_backends["attention_params"] + ): + _attention_backends["attention_params"] = attention_params + _attention_backends["backend_selection_requires_update"] = True + if _attention_backends["backend_selection_requires_update"]: + ( + use_flash_attention, + flash_attention_backend, + use_fused_attention, + fused_attention_backend, + use_unfused_attention, + _, + ) = dpa_utils.get_attention_backend(attention_params) + # Set global _attention_backends var using return value + # from get_attention_backend() + _attention_backends["use_flash_attention"] = use_flash_attention + _attention_backends["flash_attention_backend"] = flash_attention_backend + _attention_backends["use_fused_attention"] = use_fused_attention + _attention_backends["fused_attention_backend"] = fused_attention_backend + _attention_backends["use_unfused_attention"] = use_unfused_attention + _attention_backends["backend_selection_requires_update"] = False + if use_flash_attention: + self.logger.info( + "Running with FlashAttention backend (version %s)", + flash_attention_backend, + ) + elif use_fused_attention: + self.logger.info( + "Running with FusedAttention backend (sub-backend %s)", + int(fused_attention_backend), + ) + elif use_unfused_attention: + self.logger.info("Running with UnfusedDotProductAttention backend") + else: + use_flash_attention = _attention_backends["use_flash_attention"] + flash_attention_backend = _attention_backends["flash_attention_backend"] + use_fused_attention = _attention_backends["use_fused_attention"] + fused_attention_backend = _attention_backends["fused_attention_backend"] + use_unfused_attention = _attention_backends["use_unfused_attention"] # raise exception if no backend is available if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/softmax.py b/transformer_engine/pytorch/attention/dot_product_attention/softmax.py index 25362e1d5..df10fc790 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/softmax.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/softmax.py @@ -8,6 +8,7 @@ import torch from torch import nn import transformer_engine_torch as tex +from transformer_engine.pytorch.export import is_in_onnx_export_mode THREADS_PER_WARP = 32 @@ -19,12 +20,18 @@ def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor: """Return the causal upper triangular mask for softmax input""" - matrix_identifiers = (mask_type, sq, sk) - if matrix_identifiers not in _default_causal_mask: + + def _get_mask(): diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1 - _default_causal_mask[matrix_identifiers] = torch.triu( + return torch.triu( torch.ones(sq, sk, dtype=torch.bool, device="cuda"), diagonal=diagonal_offset ) + + if is_in_onnx_export_mode(): + return _get_mask() + matrix_identifiers = (mask_type, sq, sk) + if matrix_identifiers not in _default_causal_mask: + _default_causal_mask[matrix_identifiers] = _get_mask() return _default_causal_mask[matrix_identifiers] @@ -169,7 +176,11 @@ def forward( self.attn_mask_type = attn_mask_type assert scale is None or self.softmax_in_fp32, "softmax should be in fp32 when scaled" + if is_in_onnx_export_mode(): + return self.forward_torch_softmax(inp, mask, scale) + # We do not want to connect this if with previous if, + # because we want to avoid calling is_kernel_available() in ONNX mode. if self.is_kernel_available(mask, *inp.size()): return self.forward_fused_softmax(inp, mask, scale) return self.forward_torch_softmax(inp, mask, scale) @@ -245,15 +256,15 @@ def forward_torch_softmax( if self.attn_mask_type in ["causal", "causal_bottom_right"]: seq_len_q, seq_len_k = inp.size(2), inp.size(3) causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k) + if mask is None: mask = causal_mask else: mask = torch.logical_or(mask, causal_mask) - mask_output = inp if mask is not None and self.attn_mask_type != "no_mask": mask_output = self.mask_func(inp, mask) - probs = torch.nn.Softmax(dim=-1)(mask_output) + probs = torch.nn.functional.softmax(mask_output, dim=-1) if self.input_in_float16 and self.softmax_in_fp32: if self.input_in_fp16: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 6353acead..b61f2e152 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -47,6 +47,7 @@ get_device_compute_capability, get_cudnn_version, ) +from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.jit import jit_fuser @@ -108,7 +109,7 @@ class FlashAttentionUtils: version = PkgVersion("0") version_required = PkgVersion("2.1.1") version_required_blackwell = PkgVersion("2.7.3") - max_version = PkgVersion("2.8.0.post2") + max_version = PkgVersion("2.8.1") v2_plus = False v2_1_plus = False v2_3_plus = False @@ -436,8 +437,8 @@ def get_attention_backend( # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 if inference_params is not None: - if device_compute_capability == (8, 9) and cudnn_version < (9, 11, 0): - logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN < 9.11") + if device_compute_capability == (8, 9) and cudnn_version < (9, 12, 0): + logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN < 9.12") use_fused_attention = False if context_parallel: logger.debug("Disabling all backends for KV caching with context parallelism") @@ -612,9 +613,10 @@ def get_attention_backend( " bias for THD format" ) use_fused_attention = False - elif head_dim_qk != head_dim_v: + elif fp8 and head_dim_qk != head_dim_v: logger.debug( - "Disabling FusedAttention as it does not support context parallelism with MLA" + "Disabling FusedAttention as it does not support context parallelism with FP8" + " MLA attention" ) use_fused_attention = False @@ -772,6 +774,7 @@ def get_attention_backend( q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) kv_type = q_type fused_attention_backend = tex.get_fused_attn_backend( + is_training, q_type, kv_type, QKVLayout[qkv_layout], @@ -959,16 +962,24 @@ def get_attention_backend( @torch.no_grad() def get_padding_mask( batch_size: int, - cu_seqlens_q: torch.Tensor, - cu_seqlens_kv: torch.Tensor, - max_seqlen_q: int, - max_seqlen_kv: int, + cu_seqlens_q: torch.Tensor = None, + cu_seqlens_kv: torch.Tensor = None, + max_seqlen_q: int = None, + max_seqlen_kv: int = None, + attention_type: str = "self", ): """Convert cu_seqlens to attention_mask""" + assert ( + cu_seqlens_q is not None and max_seqlen_q is not None + ), "cu_seqlens_q and max_seqlen_q are required for self-attention and cross-attention" seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) - attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool) + if attention_type == "cross": + assert ( + cu_seqlens_kv is not None and max_seqlen_kv is not None + ), "cu_seqlens_kv and max_seqlen_kv are required for cross-attention" + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool) for i in range(batch_size): attention_mask_q = torch.cat( [ @@ -981,21 +992,26 @@ def get_padding_mask( ], dim=0, ) - attention_mask_kv = torch.cat( - [ - attention_mask_kv, - torch.Tensor([False] * seqlens_kv[i] + [True] * (max_seqlen_kv - seqlens_kv[i])) - .to(dtype=torch.bool) - .unsqueeze(0) - .unsqueeze(0) - .unsqueeze(0), - ], - dim=0, + if attention_type == "cross": + attention_mask_kv = torch.cat( + [ + attention_mask_kv, + torch.Tensor([False] * seqlens_kv[i] + [True] * (max_seqlen_kv - seqlens_kv[i])) + .to(dtype=torch.bool) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0), + ], + dim=0, + ) + attention_mask_q = attention_mask_q.to(device="cuda") + if attention_type == "self": + attention_mask = attention_mask_q + else: + attention_mask = ( + attention_mask_q, + attention_mask_kv.to(device="cuda"), ) - attention_mask = ( - attention_mask_q.to(device="cuda"), - attention_mask_kv.to(device="cuda"), - ) return attention_mask @@ -1138,9 +1154,7 @@ def get_full_mask( swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( actual_seqlens_kv - actual_seqlens_q + window_size[1] ).view(batch_size, 1, 1, 1) - swa_mask = torch.logical_not( - torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0) - ) + swa_mask = torch.logical_not((swa_left <= 0) & ~(swa_right < 0)) if attention_mask is not None: attention_mask = torch.logical_or(swa_mask, attention_mask) else: @@ -1331,14 +1345,22 @@ def get_full_cu_seqlens( """ global _cu_seqlens_cache - if (batch_size, max_seqlen) not in _cu_seqlens_cache: - _cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange( + + def _get_cu_seqlens(batch_size, max_seqlen, device): + return torch.arange( 0, (batch_size + 1) * max_seqlen, step=max_seqlen, dtype=torch.int32, device=device, ) + + if is_in_onnx_export_mode(): + return _get_cu_seqlens(batch_size, max_seqlen, device) + if (batch_size, max_seqlen) not in _cu_seqlens_cache: + _cu_seqlens_cache[(batch_size, max_seqlen)] = _get_cu_seqlens( + batch_size, max_seqlen, device + ) return _cu_seqlens_cache[(batch_size, max_seqlen)] @@ -1614,11 +1636,16 @@ def get_qkv_layout( def run_iteratively(q, k, v): # check data pointers - data_ptr = q.untyped_storage().data_ptr() - check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v]) - check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k]) - data_ptr = k.untyped_storage().data_ptr() - check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v]) + if is_in_onnx_export_mode(): + check_ptrs_qkv = False + check_ptrs_qk = False + check_ptrs_kv = False + else: + data_ptr = q.untyped_storage().data_ptr() + check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v]) + check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k]) + data_ptr = k.untyped_storage().data_ptr() + check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v]) # check tensor shapes shape = q.shape @@ -1706,7 +1733,10 @@ def run_iteratively(q, k, v): return qkv_layout - qkv_layout = run_iteratively(q, k, v) + if not is_in_onnx_export_mode(): + qkv_layout = run_iteratively(q, k, v) + else: + qkv_layout = "not_supported" if qkv_layout == "not_supported": # force q,k,v to be contiguous and run get_layout again q, k, v = [x.contiguous() for x in [q, k, v]] diff --git a/transformer_engine/pytorch/attention/inference.py b/transformer_engine/pytorch/attention/inference.py index 8267bf63c..8d5417a45 100644 --- a/transformer_engine/pytorch/attention/inference.py +++ b/transformer_engine/pytorch/attention/inference.py @@ -420,6 +420,8 @@ def __init__( dtype=torch.int32, device=torch.cuda.current_device(), ) + # whether reindexing is needed, i.e. when batch seq_ids have changed + self.need_reindex = True def allocate_memory(self, layer_number): """Allocate memory for the cache""" @@ -451,6 +453,7 @@ def pre_step( # step() re-indexes k_cache and v_cache using batch_indices = [0, 2, 3, 1] so that # they are contiguous and match the indexing in q prev_batch_size = len(self.sequences) + prev_seq_ids = set(self.sequences.keys()) unfinished_seqs = self.sequences.keys() & step_dict.keys() finished_seqs = self.sequences.keys() - unfinished_seqs unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] @@ -478,6 +481,9 @@ def pre_step( for i in new_seqs: self.sequences[i] = step_dict[i] + # Whether reindexing is needed + self.need_reindex = set(self.sequences.keys()) != prev_seq_ids + return self.sequences def step( @@ -538,7 +544,7 @@ def step( ctx_len, self.max_seqlen, 1, - True, + self.need_reindex, ) k_cache = k_cache[:batch_size] diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index f018465dc..142044240 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -12,6 +12,7 @@ from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module import LayerNormLinear, Linear +from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization from transformer_engine.pytorch.utils import ( SplitAlongDim, divide, @@ -174,6 +175,22 @@ class MultiheadAttention(torch.nn.Module): parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument `fuse_wgrad_accumulation`. + use_qk_norm: bool, default = 'False' + if set to `True`, L2 normalization is applied to query and key tensors + after RoPE (if applicable) but before attention computation. + This follows the Llama4 approach for QK normalization to improve + training stability and model performance. + qk_norm_eps: float, default = 1e-6 + epsilon value for L2 normalization of query and key tensors. + Only used when `use_qk_norm` is True. + seq_length: Optional[int], default = `None` + sequence length of input samples. Needed for JIT Warmup, a technique where jit + fused functions are warmed up before training to ensure same kernels are used for + forward propagation and activation recompute phase. + micro_batch_size: Optional[int], default = `None` + batch size per training step. Needed for JIT Warmup, a technique where jit + fused functions are warmed up before training to ensure same kernels are + used for forward propagation and activation recompute phase. """ def __init__( @@ -214,6 +231,10 @@ def __init__( device: Union[torch.device, str] = "cuda", qkv_format: str = "sbhd", name: str = None, + use_qk_norm: bool = False, + qk_norm_eps: float = 1e-6, + seq_length: Optional[int] = None, + micro_batch_size: Optional[int] = None, ) -> None: super().__init__() @@ -267,6 +288,7 @@ def __init__( self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups self.name = name + self.use_qk_norm = use_qk_norm common_gemm_kwargs = { "fuse_wgrad_accumulation": fuse_wgrad_accumulation, @@ -278,6 +300,14 @@ def __init__( "device": device, } + # Initialize L2 normalization modules for query and key if enabled + if self.use_qk_norm: + self.qk_norm = L2Normalization( + eps=qk_norm_eps, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + ) + qkv_parallel_mode = "column" if set_parallel_mode else None if self.attention_type == "self": @@ -482,6 +512,8 @@ def forward( alibi_slopes: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None, + cu_seqlens_q_padded: Optional[torch.Tensor] = None, + cu_seqlens_kv_padded: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_kv: Optional[int] = None, fast_zero_fill: bool = True, @@ -556,6 +588,12 @@ def forward( cu_seqlens_kv: Optional[torch.Tensor], default = `None` Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + cu_seqlens_q_padded: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`, + with shape [batch_size + 1] and dtype torch.int32. + cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (with offset) in a batch for `key_layer` + and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. max_seqlen_q: Optional[int], default = `None` Maximum sequence length in `query_layer`. Calculated from `cu_seqlens_q` if not provided. @@ -714,6 +752,18 @@ def forward( for x in (key_layer, value_layer) ) + if self.qkv_format == "thd": + key_layer, value_layer = ( + x.reshape(x.size(0), -1, self.hidden_size_per_attention_head) + for x in (key_layer, value_layer) + ) + else: + # key, value: -> [sq, b, ng, hn] + key_layer, value_layer = ( + x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head) + for x in (key_layer, value_layer) + ) + # Attention head [sq, b, h] --> [sq, b, hp] if self.input_layernorm: layernorm_query_outputs = self.layernorm_query( @@ -792,6 +842,14 @@ def forward( interleaved=self.rotary_pos_interleaved, ) + # =========================== + # Apply L2 normalization to query and key tensors + # =========================== + + if self.use_qk_norm: + query_layer = self.qk_norm(query_layer) + key_layer = self.qk_norm(key_layer) + # =========================== # Core attention computation # =========================== @@ -803,6 +861,8 @@ def forward( qkv_format=self.qkv_format, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, attention_mask=attention_mask, diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 7864046c9..16fa9f3e8 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -261,7 +261,8 @@ def fused_attn_fwd( ), "Fused attention does not support this input combination." if IS_HIP_EXTENSION: - rng_elts_per_thread = 0 + # Both CK/aiter and aotriton follow the flash-attn rng design + rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS else: # BF16/FP16 fused attention API from fmha_v1 apex if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index ce4594df1..9f3921d36 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -77,6 +77,14 @@ def general_gemm( # There is not use_split_accumulator == False # implementation for Float8BlockwiseQTensorBase GEMM use_split_accumulator = True + + # Check that data format is supported + if ( + A._data_format != tex.Float8BlockScaleTensorFormat.GEMM_READY + or B._data_format != tex.Float8BlockScaleTensorFormat.GEMM_READY + ): + raise RuntimeError("GEMM with Float8BlockwiseQTensor requires GEMM_READY format") + args = ( A, transa, # transa diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 5ad113e8b..1c03e3d37 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -9,8 +9,8 @@ import torch +from transformer_engine.debug.pytorch.debug_state import TEDebugState from .tensor.quantized_tensor import QuantizedTensorBase - from .tensor.float8_tensor import Float8Tensor __all__ = ["get_cpu_offload_context"] @@ -20,6 +20,9 @@ def mark_activation_offload(*tensors): """Set the type of the offloading needed for a tensor.""" + if TEDebugState.debug_enabled: + raise RuntimeError("CPU offload is not supported in debug mode.") + for tensor in tensors: if tensor is None: continue @@ -253,12 +256,20 @@ def offload(src_tensor, pin_memory=True): return state @staticmethod - def reload(state, non_blocking=None): + def reload(state, non_blocking=None, copy_buffer=None): """Reload.""" dev, cpu_backup = state if non_blocking is None: non_blocking = cpu_backup.is_pinned() - return cpu_backup.to(dev, non_blocking=non_blocking) + + if copy_buffer is None: + return cpu_backup.to(dev, non_blocking=non_blocking) + + assert cpu_backup.size() == copy_buffer.size(), "Can't copy two buffers of different sizes!" + + copy_buffer.copy_(cpu_backup, non_blocking=non_blocking) + + return copy_buffer def tensor_push(self, tensor: torch.Tensor, **kwargs): """Tensor push.""" @@ -300,6 +311,7 @@ def __init__( num_offload_group, # must be <= actual number of groups (number of commits) num_model_group, tensor_need_offloading_checker=(lambda t: True), + double_buffering=False, debug=False, ) -> None: super().__init__( @@ -314,11 +326,17 @@ def __init__( # Data structure to hold the FP8/MXFP8 tensor objects self.fp8_tensor_object_map = {} self.float8_transpose_cache_valid = {} + self.dereferencing_list = [] # Tracking the number of layers offloaded self.offloaded_group_count = 0 # Core data structure that decides the window for offloading self.layer_window_map = {} + # Data structures fo double buffered reloading + self.double_buffering = double_buffering + self.reload_double_buffer = [[], []] + self.double_buffer_created = False + # Logic to make offloading load balance across computation # for optimal CPU/GPU interconnect usage constant = 0 @@ -360,6 +378,12 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: self.tensor_tag_to_state[tensor_tag] = [] self.tensor_tag_to_buf[tensor_tag] = [] + # Added support for de-duplicating FP8 param tensors + for _, value in self.fp8_tensor_object_map.items(): + if tensor is value: + self.dereferencing_list.append(tensor_tag) + break + self.fp8_tensor_object_map[tensor_tag] = tensor if isinstance(tensor, Float8Tensor): self.float8_transpose_cache_valid[tensor_tag] = getattr( @@ -398,11 +422,18 @@ def tensor_pop(self, tensor_tag, **kwargs): # Handling the quantized tensor case specially here if isinstance(tensor, list): - self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor) + # If it's a duplicated tensor, we don't need to locally + # write back a tensor as it would already be written + if tensor_tag in self.dereferencing_list: + self.dereferencing_list.remove(tensor_tag) + else: + self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor) tensor = self.fp8_tensor_object_map.pop(tensor_tag) - self.tensor_tag_to_buf.pop(tensor_tag, None) + if self.double_buffering: + tensor.do_not_clear = True + self.tensor_tag_to_buf.pop(tensor_tag, None) # the tensor should have been copied back in on_group_commit_backward() # which invokes bulk_reload_group. assert not isinstance(tensor, tuple) @@ -454,6 +485,20 @@ def synchronize_on_group_commit_forward(self, current_group): # the first compute completion if current_group == 0: self.d2h_stream.wait_stream(torch.cuda.current_stream()) + + if not self.double_buffer_created: + # Creating the first copy of double buffer for tensors that are offloaded + for tensor_tag, buf in self.tensor_tag_to_buf.items(): + if isinstance(buf, list): + for b in buf: + self.reload_double_buffer[0].append( + torch.empty_like(b) if self.double_buffering else None + ) + else: + self.reload_double_buffer[0].append( + torch.empty_like(buf) if self.double_buffering else None + ) + self.bulk_offload_group(current_group) # Window map data structure helps us synchronize based on number @@ -483,6 +528,15 @@ def synchronize_on_group_commit_forward(self, current_group): # Increment the offload group count to keep track self.offloaded_group_count += 1 + if not self.double_buffer_created: + # Creating second copy of double buffer for tensors that are offloaded + if current_group == (self.num_layers - 1): + for buf in self.reload_double_buffer[0]: + self.reload_double_buffer[1].append( + torch.empty_like(buf) if self.double_buffering else None + ) + self.double_buffer_created = True + def on_group_commit_forward(self): """This function will cause host device synchronization""" # handle synchronization events @@ -494,28 +548,49 @@ def bulk_reload_group(self, group_to_reload): """Bulk reload group.""" assert group_to_reload < self.num_offload_group + buffer_idx = 0 + double_buffer_idx = group_to_reload % 2 + with torch.cuda.stream(self.h2d_stream): # move back tensors for tensor_label, state in self.tensor_tag_to_state.items(): group_id, _ = tensor_label if group_id == group_to_reload: if isinstance(state, tuple): - recovered_tensor = SynchronizedGroupOffloadHandler.reload(state) + recovered_tensor = SynchronizedGroupOffloadHandler.reload( + state, True, self.reload_double_buffer[double_buffer_idx][buffer_idx] + ) + buffer_idx = buffer_idx + 1 self.tensor_tag_to_state[tensor_label] = recovered_tensor elif isinstance(state, list): tensor_list = [] for state_tuple in state: if isinstance(state_tuple, tuple): tensor_list.append( - SynchronizedGroupOffloadHandler.reload(state_tuple) + SynchronizedGroupOffloadHandler.reload( + state_tuple, + True, + self.reload_double_buffer[double_buffer_idx][buffer_idx], + ) ) + buffer_idx = buffer_idx + 1 else: tensor_list.append(state_tuple) - _ = self.fp8_tensor_object_map[tensor_label].restore_from_saved(tensor_list) + + # No need to write back the duplicated tensor againn + # to the same location, this check ensures that + if tensor_label in self.dereferencing_list: + self.dereferencing_list.remove(tensor_label) + else: + _ = self.fp8_tensor_object_map[tensor_label].restore_from_saved( + tensor_list + ) + if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor): self.fp8_tensor_object_map[tensor_label]._transpose_invalid = ( self.float8_transpose_cache_valid.pop(tensor_label) ) + self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop( tensor_label ) @@ -552,6 +627,7 @@ def get_cpu_offload_context( model_layers: int = 1, offload_activations: bool = True, offload_weights: bool = False, + double_buffering: bool = False, ): """ This function returns the CPU Offload context and the synchronizer function that needs to be @@ -580,6 +656,8 @@ def get_cpu_offload_context( When set to `True`, offloads the activations for the TE layer. offload_weights: bool, default = `True` When set to `True`, offloads the weights for the TE layer. + double_buffering: bool, default = `False` + When set to `True`, uses double buffering for offloading. """ @@ -611,6 +689,7 @@ def tensor_need_offloading_checker_activations(tensor): num_offload_group=num_layers, num_model_group=model_layers, tensor_need_offloading_checker=tensor_need_offloading_checker, + double_buffering=double_buffering, ) def group_prefetch_offload_commit_async(tensor): diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 75e8c14fc..f86b60f61 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -20,6 +20,20 @@ std::vector getTensorShape(at::Tensor t) { return shape; } +NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) { + NVTEShape ret; + ret.ndim = torch_shape.size(); + constexpr int max_dimensions = sizeof(ret.data) / sizeof(size_t); + NVTE_CHECK(ret.ndim < max_dimensions, + "Torch tensor has too many dimensions. Max supported: ", max_dimensions, " and got ", + ret.ndim, "."); + for (size_t i = 0; i < ret.ndim; ++i) { + const auto& v = torch_shape[i]; + ret.data[i] = static_cast(v); + } + return ret; +} + std::unique_ptr convert_quantizer(py::handle quantizer) { init_extension(); if (quantizer.is_none()) { diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index c15a1ae3c..1d1511c73 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -34,7 +34,9 @@ #endif #include #include +#include #include +#include #include #include #include @@ -182,6 +184,8 @@ class Float8BlockQuantizer : public Quantizer { bool force_pow_2_scales = false; // Amax within quantization tile has a floor of epsilon. float amax_epsilon = 0.0; + // Whether quantized tensor will be used in an all-gather + bool all_gather_usage = false; private: int block_scaling_dim = 2; @@ -203,6 +207,8 @@ class Float8BlockQuantizer : public Quantizer { std::pair create_tensor( const std::vector& shape, DType dtype, std::optional rowwise_data = std::nullopt) const override; + + std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; class MXFP8Quantizer : public Quantizer { @@ -218,6 +224,8 @@ class MXFP8Quantizer : public Quantizer { std::pair create_tensor( const std::vector& shape, DType dtype, std::optional rowwise_data = std::nullopt) const override; + + std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; std::unique_ptr convert_quantizer(py::handle quantizer); @@ -227,21 +235,23 @@ std::vector getTensorShape(at::Tensor t); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); -inline size_t typeToSize(transformer_engine::DType t) { +inline size_t typeToNumBits(transformer_engine::DType t) { switch (t) { case transformer_engine::DType::kInt64: - return 8; + return 64; case transformer_engine::DType::kInt32: case transformer_engine::DType::kFloat32: - return 4; + return 32; case transformer_engine::DType::kInt16: case transformer_engine::DType::kFloat16: case transformer_engine::DType::kBFloat16: - return 2; + return 16; case transformer_engine::DType::kByte: case transformer_engine::DType::kFloat8E4M3: case transformer_engine::DType::kFloat8E5M2: - return 1; + return 8; + case transformer_engine::DType::kFloat4E2M1: + return 4; default: NVTE_ERROR("Invalid type"); } @@ -374,6 +384,7 @@ std::vector convertShape(const NVTEShape& shape); int roundup(const int value, const int multiple); +NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); } // namespace transformer_engine::pytorch namespace std { diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2ff64ae90..4e4e46fd1 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -23,6 +23,38 @@ class CommOverlapType{}; namespace transformer_engine::pytorch { +/*************************************************************************************************** + * Router fusion + **************************************************************************************************/ + +std::tuple fused_topk_with_score_function_fwd( + at::Tensor logits, int topk, bool use_pre_softmax, c10::optional num_groups, + c10::optional group_topk, c10::optional scaling_factor, std::string score_function, + c10::optional expert_bias); + +at::Tensor fused_topk_with_score_function_bwd(int num_tokens, int num_experts, + at::Tensor routing_map, + at::Tensor intermediate_output, at::Tensor grad_probs, + int topk, bool use_pre_softmax, + c10::optional scaling_factor, + std::string score_function); + +std::tuple fused_score_for_moe_aux_loss_fwd( + at::Tensor logits, int topk, std::string score_function); + +at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, + at::Tensor intermediate_output, at::Tensor grad_probs, + int topk, std::string score_function); + +std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, + at::Tensor tokens_per_expert, + int total_num_tokens, int num_experts, + int num_rows, int num_cols, int topk, + float coeff); + +at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, int num_rows, + int num_cols, at::Tensor grad_aux_loss); + /*************************************************************************************************** * Permutation **************************************************************************************************/ @@ -45,13 +77,11 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T * Attention **************************************************************************************************/ -NVTE_Fused_Attn_Backend get_fused_attn_backend(const DType q_dtype, const DType kv_dtype, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, float p_dropout, - size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, - size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right); +NVTE_Fused_Attn_Backend get_fused_attn_backend( + bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, @@ -120,10 +150,6 @@ std::optional> te_general_grouped_gemm( * Transpose **************************************************************************************************/ -std::vector fused_multi_quantize(std::vector input_list, - std::optional> output_list, - std::vector quantizer_list, DType otype); - at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional output = std::nullopt); @@ -194,10 +220,17 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w **************************************************************************************************/ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, - std::optional noop); + std::optional noop_flag); py::object dequantize(const py::handle &input, DType otype); +std::vector multi_tensor_quantize(const std::vector &tensor_list, + std::vector quantizer_list); + +std::vector split_quantize(const at::Tensor &tensor, + const std::vector &split_sections, + std::vector quantizer_list); + /*************************************************************************************************** * Bias gradient fusions **************************************************************************************************/ @@ -381,6 +414,9 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, std::vector input_row_list, std::vector padded_input_row_list); +void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, + std::vector input_row_list, + std::vector unpadded_input_row_list); #ifndef USE_ROCM /*************************************************************************************************** * NVSHMEM APIs @@ -447,6 +483,8 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve at::Tensor get_buffer(bool local_chunk = false, std::optional> shape = std::nullopt); + at::Stream get_communication_stream(); + }; // CommOverlap class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase { @@ -466,6 +504,8 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm at::Tensor get_buffer(bool local_chunk = false, std::optional> shape = std::nullopt); + at::Stream get_communication_stream(); + }; // CommOverlapP2P #endif // !USE_ROCM diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 189190f68..dfc8a8291 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -4,8 +4,8 @@ * See LICENSE for license information. ************************************************************************/ +#include "../extensions.h" #include "common.h" -#include "extensions.h" #include "pybind.h" namespace transformer_engine::pytorch { diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index fe640f67c..6f6f82725 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -4,8 +4,8 @@ * See LICENSE for license information. ************************************************************************/ +#include "../extensions.h" #include "common.h" -#include "extensions.h" namespace transformer_engine::pytorch { diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 7b84ff6e7..c26614103 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -6,8 +6,8 @@ * See LICENSE for license information. ************************************************************************/ +#include "../extensions.h" #include "common.h" -#include "extensions.h" #include "pybind.h" namespace { @@ -26,12 +26,12 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s NVTE_CHECK(fcd_size % block_size == 0, "input size not aligned to block size"); - size_t element_size = transformer_engine::pytorch::typeToSize(self.dtype()); + size_t element_size_bits = transformer_engine::pytorch::typeToNumBits(self.dtype()); int32_t start_row = start_index.data_ptr()[0]; void *base_ptr = static_cast(self.get_rowwise_data().data_ptr) + - static_cast(start_row) * fcd_size * element_size; + static_cast(start_row) * fcd_size * element_size_bits / 8; size_t num_rows_to_zero = max_tokens - start_row; - size_t total_bytes = num_rows_to_zero * fcd_size * element_size; + size_t total_bytes = num_rows_to_zero * fcd_size * element_size_bits / 8; NVTE_SCOPED_GIL_RELEASE( { nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream()); }); @@ -59,14 +59,14 @@ namespace transformer_engine::pytorch { // get the fused attention backend NVTE_Fused_Attn_Backend get_fused_attn_backend( - const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right) { + bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, - attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, - head_dim_qk, head_dim_v, window_size_left, window_size_right); + is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, + bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, + max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right); return fused_attention_backend; } @@ -174,16 +174,7 @@ std::vector fused_attn_fwd( // extract rng seed and offset auto gen = at::get_generator_or_default( rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); -#ifndef USE_ROCM at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); -#else - const transformer_engine::Tensor *input_cu_seqlens_q = reinterpret_cast(te_cu_seqlens_q.data()); - size_t batch_size = input_cu_seqlens_q->data.shape[0]-1; - const transformer_engine::Tensor *input_Q = reinterpret_cast(te_Q.data()); - size_t ndim = input_Q->data.shape.size(); - size_t num_attn_heads = input_Q->data.shape[ndim-2]; - at::PhiloxCudaState philox_args = init_philox_state(gen, batch_size*num_attn_heads*max_seqlen_q*max_seqlen_kv); -#endif auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); unpack(philox_args, static_cast(rng_state.data_ptr())); auto te_rng_state = makeTransformerEngineTensor(rng_state); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 1edbef8cd..07f2be9df 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -6,60 +6,51 @@ #include "transformer_engine/cast.h" +#include +#include +#include +#include +#include +#include + +#include "../extensions.h" #include "common.h" -#include "extensions.h" #include "pybind.h" #include "transformer_engine/transformer_engine.h" -namespace transformer_engine::pytorch { +namespace transformer_engine { +namespace pytorch { -py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::object& output, - std::optional noop) { - init_extension(); - auto my_quantizer = convert_quantizer(quantizer); - auto input_tensor = tensor.contiguous(); +namespace { - const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); - const auto& te_input_shape = te_input.shape(); - std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); - auto fake_tensor_type = tensor.scalar_type(); - if (!detail::IsFloatingPointType(fake_tensor_type)) { - fake_tensor_type = at::kFloat; - } +std::vector get_tensor_shape(const TensorWrapper &tensor) { + const auto &shape = tensor.shape(); + return std::vector(shape.data, shape.data + shape.ndim); +} - TensorWrapper te_output; - py::object out; - if (output.is_none()) { - DType fake_te_type = GetTransformerEngineDType(fake_tensor_type); - std::tie(te_output, out) = my_quantizer->create_tensor(input_shape, fake_te_type); - } else { - out = output; - te_output = makeTransformerEngineTensor(output, quantizer); +void quantize_impl(const TensorWrapper &input, py::handle &quantizer_py, + std::unique_ptr &quantizer_cpp, TensorWrapper &output, + TensorWrapper &noop_flag) { + // Check tensor dims + NVTE_CHECK(get_tensor_shape(input) == get_tensor_shape(output), + "Input tensor (shape=", get_tensor_shape(input), + ") and output tensor (shape=", get_tensor_shape(output), ") do not match"); + if (input.numel() == 0) { + return; } - TensorWrapper te_noop; - if (noop.has_value()) { - te_noop = makeTransformerEngineTensor(*noop); - } else { - te_noop = TensorWrapper(); - } - - if (te_output.numel() == 0) return out; - + // Recipe-specific configuration QuantizationConfigWrapper quant_config; - quant_config.set_noop_tensor(te_noop.data()); - - if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // my_quantizer here has to be a Float8CurrentScalingQuantizer - auto my_quantizer_cs = static_cast(my_quantizer.get()); - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); - }); + quant_config.set_noop_tensor(noop_flag.data()); + if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { + auto my_quantizer_cs = static_cast(quantizer_cpp.get()); + NVTE_SCOPED_GIL_RELEASE( + { nvte_compute_amax(input.data(), output.data(), at::cuda::getCurrentCUDAStream()); }); // check if we need to do amax reudction (depending on model parallel configs) if (my_quantizer_cs->with_amax_reduction) { c10::intrusive_ptr process_group_ptr = my_quantizer_cs->amax_reduction_group; // construct torch tesnor from NVTEBasicTensor without reallocating memory - at::Tensor& amax_tensor_torch = my_quantizer_cs->amax; + at::Tensor &amax_tensor_torch = my_quantizer_cs->amax; std::vector tensors = {amax_tensor_torch}; // allreduce amax tensor c10d::AllreduceOptions allreduce_opts; @@ -72,34 +63,70 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_scale_from_amax(te_output.data(), quant_config, - at::cuda::getCurrentCUDAStream()); + nvte_compute_scale_from_amax(output.data(), quant_config, at::cuda::getCurrentCUDAStream()); }); - // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel - te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); - } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { - auto my_quantizer_bw = static_cast(my_quantizer.get()); + // set amax ptr to null in output TensorWrapper to avoid atomic amax updates in kernel + output.set_amax(nullptr, DType::kFloat32, output.defaultShape); + } else if (detail::IsFloat8BlockwiseQuantizers(quantizer_py.ptr())) { + auto my_quantizer_bw = static_cast(quantizer_cpp.get()); quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); + if (my_quantizer_bw->all_gather_usage) { + quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); + } } + + // Perform quantization NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_v2(te_input.data(), te_output.data(), quant_config, - at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(input.data(), output.data(), quant_config, at::cuda::getCurrentCUDAStream()); }); +} - return out; +} // namespace + +py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, + std::optional noop_flag) { + // Convert quantizer to C++ object + auto quantizer_cpp = convert_quantizer(quantizer); + + // Convert input tensor to C++ object + auto input_contiguous = tensor.contiguous(); + const auto input_cpp = makeTransformerEngineTensor(input_contiguous); + + // Initialize output tensor + TensorWrapper output_cpp; + py::object output_py; + if (output.is_none()) { + const auto shape = get_tensor_shape(input_cpp); + const auto fake_dtype = input_cpp.dtype(); + std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype); + } else { + output_py = output; + output_cpp = makeTransformerEngineTensor(output_py, quantizer); + } + + // Initialize no-op flag + TensorWrapper noop_flag_cpp; + if (noop_flag.has_value()) { + noop_flag_cpp = makeTransformerEngineTensor(*noop_flag); + } + + // Perform quantization + quantize_impl(input_cpp, quantizer, quantizer_cpp, output_cpp, noop_flag_cpp); + + return output_py; } -py::object dequantize(const py::handle& input, transformer_engine::DType otype) { +py::object dequantize(const py::handle &input, transformer_engine::DType otype) { init_extension(); const auto none = py::none(); - const auto& input_tensor = makeTransformerEngineTensor(input, none); + const auto &input_tensor = makeTransformerEngineTensor(input, none); NoneQuantizer q(none); - const auto& shape = convertShape(input_tensor.shape()); + const auto &shape = convertShape(input_tensor.shape()); auto [out_tensor, out] = q.create_tensor(shape, otype); @@ -110,9 +137,522 @@ py::object dequantize(const py::handle& input, transformer_engine::DType otype) return out; } +namespace { + +void multi_tensor_quantize_impl(const std::vector &input_list, + std::vector &quantizer_py_list, + std::vector> &quantizer_cpp_list, + std::vector &output_list) { + // Check number of tensors + const size_t num_tensors = input_list.size(); + NVTE_CHECK(quantizer_py_list.size() == num_tensors, "Expected ", num_tensors, + " Python quantizers, but got ", quantizer_py_list.size()); + NVTE_CHECK(quantizer_cpp_list.size() == num_tensors, "Expected ", num_tensors, + " C++ quantizers, but got ", quantizer_cpp_list.size()); + NVTE_CHECK(output_list.size() == num_tensors, "Expected ", num_tensors, + " output tensors, but got ", output_list.size()); + + // Choose implementation + // Note: Currently only have fused kernel for FP8 delayed scaling + bool with_fused_kernel = true; + for (size_t i = 0; i < num_tensors; i++) { + if (!detail::IsFloat8Quantizers(quantizer_py_list[i].ptr())) { + with_fused_kernel = false; + break; + } + if (nvte_tensor_data(output_list[i].data()) == nullptr || + nvte_tensor_columnwise_data(output_list[i].data()) == nullptr) { + with_fused_kernel = false; + break; + } + } + + // Launch TE kernel + if (with_fused_kernel) { + // Fused kernel for multi-tensor quantize + std::vector nvte_tensor_input_list; + std::vector nvte_tensor_output_list; + for (size_t i = 0; i < num_tensors; ++i) { + nvte_tensor_input_list.push_back(input_list[i].data()); + nvte_tensor_output_list.push_back(output_list[i].data()); + } + NVTE_SCOPED_GIL_RELEASE({ + nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(), + nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); + }); + } else { + // Quantize kernels individually + TensorWrapper dummy_noop_flag; + for (size_t i = 0; i < num_tensors; ++i) { + quantize_impl(input_list[i], quantizer_py_list[i], quantizer_cpp_list[i], output_list[i], + dummy_noop_flag); + } + } +} + +} // namespace + +std::vector multi_tensor_quantize(const std::vector &tensor_list, + std::vector quantizer_list) { + // Check number of tensors + const size_t num_tensors = tensor_list.size(); + NVTE_CHECK(quantizer_list.size() == num_tensors, "Expected ", num_tensors, + " quantizers, but got ", quantizer_list.size()); + + // Convert quantizers to C++ objects + std::vector> quantizer_cpp_list; + for (size_t i = 0; i < num_tensors; i++) { + quantizer_cpp_list.push_back(convert_quantizer(quantizer_list[i])); + } + + // Initialize input and output tensors + std::vector input_cpp_list; + std::vector output_cpp_list; + std::vector output_py_list; + for (size_t i = 0; i < num_tensors; ++i) { + // Convert input tensor to C++ object + const auto &input_py = tensor_list[i]; + NVTE_CHECK(input_py.is_contiguous(), "Input tensor ", i, " is not contiguous"); + input_cpp_list.emplace_back(makeTransformerEngineTensor(input_py)); + const auto &input_cpp = input_cpp_list.back(); + const auto input_shape = input_cpp.shape(); + const auto input_dtype = GetTransformerEngineDType(input_py.scalar_type()); + + // Construct output tensor + std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + auto [output_cpp, output_py] = quantizer_cpp_list[i]->create_tensor(output_shape, input_dtype); + output_cpp_list.emplace_back(std::move(output_cpp)); + output_py_list.emplace_back(std::move(output_py)); + } + + // Perform multi-tensor quantization + multi_tensor_quantize_impl(input_cpp_list, quantizer_list, quantizer_cpp_list, output_cpp_list); + + return output_py_list; +} + +namespace { + +std::tuple, std::vector> bulk_allocate_fp8_blockwise_tensors( + std::vector> &shape_list, std::vector &quantizer_py_list, + std::vector &quantizer_cpp_list) { + init_extension(); + std::tuple, std::vector> retval; + auto &tensor_py_list = std::get<0>(retval); + auto &tensor_cpp_list = std::get<1>(retval); + + // Number of tensors + const size_t num_tensors = shape_list.size(); + if (num_tensors == 0) { + return retval; + } + + // Quantization parameters + const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; + const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; + const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); + const auto is_2D_scaled = scaling_mode == NVTE_BLOCK_SCALING_2D; + const auto fp8_dtype = quantizer_cpp_list[0]->dtype; + constexpr size_t fp8_elem_size = 1; + constexpr size_t scale_elem_size = 4; + + // Helper function to construct tensor view + // Note: Deleter holds a shared_ptr for the buffer, so the buffer + // will survive until all views are deleted. + auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + size_t offset, at::ScalarType dtype) -> at::Tensor { + std::vector shape_int64(shape.begin(), shape.end()); + // in the case where full buffer is empty because local rank receives no tokens for all the experts + // then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob + // but in the case where some experts receive tokens, some not, we want to leverage from_blob + // as much as possible to avoid CPU overhead + if (buffer->data_ptr() == nullptr) { + return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); + } + return at::from_blob( + buffer->data_ptr() + offset, shape_int64, + [buffer](void *) {}, // deleter holds shared_ptr + at::device(at::kCUDA).dtype(dtype)); + }; + + // Allocate row-wise data + std::vector rowwise_data_list, rowwise_scale_list; + std::vector> rowwise_data_shapes, rowwise_scale_shapes; + if (rowwise_usage) { + // Tensor sizes + for (size_t i = 0; i < num_tensors; ++i) { + rowwise_data_shapes.emplace_back(shape_list[i]); + rowwise_scale_shapes.emplace_back( + quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); + } + + // Offsets in full buffer + size_t buffer_size = 0; + std::vector data_offsets, scale_offsets; + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 256); // align to 256B + data_offsets.push_back(buffer_size); + buffer_size += product(rowwise_data_shapes[i]) * fp8_elem_size; + } + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 16); // align to 16B + scale_offsets.push_back(buffer_size); + buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size; + } + + // Allocate full buffer + auto buffer = std::make_shared( + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + + // Construct tensor views + for (size_t i = 0; i < num_tensors; ++i) { + rowwise_data_list.emplace_back( + make_torch_view(buffer, rowwise_data_shapes[i], data_offsets[i], torch::kUInt8)); + rowwise_scale_list.emplace_back( + make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kFloat32)); + } + } + + // Allocate column-wise data + std::vector columnwise_data_list, columnwise_scale_list; + std::vector> columnwise_data_shapes, columnwise_scale_shapes; + if (columnwise_usage) { + // Tensor sizes + for (size_t i = 0; i < num_tensors; ++i) { + columnwise_data_shapes.emplace_back(); + auto &shape = columnwise_data_shapes.back(); + shape.push_back(shape_list[i].back()); + for (size_t j = 0; j < shape_list[i].size() - 1; ++j) { + shape.push_back(shape_list[i][j]); + } + columnwise_scale_shapes.emplace_back( + quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); + } + + // Offsets in full buffer + size_t buffer_size = 0; + std::vector data_offsets, scale_offsets; + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 256); // align to 256B + data_offsets.push_back(buffer_size); + buffer_size += product(columnwise_data_shapes[i]) * fp8_elem_size; + } + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 16); // align to 16B + scale_offsets.push_back(buffer_size); + buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size; + } + + // Allocate full buffer + auto buffer = std::make_shared( + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + + // Construct tensor views + for (size_t i = 0; i < num_tensors; ++i) { + columnwise_data_list.emplace_back( + make_torch_view(buffer, columnwise_data_shapes[i], data_offsets[i], torch::kUInt8)); + columnwise_scale_list.emplace_back( + make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kFloat32)); + } + } + + // Construct FP8 block-wise tensors + py::handle Float8BlockwiseQTensorClass( + reinterpret_cast(Float8BlockwiseQTensorBasePythonClass)); + for (size_t i = 0; i < num_tensors; ++i) { + // Create tensor objects with proper reference counting + py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); + py::object rowwise_scale = rowwise_usage ? py::cast(rowwise_scale_list[i]) : py::none(); + py::object columnwise_data = + (columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none()); + py::object columnwise_scale = + (columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none()); + + // Construct Python tensor + tensor_py_list.emplace_back(Float8BlockwiseQTensorClass( + rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype, + quantizer_py_list[i], is_2D_scaled, Float8BlockScaleTensorFormat::GEMM_READY)); + + // Construct C++ tensor + tensor_cpp_list.emplace_back(makeTransformerEngineTensor( + rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, + columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, + rowwise_usage ? rowwise_data_shapes[i] : std::vector{}, + columnwise_usage ? columnwise_data_shapes[i] : std::vector{}, fp8_dtype, nullptr, + nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, + columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, + rowwise_usage ? rowwise_scale_shapes[i] : std::vector{}, + columnwise_usage ? columnwise_scale_shapes[i] : std::vector{}, scaling_mode)); + } + + return retval; +} + +std::tuple, std::vector> bulk_allocate_mxfp8_tensors( + std::vector> &shape_list, std::vector &quantizer_py_list, + std::vector &quantizer_cpp_list) { + init_extension(); + std::tuple, std::vector> retval; + auto &tensor_py_list = std::get<0>(retval); + auto &tensor_cpp_list = std::get<1>(retval); + + // Number of tensors + const size_t num_tensors = shape_list.size(); + if (num_tensors == 0) { + return retval; + } + + // Quantization parameters + const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; + const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; + const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); + const auto fp8_dtype = quantizer_cpp_list[0]->dtype; + constexpr size_t fp8_elem_size = 1; + constexpr size_t scale_elem_size = 1; + + // Helper function to construct tensor view + // Note: Deleter holds a shared_ptr for the buffer, so the buffer + // will survive until all views are deleted. + auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + size_t offset, at::ScalarType dtype) -> at::Tensor { + std::vector shape_int64(shape.begin(), shape.end()); + // in the case where full buffer is empty because local rank receives no tokens for all the experts + // then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob + // but in the case where some experts receive tokens, some not, we want to leverage from_blob + // as much as possible to avoid CPU overhead + if (buffer->data_ptr() == nullptr) { + return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); + } + return at::from_blob( + buffer->data_ptr() + offset, shape_int64, + [buffer](void *) {}, // deleter holds shared_ptr + at::device(at::kCUDA).dtype(dtype)); + }; + + // Allocate row-wise data + std::vector rowwise_data_list, rowwise_scale_list; + std::vector> rowwise_data_shapes, rowwise_scale_shapes; + if (rowwise_usage) { + // Tensor sizes + for (size_t i = 0; i < num_tensors; ++i) { + rowwise_data_shapes.emplace_back(shape_list[i]); + rowwise_scale_shapes.emplace_back( + quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); + } + + // Offsets in full buffer + size_t buffer_size = 0; + std::vector data_offsets, scale_offsets; + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 256); // align to 256B + data_offsets.push_back(buffer_size); + buffer_size += product(rowwise_data_shapes[i]) * fp8_elem_size; + } + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 16); // align to 16B + scale_offsets.push_back(buffer_size); + buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size; + } + + // Allocate full buffer + // TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel + auto buffer = std::make_shared( + at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + // auto buffer = std::make_shared( + // at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + + // Construct tensor views + for (size_t i = 0; i < num_tensors; ++i) { + rowwise_data_list.emplace_back( + make_torch_view(buffer, rowwise_data_shapes[i], data_offsets[i], torch::kUInt8)); + rowwise_scale_list.emplace_back( + make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + } + } + + // Allocate column-wise data + std::vector columnwise_data_list, columnwise_scale_list; + std::vector> columnwise_data_shapes, columnwise_scale_shapes; + if (columnwise_usage) { + // Tensor sizes + for (size_t i = 0; i < num_tensors; ++i) { + // For MXFP8, the columnwise data doesn't need transpose + // because of TN, NT, NN layout support in SM100 + columnwise_data_shapes.emplace_back(shape_list[i]); + columnwise_scale_shapes.emplace_back( + quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); + } + + // Offsets in full buffer + size_t buffer_size = 0; + std::vector data_offsets, scale_offsets; + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 256); // align to 256B + data_offsets.push_back(buffer_size); + buffer_size += product(columnwise_data_shapes[i]) * fp8_elem_size; + } + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 16); // align to 16B + scale_offsets.push_back(buffer_size); + buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size; + } + + // Allocate full buffer + // TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel + auto buffer = std::make_shared( + at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + // auto buffer = std::make_shared( + // at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + + // Construct tensor views + for (size_t i = 0; i < num_tensors; ++i) { + columnwise_data_list.emplace_back( + make_torch_view(buffer, columnwise_data_shapes[i], data_offsets[i], torch::kUInt8)); + columnwise_scale_list.emplace_back( + make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + } + } + + // Construct mxfp8 tensors + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); + for (size_t i = 0; i < num_tensors; ++i) { + // Create tensor objects with proper reference counting + py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); + py::object rowwise_scale = rowwise_usage ? py::cast(rowwise_scale_list[i]) : py::none(); + py::object columnwise_data = + (columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none()); + py::object columnwise_scale = + (columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none()); + + // Construct Python tensor + tensor_py_list.emplace_back(MXFP8TensorClass(rowwise_data, rowwise_scale, columnwise_data, + columnwise_scale, fp8_dtype, + quantizer_py_list[i])); + + // Construct C++ tensor + tensor_cpp_list.emplace_back(makeTransformerEngineTensor( + rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, + columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, + rowwise_usage ? rowwise_data_shapes[i] : std::vector{}, + columnwise_usage ? columnwise_data_shapes[i] : std::vector{}, fp8_dtype, nullptr, + nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, + columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, + rowwise_usage ? rowwise_scale_shapes[i] : std::vector{}, + columnwise_usage ? columnwise_scale_shapes[i] : std::vector{}, scaling_mode)); + } + + return retval; +} + +} // namespace + +std::vector split_quantize(const at::Tensor &tensor, + const std::vector &split_sections, + std::vector quantizer_list) { + init_extension(); + + // Check number of tensors + const size_t num_splits = split_sections.size(); + NVTE_CHECK(quantizer_list.size() == num_splits, "Expected ", num_splits, " quantizers, but got ", + quantizer_list.size()); + if (num_splits == 0) { + return {}; + } + + // Input tensor properties + auto input_py = tensor.contiguous(); + uint8_t *input_dptr = reinterpret_cast(input_py.data_ptr()); + auto input_dtype = GetTransformerEngineDType(input_py.scalar_type()); + std::vector input_shape; + size_t input_size = 1; + for (const auto &d : input_py.sizes()) { + input_shape.push_back(d); + input_size *= d; + } + NVTE_CHECK(input_shape.size() > 0, "Input tensor has 0 dims"); + + // Split input tensor along dim 0 + std::vector input_list; + std::vector> split_shapes; + size_t dim0_offset = 0; + const size_t dim0_stride = + input_shape[0] == 0 ? 0 : input_py.element_size() * input_size / input_shape[0]; + for (size_t i = 0; i < num_splits; ++i) { + NVTE_CHECK(split_sections[i] >= 0, "Attempted to split tensor with shape=", input_shape, + " along dim 0 with split_sections=", split_sections); + NVTE_CHECK(dim0_offset + split_sections[i] <= input_shape[0], + "Attempted to split tensor with shape=", input_shape, + " along dim 0 with split_sections=", split_sections); + split_shapes.push_back(input_shape); + auto &split_shape = split_shapes.back(); + split_shape[0] = split_sections[i]; + void *split_dptr = static_cast(input_dptr + dim0_offset * dim0_stride); + input_list.emplace_back(makeTransformerEngineTensor(split_dptr, split_shape, input_dtype)); + dim0_offset += split_sections[i]; + } + + // Convert quantizers to C++ objects + std::vector> quantizer_cpp_list; + for (size_t i = 0; i < num_splits; i++) { + quantizer_cpp_list.push_back(convert_quantizer(quantizer_list[i])); + } + + // For FP8 block-scaling, we construct output tensors with bulk allocations + // For MXFP8, we also use bulk allocations + bool use_fused_bulk_alloc = true; + for (size_t i = 0; i < quantizer_list.size(); i++) { + if (!detail::IsFloat8BlockwiseQuantizers(quantizer_list[i].ptr()) && + !detail::IsMXFP8Quantizers(quantizer_list[i].ptr())) { + use_fused_bulk_alloc = false; + break; + } + } + + // Allocate output tensors + std::vector output_cpp_list; + std::vector output_py_list; + if (!use_fused_bulk_alloc) { + // Allocate output tensors individually + for (size_t i = 0; i < num_splits; ++i) { + auto [output_cpp, output_py] = + quantizer_cpp_list[i]->create_tensor(split_shapes[i], input_dtype); + output_cpp_list.emplace_back(std::move(output_cpp)); + output_py_list.emplace_back(std::move(output_py)); + } + } else { + // TODO(zhongbo): make a better api to make this part less hacky + bool is_fp8_blockwise = detail::IsFloat8BlockwiseQuantizers(quantizer_list[0].ptr()); + bool is_mxfp8 = detail::IsMXFP8Quantizers(quantizer_list[0].ptr()); + if (is_fp8_blockwise) { + // FP8 block-scaling: construct output tensors with bulk allocations + std::vector blockwise_quantizers; + for (auto &quantizer : quantizer_cpp_list) { + blockwise_quantizers.push_back(static_cast(quantizer.get())); + } + std::tie(output_py_list, output_cpp_list) = + bulk_allocate_fp8_blockwise_tensors(split_shapes, quantizer_list, blockwise_quantizers); + } else if (is_mxfp8) { + // MXFP8: construct output tensors with bulk allocations + std::vector mxfp8_quantizers; + for (auto &quantizer : quantizer_cpp_list) { + mxfp8_quantizers.push_back(static_cast(quantizer.get())); + } + std::tie(output_py_list, output_cpp_list) = + bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers); + } else { + NVTE_CHECK(false, "Expected either FP8 block-scaling or MXFP8 quantizer"); + } + } + + // Perform multi-tensor quantization + multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list); + + return output_py_list; +} + template -std::vector dbias_dact(const at::Tensor& grad_output, const at::Tensor& act_input, +std::vector dbias_dact(const at::Tensor &grad_output, const at::Tensor &act_input, py::handle quantizer) { init_extension(); auto my_quantizer = convert_quantizer(quantizer); @@ -122,7 +662,7 @@ std::vector dbias_dact(const at::Tensor& grad_output, const at::Tens auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_tensor.dtype()); auto act_input_tensor = makeTransformerEngineTensor(act_input); - const auto& shape = convertShape(grad_tensor.shape()); + const auto &shape = convertShape(grad_tensor.shape()); auto [dact_tensor, dact] = my_quantizer->create_tensor(shape, act_input_tensor.dtype()); auto dbias_tensor = makeTransformerEngineTensor(grad_bias); @@ -146,29 +686,30 @@ std::vector dbias_dact(const at::Tensor& grad_output, const at::Tens return {py::cast(grad_bias), dact}; } -std::vector dbias_dgelu(const at::Tensor& grad_output, const at::Tensor& act_input, +std::vector dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input, py::handle quantizer) { return dbias_dact(grad_output, act_input, quantizer); } -std::vector dbias_dsilu(const at::Tensor& grad_output, const at::Tensor& act_input, +std::vector dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input, py::handle quantizer) { return dbias_dact(grad_output, act_input, quantizer); } -std::vector dbias_drelu(const at::Tensor& grad_output, const at::Tensor& act_input, +std::vector dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input, py::handle quantizer) { return dbias_dact(grad_output, act_input, quantizer); } -std::vector dbias_dqgelu(const at::Tensor& grad_output, const at::Tensor& act_input, +std::vector dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input, py::handle quantizer) { return dbias_dact(grad_output, act_input, quantizer); } -std::vector dbias_dsrelu(const at::Tensor& grad_output, const at::Tensor& act_input, +std::vector dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, py::handle quantizer) { return dbias_dact(grad_output, act_input, quantizer); } -} // namespace transformer_engine::pytorch +} // namespace pytorch +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index af59d544e..26fb35664 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -218,6 +218,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional #include "../common.h" +#include "../extensions.h" #include "common.h" #include "common/util/cuda_runtime.h" #include "common/util/system.h" -#include "extensions.h" #include "pybind.h" #include "transformer_engine/transformer_engine.h" #include "util.h" diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index 2ef19fe7f..78a44778c 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -6,7 +6,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "extensions.h" +#include "../extensions.h" namespace transformer_engine::pytorch { #ifndef USE_ROCM diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp index 22db8f08f..21d3e0574 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "extensions.h" +#include "../../extensions.h" namespace transformer_engine::pytorch { diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp index 3cbf47682..290f70b57 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "extensions.h" +#include "../../extensions.h" namespace transformer_engine::pytorch { diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp index bcd616bd3..1e8eb44d9 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "extensions.h" +#include "../../extensions.h" namespace transformer_engine::pytorch { diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp index 96e8cde83..ba33f04bf 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "extensions.h" +#include "../../extensions.h" namespace transformer_engine::pytorch { diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp index 978838ecd..de3209535 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "extensions.h" +#include "../../extensions.h" namespace transformer_engine::pytorch { diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 23e415c40..88404a2e1 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -4,8 +4,8 @@ * See LICENSE for license information. ************************************************************************/ +#include "../extensions.h" #include "common/util/system.h" -#include "extensions.h" #include "pybind.h" namespace transformer_engine::pytorch { @@ -170,6 +170,9 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe auto my_quantizer_bw = static_cast(my_quantizer.get()); quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); + if (my_quantizer_bw->all_gather_usage) { + quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); + } } NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, @@ -328,6 +331,9 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w auto my_quantizer_bw = static_cast(my_quantizer.get()); quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); + if (my_quantizer_bw->all_gather_usage) { + quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); + } } NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index dd68429ef..d4b64a485 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "extensions.h" +#include "../extensions.h" #include "pybind.h" namespace transformer_engine::pytorch { @@ -81,4 +81,77 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, }); } +void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, + std::vector input_row_list, + std::vector unpadded_input_row_list) { + using namespace transformer_engine; + using namespace transformer_engine::pytorch; + + NVTE_CHECK(input_row_list.size() == unpadded_input_row_list.size(), + "Number of input row list and padded row list must match."); + NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2."); + NVTE_CHECK(output.dim() == 2, "Dimension of output must equal 2."); + + const auto num_tensors = input_row_list.size(); + // Extract properties from PyTorch tensors + std::vector input_dptr_list, output_dptr_list; + std::vector> input_shape_list, output_shape_list; + std::vector input_type_list; + void* d_input_ptr = reinterpret_cast(input.data_ptr()); + void* d_output_ptr = reinterpret_cast(output.data_ptr()); + for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { + input_dptr_list.push_back(d_input_ptr); + output_dptr_list.push_back(d_output_ptr); + + // Move the input pointer to the next split. + char* input_char_ptr = reinterpret_cast(d_input_ptr); + const size_t input_dptr_offset = + input_row_list[tensor_id] * input.size(1) * input.element_size(); + input_char_ptr += input_dptr_offset; + d_input_ptr = reinterpret_cast(input_char_ptr); + + input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); + input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); + + // Move the output pointer to the next split. + char* output_char_ptr = reinterpret_cast(d_output_ptr); + const size_t output_dptr_offset = + unpadded_input_row_list[tensor_id] * output.size(1) * output.element_size(); + output_char_ptr += output_dptr_offset; + d_output_ptr = reinterpret_cast(output_char_ptr); + + output_shape_list.push_back( + {unpadded_input_row_list[tensor_id], static_cast(output.size(1))}); + } + + // Construct TE tensors + std::vector nvte_input_list, nvte_output_list; + std::vector tensor_wrappers; + auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + transformer_engine::DType dtype) -> NVTETensor { + tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype)); + return tensor_wrappers.back().data(); + }; + + std::vector unpadded_num_rows_list; + for (size_t i = 0; i < input_dptr_list.size(); ++i) { + if (input_dptr_list[i] == nullptr || input_row_list[i] == 0) continue; + nvte_input_list.emplace_back( + make_tensor(input_dptr_list[i], input_shape_list[i], input_type_list[i])); + nvte_output_list.emplace_back( + make_tensor(output_dptr_list[i], output_shape_list[i], input_type_list[i])); + unpadded_num_rows_list.emplace_back(unpadded_input_row_list[i]); + } + + // Check tensor lists + NVTE_CHECK(nvte_output_list.size() == nvte_input_list.size(), + "Number of input and output tensors must match"); + NVTE_CHECK(unpadded_num_rows_list.size() == nvte_input_list.size() && + "Number of input and padded row list must match"); + + // Launch TE kernel + nvte_multi_unpadding(nvte_input_list.size(), nvte_input_list.data(), nvte_output_list.data(), + unpadded_num_rows_list.data(), at::cuda::getCurrentCUDAStream()); +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cpp b/transformer_engine/pytorch/csrc/extensions/permutation.cpp index c70d929d0..97cf40085 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cpp @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "extensions.h" +#include "../extensions.h" namespace transformer_engine::pytorch { diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 93a42bcc3..a6f3e8113 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -14,7 +14,9 @@ #include #include -#include +#include +#include +#include #include "../common.h" #include "../extensions.h" @@ -201,10 +203,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma")); m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm"); - m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize, - "Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"), - py::arg("quantizer_list"), py::arg("otype")); - + m.def("multi_tensor_quantize", &transformer_engine::pytorch::multi_tensor_quantize, + "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list")); + m.def("split_quantize", &transformer_engine::pytorch::split_quantize, + "Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"), + py::arg("quantizer_list")); m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, "Grouped GEMM"); m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", @@ -231,6 +234,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("out_dtype"), py::call_guard()); m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding, "Fused Multi-tensor padding", py::call_guard()); + m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding, + "Fused Multi-tensor unpadding", py::call_guard()); // attention kernels m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd, @@ -255,6 +260,32 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward, "Fused Apply RoPE BWD", py::call_guard()); + // fused router + m.def("fused_topk_with_score_function_fwd", + &transformer_engine::pytorch::fused_topk_with_score_function_fwd, py::arg("logits"), + py::arg("topk"), py::arg("use_pre_softmax"), py::arg("num_groups"), py::arg("group_topk"), + py::arg("scaling_factor"), py::arg("score_function"), py::arg("expert_bias"), + "Fused topk softmax fwd"); + m.def("fused_topk_with_score_function_bwd", + &transformer_engine::pytorch::fused_topk_with_score_function_bwd, py::arg("num_tokens"), + py::arg("num_experts"), py::arg("routing_map"), py::arg("intermediate_output"), + py::arg("grad_probs"), py::arg("topk"), py::arg("use_pre_softmax"), + py::arg("scaling_factor"), py::arg("score_function"), "Fused topk softmax bwd"); + m.def("fused_score_for_moe_aux_loss_fwd", + &transformer_engine::pytorch::fused_score_for_moe_aux_loss_fwd, py::arg("logits"), + py::arg("topk"), py::arg("score_function"), "Fused topk softmax fwd"); + m.def("fused_score_for_moe_aux_loss_bwd", + &transformer_engine::pytorch::fused_score_for_moe_aux_loss_bwd, py::arg("num_tokens"), + py::arg("num_experts"), py::arg("intermediate_output"), py::arg("grad_scores"), + py::arg("topk"), py::arg("score_function"), "Fused topk softmax bwd"); + m.def("fused_moe_aux_loss_fwd", &transformer_engine::pytorch::fused_moe_aux_loss_fwd, + py::arg("probs"), py::arg("tokens_per_expert"), py::arg("total_num_tokens"), + py::arg("num_experts"), py::arg("num_rows"), py::arg("num_cols"), py::arg("topk"), + py::arg("coeff"), "Fused aux loss fwd"); + m.def("fused_moe_aux_loss_bwd", &transformer_engine::pytorch::fused_moe_aux_loss_bwd, + py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_rows"), + py::arg("num_cols"), py::arg("grad_aux_loss"), "Fused aux loss bwd"); + // Misc #ifndef USE_ROCM m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version, @@ -262,7 +293,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_cudnn_version", &transformer_engine::pytorch::get_cudnn_version, "Get cuDNN version", py::call_guard()); #endif - m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams); + m.def("get_num_cublas_streams", &nvte_get_num_compute_streams, "Get number of compute streams", + py::call_guard()); // Support THD format for Context Parallel m.def("thd_read_half_tensor", &transformer_engine::pytorch::thd_read_half_tensor, @@ -391,7 +423,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"), py::arg("local_chunk") = false) .def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false, - py::arg("shape") = std::nullopt); + py::arg("shape") = std::nullopt) + .def("get_communication_stream", &CommOverlap::get_communication_stream); py::class_, transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>( @@ -408,7 +441,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), py::arg("local_chunk") = false) .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, - py::arg("shape") = std::nullopt); + py::arg("shape") = std::nullopt) + .def("get_communication_stream", &CommOverlapP2P::get_communication_stream); #else m.def("CommOverlapHelper", &transformer_engine::pytorch::placeholder, "Dummy function for python side annotations"); m.def("CommOverlap", &transformer_engine::pytorch::placeholder, "Dummy function for python side annotations"); diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index eb4d60bd0..3635d4a9c 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -9,8 +9,8 @@ #include -#include "common/common.h" -#include "extensions.h" +#include "../extensions.h" +#include "transformer_engine/transformer_engine.h" namespace transformer_engine::pytorch { @@ -34,30 +34,35 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio const std::string& amax_compute_algo, DType fp8_dtype, float margin) { size_t num_tensors = amax_histories.size(); - std::vector t_amax_histories(num_tensors); - std::vector t_scales(num_tensors); - std::vector te_amax_histories(num_tensors); - std::vector te_scales(num_tensors); + std::vector te_amax_histories; + std::vector te_scales; + te_amax_histories.reserve(num_tensors); + te_scales.reserve(num_tensors); for (size_t i = 0; i < num_tensors; i++) { - t_amax_histories[i].data.dptr = amax_histories[i].data_ptr(); - auto amax_sizes = amax_histories[i].sizes().vec(); - std::vector amax_shape{amax_sizes.begin(), amax_sizes.end()}; - t_amax_histories[i].data.shape = amax_shape; - t_amax_histories[i].data.dtype = DType::kFloat32; - - t_scales[i].data.dptr = scales[i].data_ptr(); - auto scale_sizes = scales[i].sizes().vec(); - std::vector scale_shape{scale_sizes.begin(), scale_sizes.end()}; - t_scales[i].data.shape = scale_shape; - t_scales[i].data.dtype = DType::kFloat32; - - te_amax_histories[i] = reinterpret_cast(&t_amax_histories[i]); - te_scales[i] = reinterpret_cast(&t_scales[i]); + te_amax_histories.push_back(nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING)); + NVTETensor& amax_history = te_amax_histories.back(); + NVTEShape amax_shape = convertTorchShape(amax_histories[i].sizes()); + NVTEBasicTensor amax_history_data = {amax_histories[i].data_ptr(), + static_cast(DType::kFloat32), amax_shape}; + nvte_set_tensor_param(&amax_history, kNVTERowwiseData, &amax_history_data); + + te_scales.push_back(nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING)); + NVTETensor& scale = te_scales.back(); + NVTEShape scale_shape = convertTorchShape(scales[i].sizes()); + NVTEBasicTensor scale_data = {scales[i].data_ptr(), static_cast(DType::kFloat32), + scale_shape}; + nvte_set_tensor_param(&scale, kNVTERowwiseData, &scale_data); } nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( makeTransformerEngineTensor(amax_reduction_buffer).data(), te_amax_histories, te_scales, amax_compute_algo.c_str(), static_cast(fp8_dtype), margin, at::cuda::getCurrentCUDAStream()); + for (auto& t : te_amax_histories) { + nvte_destroy_tensor(t); + } + for (auto& t : te_scales) { + nvte_destroy_tensor(t); + } } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp new file mode 100644 index 000000000..9befe14f8 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -0,0 +1,190 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../extensions.h" +#include "common.h" + +namespace transformer_engine::pytorch { + +static std::map score_function_map = {{"sigmoid", 0}, {"softmax", 1}}; + +std::tuple fused_topk_with_score_function_fwd( + at::Tensor logits, int topk, bool use_pre_softmax, c10::optional num_groups, + c10::optional group_topk, c10::optional scaling_factor, std::string score_function, + c10::optional expert_bias) { + int num_tokens = logits.size(0); + int num_experts = logits.size(1); + // Check if the input is valid + TORCH_CHECK(num_tokens > 0 && num_experts > 0, + "num_tokens and num_experts must be greater than 0"); + // Expert bias only happens at the sigmoid case + if (expert_bias.has_value()) { + TORCH_CHECK(score_function == "sigmoid", + "score_function must be sigmoid when expert_bias is not None"); + } + // Check if the score function is valid + TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid", + "score_function must be softmax or sigmoid for router fusion"); + if (score_function == "sigmoid") { + use_pre_softmax = false; // Pre-softmax only happens at the softmax case + } + + // Reformat the input to make it compatible with the kernel + int group_topk_value = group_topk.has_value() ? group_topk.value() : -1; + int num_groups_value = num_groups.has_value() ? num_groups.value() : -1; + float scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f; + + // Construct the output tensor + at::Tensor probs = + at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA)); + at::Tensor routing_map = + at::empty({num_tokens, num_experts}, at::dtype(at::kBool).device(at::kCUDA)); + // Intermediate output is used to store the output of the softmax/sigmoid function + at::Tensor intermediate_output = + at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA)); + + auto logits_cu = makeTransformerEngineTensor(logits); + auto probs_cu = makeTransformerEngineTensor(probs); + auto routing_map_cu = makeTransformerEngineTensor(routing_map); + auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output); + auto expert_bias_cu = TensorWrapper(); // empty expert_bias_cu tensor + if (expert_bias.has_value()) { + expert_bias_cu = makeTransformerEngineTensor(expert_bias.value()); + } + + nvte_fused_topk_with_score_function_forward( + logits_cu.data(), num_tokens, num_experts, topk, use_pre_softmax, num_groups_value, + group_topk_value, scaling_factor_value, score_function_map[score_function], + expert_bias_cu.data(), probs_cu.data(), routing_map_cu.data(), intermediate_output_cu.data(), + at::cuda::getCurrentCUDAStream()); + + return std::make_tuple(probs, routing_map, intermediate_output); +} + +at::Tensor fused_topk_with_score_function_bwd(int num_tokens, int num_experts, + at::Tensor routing_map, + at::Tensor intermediate_output, at::Tensor grad_probs, + int topk, bool use_pre_softmax, + c10::optional scaling_factor, + std::string score_function) { + // Get the value of the parameters + auto scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f; + auto score_function_value = score_function_map[score_function]; + // Init the output tensor + at::Tensor grad_logits = at::empty( + {num_tokens, num_experts}, at::dtype(intermediate_output.scalar_type()).device(at::kCUDA)); + + auto routing_map_cu = makeTransformerEngineTensor(routing_map); + auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output); + auto grad_probs_cu = makeTransformerEngineTensor(grad_probs); + auto grad_logits_cu = makeTransformerEngineTensor(grad_logits); + + nvte_fused_topk_with_score_function_backward( + routing_map_cu.data(), intermediate_output_cu.data(), grad_probs_cu.data(), num_tokens, + num_experts, topk, use_pre_softmax, scaling_factor_value, score_function_value, + grad_logits_cu.data(), at::cuda::getCurrentCUDAStream()); + + return grad_logits; +} + +std::tuple fused_score_for_moe_aux_loss_fwd( + at::Tensor logits, int topk, std::string score_function) { + int num_tokens = logits.size(0); + int num_experts = logits.size(1); + // Check if the input is valid + TORCH_CHECK(num_tokens > 0 && num_experts > 0, + "num_tokens and num_experts must be greater than 0"); + TORCH_CHECK(topk > 0, "topk must be greater than 0"); + // Check if the score function is valid + TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid", + "score_function must be softmax or sigmoid for router fusion"); + int score_function_value = score_function_map[score_function]; + + // Construct the output tensor + at::Tensor scores = + at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA)); + at::Tensor routing_map = + at::empty({num_tokens, num_experts}, at::dtype(at::kBool).device(at::kCUDA)); + at::Tensor intermediate_output = + at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA)); + + auto logits_cu = makeTransformerEngineTensor(logits); + auto scores_cu = makeTransformerEngineTensor(scores); + auto routing_map_cu = makeTransformerEngineTensor(routing_map); + auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output); + + nvte_fused_score_for_moe_aux_loss_forward( + logits_cu.data(), num_tokens, num_experts, topk, score_function_value, scores_cu.data(), + routing_map_cu.data(), intermediate_output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return std::make_tuple(scores, routing_map, intermediate_output); +} + +at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, + at::Tensor intermediate_output, at::Tensor grad_scores, + int topk, std::string score_function) { + // Get the value of the parameters + int score_function_value = score_function_map[score_function]; + // Init the output tensor + at::Tensor grad_logits = at::empty( + {num_tokens, num_experts}, at::dtype(intermediate_output.scalar_type()).device(at::kCUDA)); + + auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output); + auto grad_scores_cu = makeTransformerEngineTensor(grad_scores); + auto grad_logits_cu = makeTransformerEngineTensor(grad_logits); + + nvte_fused_score_for_moe_aux_loss_backward( + intermediate_output_cu.data(), grad_scores_cu.data(), num_tokens, num_experts, topk, + score_function_value, grad_logits_cu.data(), at::cuda::getCurrentCUDAStream()); + + return grad_logits; +} + +std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, + at::Tensor tokens_per_expert, + int total_num_tokens, int num_experts, + int num_rows, int num_cols, int topk, + float coeff) { + TORCH_CHECK(topk > 0, "topk must be greater than 0"); + TORCH_CHECK(total_num_tokens > 0, "total_num_tokens must be greater than 0"); + TORCH_CHECK(num_experts > 0, "num_experts must be greater than 0"); + + // Create the output tensor + at::Tensor aux_loss = at::empty({}, at::dtype(probs.scalar_type()).device(at::kCUDA)); + at::Tensor Const_buf = at::empty({}, at::dtype(at::kFloat).device(at::kCUDA)); + + auto probs_cu = makeTransformerEngineTensor(probs); + auto tokens_per_expert_cu = makeTransformerEngineTensor(tokens_per_expert); + auto aux_loss_cu = makeTransformerEngineTensor(aux_loss); + auto Const_buf_cu = makeTransformerEngineTensor(Const_buf); + + nvte_fused_moe_aux_loss_forward(probs_cu.data(), tokens_per_expert_cu.data(), total_num_tokens, + num_experts, num_rows, num_cols, topk, coeff, aux_loss_cu.data(), + Const_buf_cu.data(), at::cuda::getCurrentCUDAStream()); + + return std::make_tuple(aux_loss, Const_buf); +} + +at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, int num_rows, + int num_cols, at::Tensor grad_aux_loss) { + // Create the output tensor + at::Tensor grad_probs = + at::empty({num_rows, num_cols}, at::dtype(grad_aux_loss.scalar_type()).device(at::kCUDA)); + + auto Const_buf_cu = makeTransformerEngineTensor(Const_buf); + auto tokens_per_expert_cu = makeTransformerEngineTensor(tokens_per_expert); + auto grad_aux_loss_cu = makeTransformerEngineTensor(grad_aux_loss); + auto grad_probs_cu = makeTransformerEngineTensor(grad_probs); + + // Meta data for the kernel + nvte_fused_moe_aux_loss_backward(Const_buf_cu.data(), tokens_per_expert_cu.data(), num_rows, + num_cols, grad_aux_loss_cu.data(), grad_probs_cu.data(), + at::cuda::getCurrentCUDAStream()); + + return grad_probs; +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/softmax.cpp b/transformer_engine/pytorch/csrc/extensions/softmax.cpp index 0baa1d6e7..2e0e482eb 100644 --- a/transformer_engine/pytorch/csrc/extensions/softmax.cpp +++ b/transformer_engine/pytorch/csrc/extensions/softmax.cpp @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "extensions.h" +#include "../extensions.h" namespace transformer_engine::pytorch { diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index e4465fece..d2f7107fe 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -4,80 +4,16 @@ * See LICENSE for license information. ************************************************************************/ +#include + #include +#include -#include "extensions.h" +#include "../extensions.h" #include "pybind.h" -namespace transformer_engine::pytorch { - -std::vector fused_multi_quantize(std::vector input_list, - std::optional> output_list, - std::vector quantizer_list, DType otype) { - init_extension(); - std::vector nvte_tensor_input_list; - std::vector nvte_tensor_output_list; - std::vector py_output_objects_list; - std::vector tensor_wrappers; - if (output_list.has_value()) { - py_output_objects_list = output_list.value(); - } - - // Choose implementation - // Note: Currently only have fused kernel for FP8 cast-transpose - bool with_fused_kernel = true; - - // create TE tensors from input - for (size_t i = 0; i < input_list.size(); i++) { - auto input_tensor = makeTransformerEngineTensor(input_list[i]); - const NVTEShape input_shape = input_tensor.shape(); - - TensorWrapper output_tensor; - - if (!detail::IsFloat8Quantizers(quantizer_list[i].ptr())) { - with_fused_kernel = false; - } - if (output_list == std::nullopt) { - std::unique_ptr quantizer = convert_quantizer(quantizer_list[i]); - std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); - py::object o; - std::tie(output_tensor, o) = quantizer->create_tensor(output_shape, otype); - py_output_objects_list.push_back(o); - } else { - output_tensor = makeTransformerEngineTensor((*output_list)[i], quantizer_list[i]); - } - if (input_tensor.numel() == 0) continue; - - nvte_tensor_output_list.emplace_back(output_tensor.data()); - nvte_tensor_input_list.emplace_back(input_tensor.data()); - tensor_wrappers.emplace_back(std::move(input_tensor)); - tensor_wrappers.emplace_back(std::move(output_tensor)); - } - - // Check tensor lists - NVTE_CHECK(nvte_tensor_output_list.size() == nvte_tensor_input_list.size(), - "Number of input and output tensors must match"); - - for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { - if (nvte_tensor_columnwise_data(nvte_tensor_output_list[i]) == nullptr) { - with_fused_kernel = false; - break; - } - } - - // Launch TE kernel - if (with_fused_kernel) { - NVTE_SCOPED_GIL_RELEASE({ - nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(), - nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); - }); - } else { - for (size_t i = 0; i < py_output_objects_list.size(); i++) { - quantize(input_list[i], quantizer_list[i], py_output_objects_list[i], std::nullopt); - } - } - return py_output_objects_list; -} +namespace transformer_engine { +namespace pytorch { at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional output) { init_extension(); @@ -108,4 +44,5 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optionalamax_epsilon = quantizer.attr("amax_epsilon").cast(); NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2, "Unsupported block scaling dim."); + this->all_gather_usage = quantizer.attr("all_gather_usage").cast(); } void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const { @@ -282,10 +283,8 @@ std::pair Float8BlockQuantizer::create_tensor( const std::vector& shape, DType dtype, std::optional rowwise_data) const { using namespace pybind11::literals; std::vector torch_shape; - size_t numel = 1; for (auto s : shape) { torch_shape.emplace_back(static_cast(s)); - numel *= s; } TensorWrapper tensor(this->get_scaling_mode()); @@ -295,9 +294,9 @@ std::pair Float8BlockQuantizer::create_tensor( opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); - size_t k_dim = torch_shape.size() == 0 ? 1u : torch_shape.back(); - size_t m_dim = numel / k_dim; - constexpr size_t kBlockLen = 128; + Float8BlockScaleTensorFormat data_format = + (all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT + : Float8BlockScaleTensorFormat::GEMM_READY); if (rowwise_usage) { if (rowwise_data.has_value()) { @@ -305,20 +304,9 @@ std::pair Float8BlockQuantizer::create_tensor( } else { data_rowwise = at::empty(torch_shape, opts); } - size_t sinv0 = 0; - size_t sinv1 = 0; - if (block_scaling_dim == 2) { - sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; - sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4); - } else if (block_scaling_dim == 1) { - sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; - sinv1 = roundup(m_dim, 4); - } else { - NVTE_CHECK(false, - "Unsupported block_scaling_dim in create_tensor rowwise." - "Expected 1 or 2. Got ", - block_scaling_dim); - } + auto scale_shape = get_scale_shape(shape, false); + size_t sinv0 = scale_shape[0]; + size_t sinv1 = scale_shape[1]; scale_inv_rowwise = at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape); @@ -332,29 +320,26 @@ std::pair Float8BlockQuantizer::create_tensor( NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ", columnwise_shape, " torch shape: ", torch_columnwise_shape); if (torch_shape.size() > 0) { - torch_columnwise_shape.reserve(torch_shape.size()); - columnwise_shape.reserve(shape.size()); - torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); - columnwise_shape.push_back(shape[shape.size() - 1]); - for (size_t i = 0; i < torch_shape.size() - 1; ++i) { - torch_columnwise_shape.push_back(torch_shape[i]); - columnwise_shape.push_back(shape[i]); + if (!all_gather_usage) { + torch_columnwise_shape.reserve(torch_shape.size()); + columnwise_shape.reserve(shape.size()); + torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); + columnwise_shape.push_back(shape[shape.size() - 1]); + for (size_t i = 0; i < torch_shape.size() - 1; ++i) { + torch_columnwise_shape.push_back(torch_shape[i]); + columnwise_shape.push_back(shape[i]); + } + } else { + // assert we are doing 1D scaling + NVTE_CHECK(block_scaling_dim == 1, + "Compact columnwise format is not supported for 128x128 2D block scaling."); + torch_columnwise_shape = torch_shape; + columnwise_shape = shape; } } - size_t sinv0 = 0; - size_t sinv1 = 0; - if (block_scaling_dim == 2) { - sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; - sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4); - } else if (block_scaling_dim == 1) { - sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; - sinv1 = roundup(k_dim, 4); - } else { - NVTE_CHECK(false, - "Unsupported block_scaling_dim in create_tensor columnwise." - "Expected 1 or 2. Got ", - block_scaling_dim); - } + auto scale_shape = get_scale_shape(shape, true); + size_t sinv0 = scale_shape[0]; + size_t sinv1 = scale_shape[1]; data_colwise = at::empty(torch_columnwise_shape, opts); scale_inv_colwise = at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); @@ -373,7 +358,7 @@ std::pair Float8BlockQuantizer::create_tensor( "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer, - "is_2D_scaled"_a = (block_scaling_dim == 2)); + "is_2D_scaled"_a = (block_scaling_dim == 2), "data_format"_a = data_format); } else { py::handle Float8BlockwiseQTensorClass( reinterpret_cast(Float8BlockwiseQTensorPythonClass)); @@ -381,12 +366,88 @@ std::pair Float8BlockQuantizer::create_tensor( "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2)); + "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2), + "data_format"_a = data_format); } return {std::move(tensor), std::move(ret)}; } +std::vector Float8BlockQuantizer::get_scale_shape(const std::vector& shape, + bool columnwise) const { + size_t numel = 1; + for (auto s : shape) { + numel *= s; + } + + size_t k_dim = shape.size() == 0 ? 1u : shape.back(); + size_t m_dim = numel / k_dim; + constexpr size_t kBlockLen = 128; + + Float8BlockScaleTensorFormat data_format = + (all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT + : Float8BlockScaleTensorFormat::GEMM_READY); + + std::vector scale_shape; + + bool rowwise_usage = !columnwise; + + if (rowwise_usage) { + // rowwise scaling factor shape + size_t sinv0 = 0; + size_t sinv1 = 0; + if (block_scaling_dim == 2) { + // 2D scaling is always GEMM_READY for now + NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY, + "2D scaling is always GEMM_READY for now."); + sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; + sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4); + } else if (block_scaling_dim == 1) { + // 1D scaling can be GEMM_READY or COMPACT + bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT; + // default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY + sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; + sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4); + // if the rowwise format is compact, the scaling factor is not be transposed + if (rowwise_compact) { + std::swap(sinv0, sinv1); + } + } else { + NVTE_CHECK(false, + "Unsupported block_scaling_dim in create_tensor rowwise." + "Expected 1 or 2. Got ", + block_scaling_dim); + } + scale_shape = {sinv0, sinv1}; + } else { + // columnwise scaling factor shape + size_t sinv0 = 0; + size_t sinv1 = 0; + if (block_scaling_dim == 2) { + // 2D scaling is always GEMM_READY for now + NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY, + "2D scaling is always GEMM_READY for now."); + sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; + sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4); + } else if (block_scaling_dim == 1) { + // 1D scaling can be GEMM_READY or COMPACT + bool columnwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT; + sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; + sinv1 = columnwise_compact ? k_dim : roundup(k_dim, 4); + // GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS + // for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1] + // so no need to swap sinv0 and sinv1 here + } else { + NVTE_CHECK(false, + "Unsupported block_scaling_dim in create_tensor columnwise." + "Expected 1 or 2. Got ", + block_scaling_dim); + } + scale_shape = {sinv0, sinv1}; + } + return scale_shape; +} + MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); } @@ -419,11 +480,6 @@ std::pair MXFP8Quantizer::create_tensor( at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv, columnwise_scale_inv; // TODO(pgadzinski) - change opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - auto last_dim = static_cast(torch_shape.back()); - - NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, - "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, - " (got shape=", torch_shape, ")"); at::Tensor data; if (rowwise_usage) { @@ -432,9 +488,10 @@ std::pair MXFP8Quantizer::create_tensor( } else { data = at::empty(torch_shape, opts); } - auto sinv0 = roundup(numel / last_dim, 128); - auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); - rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts); + auto scale_shape = get_scale_shape(shape, false); + size_t sinv0 = scale_shape[0]; + size_t sinv1 = scale_shape[1]; + rowwise_scale_inv = at::zeros({static_cast(sinv0), static_cast(sinv1)}, opts); tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); tensor.set_rowwise_scale_inv( rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, @@ -442,10 +499,12 @@ std::pair MXFP8Quantizer::create_tensor( } if (columnwise_usage) { - auto sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); - auto sinv1 = roundup(last_dim, 128); + auto scale_shape = get_scale_shape(shape, true); + size_t sinv0 = scale_shape[0]; + size_t sinv1 = scale_shape[1]; columnwise_data = at::empty(torch_shape, opts); - columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts); + columnwise_scale_inv = + at::zeros({static_cast(sinv0), static_cast(sinv1)}, opts); tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape); tensor.set_columnwise_scale_inv( @@ -473,4 +532,35 @@ std::pair MXFP8Quantizer::create_tensor( return {std::move(tensor), std::move(ret)}; } +std::vector MXFP8Quantizer::get_scale_shape(const std::vector& shape, + bool columnwise) const { + size_t numel = 1; + for (auto s : shape) { + numel *= s; + } + + auto last_dim = shape.back(); + + NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, + "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, + " (got shape=", shape, ")"); + + std::vector scale_shape; + + bool rowwise_usage = !columnwise; + + if (rowwise_usage) { + // rowwise scaling factor shape + size_t sinv0 = roundup(numel / last_dim, 128); + size_t sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); + scale_shape = {sinv0, sinv1}; + } else { + // columnwise scaling factor shape + size_t sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); + size_t sinv1 = roundup(last_dim, 128); + scale_shape = {sinv0, sinv1}; + } + return scale_shape; +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index ea601397a..c3ec514a5 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -10,6 +10,7 @@ from collections.abc import Iterable from contextlib import contextmanager, AbstractContextManager, ContextDecorator from functools import lru_cache +from dataclasses import dataclass import math from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings @@ -21,6 +22,15 @@ from torch.distributed.fsdp._common_utils import _get_module_fsdp_state from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules +try: + import torch.distributed._symmetric_memory as symm_mem + + HAS_TORCH_SYMMETRIC = True +except ImportError: + HAS_TORCH_SYMMETRIC = False + +import transformer_engine_torch as tex + from . import torch_version from .utils import ( is_non_tn_fp8_gemm_supported, @@ -36,14 +46,8 @@ from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase -from ..debug.pytorch.debug_quantization import DebugQuantizedTensor - -try: - import torch.distributed._symmetric_memory as symm_mem +from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer - HAS_TORCH_SYMMETRIC = True -except ImportError: - HAS_TORCH_SYMMETRIC = False __all__ = ["checkpoint", "CudaRNGStatesTracker"] @@ -947,7 +951,7 @@ def _all_gather_fp8( out = quantizer.make_empty(out_shape, dtype=dtype, device=device) elif isinstance(inp, Float8Tensor): out = inp.make_like(inp, shape=out_shape) - out._data = torch.empty_like( + out._data = torch.empty( out_shape, dtype=torch.uint8, device=inp.device, @@ -981,6 +985,67 @@ def _all_gather_fp8( return out, handle +def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None: + """Make quantizer compact""" + _quantizer = quantizer + if isinstance(quantizer, DebugQuantizer): + _quantizer = quantizer.parent_quantizer + if isinstance(_quantizer, Float8BlockQuantizer): + _quantizer.all_gather_usage = compact + + +def _post_process_fp8_blockwise_gather( + out: Float8BlockwiseQTensorBase, + quantizer: Float8BlockQuantizer, + handle: Optional[torch.distributed.Work] = None, +) -> Float8BlockwiseQTensorBase: + """Post-process FP8 blockwise gather.""" + if handle is not None: + handle.wait() + handle = None + + if out._is_gemm_ready_format(): + return out + + needs_columnwise_data_transpose = ( + quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported() + ) + need_rowwise_scale_transpose = ( + quantizer is not None and quantizer.rowwise_usage and not is_non_tn_fp8_gemm_supported() + ) + + # CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024 + # columnwise compact format means doing 128x1 quantization of it + # so quantized tensor is 256x1024, scale inv is 2x1024 + # If we were doing GEMM_READY format, then it's equivalent to do 1x128 quantization + # on a transposed 1024x256 tensor, so scale inv is 1024x2, cublas requries 2x1024 + # Thereforce, it turns out we don't need to transpose the scale inv, only columnwise data + if needs_columnwise_data_transpose: + out._transpose_columnwise_data() + if need_rowwise_scale_transpose: + out._rowwise_scale_inv = out._rowwise_scale_inv.transpose(-2, -1).contiguous() + out._data_format = tex.Float8BlockScaleTensorFormat.GEMM_READY + return out + + +@dataclass +class _FP8BlockwiseAllGatherAsyncHandle: + """Handle for asynchronous FP8 blockwise all-gather.""" + + tensor: Float8BlockwiseQTensorBase + quantizer: Float8BlockQuantizer + async_handle: torch.distributed.Work + _synchronized: bool = False + + def wait(self) -> None: + """Wait for the async operation to complete and post-process the tensor.""" + if self._synchronized: + return + self.async_handle.wait() + _post_process_fp8_blockwise_gather(self.tensor, self.quantizer) + self._synchronized = True + + def _all_gather_fp8_blockwise( inp: torch.Tensor, process_group: dist_group_type, @@ -994,8 +1059,9 @@ def _all_gather_fp8_blockwise( Returns: quantizer(gather(inp)) - NOTE: The implementation is not sophisticated enough to honor async_op=True. - In some cases it falls back to synchronous gather and invokes the quantizer. + NOTE: The implementation is only going to honor async_op=True for FP8 gather case. + In the case where tensor shape is not divisible by 128, the implementation will fall back + to synchronous gather and invoke the quantizer. """ # Input tensor attributes @@ -1031,7 +1097,11 @@ def _all_gather_fp8_blockwise( out_shape[0] *= world_size # Doing BF16 gather for now as baseline because it's simpler - if not isinstance(inp, Float8BlockwiseQTensorBase) and quantizer is not None: + if ( + not isinstance(inp, Float8BlockwiseQTensorBase) + and quantizer is not None + and not quantizer.is_quantizable(inp) + ): out = torch.empty( out_shape, dtype=dtype, @@ -1039,14 +1109,93 @@ def _all_gather_fp8_blockwise( memory_format=torch.contiguous_format, ) torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) + orig_all_gather_usage = quantizer.all_gather_usage + quantizer.all_gather_usage = False out = quantizer(out) + quantizer.all_gather_usage = orig_all_gather_usage return out, None + # Implementation of fp8 gather needs to account for: # * Getting columnwise data as a transpose of how it is stored for GEMMS. # * Gathering non GEMM swizzled scales. - # * Refer to scaffold code when implementing at: - # https://github.com/kwyss-nvidia/TransformerEngine/commit/6659ee9dc84fb515d1d47699d8bfd20a72b76477 - raise NotImplementedError("fp8 blockwise allgather not yet implemented") + + # Cast input tensor to Float8BlockwiseQTensor with required data + # Set to compact usage in case the quantizer is not correctly configured + orig_all_gather_usage = quantizer.all_gather_usage + quantizer.all_gather_usage = True + if not isinstance(inp, Float8BlockwiseQTensorBase): + inp = quantizer(inp) + elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( + quantizer.columnwise_usage and inp._columnwise_data is None + ): + warnings.warn( + "Input and quantizer do not have matching usages. " + "Dequantizing and requantizing to Float8BlockwiseQTensor." + ) + inp = quantizer(inp.dequantize()) + quantizer.all_gather_usage = orig_all_gather_usage + + # Begin to do network communication, need to make sure compact format + if inp._data_format != tex.Float8BlockScaleTensorFormat.COMPACT: + raise RuntimeError( + "All-gather with FP8 block-wise quantized tensor requires compact data format, " + f"but found data_format={inp._data_format}" + ) + + # Construct Float8BlockwiseQTensor output tensor + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + + # Coalesce NCCL collectives + with torch.distributed._coalescing_manager( + group=process_group, + device=device, + async_ops=async_op, + ) as coalescing_manager: + + # Gather Float8BlockwiseQTensor data for row-wise usage + if quantizer.rowwise_usage: + # Launch all-gathers + torch.distributed.all_gather_into_tensor( + out._rowwise_scale_inv, + inp._rowwise_scale_inv, + group=process_group, + ) + torch.distributed.all_gather_into_tensor( + out._rowwise_data, + inp._rowwise_data, + group=process_group, + ) + + # Gather Float8BlockwiseQTensor data for column-wise usage + if quantizer.columnwise_usage: + # Launch all-gathers + torch.distributed.all_gather_into_tensor( + out._columnwise_scale_inv, + inp._columnwise_scale_inv, + group=process_group, + ) + torch.distributed.all_gather_into_tensor( + out._columnwise_data, + inp._columnwise_data, + group=process_group, + ) + + handle = coalescing_manager if async_op else None + + # Unlike MXFP8, this fp8 blockwise tensor primarily works with Hopper + # This means that we need to transpose the gathered columnwise data + # Example usage is grad_output tensor, ie. dY in linear backward + # We want to gather two FP8 tensors (rowwise and columnwise) along dim0 + # and then transpose the columnwise data to match the rowwise data + # Make sure FP8 transpose is populated if needed + + if async_op: + handle = _FP8BlockwiseAllGatherAsyncHandle(out, quantizer, handle) + else: + # if it's a sync op, we need to do the transpose here as post processing step + _post_process_fp8_blockwise_gather(out, quantizer, handle) + + return out, handle def _all_gather_mxfp8( @@ -1243,12 +1392,18 @@ def gather_along_first_dim( final_quantizer = ( None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer ) + # Temporary fix for TP communication of Float8BlockwiseQTensorBase + if isinstance(rowwise, Float8BlockwiseQTensorBase): + rowwise = inp._original_tensor rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0] out_obj.rowwise_gemm_tensor = rowwise_total if rowwise is not columnwise: final_quantizer_columnwise = ( None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer ) + # Temporary fix for TP communication of Float8BlockwiseQTensorBase + if isinstance(columnwise, Float8BlockwiseQTensorBase): + columnwise = inp._original_tensor columnwise_total, _ = gather_along_first_dim( columnwise, process_group, False, final_quantizer_columnwise ) @@ -1265,6 +1420,9 @@ def gather_along_first_dim( ) if isinstance(inp, QuantizedTensor): inp = inp.dequantize() + # Falling back to high-precision all-gather for Float8BlockQuantizer + # means that it should directly output GEMM_READY format + _set_quantizer_format(quantizer, compact=False) out = torch.empty( out_shape, dtype=inp.dtype, diff --git a/transformer_engine/pytorch/export.py b/transformer_engine/pytorch/export.py new file mode 100644 index 000000000..f75271e2c --- /dev/null +++ b/transformer_engine/pytorch/export.py @@ -0,0 +1,71 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Export utilities for TransformerEngine""" + +from contextlib import contextmanager +from typing import Generator +import torch + + +_IN_ONNX_EXPORT_MODE = False +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) + + +@contextmanager +def onnx_export(enabled: bool = False) -> Generator[None, None, None]: + """ + Context manager for exporting to ONNX. + + .. code-block:: python + + from transformer_engine.pytorch.export import onnx_export, te_translation_table + + with onnx_export(enabled=True): + torch.onnx.export(model, dynamo=True, custom_translation_table=te_translation_table) + + Parameters + ---------- + enabled: bool, default = `False` + whether or not to enable export + """ + + global _IN_ONNX_EXPORT_MODE + onnx_export_state = _IN_ONNX_EXPORT_MODE + if (TORCH_MAJOR, TORCH_MINOR) < (2, 4): + raise RuntimeError("ONNX export is not supported for PyTorch versions less than 2.4") + try: + _IN_ONNX_EXPORT_MODE = enabled + yield + finally: + _IN_ONNX_EXPORT_MODE = onnx_export_state + + +def is_in_onnx_export_mode() -> bool: + """Returns True if onnx export mode is enabled, False otherwise.""" + return _IN_ONNX_EXPORT_MODE + + +def assert_warmed_up(module: torch.nn.Module) -> None: + """Assert that the model has been warmed up before exporting to ONNX.""" + assert hasattr(module, "forwarded_at_least_once"), ( + "Model must be warmed up before exporting to ONNX, please run model with the" + " same recipe before exporting." + ) + + +if TORCH_MAJOR == 2 and TORCH_MINOR >= 4 or TORCH_MAJOR > 2: + # pylint: disable=unused-import + from .onnx_extensions import ( + torch_onnx_gemm_inf_op, + onnx_quantize_fp8_op, + onnx_dequantize_fp8_op, + onnx_quantize_mxfp8_op, + onnx_dequantize_mxfp8_op, + onnx_layernorm, + onnx_attention_mask_func, + onnx_gemm, + te_translation_table, + ) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index c76a40807..55280aa82 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -60,6 +60,8 @@ def check_mxfp8_support() -> Tuple[bool, str]: if gpu_arch == (9, 5): return True, "" return False, "Gfx95x is required for MXFP8 execution." + if get_device_compute_capability() >= (12, 0): + return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." if get_device_compute_capability() >= (10, 0): # blackwell and above return True, "" return False, "Device compute capability 10.0 or higher required for MXFP8 execution." @@ -87,7 +89,11 @@ def get_default_fp8_recipe() -> Recipe: if gpu_arch == (9, 5): return MXFP8BlockScaling() return DelayedScaling() - if get_device_compute_capability() >= (10, 0): # blackwell and above + if check_mxfp8_support()[0]: + # This is a temporary restriction until MXFP8 is supported for all + # gemm layouts. + if get_device_compute_capability() >= (12, 0): + return Float8BlockScaling() return MXFP8BlockScaling() return DelayedScaling() diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 87d50ccc9..b4861286e 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -21,6 +21,7 @@ from .distributed import get_all_rng_states, graph_safe_rng_available from .module.base import TransformerEngineBaseModule from .ops.op import BasicOperation +from .utils import make_weak_ref __all__ = ["make_graphed_callables"] @@ -63,8 +64,10 @@ def _make_graphed_callables( fp8_weight_caching: bool = False, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, _order: Optional[List[int]] = None, + _num_layers_per_chunk: Optional[List[int]] = None, pool: Optional[Tuple[int, ...]] = None, retain_graph_in_backward: bool = False, + _reuse_graph_input_output_buffers: bool = False, ) -> SingleOrTuple[Callable]: """ Helper method for `make_graphed_callables` @@ -110,29 +113,113 @@ def _make_graphed_callables( # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py. # Note: The model is assumed to consist of layers # (corresponding to callables) that are grouped into - # equally-sized model chunks. _order is a list of chunk - # indices (1-indexed) that indicates the order in which the - # layers are evaluated. Positive values indicate forward - # passes and negative values indicate backward passes. Each + # model chunks. _num_layers_per_chunk is a list of integers + # that indicates the number of layers in each model chunk. + # _order is a list of chunk indices (1-indexed) that + # indicates the order in which the layers are evaluated. + # Positive values indicate forward passes and negative + # values indicate backward passes. Each # entry in sample_args corresponds to one of the forward # passes. num_model_chunks = max(_order) num_microbatches = len(_order) // num_model_chunks // 2 assert num_model_chunks * num_microbatches * 2 == len(_order) - assert len(sample_args) * 2 >= len(_order) and ( - len(sample_args) * 2 % len(_order) == 0 - ), f"{len(sample_args)} >= {len(_order)} and {len(sample_args)} % {len(_order)} == 0" - num_layers = len(sample_args) // num_model_chunks // num_microbatches - assert len(callables) == num_model_chunks * num_layers, ( - f"Callables should have ({num_model_chunks * num_layers}) " + + # Determine number of layers in each model chunk. + if _num_layers_per_chunk is None: + assert len(sample_args) * 2 >= len(_order) and ( + len(sample_args) * 2 % len(_order) == 0 + ), ( + f"{len(sample_args)} * 2 >= {len(_order)} and {len(sample_args)} * 2 %" + f" {len(_order)} == 0" + ) + num_layers = len(sample_args) // num_model_chunks // num_microbatches + _num_layers_per_chunk = [num_layers] * num_model_chunks + else: + assert ( + isinstance(_num_layers_per_chunk, int) + or len(_num_layers_per_chunk) == num_model_chunks + ), ( + "If _num_layers_per_chunk is provided, it must be an integer or a list of" + f" {num_model_chunks} integers, but got {_num_layers_per_chunk}." + ) + if isinstance(_num_layers_per_chunk, int): + _num_layers_per_chunk = [_num_layers_per_chunk] * num_model_chunks + total_num_layers = sum(_num_layers_per_chunk) + assert len(callables) == total_num_layers, ( + f"Callables should have ({total_num_layers}) " + f"entries when order input is provided but got {len(callables)}." ) - assert len(sample_args) == num_model_chunks * num_microbatches * num_layers, ( - f"Expected {num_model_chunks * num_microbatches}" + assert len(sample_args) == total_num_layers * num_microbatches, ( + f"Expected {total_num_layers * num_microbatches}" + f"args tuple, but got {len(sample_args)}." ) + + # Calculate the starting index of each chunk in callables for future use. + _prefix_num_layers = [0] + for m_chunk in range(num_model_chunks): + num_layers = _num_layers_per_chunk[m_chunk] + _prefix_num_layers.append(_prefix_num_layers[-1] + num_layers) + assert len(sample_kwargs) == len(sample_args) + # Check reuse graph conditions and reorganize sample_args and sample_kwargs. + # Note: When capturing a graph, we hold onto the args and kwargs so we have static buffers + # when the graph is replayed. If two model chunk microbatches have no overlap between their + # forward and backward, then we can reduce memory usage by reusing the same static buffers. + if _reuse_graph_input_output_buffers: + assert ( + _order is not None + ), "`_order` must be provided when `_reuse_graph_input_output_buffers` is True." + assert ( + is_training + ), "`_reuse_graph_input_output_buffers` is only available in training mode." + assert isinstance( + sample_args, list + ), "sample_args must be a list for _reuse_graph_input_output_buffers." + len_args = len(sample_args[0]) + for i, arg in enumerate(sample_args): + assert len_args == len( + arg + ), "Arguments must have same length and shape for `_reuse_graph_input_output_buffers`." + len_kwargs = len(sample_kwargs[0]) + assert isinstance( + sample_kwargs, list + ), "sample_kwargs must be a list for _reuse_graph_input_output_buffers." + for i, kwarg in enumerate(sample_kwargs): + assert len_kwargs == len(kwarg), ( + "Keyword arguments must have same length and shape for" + " `_reuse_graph_input_output_buffers`." + ) + + # Reorganize args and kwargs for input tensor reuse. + fwd_sample_qs = {} + consumed_sample_q = [] + fwd_idx = [0] * num_model_chunks + for c_id in _order: + m_chunk = abs(c_id) - 1 + + if c_id > 0: + sample_start_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( + fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + ) + fwd_sample_idx = [ + sample_start_idx + i for i in range(_num_layers_per_chunk[m_chunk]) + ] + fwd_sample_qs[m_chunk] = fwd_sample_qs.get(m_chunk, []) + fwd_sample_idx + for per_callable_fwd_idx in fwd_sample_idx: + if consumed_sample_q: + reuse_fwd_idx = consumed_sample_q.pop(0) + sample_args[per_callable_fwd_idx] = sample_args[reuse_fwd_idx] + sample_kwargs[per_callable_fwd_idx] = sample_kwargs[reuse_fwd_idx] + fwd_idx[m_chunk] += 1 + else: + num_consumed_samples = min( + len(fwd_sample_qs[m_chunk]), _num_layers_per_chunk[m_chunk] + ) + consumed_sample_q += fwd_sample_qs[m_chunk][:num_consumed_samples] + fwd_sample_qs[m_chunk] = fwd_sample_qs[m_chunk][num_consumed_samples:] + if fp8_weight_caching: # Initialize flag that controls FP8 weight updates FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) @@ -185,10 +272,13 @@ def _make_graphed_callables( per_callable_module_params = [] for m_chunk in range(num_model_chunks): for _ in range(num_microbatches): - for l_no in range(num_layers): + for l_no in range(_num_layers_per_chunk[m_chunk]): per_callable_module_params.append( - tuple(callables[m_chunk * num_layers + l_no].parameters()) - if isinstance(callables[m_chunk * num_layers + l_no], torch.nn.Module) + tuple(callables[_prefix_num_layers[m_chunk] + l_no].parameters()) + if isinstance( + callables[_prefix_num_layers[m_chunk] + l_no], + torch.nn.Module, + ) else () ) assert len(per_callable_module_params) == len(flatten_sample_args) @@ -227,10 +317,10 @@ def _make_graphed_callables( for c_id in _order: if c_id > 0: m_chunk = c_id - 1 - for l_no in range(num_layers): - func = callables[m_chunk * num_layers + l_no] - func_idx = (m_chunk * num_microbatches * num_layers) + ( - fwd_idx[m_chunk] * num_layers + l_no + for l_no in range(_num_layers_per_chunk[m_chunk]): + func = callables[_prefix_num_layers[m_chunk] + l_no] + func_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( + fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no ) warmup_func_idx.append(func_idx) warmup_func.append(func) @@ -255,7 +345,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument args = sample_args[func_idx] kwargs = sample_kwargs[func_idx] static_input_surface = per_callable_static_input_surfaces[func_idx] - for _ in range(num_warmup_iters): + for warmup_iter in range(num_warmup_iters): hooks = [] for module in func.modules(): hook = module.register_forward_hook(hook_fn) @@ -271,6 +361,34 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument only_inputs=True, allow_unused=allow_unused_input, ) + + # Filter module params that get None grad from grad_inputs and remove them + # from static_input_surface. This is to ensure that the backward hooks + # registered to these params are not wrongly triggered. + num_required_grad_sample_args = sum( + arg.requires_grad for arg in flatten_sample_args[func_idx] + ) + required_grad_input_idx = [] + for i, arg in enumerate(static_input_surface): + if arg.requires_grad: + required_grad_input_idx.append(i) + module_params_with_grad = [] + for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx): + if ( + grad_inputs[grad_inputs_idx] is not None + and grad_inputs_idx >= num_required_grad_sample_args + ): + module_params_with_grad.append(static_input_surface[inputs_idx]) + if len(module_params_with_grad) != len(per_callable_module_params[func_idx]): + assert warmup_iter == 0, ( + "no-grad params should only be used as inputs in the first warmup" + " iteration" + ) + per_callable_module_params[func_idx] = tuple(module_params_with_grad) + static_input_surface = flatten_sample_args[func_idx] + tuple( + module_params_with_grad + ) + per_callable_static_input_surfaces[func_idx] = static_input_surface else: grad_inputs = None del outputs, grad_inputs @@ -292,14 +410,16 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument per_callable_static_grad_inputs = [None] * len(flatten_sample_args) fwd_idx = [0] * num_model_chunks bwd_idx = [0] * num_model_chunks + static_grad_outputs = None + previous_per_callable_bwd_idx = None for c_id in _order: if c_id > 0: # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] m_chunk = c_id - 1 - for l_no in range(num_layers): - func = callables[m_chunk * num_layers + l_no] - per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) + ( - fwd_idx[m_chunk] * num_layers + l_no + for l_no in range(_num_layers_per_chunk[m_chunk]): + func = callables[_prefix_num_layers[m_chunk] + l_no] + per_callable_fwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( + fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no ) args = sample_args[per_callable_fwd_idx] kwargs = sample_kwargs[per_callable_fwd_idx] @@ -314,17 +434,20 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument else: # Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1] m_chunk = -c_id - 1 - for l_no in list(reversed(range(num_layers))): - per_callable_bwd_idx = (m_chunk * num_microbatches * num_layers) + ( - bwd_idx[m_chunk] * num_layers + l_no + for l_no in list(reversed(range(_num_layers_per_chunk[m_chunk]))): + per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( + bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no ) static_input_surface = per_callable_static_input_surfaces[per_callable_bwd_idx] static_outputs = per_callable_static_outputs[per_callable_bwd_idx] bwd_graph = bwd_graphs[per_callable_bwd_idx] # For now, assumes all static_outputs require grad - static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None for o in static_outputs - ) + if not _reuse_graph_input_output_buffers or static_grad_outputs is None: + # Note for _reuse_graph_input_output_buffers: grad output is only used + # within backward, so we can reuse the same static buffers every time. + static_grad_outputs = tuple( + torch.empty_like(o) if o.requires_grad else None for o in static_outputs + ) if is_training: with torch.cuda.graph(bwd_graph, pool=mempool): grad_inputs = torch.autograd.grad( @@ -350,6 +473,30 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument per_callable_static_grad_outputs[per_callable_bwd_idx] = static_grad_outputs per_callable_static_grad_inputs[per_callable_bwd_idx] = static_grad_inputs + + # Weak ref the static outputs and static grad inputs that are no longer needed + # in the following steps. These two type of tensors are both in cudagraph + # mempool, so we just deallocate them and let PyTorch's memory allocator + # reuse them elsewhere. + if _reuse_graph_input_output_buffers: + # Weak ref the static outputs of the forward pass of this backward. It's + # no longer needed after the corresponding backward graph is built up. + per_callable_static_outputs[per_callable_bwd_idx] = make_weak_ref( + static_outputs + ) + # Weak ref the static grad inputs of the previous backward pass. + # Note: After a backward pass, we assume Mcore will send the + # grad input to another pipeline parallel rank and that the + # communication is finished before the end of the next backward + # pass. + if previous_per_callable_bwd_idx is not None: + per_callable_static_grad_inputs[previous_per_callable_bwd_idx] = ( + make_weak_ref( + per_callable_static_grad_inputs[previous_per_callable_bwd_idx] + ) + ) + previous_per_callable_bwd_idx = per_callable_bwd_idx + bwd_idx[m_chunk] += 1 else: # Capture forward graphs @@ -595,7 +742,7 @@ def save_fp8_tensors( m.adjust_amax_history_length(fp8_recipe.amax_history_len) module_tensors = m.get_fp8_meta_tensors() elif isinstance(m, BasicOperation): - m.pre_forward(fp8_enabled=True, fp8_recipe=fp8_recipe) + m.pre_first_forward(recipe=fp8_recipe) module_tensors = m._save_fp8_metas() fp8_tensors.append(module_tensors) return fp8_tensors @@ -636,8 +783,10 @@ def make_graphed_callables( fp8_group: Optional[dist_group_type] = None, fp8_weight_caching: bool = False, _order: Optional[List[int]] = None, + _num_layers_per_chunk: Optional[List[int]] = None, pool: Optional[Tuple[int, ...]] = None, retain_graph_in_backward: bool = False, + _reuse_graph_input_output_buffers: bool = False, ) -> Union[Callable, Tuple[Callable, ...]]: """ Make CUDA graph version of Transformer Engine modules @@ -666,6 +815,11 @@ def make_graphed_callables( this graph may share memory with the indicated pool. retain_graph_in_backward: bool, default = `False` Whether to set retain_graph=True in backward graph capture. + _reuse_graph_input_output_buffers: bool, default = `False` + Reduce memory usage by reusing input/output data buffers between + graphs. Only supported with Mcore interleaved pipeline parallelism, i.e. + when `_order` is provided. All callables in `modules` are assumed to have + inputs and outputs with the same dtype and shape. FP8-related parameters ---------------------- @@ -704,10 +858,17 @@ def make_graphed_callables( saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe) # FP8 wrapper. + old_call_funcs = {} + def wrap_autocast(block): - old_forward = block.forward + block_cls = type(block) + if block_cls in old_call_funcs: + return + + old_call_funcs[block_cls] = block_cls.__call__ - def forward_func(*args, **kwargs): + # Wrap the original call function of the module class. + def call_func(*args, **kwargs): with fp8_autocast( enabled=fp8_enabled, calibrating=fp8_calibrating, @@ -715,10 +876,10 @@ def forward_func(*args, **kwargs): fp8_group=fp8_group, _graph=True, ): - outputs = old_forward(*args, **kwargs) + outputs = old_call_funcs[block_cls](*args, **kwargs) return outputs - block.forward = forward_func + block_cls.__call__ = call_func forward_funcs = [] for module in modules: @@ -749,8 +910,10 @@ def forward_func(*args, **kwargs): fp8_weight_caching=fp8_weight_caching, sample_kwargs=sample_kwargs, _order=_order, + _num_layers_per_chunk=_num_layers_per_chunk, pool=pool, retain_graph_in_backward=retain_graph_in_backward, + _reuse_graph_input_output_buffers=_reuse_graph_input_output_buffers, ) # Ensures warmup does not affect numerics for ops such as dropout. @@ -760,6 +923,10 @@ def forward_func(*args, **kwargs): else: torch.cuda.set_rng_state(original_rng_states) + # Remove FP8 wrapper. + for module_cls, old_call in old_call_funcs.items(): + module_cls.__call__ = old_call + # Restore FP8 state. restore_fp8_tensors(modules, saved_fp8_tensors) diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index 365d79646..cdd08766c 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -8,11 +8,11 @@ import os from functools import wraps from typing import Callable, Optional, Tuple - import torch from torch.utils.cpp_extension import IS_HIP_EXTENSION from . import torch_version +from .export import is_in_onnx_export_mode from .utils import gpu_autocast_ctx # pylint: disable=unnecessary-lambda-assignment @@ -49,7 +49,17 @@ def wrapper(*args, **kwargs): # Decorator to disable Torch Dynamo # See: https://github.com/NVIDIA/TransformerEngine/issues/308 -no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recursive=recursive) +no_torch_dynamo = lambda recursive=True: lambda func: func +if torch.__version__ >= "2": + import torch._dynamo + + if torch.__version__ >= "2.1": + no_torch_dynamo = lambda recursive=True: lambda f: ( + f if is_in_onnx_export_mode() else torch._dynamo.disable(f, recursive=recursive) + ) + else: + # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True + no_torch_dynamo = lambda recursive=True: torch._dynamo.disable def set_jit_fusion_options() -> None: @@ -125,6 +135,35 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor: return dgelu +@jit_fuser +def l2normalization_fused_(x: torch.Tensor, eps: float) -> torch.Tensor: + """L2 normalization fused - inference version""" + x_squared = x.pow(2) + l2_norm_squared = x_squared.sum(dim=-1, keepdim=True) + rsqrt_norm = torch.rsqrt(l2_norm_squared + eps) + return x * rsqrt_norm + + +@jit_fuser +def l2normalization_fwd_fused_(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]: + """L2 normalization fused - training version that returns intermediate values""" + x_squared = x.pow(2) + l2_norm_squared = x_squared.sum(dim=-1, keepdim=True) + rsqrt_norm = torch.rsqrt(l2_norm_squared + eps) + y = x * rsqrt_norm + return y, rsqrt_norm + + +@jit_fuser +def l2normalization_backward_fused_( + grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float +) -> torch.Tensor: + """L2 normalization backward fused""" + x_dy_sum = (x * grad_output).sum(dim=-1, keepdim=True) + x_norm_squared = x.pow(2).sum(dim=-1, keepdim=True) + eps + return rsqrt_norm * (grad_output - x * x_dy_sum / x_norm_squared) + + def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: """Disable native AMP for bias_gelu_fused_""" with gpu_autocast_ctx(enabled=False): @@ -143,6 +182,26 @@ def bgrad_dgelu_fused( return None, dgelu_fused_(grad_output, inp) +def l2normalization_fused(x: torch.Tensor, eps: float) -> torch.Tensor: + """Disable native AMP for l2normalization_fused_ - inference version""" + with gpu_autocast_ctx(enabled=False): + return l2normalization_fused_(x, eps) + + +def l2normalization_fwd_fused(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]: + """Disable native AMP for l2normalization_fwd_fused_ - training version""" + with gpu_autocast_ctx(enabled=False): + return l2normalization_fwd_fused_(x, eps) + + +def l2normalization_backward_fused( + grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float +) -> torch.Tensor: + """Disable native AMP for l2normalization_backward_fused_""" + with gpu_autocast_ctx(enabled=False): + return l2normalization_backward_fused_(grad_output, x, rsqrt_norm, eps) + + def bias_dropout_add( x: torch.Tensor, bias: torch.Tensor, @@ -268,3 +327,45 @@ def warmup_jit_bias_gelu_all_dtypes( """Call `warmup_jit_bias_gelu` for all training dtypes""" for dtype in [torch.float32, torch.bfloat16, torch.float16]: warmup_jit_bias_gelu(ffn_hidden_size, dtype, seq_length, micro_batch_size) + + +def warmup_jit_l2normalization( + hidden_size: int, dtype: torch.dtype, seq_length: int, micro_batch_size: int +) -> None: + """Compile L2Normalization JIT function before the main training steps""" + + # Save cuda RNG state to ensure warmup does not affect reproducibility. + rng_state = torch.cuda.get_rng_state() + + inp = torch.rand( + (seq_length * micro_batch_size, hidden_size), + dtype=dtype, + device="cuda", + ) + eps = 1e-6 + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for input_grad in [False, True]: + inp.requires_grad = input_grad + for _ in range(5): + if input_grad: + # Test training version that returns intermediate values + output, rsqrt_norm = l2normalization_fwd_fused_(inp, eps) + # Test backward pass as well + grad_out = torch.rand_like(output) + _ = l2normalization_backward_fused_(grad_out, inp, rsqrt_norm, eps) + else: + # Test inference version + output = l2normalization_fused_(inp, eps) + del inp, output + + torch.cuda.empty_cache() + torch.cuda.set_rng_state(rng_state) + + +def warmup_jit_l2normalization_all_dtypes( + hidden_size: int, seq_length: int, micro_batch_size: int +) -> None: + """Call `warmup_jit_l2normalization` for all training dtypes""" + for dtype in [torch.float32, torch.bfloat16, torch.float16]: + warmup_jit_l2normalization(hidden_size, dtype, seq_length, micro_batch_size) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index ef401a91f..1f38b493c 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -17,6 +17,7 @@ from .. import cpp_extensions as tex from ..constants import TE_DType from ..utils import get_default_init_method +from ..export import is_in_onnx_export_mode if IS_HIP_EXTENSION: from ..triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton @@ -182,6 +183,8 @@ def noop_cat( raise ValueError("Attempted to concatenate 0 tensors") if len(tensors) == 1: return tensors[0] + if is_in_onnx_export_mode(): + return torch.cat(tensors, dim=dim) return _NoopCatFunc.apply(dim, *tensors) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 781c20417..a6ab1b22a 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -53,7 +53,7 @@ from ..utils import is_non_tn_fp8_gemm_supported from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase -from ...common.recipe import Recipe +from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor @@ -99,7 +99,7 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]: """Returns workspace for multi-stream cublas.""" global _multi_stream_cublas_workspace if not _multi_stream_cublas_workspace: - for _ in range(tex._num_cublas_streams): + for _ in range(tex.get_num_cublas_streams()): _multi_stream_cublas_workspace.append( torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda") ) @@ -657,6 +657,8 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> # Update quantizers with new amax pointers. self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers() + # Make sure weight tensors has correct quantizers + self._update_weight_quantizers() # Update the global buffers with new amax and history pointers. if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta: @@ -710,6 +712,30 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: self.fp8_meta[fp8_meta_tensor_key] = recipe_state self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers() + def _update_weight_quantizers(self) -> None: + """Update the quantizers for the weight tensors.""" + weight_tensors = self._get_weight_tensors() + weight_quantizers = self._get_weight_quantizers() + assert len(weight_tensors) == len(weight_quantizers), ( + f"Number of weight tensors ({len(weight_tensors)}) and quantizers " + f"({len(weight_quantizers)}) must match" + ) + for weight, quantizer in zip(weight_tensors, weight_quantizers): + if quantizer is not None and isinstance(weight, QuantizedTensorBase): + weight.update_quantizer(quantizer) + + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + """Get the weight tensors of the module.""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement _get_weight_tensors function" + ) + + def _get_weight_quantizers(self) -> List[Quantizer]: + """Get the weight quantizers of the module.""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement _get_weight_quantizers function" + ) + def init_fp8_meta_tensors(self, recipe: Recipe) -> None: """Init scales and amaxes.""" self.set_meta_tensor(True, recipe) @@ -749,7 +775,7 @@ def reset(key): reset("scaling_fwd") reset("scaling_bwd") - def get_extra_state(self) -> Optional[torch.Tensor]: + def get_extra_state(self) -> torch.Tensor: """Save before checkpointing.""" # This implementation is working around a few issues: @@ -784,7 +810,7 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: state = None fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration if not fp8_checkpoint: - return None + return torch.empty(0, dtype=torch.uint8) # Copy tensors to CPU and store state = {} @@ -810,13 +836,18 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8) return state_serialized - def set_extra_state(self, state: Optional[torch.Tensor]) -> None: + def set_extra_state(self, state: torch.Tensor) -> None: """Load previous state.""" + + # Maintain backwards compatibility with older checkpoints. if state is None: return # Load state if isinstance(state, torch.Tensor): + # No FP8 is indicated by an empty tensor we don't need to unpickle. + if state.numel() == 0: + return # Default format: byte tensor with pickled data state = pickle.loads(state.detach().cpu().numpy().tobytes()) elif isinstance(state, io.BytesIO): @@ -829,6 +860,14 @@ def set_extra_state(self, state: Optional[torch.Tensor]) -> None: if state is None: return + # TE 1.x checkpoint compatibility: add DelayedScaling recipe if missing + if "recipe" not in state: + # TE 1.x only supported delayed scaling, which was the default recipe + state["recipe"] = DelayedScaling() + # TE 1.x also saved scale_inv, which is not needed with Recipe object + state.pop("scale_inv_fwd", None) + state.pop("scale_inv_bwd", None) + # Load extra items self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta["recipe"] = state["recipe"] @@ -902,6 +941,8 @@ def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: # assume FP8 execution. def init_fp8_metadata(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" + _original_recipe = self.fp8_meta.get("recipe", None) + self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() self.fp8 = FP8GlobalStateManager.is_fp8_enabled() self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() @@ -945,6 +986,19 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: if self.fp8_meta["recipe"].mxfp8(): self.keep_fp8_weight_transpose_cache = True + _current_recipe = self.fp8_meta["recipe"] + if _original_recipe is not None and not ( + issubclass(_current_recipe.__class__, _original_recipe.__class__) + or issubclass(_original_recipe.__class__, _current_recipe.__class__) + ): + warnings.warn( + f"Recipe type changed from {_original_recipe.__class__.__name__} " + f"to {_current_recipe.__class__.__name__}. " + "This may affect model behavior." + ) + # Clear cached workspaces as they were created with the old recipe/quantizer type + self._fp8_workspaces.clear() + @contextmanager def prepare_forward( self, @@ -958,6 +1012,7 @@ def prepare_forward( to setup the forward aggregated amax reduction for every module just in case. The autocast exit will pick up the most recent one. """ + self.forwarded_at_least_once = True # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) @@ -969,6 +1024,7 @@ def prepare_forward( self.set_activation_dtype(inp) self.init_fp8_metadata(num_gemms=num_gemms) + self._check_weight_tensor_recipe_correspondence() if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed(): assert self.fp8_meta["recipe"].reduce_amax, ( @@ -1080,7 +1136,12 @@ def grad_output_preprocess( if ( isinstance( grad_output_.get_tensor(True), - (QuantizedTensor, Float8TensorBase, MXFP8TensorBase), + ( + QuantizedTensor, + Float8TensorBase, + MXFP8TensorBase, + Float8BlockwiseQTensorBase, + ), ) and ctx.use_bias ): @@ -1146,20 +1207,25 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: with get_rng_state_tracker().fork(): init_fn(param) - # If primary weights are in fp8, wrap the parameter as FP8Tensor + # Wrap parameters in QuantizedTensor if needed fp8_meta_index = self.param_init_meta[name].fp8_meta_index high_precision_init_val = None if self.primary_weights_in_fp8 and fp8_meta_index is not None: + + # Keep high-precision values on CPU if needed if self.preserve_high_precision_init_val: high_precision_init_val = param.detach().cpu() + # Configure quantizer quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] - assert ( - quantizer is not None - ) # to use primary fp8 weight one needs to use FP8 autocast with specific recipe. + if quantizer is None: + raise RuntimeError("Weight quantizer has not been initialized") + quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) quantizer.internal = False if IS_HIP_EXTENSION and not self.keep_fp8_weight_transpose_cache: quantizer.columnwise_usage=False + + # Quantize parameter param = quantizer(param) if IS_HIP_EXTENSION and self.use_fsdp2 and not self.primary_weights_in_fp8 and fp8_meta_index is not None: self.keep_fp8_weight_transpose_cache = False @@ -1175,6 +1241,8 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # re-applying the nn.Parameter() wrap is a no-op when the input is already # a parameter so we always re-apply it just for extra safety. param = torch.nn.Parameter(param) + + # Keep high-precision values on CPU if needed if high_precision_init_val is not None: # - Master weights are initialized from model weights, if we use fp8 primary @@ -1218,7 +1286,7 @@ def get_weight_workspace( fsdp_group: Optional[dist_group_type] = None, workspace_dtype: Optional[torch.dtype] = None, ) -> QuantizedTensor: - """Get FP8 workspace buffer and maybe update its values + """Get workspace buffer for weights and maybe update its values The workspace buffer may be cached for future function calls. @@ -1244,13 +1312,16 @@ def get_weight_workspace( for debug quantization, this is dtype of the tensor. """ - # FP8 primary weights + # Handle case where weights are already quantized + # Note: Make sure weights have required usages, but do not + # destroy unnecessary usages since they may be used later. if isinstance(tensor, QuantizedTensor): - if update_workspace and quantizer is not None: - tensor.update_usage( - rowwise_usage=quantizer.rowwise_usage, - columnwise_usage=quantizer.columnwise_usage, - ) + update_rowwise_usage = True if quantizer.rowwise_usage else None + update_columnwise_usage = True if quantizer.columnwise_usage else None + tensor.update_usage( + rowwise_usage=update_rowwise_usage, + columnwise_usage=update_columnwise_usage, + ) return tensor # Try getting workspace from cache @@ -1380,6 +1451,43 @@ def _validate_name(self): ) self.name = f"Layer_{TEDebugState.get_layer_count()}" + def _check_weight_tensor_recipe_correspondence(self) -> None: + """ + Verify that the weight tensor types match their corresponding recipe type. + This is invoked in the forward(). + + This establishes a 1:1 correspondence between recipe types and tensor types: + - DelayedScaling → Float8Tensor + - Float8CurrentScaling → Float8Tensor + - MXFP8BlockScaling → MXFP8Tensor + - Float8BlockScaling → Float8BlockTensor + + Example case to check: recipe is DelayedScaling (DelayedScaling is set in fp8_autocast()), + but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in fp8_model_init()). + """ + if not self.fp8 and not self.fp8_calibration: + return + if not hasattr(self, "weight_names") or not self.weight_names: + return + + recipe = self.fp8_meta["recipe"] + weight_tensors = [getattr(self, name) for name in self.weight_names] + for i, tensor in enumerate(weight_tensors): + if isinstance(tensor, QuantizedTensorBase): + quantizer = tensor._get_quantizer() + if quantizer is None: + continue + compatible_recipe_class = quantizer._get_compatible_recipe() + if compatible_recipe_class is None: + continue + if not isinstance(recipe, compatible_recipe_class): + raise RuntimeError( + f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe" + f" {compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}." + " Please check the recipes assigned during fp8_model_init() and" + " fp8_autocast() calls." + ) + def _turn_off_unsupported_features_in_debug(self): if ( getattr(self, "ub_bulk_wgrad", False) diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py index 974840833..0d2e3e6d7 100644 --- a/transformer_engine/pytorch/module/fp8_padding.py +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -53,15 +53,16 @@ def backward(ctx, grad_output: torch.Tensor): if ctx.requires_dgrad: grad_output = grad_output.contiguous() - grad_output_mats = torch.split( - grad_output.view(-1, grad_output.shape[-1]), ctx.padded_m_splits + in_features = grad_output.shape[-1] + + # Allocate cast and transpose output tensor + total_row = sum(ctx.m_splits) + grad_input = torch.empty( + [total_row, in_features], dtype=grad_output.dtype, device=grad_output.device ) - grad_input = torch.cat( - [ - grad_output_mat[: ctx.m_splits[i]] - for i, grad_output_mat in enumerate(grad_output_mats) - ], - dim=0, + + tex.fused_multi_row_unpadding( + grad_output.view(-1, in_features), grad_input, ctx.padded_m_splits, ctx.m_splits ) return (grad_input, None, None, None) @@ -73,11 +74,12 @@ class Fp8Padding(torch.nn.Module): Parameters ---------- - num_gemms: int - number of GEMMs to be performed simutaneously. - align_size: int, optional - the alignment size for the input tensor. If not provided, the alignment size will - be determined by the FP8 recipe, 32 for MXFP8 and 16 for others. + num_gemms : int + number of GEMMs to be performed simultaneously. + align_size : int, optional + the alignment size for the input tensor. If not provided, the alignment size will + be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first + forward pass. """ def __init__( @@ -88,10 +90,7 @@ def __init__( super().__init__() self.num_gemms = num_gemms - if align_size is None: - self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16 - else: - self.align_size = align_size + self.align_size = align_size @no_torch_dynamo() def forward( @@ -111,6 +110,8 @@ def forward( """ assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." + if self.align_size is None: + self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16 # FP8 padding calculate padded_m_splits = [ diff --git a/transformer_engine/pytorch/module/fp8_unpadding.py b/transformer_engine/pytorch/module/fp8_unpadding.py index 7e1fbcb2a..3b0f8928f 100644 --- a/transformer_engine/pytorch/module/fp8_unpadding.py +++ b/transformer_engine/pytorch/module/fp8_unpadding.py @@ -29,10 +29,13 @@ def forward( is_grad_enabled: bool, ) -> torch.Tensor: # pylint: disable=missing-function-docstring - inputmats = torch.split(inp.view(-1, inp.shape[-1]), padded_m_splits) - out_ret = torch.cat( - [grad_output_mat[: m_splits[i]] for i, grad_output_mat in enumerate(inputmats)], dim=0 - ) + in_features = inp.shape[-1] + + # Allocate cast and transpose output tensor + total_row = sum(m_splits) + out_ret = torch.empty([total_row, in_features], dtype=inp.dtype, device=inp.device) + + tex.fused_multi_row_unpadding(inp.view(-1, in_features), out_ret, padded_m_splits, m_splits) if is_grad_enabled: ctx.m_splits = m_splits @@ -69,11 +72,12 @@ class Fp8Unpadding(torch.nn.Module): Parameters ---------- - num_gemms: int - number of GEMMs to be performed simutaneously. - align_size: int, optional - the alignment size for the input tensor. If not provided, the alignment size will - be determined by the FP8 recipe, 32 for MXFP8 and 16 for others. + num_gemms : int + number of GEMMs to be performed simultaneously. + align_size : int, optional + the alignment size for the input tensor. If not provided, the alignment size will + be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first + forward pass. """ def __init__( @@ -84,10 +88,7 @@ def __init__( super().__init__() self.num_gemms = num_gemms - if align_size is None: - self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16 - else: - self.align_size = align_size + self.align_size = align_size @no_torch_dynamo() def forward( @@ -107,6 +108,8 @@ def forward( """ assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." + if self.align_size is None: + self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16 # FP8 padding calculate padded_m_splits = [ diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 6225d3119..a684dad6b 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -27,7 +27,6 @@ from ..utils import ( divide, cast_if_needed, - assert_dim_for_fp8_exec, clear_tensor_data, init_method_constant, requires_grad, @@ -41,11 +40,12 @@ from ..cpp_extensions import ( general_grouped_gemm, ) -from ..constants import GemmParallelModes, dist_group_type, TE_DType +from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ..cpu_offload import is_cpu_offload_enabled +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.quantized_tensor import ( QuantizedTensorBase, Quantizer, @@ -83,6 +83,7 @@ def forward( is_grad_enabled: bool, module, skip_fp8_weight_update, + save_original_input, *weights_and_biases, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -91,25 +92,18 @@ def forward( weights = weights_and_biases[:num_gemms] biases = weights_and_biases[num_gemms:] device = inp.device - - # Make sure input dimensions are compatible - in_features = weights[0].shape[-1] - assert inp.shape[-1] == in_features, "GEMM not possible" - inputmats = torch.split(inp.view(-1, in_features), m_splits) - if fp8: - assert_dim_for_fp8_exec(*inputmats, *weights) - - # Cast input to expected dtype - inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats] - inputmats = [] - weight_requires_grad = weights[0].requires_grad + # Configure quantizers + if save_original_input and isinstance(input_quantizers[0], Float8Quantizer): + raise ValueError("DelayedScaling recipe is not supported with save_original_input") if input_quantizers[0] is not None: for input_quantizer in input_quantizers: input_quantizer.set_usage( rowwise=True, - columnwise=(is_grad_enabled and weight_requires_grad), + columnwise=( + is_grad_enabled and weight_requires_grad and not save_original_input + ), ) columnwise_usage = is_grad_enabled and inp.requires_grad if not columnwise_usage: @@ -124,28 +118,25 @@ def forward( for output_quantizer in output_quantizers: output_quantizer.set_usage(rowwise=True, columnwise=False) - fprop_gemm_use_split_accumulator = _2X_ACC_FPROP + # Initialize input tensors + in_features = weights[0].size(-1) + if inp.size(-1) != in_features: + raise ValueError( + f"Input tensor (shape={tuple(inp.size())}) is not compatible with " + f"weight tensor (shape={tuple(weights[0].size())})" + ) + inp_view = inp.reshape(-1, in_features) + inputmats: list if fp8: - recipe = FP8GlobalStateManager.get_fp8_recipe() - if hasattr(recipe, "fp8_gemm_fprop"): - fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator - - if IS_HIP_EXTENSION and bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ): - # The Triton path has no equivalent for tex.fused_multi_quantize() - inputmats = [] - for i, x in enumerate(inputmats_no_fp8): - qi = input_quantizers[i] - dst = qi.make_empty(x.shape, dtype=x.dtype, device=x.device, requires_grad=False) - qi.update_quantized(x, dst, noop_flag=None) - inputmats.append(dst) - else: - inputmats = tex.fused_multi_quantize( - inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype] - ) + inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers) + else: + inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) - weights_fp8 = [] - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype + # Initialize weights + weights_fp8: list + if fp8: # FP8 cast to workspace buffer + weights_fp8 = [] update_workspace = is_first_microbatch is None or is_first_microbatch for i in range(num_gemms): weight_fp8 = module.get_weight_workspace( @@ -158,18 +149,29 @@ def forward( weights_fp8.append(weight_fp8) else: - inputmats = inputmats_no_fp8 - bias_dtype = activation_dtype weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights] + # Initialize biases + bias_dtype = activation_dtype + if fp8 and activation_dtype == torch.float32: + bias_dtype = torch.bfloat16 # FP8 GEMM only supports BF16/FP16 bias biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases + # Initialize output tensor out = torch.empty( [sum(m_splits), weights_fp8[0].size(0)], dtype=activation_dtype, device=device, ) + # Choose whether to use split accumulator + use_split_accumulator = _2X_ACC_FPROP + if fp8: + recipe = FP8GlobalStateManager.get_fp8_recipe() + if hasattr(recipe, "fp8_gemm_fprop"): + use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator + + # Perform GEMM _ = general_grouped_gemm( weights_fp8, inputmats, @@ -180,7 +182,7 @@ def forward( m_splits=m_splits, bias=biases, use_bias=use_bias, - use_split_accumulator=fprop_gemm_use_split_accumulator, + use_split_accumulator=use_split_accumulator, ) if fp8_calibration: @@ -197,9 +199,15 @@ def forward( # TODO: update after #1638 is merged. # pylint: disable=fixme if weight_requires_grad: - for inputmat in inputmats: - if isinstance(inputmat, QuantizedTensorBase): - inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + if save_original_input: + inputmats = [None] * num_gemms + inputmats[0] = inp + else: + for inputmat in inputmats: + if isinstance(inputmat, QuantizedTensorBase): + inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + else: + inputmats = [None] * num_gemms if inp.requires_grad: for weight in weights_fp8: if isinstance(weight, QuantizedTensorBase): @@ -216,9 +224,18 @@ def forward( ctx.weights_requires_grad = weights[0].requires_grad if fuse_wgrad_accumulation and ctx.weights_requires_grad: - ctx.main_grads = [weights[i].main_grad for i in range(num_gemms)] + # This check is needed to ensure that main_grad is not created + # during the forward pass when using MCore FSDP as it creates + # the main_grad buffer lazily before backprop + if hasattr(weights[0], "__fsdp_param__"): + # MCore FSDP creates main_grad lazily before backward + ctx.main_grad_funcs = [weights[i].get_main_grad for i in range(num_gemms)] + else: + ctx.main_grad_funcs = [ + lambda j=i: weights[j].main_grad for i in range(num_gemms) + ] else: - ctx.main_grads = [None] * num_gemms + ctx.main_grad_funcs = [lambda: None for i in range(num_gemms)] ctx.device = device ctx.grad_output_quantizers = grad_output_quantizers ctx.m_splits = m_splits @@ -240,6 +257,8 @@ def forward( or FP8GlobalStateManager.is_first_fp8_module() ) ctx.wgrad_store = wgrad_store + ctx.save_original_input = save_original_input + ctx.input_quantizers = input_quantizers # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -254,44 +273,52 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], weights = saved_tensors[N : 2 * N] origin_weights = saved_tensors[2 * N : 3 * N] biases = saved_tensors[3 * N : 4 * N] - main_grads = ctx.main_grads + main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO - for i in ctx.num_gemms: + if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: + for i in range(ctx.num_gemms): w = torch.nn.Parameter(weights[i], weights[i].requires_grad) w.main_grad = main_grads[i] weights[i] = w - # preprocess grad_output - - grad_output = grad_output.contiguous() - grad_output_mats = torch.split( - grad_output.view(-1, grad_output.shape[-1]), ctx.m_splits - ) + # Preprocess grad output + grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms if ctx.fp8: if ctx.use_bias: - # unfuse bgrad for now until cast_transpose + dgrad calculation is ready - # for Float8BlockQuantizer. - if ctx.fp8_recipe.float8_block_scaling(): - for i in range(ctx.num_gemms): - grad_biases[i] = grad_output_mats[i].sum(dim=0) - grad_output[i] = ctx.grad_output_quantizers[i](grad_output_mats[i]) - else: + grad_output_mats = torch.split(grad_output_view, ctx.m_splits) + recipe = ctx.fp8_recipe + if recipe.delayed() or recipe.float8_current_scaling() or recipe.mxfp8(): + # Fused bias grad + quantize kernel for i in range(ctx.num_gemms): grad_biases[i], grad_output[i] = tex.bgrad_quantize( - grad_output_mats[i], ctx.grad_output_quantizers[i] + grad_output_mats[i], + ctx.grad_output_quantizers[i], ) + else: + # Unfused bias grad and multi-tensor quantize + for i in range(ctx.num_gemms): + grad_biases[i] = grad_output_mats[i].sum(dim=0) + grad_output = tex.split_quantize( + grad_output_view, + ctx.m_splits, + ctx.grad_output_quantizers, + ) else: - grad_output = tex.fused_multi_quantize( - grad_output_mats, - None, + # Multi-tensor quantize + grad_output = tex.split_quantize( + grad_output_view, + ctx.m_splits, ctx.grad_output_quantizers, - TE_DType[ctx.activation_dtype], ) else: - grad_output = grad_output_mats + # Only split grad output. Grad bias is fused with + # wgrad GEMM. + grad_output = torch.split( + cast_if_needed(grad_output_view, ctx.activation_dtype), + ctx.m_splits, + ) if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( @@ -348,6 +375,27 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) for w in weights ] + + if ctx.save_original_input: + inp = inputmats[0] + in_features = inp.shape[-1] + inp_view = inp.reshape(-1, in_features) + if ctx.input_quantizers[0] is not None: + for input_quantizer in ctx.input_quantizers: + if isinstance( + input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ): + input_quantizer.set_usage(rowwise=True, columnwise=True) + else: + input_quantizer.set_usage(rowwise=False, columnwise=True) + inputmats: list + if ctx.fp8: + inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) + else: + inputmats = torch.split( + cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits + ) + grouped_gemm_wgrad = functools.partial( general_grouped_gemm, out_dtype=ctx.activation_dtype, @@ -439,6 +487,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): None, None, None, + None, *wgrad_list, *grad_biases, ) @@ -489,6 +538,11 @@ class GroupedLinear(TransformerEngineBaseModule): would not fit in GPU memory. delay_wgrad_compute : bool, default = `False` Whether to delay weight gradient computation + save_original_input : bool, default = `False` + If set to `True`, always saves the original input tensor rather than the + cast tensor. In some scenarios, the input tensor is used by multiple modules, + and saving the original input tensor may reduce the memory usage. + Cannot work with FP8 DelayedScaling recipe. Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and `parallel_mode` are used to determine the shapes of weights and biases. @@ -516,6 +570,7 @@ def __init__( ub_overlap_ag: bool = False, ub_name: Optional[str] = None, delay_wgrad_compute: bool = False, + save_original_input: bool = False, ) -> None: super().__init__() @@ -530,6 +585,7 @@ def __init__( self.ub_overlap_rs = ub_overlap_rs self.ub_overlap_ag = ub_overlap_ag self.ub_name = ub_name + self.save_original_input = save_original_input assert ( not ub_overlap_rs and not ub_overlap_ag ), "GroupedLinear doesn't support Userbuffer overlap." @@ -683,26 +739,19 @@ def forward( ), "GroupedLinear doesn't support input tensor in FP8." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None if skip_fp8_weight_update is not None: is_first_microbatch = False with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: - - weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + weight_tensors = self._get_weight_tensors() bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] - if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors): - warnings.warn( - "You are using quantized weights without quantized compute. " - "Please make sure this is intentional." - ) - weight_tensors = [ - w.dequantize() if isinstance(w, QuantizedTensorBase) else w - for w in weight_tensors - ] - input_quantizers, weight_quantizers, output_quantizers = ( - [None] * self.num_gemms, + weight_quantizers = self._get_weight_quantizers() + input_quantizers, output_quantizers = ( [None] * self.num_gemms, [None] * self.num_gemms, ) @@ -717,14 +766,6 @@ def forward( # TODO: use internal after #1638 is merged. # pylint: disable=fixme for i in range(self.num_gemms): input_quantizers[i].internal = False - weight_quantizers = [ - self.quantizers["scaling_fwd"][ - self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"] - ] - for i in range(self.num_gemms) - ] - for i in range(self.num_gemms): - weight_quantizers[i].internal = True if torch.is_grad_enabled(): grad_output_quantizers = [ self.quantizers["scaling_bwd"][ @@ -760,6 +801,7 @@ def forward( torch.is_grad_enabled(), self, skip_fp8_weight_update, + self.save_original_input, *weight_tensors, *bias_tensors, ) @@ -823,3 +865,30 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_bwd"][ self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"] ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon + + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + """Get the weight tensors of the module.""" + weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors): + warnings.warn( + "You are using quantized weights without quantized compute. " + "Please make sure this is intentional." + ) + weight_tensors = [ + w.dequantize() if isinstance(w, QuantizedTensorBase) else w for w in weight_tensors + ] + return weight_tensors + + def _get_weight_quantizers(self) -> List[Quantizer]: + """Get the weight quantizers of the module.""" + if not self.fp8: + return [None] * self.num_gemms + weight_quantizers = [ + self.quantizers["scaling_fwd"][ + self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"] + ] + for i in range(self.num_gemms) + ] + for i in range(self.num_gemms): + weight_quantizers[i].internal = True + return weight_quantizers diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 3e0844a7a..d67b19c74 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -7,7 +7,7 @@ """LayerNormLinear API""" import os import warnings -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union, List from functools import reduce from operator import mul as multiply_op @@ -67,10 +67,12 @@ ) from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.utils import any_feature_enabled -from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ..export import is_in_onnx_export_mode, assert_warmed_up from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpp_extensions import ( @@ -194,19 +196,13 @@ def forward( # All-gather is not supported with FP8 column-wise data input_quantizer.set_usage(columnwise=False) - # Do TP communication in high precision if quantized format - # does not support communication - force_hp_blockwise_ln_out_gather = ( - fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer) - ) - # Avoid quantized norm kernel if norm output will be returned # or if a gather of ln_out must be in high precision. with_quantized_norm = ( fp8 + and not debug and not return_layernorm_output and not return_layernorm_output_gathered - and not force_hp_blockwise_ln_out_gather ) # ROCm does not currently support quantized norm for Float8CurrentScalingQuantizer @@ -247,15 +243,16 @@ def forward( ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_return = ln_out_total if fp8 or debug: - if not force_hp_blockwise_ln_out_gather: - ln_out = input_quantizer(ln_out) + ln_out = input_quantizer(ln_out) input_quantizer.set_usage(rowwise=True, columnwise=False) + if isinstance(input_quantizer, Float8BlockQuantizer): + input_quantizer.all_gather_usage = False ln_out_total = input_quantizer(ln_out_total) else: quantizer = None if fp8 or debug: quantizer = input_quantizer - if not with_quantized_norm and not force_hp_blockwise_ln_out_gather: + if not with_quantized_norm: ln_out = quantizer(ln_out) quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather @@ -290,7 +287,7 @@ def forward( # Configure quantizer if weight_quantizer is not None: - weight_quantizer.set_usage(rowwise=True, columnwise=keep_fp8_weight_transpose_cache) + weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) # Get quantized weight update_workspace = is_first_microbatch is None or is_first_microbatch @@ -407,7 +404,6 @@ def forward( ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) - ctx.force_hp_blockwise_ln_out_gather = force_hp_blockwise_ln_out_gather # Input with column-wise usage is needed for wgrad GEMM. if backward_needs_input: @@ -415,7 +411,10 @@ def forward( # For sequence parallel in vanilla FP8, rowwise data is # to gather the input. For MXFP8, columnwise only data # can be allgathered. - if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: + if ( + isinstance(ln_out, (MXFP8TensorBase, Float8BlockwiseQTensorBase)) + or not ctx.ln_out_needs_gather + ): ln_out.update_usage(rowwise_usage=False) # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. @@ -467,7 +466,14 @@ def forward( ctx.requires_wgrad = weight.requires_grad ctx.quantized_weight = quantized_weight if fuse_wgrad_accumulation and weight.requires_grad: - ctx.main_grad = weight.main_grad + # This check is needed to ensure that main_grad is not created + # during the forward pass when using MCore FSDP as it creates + # the main_grad buffer lazily before backprop + if hasattr(weight, "__fsdp_param__"): + # MCore FSDP creates main_grad lazily before backward + ctx.main_grad_func = weight.get_main_grad + else: + ctx.main_grad_func = lambda: weight.main_grad ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_weight_quantizer = grad_weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer @@ -517,7 +523,7 @@ def forward( if return_layernorm_output: if return_layernorm_output_gathered: shape = list(inp_shape) - shape[0] *= tp_size + shape[0] *= tp_size if with_input_all_gather else 1 return out, ln_out_return.view(shape) return out, ln_out_return.view(inp_shape) return out @@ -551,7 +557,7 @@ def backward( # Since main_grad can be modified inplace, it should not be a part of saved_tensors main_grad = ( - ctx.main_grad + ctx.main_grad_func() if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad else None ) @@ -651,7 +657,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather: quantizer = None - if ctx.input_quantizer is not None and not ctx.force_hp_blockwise_ln_out_gather: + if ctx.input_quantizer is not None: quantizer = ctx.input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -770,6 +776,31 @@ def backward( wgrad = None if ctx.requires_wgrad: + # Prepare grad output tensor + # Note: Synchronize tensor-parallel communication and + # make sure required data is available + if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): + # UB does not support overlapping grad output + # all-gather with wgrad GEMM. Also, we can't + # convert row-scaled MXFP8 to column-scaled, so we + # can't reuse the grad output that was gathered + # for the dgrad GEMM. We work around by explicitly + # overlapping the NCCL operation with the dgrad GEMM. + ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) + + # Get the communication stream from the dgrad GEMM and set it as the current torch stream + dgrad_comm_stream = ub_obj_dgrad.get_communication_stream() + with torch.cuda.stream(dgrad_comm_stream): + # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather + # This ensures that we don't start until all communication for the dgrad GEMM is complete + grad_output, mxfp8_grad_output_work = gather_along_first_dim( + grad_outputs[0], + ctx.tp_group, + async_op=True, + quantizer=ctx.grad_output_quantizer, + ) + # Synchronize with the main stream + mxfp8_grad_output_work.wait() # Prepare input tensor # Note: Synchronize tensor-parallel communication and @@ -784,22 +815,6 @@ def backward( ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - # Prepare grad output tensor - # Note: Synchronize tensor-parallel communication and - # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): - # UB does not support overlapping grad output - # all-gather with wgrad GEMM. Also, we can't - # convert row-scaled MXFP8 to column-scaled, so we - # can't reuse the grad output that was gathered - # for the dgrad GEMM. We work around with blocking - # all-gather for column-scaled MXFP8 data. - ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - grad_output, _ = gather_along_first_dim( - grad_outputs[0], - ctx.tp_group, - quantizer=ctx.grad_output_quantizer, - ) if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorBase): grad_output.update_usage(columnwise_usage=True) @@ -1426,6 +1441,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe = FP8GlobalStateManager.get_fp8_recipe() if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) + elif recipe.float8_block_scaling(): + self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) # elif other recipes (mxfp8, etc) def reset_layer_norm_parameters(self) -> None: @@ -1498,6 +1515,8 @@ def forward( first microbatch (since it is the first gradient being produced) """ + if is_in_onnx_export_mode(): + return self.onnx_forward(inp, fp8_output) debug = TEDebugState.debug_enabled if debug: self._validate_name() @@ -1521,25 +1540,7 @@ def forward( ) as inp: # Get concatenated weight and bias tensors - unfused_weights = [getattr(self, name) for name in self.weight_names] - if any(isinstance(w, QuantizedTensor) for w in unfused_weights): - if self.fp8: - if len(unfused_weights) != 1: - raise RuntimeError( - "Splitting QuantizedTensor into multiple params is not supported" - ) - else: - warnings.warn( - "You are using quantized weights without quantized compute. " - "Please make sure this is intentional." - ) - unfused_weights = [w.dequantize() for w in unfused_weights] - - weight_tensor = noop_cat(unfused_weights) - if self.use_bias: - bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) - else: - bias_tensor = getattr(self, self.bias_names[0]) # Unused + weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() quantizers = ( self._get_quantizers(fp8_output, fp8_grad) @@ -1642,10 +1643,7 @@ def _get_quantizers(self, fp8_output, fp8_grad): output_quantizer = None input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer.internal = True - weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] - weight_quantizer.internal = True - if IS_HIP_EXTENSION: - weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache) + (weight_quantizer,) = self._get_weight_quantizers() if fp8_output: output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] if torch.is_grad_enabled(): @@ -1674,6 +1672,72 @@ def _get_debug_quantizers(self, fp8_output, fp8_grad): for name, q in zip(names, original_quantizers) ) + def _get_weight_and_bias_tensors(self): + # Get concatenated weight and bias tensors + unfused_weights = self._get_weight_tensors() + + weight_tensor = noop_cat(unfused_weights) + if self.use_bias: + bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) + else: + bias_tensor = getattr(self, self.bias_names[0]) # Unused + return weight_tensor, bias_tensor + + def onnx_forward( + self, + inp: torch.Tensor, + fp8_output: bool, + ) -> torch.Tensor: + """ + ONNX-compatible version of the forward function that provides numerical equivalence + while only using operations that have defined ONNX symbolic translations. + This simplified implementation is designed specifically for inference scenarios. + """ + from ..export import onnx_layernorm, onnx_gemm + + assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export" + assert_warmed_up(self) + ( + input_quantizer, + weight_quantizer, + output_quantizer, + *_, + ) = self._get_quantizers(fp8_output, fp8_grad=False) + inp_dtype = inp.dtype + + weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + ln_out, ln_out_return = onnx_layernorm( + inp, + self.layer_norm_weight, + self.layer_norm_bias, + self.eps, + self.normalization, + self.zero_centered_gamma, + inp_dtype, + self.return_layernorm_output, + input_quantizer, + ) + + if weight_quantizer is not None: + weight_tensor_quantized = weight_quantizer.onnx_quantize(weight_tensor) + weight_tensor = weight_quantizer.onnx_dequantize(weight_tensor_quantized) + weight_tensor = weight_tensor.to(inp_dtype) + + if bias_tensor is not None: + bias_tensor = bias_tensor.to(inp_dtype) + + output = onnx_gemm(weight_tensor, ln_out, bias_tensor if self.apply_bias else None) + + if output_quantizer is not None: + raise NotImplementedError("ONNX export of quantized output is not supported") + if self.return_layernorm_output and self.return_bias: + return output, bias_tensor.to(inp_dtype), ln_out_return + if self.return_layernorm_output: + return output, ln_out_return + if self.return_bias: + return output, bias_tensor.to(inp_dtype) + return output + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: """Customize quantizers based on current scaling recipe + layernorm_linear.""" assert ( @@ -1720,3 +1784,41 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group + + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + """Get the weight tensors of the module.""" + unfused_weights = [getattr(self, name) for name in self.weight_names] + if any(isinstance(w, QuantizedTensor) for w in unfused_weights): + if self.fp8: + if len(unfused_weights) != 1: + raise RuntimeError( + "Splitting QuantizedTensor into multiple params is not supported" + ) + else: + warnings.warn( + "You are using quantized weights without quantized compute. " + "Please make sure this is intentional." + ) + unfused_weights = [w.dequantize() for w in unfused_weights] + return unfused_weights + + def _get_weight_quantizers(self) -> List[Quantizer]: + """Get the weight quantizers of the module.""" + if not self.fp8: + return [None] + weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + weight_quantizer.internal = True + if IS_HIP_EXTENSION: + weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache) + return [weight_quantizer] + + def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on blockwise scaling recipe + layernorm_linear.""" + assert ( + recipe.float8_block_scaling() + ), "blockwise scaling recipe quantizer customization here" + if fwd: + if self.sequence_parallel and self.parallel_mode == "column": + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].all_gather_usage = True diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 2edea459c..8772418c9 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -7,7 +7,7 @@ """LayerNormMLP API""" import os import warnings -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union, List from functools import reduce from operator import mul as multiply_op @@ -80,6 +80,7 @@ from ..cpp_extensions import ( general_gemm, ) +from ..export import is_in_onnx_export_mode, assert_warmed_up from ...debug.pytorch.utils import any_feature_enabled from ...debug.pytorch.debug_state import TEDebugState @@ -92,16 +93,16 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): if recipe is None: - # bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + # bf16 (recipe is None): return { - "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), - "relu": (tex.relu, tex.drelu, tex.dbias_drelu), + "gelu": (tex.gelu, tex.dgelu, None), + "relu": (tex.relu, tex.drelu, None), "geglu": (tex.geglu, tex.dgeglu, None), "reglu": (tex.reglu, tex.dreglu, None), "swiglu": (tex.swiglu, tex.dswiglu, None), - "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), + "qgelu": (tex.qgelu, tex.dqgelu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None), - "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), + "srelu": (tex.srelu, tex.dsrelu, None), } if recipe.delayed() or recipe.mxfp8(): # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] @@ -252,32 +253,23 @@ def forward( # All-gather is not supported with FP8 column-wise data fc1_input_quantizer.set_usage(columnwise=False) - # Do TP communication in high precision if quantized format - # does not support communication - force_hp_fc1_input_gather = ( - fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer) - ) - # for fp8 DelayedScaling: layernorm output = FP8 # only output of the linear is returned # for return_layernorm_output: layernorm output = High precision, then cast to FP8 # high precision layernorm output and output of the linear are returned # for debug: : layernorm output = High precision to enable processing of this norm + with_quantized_norm = ( fp8 + and not debug and not return_layernorm_output and not return_layernorm_output_gathered - and not debug ) # ROCm does not currently support quantized norm for Float8CurrentScalingQuantizer if IS_HIP_EXTENSION and isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer): with_quantized_norm = False - if isinstance(fc1_input_quantizer, Float8BlockQuantizer): - # Kernels not available for norm fusion. - with_quantized_norm = False - # Apply normalization ln_out, mu, rsigma = apply_normalization( inputmat, @@ -306,15 +298,16 @@ def forward( ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_return = ln_out_total if fp8 or debug: - if not force_hp_fc1_input_gather: - ln_out = fc1_input_quantizer(ln_out) + ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) + if isinstance(fc1_input_quantizer, Float8BlockQuantizer): + fc1_input_quantizer.all_gather_usage = False ln_out_total = fc1_input_quantizer(ln_out_total) else: quantizer = None if fp8 or debug: quantizer = fc1_input_quantizer - if not with_quantized_norm and not force_hp_fc1_input_gather: + if not with_quantized_norm: ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag: @@ -346,8 +339,8 @@ def forward( # which handles weight caching etc. # FP8 cast to workspace buffer update_workspace = is_first_microbatch is None or is_first_microbatch - fc1_weight_quantizer.set_usage(rowwise=True, columnwise=keep_fp8_weight_transpose_cache) - fc2_weight_quantizer.set_usage(rowwise=True, columnwise=keep_fp8_weight_transpose_cache) + fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) fc1_weight_final = module.get_weight_workspace( tensor=fc1_weight, quantizer=fc1_weight_quantizer, @@ -581,14 +574,25 @@ def forward( ) if fuse_wgrad_accumulation: - ctx.fc1_main_grad = fc1_weight.main_grad if fc1_weight.requires_grad else None - ctx.fc2_main_grad = fc2_weight.main_grad if fc2_weight.requires_grad else None + # This check is needed to ensure that main_grad is not created + # during the forward pass when using MCore FSDP as it creates + # the main_grad buffer lazily before backprop + if hasattr(fc1_weight, "__fsdp_param__") and hasattr(fc2_weight, "__fsdp_param__"): + # MCore FSDP creates main_grad lazily before backward + ctx.fc1_main_grad_func = ( + fc1_weight.get_main_grad if fc1_weight.requires_grad else lambda: None + ) + ctx.fc2_main_grad_func = ( + fc2_weight.get_main_grad if fc2_weight.requires_grad else lambda: None + ) + else: + ctx.fc1_main_grad_func = lambda: fc1_weight.main_grad + ctx.fc2_main_grad_func = lambda: fc2_weight.main_grad ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.force_hp_fc1_input_gather = force_hp_fc1_input_gather ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer @@ -653,7 +657,7 @@ def forward( if return_layernorm_output: if return_layernorm_output_gathered: shape = list(inp_shape) - shape[0] *= tp_size + shape[0] *= tp_size if (sequence_parallel and set_parallel_mode) else 1 return fc2_out, ln_out_return.view(shape) return fc2_out, ln_out_return.view(inp_shape) return fc2_out @@ -687,14 +691,14 @@ def backward( # Since main_grad can be modified inplace, it should not be a part of saved_tensors fc1_weight_main_grad = ( - ctx.fc1_main_grad + ctx.fc1_main_grad_func() if fc1_weight is not None and ctx.fuse_wgrad_accumulation and ctx.fc1_weight_requires_grad else None ) fc2_weight_main_grad = ( - ctx.fc2_main_grad + ctx.fc2_main_grad_func() if origin_fc2_weight is not None and ctx.fuse_wgrad_accumulation and ctx.fc2_weight_requires_grad @@ -768,7 +772,7 @@ def backward( ub_obj_fc1_dgrad = None if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: quantizer = None - if ctx.fp8 or ctx.debug and not ctx.force_hp_fc1_input_gather: + if ctx.fp8 or ctx.debug: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -869,6 +873,30 @@ def backward( fc2_wgrad = None if ctx.fc2_weight_requires_grad: + # Prepare grad output tensor + # Note: Synchronize tensor-parallel communication and + # make sure required data is available + if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer): + # UB does not support overlapping grad output + # all-gather with wgrad GEMM. Also, we can't + # convert row-scaled MXFP8 to column-scaled, so we + # can't reuse the grad output that was gathered + # for the dgrad GEMM. We work around by explicitly + # overlapping the NCCL operation with the dgrad GEMM. + ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) + # Get the communication stream from the dgrad GEMM and set it as the current torch stream + dgrad_comm_stream = ub_obj_fc2_dgrad.get_communication_stream() + with torch.cuda.stream(dgrad_comm_stream): + # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather + # This ensures that we don't start until all communication for the dgrad GEMM is complete + grad_output, mxfp8_fc2_grad_output_work = gather_along_first_dim( + grad_outputs[0], + ctx.tp_group, + async_op=True, + quantizer=ctx.fc2_grad_output_quantizer, + ) + # Synchronize with the main stream + mxfp8_fc2_grad_output_work.wait() # Prepare input tensor # Note: Synchronize tensor-parallel communication and @@ -880,22 +908,6 @@ def backward( ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - # Prepare grad output tensor - # Note: Synchronize tensor-parallel communication and - # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer): - # UB does not support overlapping grad output - # all-gather with wgrad GEMM. Also, we can't - # convert row-scaled MXFP8 to column-scaled, so we - # can't reuse the grad output that was gathered - # for the dgrad GEMM. We work around with blocking - # all-gather for column-scaled MXFP8 data. - ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - grad_output, _ = gather_along_first_dim( - grad_outputs[0], - ctx.tp_group, - quantizer=ctx.fc2_grad_output_quantizer, - ) if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorBase): grad_output.update_usage(columnwise_usage=True) @@ -1701,8 +1713,11 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: super().set_meta_tensor(fwd, recipe) # customize quantizers based on each recipe & layer configs - if FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling(): + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) + elif recipe.float8_block_scaling(): + self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) # elif for other recipes (mxfp8, etc.) def reset_layer_norm_parameters(self) -> None: @@ -1764,6 +1779,8 @@ def forward( first microbatch (since it is the first gradient being produced) """ + if is_in_onnx_export_mode(): + return self.onnx_forward(inp) debug = TEDebugState.debug_enabled if debug: self._validate_name() @@ -1812,15 +1829,14 @@ def forward( ) = quantizers # Get weight tensors - fc1_weight = self.fc1_weight + fc1_weight, fc2_weight = self._get_weight_tensors() fc1_bias = self.fc1_bias if self.use_bias else None - fc2_weight = self.fc2_weight fc2_bias = self.fc2_bias if self.use_bias else None if not self.fp8: if isinstance(fc1_weight, Float8Tensor): - fc1_weight = fc1_weight.from_float8() + fc1_weight = fc1_weight.dequantize() if isinstance(fc2_weight, Float8Tensor): - fc2_weight = fc2_weight.from_float8() + fc2_weight = fc2_weight.dequantize() # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode if ( not IS_HIP_EXTENSION @@ -1908,35 +1924,26 @@ def forward( def _get_quantizers(self, fp8_output): ( fc1_input_quantizer, - fc1_weight_quantizer, fc1_output_quantizer, fc1_grad_input_quantizer, fc1_grad_weight_quantizer, fc1_grad_output_quantizer, fc2_input_quantizer, - fc2_weight_quantizer, fc2_output_quantizer, fc2_grad_input_quantizer, fc2_grad_weight_quantizer, fc2_grad_output_quantizer, - ) = [None] * 12 + ) = [None] * 10 + fc1_weight_quantizer, fc2_weight_quantizer = self._get_weight_quantizers() if self.fp8: fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer.internal = True - fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] - fc1_weight_quantizer.internal = True - if IS_HIP_EXTENSION: - fc1_weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache) fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer.set_usage( rowwise=True, columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)), ) fc1_input_quantizer.internal = True - fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] - fc2_weight_quantizer.internal = True - if IS_HIP_EXTENSION: - fc2_weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache) if fp8_output: fc2_output_quantizer = self.quantizers["scaling_fwd"][ tex.FP8FwdTensors.GEMM2_OUTPUT @@ -1966,6 +1973,89 @@ def _get_quantizers(self, fp8_output): fc2_grad_output_quantizer, ) + def onnx_forward(self, inp: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + """ + ONNX-compatible version of the forward function that provides numerical equivalence + while only using operations that have defined ONNX symbolic translations. + This simplified implementation is designed specifically for inference scenarios. + """ + from ..export import onnx_layernorm, onnx_gemm + + assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export" + assert_warmed_up(self) + ( + fc1_input_quantizer, + fc1_weight_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + output_quantizer, + *_, + ) = self._get_quantizers(False) + inp_dtype = inp.dtype + + fc1_weight, fc2_weight = self._get_weight_tensors() + fc1_bias = self.fc1_bias if self.use_bias else None + fc2_bias = self.fc2_bias if self.use_bias else None + + # layernorm + fp8 cast + ln_out, ln_out_return = onnx_layernorm( + inp, + self.layer_norm_weight, + self.layer_norm_bias, + self.eps, + self.normalization, + self.zero_centered_gamma, + inp_dtype, + self.return_layernorm_output, + fc1_input_quantizer, + ) + + if fc1_weight_quantizer is not None: + fc1_weight_q = fc1_weight_quantizer.onnx_quantize(fc1_weight) + fc1_weight = fc1_weight_quantizer.onnx_dequantize(fc1_weight_q) + fc1_weight = fc1_weight.to(inp_dtype) + + fc1_out = onnx_gemm(fc1_weight, ln_out, fc1_bias) + + fc1_out = fc1_out.to(torch.float32) # activation is computed in fp32 + + activation_map = { + "gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + "relu": torch.nn.functional.relu, + "geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], + "reglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], + "swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], + "qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh") + * x.chunk(2, -1)[1], + "qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + "srelu": torch.nn.functional.softplus, + } + if self.activation not in activation_map: + raise ValueError(f"Unsupported activation in onnx export: {self.activation}") + act_out = activation_map[self.activation](fc1_out) + if fc2_weight_quantizer is not None: + fc2_weight_q = fc2_weight_quantizer.onnx_quantize(fc2_weight) + fc2_weight = fc2_weight_quantizer.onnx_dequantize(fc2_weight_q) + fc2_weight = fc2_weight.to(inp_dtype) + + if fc2_input_quantizer is not None: + act_out_q = fc2_input_quantizer.onnx_quantize(act_out) + act_out = fc2_input_quantizer.onnx_dequantize(act_out_q) + act_out = act_out.to(inp_dtype) + + fc2_out = onnx_gemm(fc2_weight, act_out, fc2_bias) + + if output_quantizer is not None: + raise NotImplementedError("ONNX export of quantized output is not supported") + + if self.return_layernorm_output: + if self.return_bias: + return fc2_out, fc2_bias.to(inp_dtype), ln_out_return + return fc2_out, ln_out_return + if self.return_bias: + return fc2_out, fc2_bias.to(inp_dtype) + return fc2_out + def _get_debug_quantizers(self, fp8_output): from ...debug.pytorch.debug_quantization import DebugQuantizer @@ -2053,6 +2143,40 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_OUTPUT2 ].amax_reduction_group = self.tp_group + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + """Get the weight tensors of the module.""" + return [self.fc1_weight, self.fc2_weight] + + def _get_weight_quantizers(self) -> List[Quantizer]: + """Get the weight quantizers of the module.""" + if not self.fp8: + return [None, None] + fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + fc1_weight_quantizer.internal = True + if IS_HIP_EXTENSION: + fc1_weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache) + fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] + fc2_weight_quantizer.internal = True + if IS_HIP_EXTENSION: + fc2_weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache) + return [fc1_weight_quantizer, fc2_weight_quantizer] + + def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on blockwise scaling recipe + layernorm_mlp.""" + assert ( + recipe.float8_block_scaling() + ), "blockwise scaling recipe quantizer customization here" + if fwd: + if self.sequence_parallel and self.set_parallel_mode: + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].all_gather_usage = True + else: + if self.sequence_parallel and self.set_parallel_mode: + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT2 + ].all_gather_usage = True + def backward_dw(self): """ Execute the delayed weight gradient computation. diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index a8361ef94..9ca11801c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -5,7 +5,7 @@ # See LICENSE for license information. """Linear API""" -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union, List from functools import reduce from operator import mul as multiply_op import warnings @@ -68,7 +68,8 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer +from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ..export import is_in_onnx_export_mode, assert_warmed_up from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.utils import any_feature_enabled @@ -119,6 +120,7 @@ def forward( module: torch.nn.Module, skip_fp8_weight_update: bool, symmetric_ar_type: str, + save_original_input: bool = False, debug: Optional[bool] = False, keep_fp8_weight_transpose_cache: bool = True, use_fsdp2: bool = False, @@ -141,12 +143,6 @@ def forward( parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop ) - # Do TP communication in high precision if quantized format - # does not support communication - force_hp_input_gather = ( - fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer) - ) - # Configure Userbuffers communication (comm+GEMM overlap) ub_obj = None ub_type = None @@ -167,14 +163,21 @@ def forward( own_quantized_input = False if fp8: assert_dim_for_fp8_exec(inputmat, weight) + if save_original_input: + assert not isinstance( + input_quantizer, Float8Quantizer + ), "DelayedScaling recipe is not supported with save_original_input" + if with_input_all_gather_nccl or ub_overlap_ag_fprop: # All-gather input tensor # Cast local input tensor if needed if fp8 or debug: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - if not force_hp_input_gather and not isinstance(inputmat, QuantizedTensorBase): - input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + if not isinstance(inputmat, QuantizedTensorBase): + input_quantizer.set_usage( + rowwise=True, columnwise=backward_needs_input and not save_original_input + ) if isinstance( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): @@ -211,7 +214,9 @@ def forward( else: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + input_quantizer.set_usage( + rowwise=True, columnwise=backward_needs_input and not save_original_input + ) inputmat = input_quantizer(inputmat) own_quantized_input = True else: @@ -343,6 +348,9 @@ def forward( # ------------------------------------------------------ if is_grad_enabled: + if save_original_input: + inputmat = inp + ctx.weight_quantizer = weight_quantizer saved_inputmat = None @@ -351,14 +359,16 @@ def forward( ) if backward_needs_input: - if own_quantized_input and isinstance(inputmat, QuantizedTensorBase): - # For sequence parallel in vanilla FP8, rowwise data is - # to gather the input. For MXFP8, columnwise only data - # can be allgathered. - if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather: - inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) - if force_hp_input_gather: - assert not isinstance(inputmat, QuantizedTensorBase) + if not save_original_input: + if own_quantized_input and isinstance(inputmat, QuantizedTensorBase): + # For sequence parallel in vanilla FP8, rowwise data is + # to gather the input. For MXFP8, columnwise only data + # can be allgathered. + if ( + isinstance(inputmat, (MXFP8TensorBase, Float8BlockwiseQTensorBase)) + or not ctx.backward_input_needs_gather + ): + inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) saved_inputmat = inputmat # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. @@ -404,14 +414,20 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.force_hp_input_gather = force_hp_input_gather ctx.input_quantizer = input_quantizer ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_weight_quantizer = grad_weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation if fuse_wgrad_accumulation and weight.requires_grad: - ctx.main_grad = weight.main_grad + # This check is needed to ensure that main_grad is not created + # during the forward pass when using MCore FSDP as it creates + # the main_grad buffer lazily before backprop + if hasattr(weight, "__fsdp_param__"): + # MCore FSDP creates main_grad lazily before backward + ctx.main_grad_func = weight.get_main_grad + else: + ctx.main_grad_func = lambda: weight.main_grad ctx.debug = debug ctx.cpu_offloading = cpu_offloading @@ -471,7 +487,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Since main_grad can be modified inplace, it should not be a part of saved_tensors main_grad = ( - ctx.main_grad + ctx.main_grad_func() if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad else None ) @@ -567,9 +583,27 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # -------------------------------------------------- inputmat_total = None inputmat_total_work = None + if ctx.requires_wgrad: + input_is_quantized = isinstance(inputmat, QuantizedTensorBase) + if ctx.fp8 or ctx.debug: + if not input_is_quantized: + quantizer = ctx.input_quantizer + if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): + quantizer.set_usage( + rowwise=True, + columnwise=not ctx.backward_input_needs_gather, + ) + else: + quantizer.set_usage(rowwise=False, columnwise=True) + inputmat = quantizer(inputmat) + else: + if input_is_quantized: + inputmat = inputmat.dequantize(dtype=ctx.activation_dtype) + else: + inputmat = cast_if_needed(inputmat, ctx.activation_dtype) if ctx.backward_input_needs_gather: quantizer = None - if (ctx.fp8 or ctx.debug) and not ctx.force_hp_input_gather: + if ctx.fp8 or ctx.debug: quantizer = ctx.input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -710,14 +744,23 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we # can't reuse the grad output that was gathered - # for the dgrad GEMM. We work around with blocking - # all-gather for column-scaled MXFP8 data. + # for the dgrad GEMM. We work around by explicitly + # overlapping the NCCL operation with the dgrad GEMM. ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - grad_output, _ = gather_along_first_dim( - grad_output_arg, - ctx.tp_group, - quantizer=ctx.grad_output_quantizer, - ) + # Get the communication stream from the dgrad GEMM and set it as the current torch stream + dgrad_comm_stream = ub_obj_dgrad.get_communication_stream() + with torch.cuda.stream(dgrad_comm_stream): + # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather + # This ensures that we don't start until all communication for the dgrad GEMM is complete + grad_output, grad_output_work = gather_along_first_dim( + grad_output_arg, + ctx.tp_group, + async_op=True, + quantizer=ctx.grad_output_quantizer, + ) + # Synchronize with the main stream + grad_output_work.wait() + if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorBase): grad_output.update_usage(columnwise_usage=True) @@ -904,6 +947,7 @@ def wgrad_gemm( None, # module None, # skip_fp8_weight_update None, # symmetric_ar_type + None, # save_original_input None, # debug None, # keep_fp8_weight_transpose_cache None, # use_fsdp2 @@ -1002,6 +1046,11 @@ class Linear(TransformerEngineBaseModule): reduced efficiency of PyTorch's caching allocator. Use this setting to balance memory usage and performance based on your training configuration. + save_original_input : bool, default = `False` + If set to `True`, always saves the original input tensor rather than the + cast tensor. In some scenarios, the input tensor is used by multiple modules, + and saving the original input tensor may reduce the memory usage. + Cannot work with FP8 DelayedScaling recipe. """ def __init__( @@ -1029,6 +1078,7 @@ def __init__( ub_name: Optional[str] = None, delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, + save_original_input: bool = False, name: Optional[str] = None, keep_fp8_weight_transpose_cache: bool = True, use_fsdp2: bool = False @@ -1045,6 +1095,7 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name self.symmetric_ar_type = symmetric_ar_type + self.save_original_input = save_original_input self.name = name if TEDebugState.debug_enabled: @@ -1251,6 +1302,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe = FP8GlobalStateManager.get_fp8_recipe() if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) + elif recipe.float8_block_scaling(): + self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) # elif for other recipes (mxfp8, etc.) def reset_parameters(self, defer_init=False): @@ -1303,6 +1356,9 @@ def forward( first microbatch (since it is the first gradient being produced) """ + if is_in_onnx_export_mode(): + return self.onnx_forward(inp, fp8_output) + debug = TEDebugState.debug_enabled if debug: self._validate_name() @@ -1326,26 +1382,7 @@ def forward( allow_non_contiguous=isinstance(inp, QuantizedTensor), ) as inp: - # Get concatenated weight and bias tensors - unfused_weights = [getattr(self, name) for name in self.weight_names] - if any(isinstance(w, QuantizedTensor) for w in unfused_weights): - if self.fp8: - if len(unfused_weights) != 1: - raise RuntimeError( - "Splitting QuantizedTensor into multiple params is not supported" - ) - else: - warnings.warn( - "You are using quantized weights without quantized compute. " - "Please make sure this is intentional." - ) - unfused_weights = [w.dequantize() for w in unfused_weights] - - weight_tensor = noop_cat(unfused_weights) - if self.use_bias: - bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) - else: - bias_tensor = None + weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() quantizers = ( self._get_quantizers(fp8_output, fp8_grad) @@ -1370,12 +1407,6 @@ def forward( grad_output_quantizer, ) = quantizers - # Make sure weight tensor has correct quantizer - # Note: Quantizer might have changed if quantization - # recipe changed - if weight_quantizer is not None and isinstance(weight_tensor, QuantizedTensor): - weight_tensor._quantizer = weight_quantizer - if torch.is_grad_enabled(): linear_fn = _Linear.apply args = [] @@ -1417,6 +1448,7 @@ def forward( self, skip_fp8_weight_update, self.symmetric_ar_type, + self.save_original_input, debug, self.keep_fp8_weight_transpose_cache, self.use_fsdp2 @@ -1438,11 +1470,7 @@ def _get_quantizers(self, fp8_output, fp8_grad): output_quantizer = None input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer.internal = True - weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] - weight_quantizer.internal = True - if IS_HIP_EXTENSION: - weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache) - + (weight_quantizer,) = self._get_weight_quantizers() if fp8_output: output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] if torch.is_grad_enabled(): @@ -1470,6 +1498,95 @@ def _get_debug_quantizers(self, fp8_output, fp8_grad): for name, q in zip(names, original_quantizers) ) + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + """Get the weight tensors of the module.""" + unfused_weights = [getattr(self, name) for name in self.weight_names] + if any(isinstance(w, QuantizedTensor) for w in unfused_weights): + if self.fp8: + if len(unfused_weights) != 1: + raise RuntimeError( + "Splitting QuantizedTensor into multiple params is not supported" + ) + else: + warnings.warn( + "You are using quantized weights without quantized compute. " + "Please make sure this is intentional." + ) + unfused_weights = [w.dequantize() for w in unfused_weights] + return unfused_weights + + def _get_weight_and_bias_tensors(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Get concatenated weight and bias tensors + unfused_weights = self._get_weight_tensors() + if any(isinstance(w, QuantizedTensor) for w in unfused_weights): + if self.fp8: + if len(unfused_weights) != 1: + raise RuntimeError( + "Splitting QuantizedTensor into multiple params is not supported" + ) + else: + warnings.warn( + "You are using quantized weights without quantized compute. " + "Please make sure this is intentional." + ) + unfused_weights = [w.dequantize() for w in unfused_weights] + + weight_tensor = noop_cat(unfused_weights) + if self.use_bias: + bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) + else: + bias_tensor = None + + return weight_tensor, bias_tensor + + def onnx_forward( + self, + inp: torch.Tensor, + fp8_output: bool, + ) -> torch.Tensor: + """ + ONNX-compatible version of the forward function that provides numerical equivalence + while only using operations that have defined ONNX symbolic translations. + This simplified implementation is designed specifically for inference scenarios. + """ + from ..export import onnx_gemm + + assert_warmed_up(self) + assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export." + weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + ( + input_quantizer, + weight_quantizer, + output_quantizer, + *_, + ) = self._get_quantizers(fp8_output, False) + inp_dtype = inp.dtype + + if input_quantizer is not None: + inp_q = input_quantizer.onnx_quantize(inp) + inp = input_quantizer.onnx_dequantize(inp_q) + inp = inp.to(inp_dtype) + + if weight_quantizer is not None: + weight_q = weight_quantizer.onnx_quantize(weight_tensor) + weight_tensor = weight_quantizer.onnx_dequantize(weight_q) + if bias_tensor is not None: + bias_tensor = bias_tensor.to(inp_dtype) + weight_tensor = weight_tensor.to(inp_dtype) + + if self.apply_bias: + output = onnx_gemm(weight_tensor, inp, bias_tensor) + else: + output = onnx_gemm(weight_tensor, inp, None) + + if output_quantizer is not None: + raise NotImplementedError("ONNX export of quantized output is not supported") + + if self.return_bias: + return output, bias_tensor + + return output + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: """Customize quantizers based on current scaling recipe + linear.""" assert ( @@ -1516,3 +1633,32 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group + + def _get_weight_quantizers(self) -> List[Quantizer]: + """Get the weight quantizers of the module.""" + if not self.fp8: + return [None] + weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + weight_quantizer.internal = True + if IS_HIP_EXTENSION: + weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache) + return [weight_quantizer] + + def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on blockwise scaling recipe + linear.""" + assert ( + recipe.float8_block_scaling() + ), "blockwise scaling recipe quantizer customization here" + + if fwd: + if self.sequence_parallel and self.parallel_mode == "column": + # set compact for inp tensor X + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].all_gather_usage = True + else: + if self.sequence_parallel and self.parallel_mode == "row": + # set compact for grad_output tensor dY + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].all_gather_usage = True diff --git a/transformer_engine/pytorch/onnx_extensions.py b/transformer_engine/pytorch/onnx_extensions.py new file mode 100644 index 000000000..e34fd7846 --- /dev/null +++ b/transformer_engine/pytorch/onnx_extensions.py @@ -0,0 +1,362 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" + +File containing torch.ops extensions and their corresponding ONNX symbolic functions. + +Many transformer engine layers rely on custom calls from the transformer_engine_torch module, making ONNX export challenging because: +1. They often accept Python objects (quantizers), which ONNX does not support. +2. They are complex, incorporating fusions and precomputing certain values for backward passes—mechanisms unnecessary for ONNX export. + +For these reasons, we introduce onnx_forward methods in each layer that are simpler and +primarily leverage torch operators with known ONNX symbolic functions. +These methods avoid fusions and backward pass precomputations. +The main considerations are quantization—which PyTorch does not natively support, so we need to implement onnx symbolic functions on our own. + +Since ONNX does not yet support quantization, operators from TensorRT are employed. +The primary goal of ONNX export is to enable inference compatibility with TensorRT. + +""" + +from typing import Tuple +import math +import torch +import onnxscript +from onnxscript import opset18 as op +from onnx import defs +import transformer_engine_torch as tex + +from .tensor.float8_tensor import Float8Quantizer +from .tensor.mxfp8_tensor import MXFP8Quantizer +from .constants import MXFP8_BLOCK_SCALING_SIZE +from .utils import round_up_to_nearest_multiple +from .export import is_in_onnx_export_mode + +trt_opset = onnxscript.values.Opset( + "trt", version=1 +) # opset from TensorRT which supports FP8 quantization + +# ONNX GEMM for inference + + +def onnx_gemm(weight: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: + """ONNX GEMM used for inference.""" + reshaped_inp = inp.reshape(-1, inp.shape[-1]) + out = torch_onnx_gemm_inf_op(weight, reshaped_inp, bias) + return out.reshape(inp.shape[:-1] + (-1,)) + + +@torch.library.custom_op("tex::gemm_inf", mutates_args=[]) +def torch_onnx_gemm_inf_op( + weight: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + """Gemm used for inference -- weight is transposed""" + out = inp @ weight.T + if bias is not None: + out = out + bias + return out + + +@torch_onnx_gemm_inf_op.register_fake +def _(weight, inp, bias): + """Fake gemm used for inference.""" + out = inp @ weight.T + if bias is not None: + out = out + bias + return out + + +def onnx_gemm_inf_symbolic( + weight: onnxscript.onnx_types.TensorType, + inp: onnxscript.onnx_types.TensorType, + bias: onnxscript.onnx_types.TensorType, +) -> onnxscript.onnx_types.TensorType: + """Symbolic gemm used for inference.""" + return op.Gemm(inp, weight, bias, transA=0, transB=1) + + +# ONNX FP8 Quantization + + +@torch.library.custom_op("tex::fp8_quantize", mutates_args=[]) +def onnx_quantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor: + """Quantize to Float8Tensor used for inference.""" + scale_tensor = torch.tensor(scale, dtype=torch.float32, device=tensor.device) + amax_tensor = torch.tensor([1], dtype=torch.float32, device=tensor.device) + quantizer = Float8Quantizer(scale_tensor, amax_tensor, tex.DType.kFloat8E4M3) + return quantizer.quantize(tensor)._data + + +@onnx_quantize_fp8_op.register_fake +def _(tensor, *_): + """Fake quantize to Float8Tensor used for inference.""" + return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device) + + +def onnx_quantize_fp8_symbolic( + tensor: onnxscript.onnx_types.TensorType, + scale: float, +) -> onnxscript.onnx_types.UINT8: + """Symbolic quantize used for inference.""" + scale_inv = op.Constant(value_float=1 / scale) + return TRT_FP8QuantizeLinear(tensor, scale_inv) + + +# Define the schema for the custom operator +schema = defs.OpSchema( + name="TRT_FP8QuantizeLinear", + domain="trt", + since_version=1, + doc="TRT FP8 Quantize Linear used for inference.", + inputs=[ + defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"), + defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for quantization"), + ], + outputs=[defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor")], +) + +TRT_FP8QuantizeLinear = onnxscript.values.Op( + opset=trt_opset, name="TRT_FP8QuantizeLinear", op_schema=schema +) + + +# ONNX FP8 Dequantization + + +@torch.library.custom_op("tex::fp8_dequantize", mutates_args=[]) +def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor: + """Dequantize from Float8Tensor used for inference.""" + scale_tensor = torch.tensor(scale, dtype=torch.float32, device=tensor.device) + quantizer = Float8Quantizer( + scale_tensor, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3 + ) + quantizer_tensor = quantizer.create_tensor_from_data(tensor, fake_dtype=torch.float32) + return quantizer_tensor.dequantize() + + +@onnx_dequantize_fp8_op.register_fake +def _(tensor: torch.Tensor, _) -> torch.Tensor: + """Fake dequantize from Float8Tensor used for inference.""" + return torch.empty(tensor.shape, dtype=torch.float32, device=tensor.device) + + +def onnx_dequantize_fp8_symbolic( + tensor: onnxscript.onnx_types.TensorType, scale: float +) -> onnxscript.onnx_types.TensorType: + """Symbolic dequantize from Float8Tensor used for inference.""" + scale_inv = op.Constant(value_float=1 / scale) + return TRT_FP8DequantizeLinear(tensor, scale_inv) + + +schema = defs.OpSchema( + name="TRT_FP8DequantizeLinear", + domain="trt", + since_version=1, + doc="TRT FP8 Dequantize Linear from Float8Tensor used for inference.", + inputs=[ + defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"), + defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for dequantization"), + ], + outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")], +) + +TRT_FP8DequantizeLinear = onnxscript.values.Op( + opset=trt_opset, name="TRT_FP8DequantizeLinear", op_schema=schema +) + + +# ONNX MXFP8 Quantization + + +@torch.library.custom_op("tex::mxfp8_quantize", mutates_args=[]) +def onnx_quantize_mxfp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize to MXFP8Tensor used for inference.""" + quantizer = MXFP8Quantizer(tex.DType.kFloat8E4M3) + quantized_tensor = quantizer(tensor) + return quantized_tensor._rowwise_data, quantized_tensor._rowwise_scale_inv + + +@onnx_quantize_mxfp8_op.register_fake +def _(tensor: torch.Tensor): + """Fake quantize to MXFP8Tensor used for inference.""" + mxfp8_scale_shape = [ + round_up_to_nearest_multiple(math.prod(tensor.shape[:-1]), 128), + round_up_to_nearest_multiple(tensor.shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + ] + return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device), torch.empty( + mxfp8_scale_shape, dtype=torch.uint8, device=tensor.device + ) + + +def onnx_quantize_mxfp8_symbolic( + tensor: onnxscript.onnx_types.TensorType, +) -> Tuple[onnxscript.onnx_types.TensorType, onnxscript.onnx_types.TensorType]: + """Symbolic quantize to MXFP8Tensor used for inference.""" + tensor_out, scale_inv_out = TRT_MXFP8QuantizeLinear(tensor) + return tensor_out, scale_inv_out + + +schema = defs.OpSchema( + name="TRT_MXFP8QuantizeLinear", + domain="trt", + since_version=1, + doc="TRT MXFP8 Quantize Linear used for inference.", + inputs=[ + defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"), + ], + outputs=[ + defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor"), + defs.OpSchema.FormalParameter( + "scale_inv", "tensor(uint8)", "Scale factor for quantization" + ), + ], +) + +TRT_MXFP8QuantizeLinear = onnxscript.values.Op( + opset=trt_opset, name="TRT_MXFP8QuantizeLinear", op_schema=schema +) + + +# ONNX MXFP8 Dequantization + + +@torch.library.custom_op("tex::mxfp8_dequantize", mutates_args=[]) +def onnx_dequantize_mxfp8_op(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor: + """Dequantize from MXFP8Tensor used for inference.""" + quantizer = MXFP8Quantizer(tex.DType.kFloat8E4M3) + quantizer_tensor = quantizer.create_tensor_from_data( + tensor, scale_inv, fake_dtype=torch.float32 + ) + return quantizer_tensor.dequantize() + + +@onnx_dequantize_mxfp8_op.register_fake +def _(tensor: torch.Tensor, _): + """Fake dequantize from MXFP8Tensor used for inference.""" + return torch.empty(tensor.shape, dtype=torch.float32, device=tensor.device) + + +def onnx_dequantize_mxfp8_symbolic( + tensor: onnxscript.onnx_types.TensorType, scale_inv: onnxscript.onnx_types.TensorType +) -> onnxscript.onnx_types.TensorType: + """Symbolic dequantize from MXFP8Tensor used for inference.""" + return TRT_MXFP8DequantizeLinear(tensor, scale_inv) + + +schema = defs.OpSchema( + name="TRT_MXFP8DequantizeLinear", + domain="trt", + since_version=1, + doc="TRT MXFP8 Dequantize Linear from MXFP8Tensor used for inference.", + inputs=[ + defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"), + defs.OpSchema.FormalParameter( + "scale_inv", "tensor(uint8)", "Scale factor for dequantization" + ), + ], + outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")], +) + +TRT_MXFP8DequantizeLinear = onnxscript.values.Op( + opset=trt_opset, name="TRT_MXFP8DequantizeLinear", op_schema=schema +) + + +# ONNX LayerNorm + + +@torch.library.custom_op("tex::layernorm", mutates_args=[]) +def onnx_layernorm_op( + inp: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float +) -> torch.Tensor: + """ONNX LayerNorm used for inference.""" + model = tex.LayerNorm(inp.shape[1], eps=eps) + model.weight.data = weight + model.bias.data = bias + return model(inp) + + +@onnx_layernorm_op.register_fake +def _(inp, *_): + """Fake ONNX LayerNorm used for inference.""" + return inp + + +def onnx_layernorm_symbolic( + inp: onnxscript.onnx_types.TensorType, + weight: onnxscript.onnx_types.TensorType, + bias: onnxscript.onnx_types.TensorType, + eps: float, +) -> onnxscript.onnx_types.TensorType: + """Symbolic ONNX LayerNorm used for inference.""" + return op.LayerNormalization(inp, weight, bias, epsilon=eps) + + +# onnx layernorm helper function - handles layernorm with quantization + + +def onnx_layernorm( + inp: torch.Tensor, + layer_norm_weight: torch.Tensor, + layer_norm_bias: torch.Tensor, + eps: float, + normalization: str, + zero_centered_gamma: bool, + output_dtype: torch.dtype, + return_layernorm_output: bool, + input_quantizer, +) -> torch.Tensor: + """ONNX LayerNorm used for inference.""" + ln_weight = layer_norm_weight if not zero_centered_gamma else layer_norm_weight + 1 + ln_weight = ln_weight.to(inp.dtype).to(torch.float32) + inp = inp.to(torch.float32) + layer_norm_bias = ( + layer_norm_bias.to(output_dtype).to(torch.float32) if layer_norm_bias is not None else None + ) + + if normalization == "RMSNorm": + ln_out = torch.nn.functional.rms_norm(inp, inp.shape[-1:], ln_weight, eps) + else: + ln_out = torch.nn.functional.layer_norm( + inp, inp.shape[-1:], ln_weight, layer_norm_bias, eps + ) + ln_out_return = ln_out + + if input_quantizer is not None: + if return_layernorm_output: + # In case of return_layernorm_output, layernorm is not fused with fp8 cast, + # so we cast to input_dtype and then perform cast to fp8 if needed + ln_out = ln_out.to(output_dtype).to(torch.float32) + ln_out_return = ln_out + elif isinstance(input_quantizer, MXFP8Quantizer): + # layernorm + mxfp8 quantizer behaves differently + ln_out = ln_out.to(output_dtype).to(torch.float32) + ln_out_quantized = input_quantizer.onnx_quantize(ln_out) + ln_out = input_quantizer.onnx_dequantize(ln_out_quantized) + ln_out = ln_out.to(output_dtype) + return ln_out, ln_out_return + + +# utility functions + + +def onnx_attention_mask_func( + attention_scores: torch.Tensor, attention_mask: torch.Tensor +) -> torch.Tensor: + """Get attention mask without inp""" + assert is_in_onnx_export_mode() + return attention_scores.masked_fill(attention_mask, -10000.0) + + +# This translation table should be passed to torch.onnx.export function +# using the custom_translation_table=te_translation_table option. +te_translation_table = { + torch.ops.tex.gemm_inf.default: onnx_gemm_inf_symbolic, + torch.ops.tex.fp8_quantize.default: onnx_quantize_fp8_symbolic, + torch.ops.tex.fp8_dequantize.default: onnx_dequantize_fp8_symbolic, + torch.ops.tex.mxfp8_quantize.default: onnx_quantize_mxfp8_symbolic, + torch.ops.tex.mxfp8_dequantize.default: onnx_dequantize_mxfp8_symbolic, + torch.ops.tex.layernorm.default: onnx_layernorm_symbolic, +} diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 20e63e0e6..8e997428f 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -5,7 +5,7 @@ """Helper functions used in fusible operations.""" from __future__ import annotations -from typing import Any, Iterable, Optional +from typing import Optional import torch @@ -13,84 +13,24 @@ from .. import torch_version from ..fp8 import FP8GlobalStateManager from ..tensor.float8_tensor import Float8Tensor -from ..utils import ( - canonicalize_device, - canonicalize_dtype, - devices_match, -) - - -def is_float8_tensor(tensor: Any) -> bool: - """Check if object is a `Float8Tensor`""" - return isinstance(tensor, Float8Tensor) - - -def convert_tensor( - tensor: torch.Tensor | Float8Tensor, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - memory_format: torch.memory_format = torch.preserve_format, -) -> torch.Tensor | Float8Tensor: - """Convert tensor attributes, keeping same data if possible""" - - # Default kwargs - if device is None: - device = tensor.device - device = canonicalize_device(device) - if dtype is None: - dtype = tensor.dtype - dtype = canonicalize_dtype(dtype) - - # Make sure output is detached from autograd graph - tensor = tensor.detach() - - # Return immediately if tensor already has desired attributes - if devices_match(device, tensor.device) and dtype == tensor.dtype: - if memory_format == torch.preserve_format or tensor.is_contiguous( - memory_format=memory_format - ): - return tensor - - # Convert FP8 tensor - if is_float8_tensor(tensor): - data = tensor._data - if not devices_match(device, data.device): - data = data.to(device=device) - if memory_format != torch.preserve_format and not data.is_contiguous( - memory_format=memory_format - ): - # Note: torch.Tensor.to ignores memory_format kwarg (see - # https://github.com/pytorch/pytorch/issues/132020). - data = data.contiguous(memory_format=memory_format) - out = Float8Tensor.make_like(tensor, dtype=dtype) - out.data = data - return out - - # Convert standard PyTorch tensor - tensor = tensor.to(device=device, dtype=dtype) - if memory_format != torch.preserve_format and not tensor.is_contiguous( - memory_format=memory_format - ): - # Note: torch.Tensor.to ignores memory_format kwarg (see - # https://github.com/pytorch/pytorch/issues/132020). - tensor = tensor.contiguous(memory_format=memory_format) - return tensor +from ..tensor.quantized_tensor import QuantizedTensorBase +from ..utils import canonicalize_dtype + +def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorBase) -> bool: + """Check if tensor is a quantized tensor""" + return isinstance(tensor, QuantizedTensorBase) -def reshape( - tensor: torch.Tensor | Float8Tensor, - shape: Iterable[int], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, -) -> torch.Tensor | Float8Tensor: - """Reshape tensor, keeping same data if possible""" - tensor = convert_tensor( - tensor, - device=device, - dtype=dtype, - memory_format=torch.contiguous_format, - ) - return tensor.reshape(*shape) + +def maybe_dequantize( + tensor: torch.Tensor | QuantizedTensorBase, dtype: torch.dtype | None = None +) -> torch.Tensor: + """Dequantize tensor to given dtype or just convert if not a quantized tensor""" + if is_quantized_tensor(tensor): + return tensor.dequantize(dtype=dtype) + if dtype is not None and tensor.dtype != dtype: + return tensor.to(dtype) + return tensor def maybe_autocast_dtype( diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index ae635c956..c69e3df02 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -11,6 +11,7 @@ from .basic_linear import BasicLinear from .bias import Bias from .identity import Identity +from .l2normalization import L2Normalization from .layer_norm import LayerNorm from .make_extra_output import MakeExtraOutput from .quantize import Quantize diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index aa0bb1a52..c077829a3 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -12,11 +12,10 @@ import transformer_engine_torch as tex from ...fp8 import FP8GlobalStateManager -from ...tensor import QuantizedTensor -from ...tensor.float8_tensor import Float8CurrentScalingQuantizer -from ...utils import clear_tensor_data, devices_match +from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer +from ...utils import clear_tensor_data from ..op import BasicOperation, OperationContext -from .._common import reshape +from .._common import maybe_dequantize class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): @@ -72,8 +71,8 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op: Optional[BasicOperation] = None, - next_op: Optional[BasicOperation] = None, + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: # Compute dtype @@ -86,44 +85,28 @@ def op_forward( raise RuntimeError(f"Unsupported dtype ({dtype})") # Check input tensor - x = input_ - if isinstance(x, QuantizedTensor): - x = x.dequantize() - if x.device.type != "cuda": - x = x.cuda() - if x.dtype != dtype: - x = x.to(dtype=dtype) - if not x.is_contiguous(): - x = x.contiguous() - - # Check if FP8 is enabled - fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() - if fp8_enabled and next_op is not None and next_op.num_quantizers("forward") > 0: - quantizer = next_op.get_quantizer("forward", 0) - else: - quantizer = None + x = maybe_dequantize(input_.contiguous(), dtype) - # Launch kernel - y = self._activation_forward_impl( - reshape(x, (-1, x.size(-1))), - quantizer, - ) + # Check if quantized compute is enabled + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + quantizer = None + if with_quantized_compute: + quantizer = next_op_input_quantizer - # Check output tensor - if y.dim() != x.dim(): - y = y.reshape(list(x.shape[:-1]) + [-1]) + # Launch kernel + y = self._activation_forward_impl(x, quantizer) # Quantize input to FP8 before caching if needed if self.cache_quantized_input: - quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device) - quantizer.set_usage(rowwise=True, columnwise=False) - x = quantizer(x) + input_quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device) + input_quantizer.set_usage(rowwise=True, columnwise=False) + x = input_quantizer(x) # Save state for backward pass - ctx.save_for_backward(x.detach()) - ctx.fp8_enabled = fp8_enabled + ctx.save_for_backward(x) + ctx.with_quantized_compute = with_quantized_compute ctx.dtype = dtype - ctx.prev_op = prev_op + ctx.prev_op_grad_input_quantizer = prev_op_grad_input_quantizer return y @@ -137,36 +120,21 @@ def op_backward( (x,) = ctx.saved_tensors # Check input tensor - if isinstance(x, QuantizedTensor): - x = x.dequantize(dtype=ctx.dtype) - elif x.dtype != ctx.dtype: - x = x.to(dtype=ctx.dtype) - if not x.is_contiguous(): - x = x.contiguous() + x = maybe_dequantize(x.contiguous(), ctx.dtype) # Check grad output tensor - dy = grad_output - if isinstance(dy, QuantizedTensor): - dy = dy.dequantize(dtype=ctx.dtype) - if not devices_match(dy.device, x.device) or dy.dtype != x.dtype: - dy = dy.to(device=x.device, dtype=x.dtype) - if not dy.is_contiguous(): - dy = dy.contiguous() + dy = maybe_dequantize(grad_output.contiguous(), x.dtype) - # Launch kernel - dx = self._activation_backward_impl( - reshape(dy, (-1, dy.size(-1))), - reshape(x, (-1, x.size(-1))), - None, - ) + # Check if quantized compute is enabled + quantizer = None + if ctx.with_quantized_compute: + quantizer = ctx.prev_op_grad_input_quantizer - # Check grad input tensor - if dx.size() != x.size(): - dx = dx.reshape(x.size()) + # Launch kernel + dx = self._activation_backward_impl(dy, x, quantizer) # Clear input tensor if possible - if ctx.prev_op is not None: - clear_tensor_data(x) + clear_tensor_data(x) return dx, () diff --git a/transformer_engine/pytorch/ops/basic/add_in_place.py b/transformer_engine/pytorch/ops/basic/add_in_place.py index 4ccbaef1c..e1493d3c7 100644 --- a/transformer_engine/pytorch/ops/basic/add_in_place.py +++ b/transformer_engine/pytorch/ops/basic/add_in_place.py @@ -15,6 +15,8 @@ OperationContext, ) +from transformer_engine.pytorch.tensor import Quantizer + class AddInPlace(BasicOperation): """Add in-place @@ -57,8 +59,8 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - basic_op_prev_ops: list[Optional[BasicOperation]], - basic_op_next_ops: list[Optional[BasicOperation]], + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: output = basic_op_extra_inputs[0][0].detach() @@ -76,4 +78,4 @@ def fuser_backward( Iterable[Iterable[Optional[torch.Tensor]]], Iterable[Iterable[Optional[torch.Tensor]]], ]: - return grad_output, [], [(grad_output,)] + return grad_output, [()], [(grad_output,)] diff --git a/transformer_engine/pytorch/ops/basic/all_gather.py b/transformer_engine/pytorch/ops/basic/all_gather.py index 15b1f65d8..0df165a06 100644 --- a/transformer_engine/pytorch/ops/basic/all_gather.py +++ b/transformer_engine/pytorch/ops/basic/all_gather.py @@ -10,8 +10,9 @@ import torch from ...distributed import gather_along_first_dim -from ...tensor import QuantizedTensor +from .._common import maybe_dequantize from ..op import BasicOperation, OperationContext +from ...tensor import Quantizer class AllGather(BasicOperation): @@ -39,8 +40,8 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op: Optional[BasicOperation] = None, - next_op: Optional[BasicOperation] = None, + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: out: torch.Tensor if self.process_group_size == 1: @@ -71,10 +72,7 @@ def op_backward( input_dims[0] //= self.process_group_size # Check output gradient tensor - dy = grad_output - if isinstance(dy, QuantizedTensor): - dy = dy.dequantize() - dy = dy.contiguous() + dy = maybe_dequantize(grad_output.contiguous()) # Perform reduce-scatter dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device) diff --git a/transformer_engine/pytorch/ops/basic/all_reduce.py b/transformer_engine/pytorch/ops/basic/all_reduce.py index 8b4593b93..af928dd24 100644 --- a/transformer_engine/pytorch/ops/basic/all_reduce.py +++ b/transformer_engine/pytorch/ops/basic/all_reduce.py @@ -9,8 +9,9 @@ import torch -from ...tensor import QuantizedTensor +from .._common import maybe_dequantize from ..op import BasicOperation, OperationContext +from ...tensor import Quantizer class AllReduce(BasicOperation): @@ -41,8 +42,8 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op: Optional[BasicOperation] = None, - next_op: Optional[BasicOperation] = None, + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: # Trivial case @@ -50,10 +51,7 @@ def op_forward( return input_ # Perform all-reduce - x = input_ - if isinstance(x, QuantizedTensor): - x = x.dequantize() - x = x.contiguous() + x = maybe_dequantize(input_.contiguous()) torch.distributed.all_reduce(x, group=self.process_group) return x diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 0e976a49e..59fc09607 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -19,20 +19,21 @@ gather_along_first_dim, reduce_scatter_along_first_dim, ) -from ...fp8 import FP8GlobalStateManager +from ...fp8 import FP8GlobalStateManager, Recipe from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD -from ...tensor import Quantizer, QuantizedTensor -from ...tensor.float8_tensor import Float8Quantizer +from ...tensor import Quantizer +from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase from ..op import BasicOperation, OperationContext -from .._common import ( +from .._common import maybe_dequantize, is_quantized_tensor +from ...utils import ( canonicalize_device, canonicalize_dtype, + clear_tensor_data, devices_match, ) -from ...utils import clear_tensor_data def _wait_async(handle: Optional[Any]) -> None: @@ -271,7 +272,7 @@ def reset_parameters(self) -> None: device = canonicalize_device(None) # Allocate buffer if needed - if isinstance(weight, QuantizedTensor): + if is_quantized_tensor(weight): weight = torch.empty( weight.size(), dtype=weight.dtype, @@ -302,8 +303,12 @@ def reset_parameters(self) -> None: weight = torch.nn.Parameter(weight) self.weight = weight - def pre_forward(self, *args, **kwargs) -> None: - super().pre_forward(*args, **kwargs) + def pre_first_forward( + self, + *, + recipe: Optional[Recipe], + ) -> None: + super().pre_first_forward(recipe=recipe) # Initialize weights if needed weight = self.weight @@ -312,24 +317,47 @@ def pre_forward(self, *args, **kwargs) -> None: weight = self.weight # Configure quantizers - if FP8GlobalStateManager.is_fp8_enabled(): + if recipe is not None: input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) grad_output_quantizer = self.get_quantizer("backward", 0) # Specify required tensor formats - is_grad_enabled = torch.is_grad_enabled() - weight_requires_grad = is_grad_enabled and weight.requires_grad - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) - weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) - grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.internal = True + weight_quantizer.internal = True + grad_output_quantizer.internal = True + + # Recipe-specific configuration + if recipe.float8_current_scaling(): + if any( + not isinstance(q, Float8CurrentScalingQuantizer) + for q in (input_quantizer, weight_quantizer, grad_output_quantizer) + ): + raise RuntimeError( + "FP8 current-scaling recipe is enabled, " + f"but input quantizer is {input_quantizer.__class__.__name__}, " + f"weight quantizer is {weight_quantizer.__class__.__name__}, " + f"grad output quantizer is {grad_output_quantizer.__class__.__name__}" + ) + input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + if self.sequence_parallel and self.tensor_parallel_mode == "column": + input_quantizer.with_amax_reduction = True + input_quantizer.amax_reduction_group = self.tensor_parallel_group + if self.sequence_parallel and self.tensor_parallel_mode == "row": + grad_output_quantizer.with_amax_reduction = True + grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group # Make sure weight tensor has correct quantizer # Note: Quantizer might have changed if quantization # recipe changed - if isinstance(weight_quantizer, Float8Quantizer) and isinstance( - weight, Float8TensorBase - ): + if isinstance( + weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ) and isinstance(weight, Float8TensorBase): weight._quantizer = weight_quantizer @staticmethod @@ -349,7 +377,9 @@ def _functional_forward( input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + input_requires_grad: bool = True, + weight_requires_grad: bool = True, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """Functional API for forward pass Parameters @@ -362,7 +392,7 @@ def _functional_forward( Bias tensor device: torch.device, default = default CUDA device Tensor device - dtype: torch.dtype, default = default dtype + dtype: torch.dtype, default = infer from out or weight Tensor datatype out: torch.Tensor, optional Output tensor @@ -385,24 +415,38 @@ def _functional_forward( Builder class for quantized weight tensor. output_quantizer: Quantizer, optional Builder class for quantized output tensor. + input_requires_grad: bool, default = `True` + Whether the loss gradient w.r.t. the input tensor is + required in the backward pass. + weight_requires_grad: bool, default = `True` + Whether the loss gradient w.r.t. the weight tensor is + required in the backward pass. Returns ------- torch.Tensor Output tensor - torch.Tensor - Input tensor used in GEMM, possibly cast and reshaped from - provided input tensor - torch.Tensor - Weight tensor used in GEMM, possibly cast and reshaped from - provided weight tensor + torch.Tensor, optional + Input tensor, ready for use in backward pass. `None` is + returned if loss gradient w.r.t. the weight tensor is not + required. + torch.Tensor, optional + Weight tensor, ready for use in backward pass. `None` is + returned if loss gradient w.r.t. the input tensor is not + required. """ # Check datatype if dtype is None: - dtype = weight.dtype if out is None else out.dtype - dtype = canonicalize_dtype(dtype) + if out is not None and isinstance(out, torch.Tensor): + dtype = out.dtype + elif weight is not None and isinstance(out, torch.Tensor): + dtype = weight.dtype + else: + raise ValueError( + "Could not infer dtype from weight nor out and dtype was not provided" + ) if dtype not in (torch.float32, torch.float16, torch.bfloat16): raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") if out is not None and out.dtype != dtype: @@ -416,7 +460,7 @@ def _functional_forward( if with_quantized_compute: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) if with_x_all_gather: input_quantizer.set_usage(columnwise=False) x, x_async = gather_along_first_dim( @@ -426,14 +470,12 @@ def _functional_forward( quantizer=input_quantizer, ) else: - if not isinstance(x_local, QuantizedTensor): + if not is_quantized_tensor(x_local): x_local = input_quantizer(x_local) x = x_local else: - if isinstance(x_local, QuantizedTensor): - x_local = x_local.dequantize() - if x_local.dtype != dtype: - x_local = x_local.to(dtype=dtype) + x_local = maybe_dequantize(x_local, dtype) + if with_x_all_gather: x, x_async = gather_along_first_dim( x_local, @@ -445,16 +487,13 @@ def _functional_forward( # Check weight tensor w = weight - w_is_quantized = isinstance(w, QuantizedTensor) - if with_quantized_compute and not w_is_quantized: + if not with_quantized_compute: + w = maybe_dequantize(w, dtype) + elif with_quantized_compute and not is_quantized_tensor(w): if weight_quantizer is None: raise ValueError("Missing quantizer for weight tensor") - weight_quantizer.set_usage(rowwise=True) + weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) w = weight_quantizer(w) - elif not with_quantized_compute and w_is_quantized: - w = w.dequantize() - if not with_quantized_compute and w.dtype != dtype: - w = w.to(dtype=dtype) # Check output tensor y = out @@ -463,7 +502,7 @@ def _functional_forward( output_quantizer = None if tensor_parallel_mode == "row": output_quantizer = None - elif isinstance(y, QuantizedTensor): + elif is_quantized_tensor(y): if not with_quantized_compute: raise ValueError("Output tensor is quantized, but quantized compute is not enabled") if tensor_parallel_mode == "row": @@ -526,17 +565,21 @@ def _functional_forward( else: torch.distributed.all_reduce(y, group=tensor_parallel_group) - # Detach input tensor if needed - # Note: PyTorch autograd produces esoteric errors if we save - # input tensor as context for backward pass. - if x_local is input: - x_local = x_local.detach() + # Prepare weight tensor for backward pass + if input_requires_grad: + if w is not weight and with_quantized_compute and is_quantized_tensor(w): + w.update_usage(rowwise_usage=False, columnwise_usage=True) + else: + w = None - # Configure input tensor for backward pass - if with_quantized_compute and isinstance(x_local, QuantizedTensor): - if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather): - # FP8 does not support all-gather of transpose data - x_local.update_usage(rowwise_usage=False, columnwise_usage=True) + # Prepare input tensor for backward pass + if weight_requires_grad: + if with_quantized_compute and is_quantized_tensor(x_local): + if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather): + # FP8 does not support all-gather of transpose data + x_local.update_usage(rowwise_usage=False, columnwise_usage=True) + else: + x_local = None return y, x_local, w @@ -624,9 +667,9 @@ def _functional_backward( # Check datatype if dtype is None: - if weight is not None: + if isinstance(weight, torch.Tensor): dtype = weight.dtype - else: + elif isinstance(grad_output, torch.Tensor): dtype = grad_output.dtype dtype = canonicalize_dtype(dtype) if dtype not in (torch.float32, torch.float16, torch.bfloat16): @@ -652,14 +695,17 @@ def _functional_backward( quantizer=grad_output_quantizer, ) else: - if not isinstance(dy_local, QuantizedTensor): + if not is_quantized_tensor(dy_local): dy_local = grad_output_quantizer(dy_local) + else: + dy_local.update_usage( + rowwise_usage=input_requires_grad, + columnwise_usage=weight_requires_grad, + ) dy = dy_local else: - if isinstance(dy_local, QuantizedTensor): - dy_local = dy_local.dequantize() - if dy_local.dtype != dtype: - dy_local = dy_local.to(dtype=dtype) + dy_local = maybe_dequantize(dy_local, dtype) + if with_dy_all_gather: dy, dy_async = gather_along_first_dim( dy_local, @@ -689,16 +735,14 @@ def _functional_backward( quantizer=input_quantizer, ) else: - if isinstance(x_local, QuantizedTensor): + if is_quantized_tensor(x_local): x_local.update_usage(columnwise_usage=True) else: x_local = input_quantizer(x_local) x = x_local else: - if isinstance(x_local, QuantizedTensor): - x_local = x_local.dequantize() - if x_local.dtype != dtype: - x_local = x_local.to(dtype=dtype) + x_local = maybe_dequantize(x_local, dtype) + if with_x_all_gather: x, x_async = gather_along_first_dim( x_local, @@ -717,9 +761,8 @@ def _functional_backward( if weight is None: raise ValueError("Weight tensor is required to compute input grad") w = weight - w_is_quantized = isinstance(w, QuantizedTensor) if with_quantized_compute: - if w_is_quantized: + if is_quantized_tensor(w): w.update_usage(columnwise_usage=True) else: if weight_quantizer is None: @@ -727,10 +770,7 @@ def _functional_backward( weight_quantizer.set_usage(columnwise=True) w = weight_quantizer(w) else: - if w_is_quantized: - w = w.dequantize(dtype=dtype) - elif w.dtype != dtype: - w = w.to(dtype=dtype) + w = maybe_dequantize(w, dtype) # Synchronize tensor-parallel communication _wait_async(dy_async) @@ -743,7 +783,7 @@ def _functional_backward( grad_input_quantizer = None if tensor_parallel_mode == "column": grad_input_quantizer = None - elif isinstance(dx, QuantizedTensor): + elif is_quantized_tensor(dx): if not with_quantized_compute: raise ValueError( "Grad input tensor is quantized, but quantized compute is not enabled" @@ -854,12 +894,12 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op: Optional[BasicOperation] = None, - next_op: Optional[BasicOperation] = None, + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: # Check which grads are required - input_requires_grad = ctx.requires_grad and input_.requires_grad + input_requires_grad = ctx.requires_grad weight_requires_grad = ctx.requires_grad and self.weight.requires_grad # FP8 metadata @@ -874,11 +914,9 @@ def op_forward( # Get quantizers input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) - if next_op is not None and next_op.num_quantizers("forward") > 0: - output_quantizer = next_op.get_quantizer("forward", 0) + output_quantizer = next_op_input_quantizer grad_output_quantizer = self.get_quantizer("backward", 0) - if prev_op is not None and prev_op.num_quantizers("backward") > 0: - grad_input_quantizer = prev_op.get_quantizer("backward", 0) + grad_input_quantizer = prev_op_grad_input_quantizer # Configure quantizers # Note: We cache the quantized input for backward pass, @@ -887,12 +925,13 @@ def op_forward( weight_quantizer.set_usage(rowwise=True, columnwise=False) # Get autocast dtype if needed - dtype = None if torch.is_autocast_enabled(): dtype = torch.get_autocast_dtype("cuda") + else: + dtype = self.weight.dtype # Linear forward - output, x_local, _ = BasicLinear._functional_forward( + output, x_local, w = BasicLinear._functional_forward( input=input_, weight=self.weight, dtype=dtype, @@ -903,10 +942,12 @@ def op_forward( input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, + input_requires_grad=input_requires_grad, + weight_requires_grad=weight_requires_grad, ) # Save state for backward pass - ctx.save_for_backward(x_local) + ctx.save_for_backward(x_local, w) ctx.with_quantized_compute = with_quantized_compute ctx.input_quantizer = input_quantizer ctx.weight_quantizer = weight_quantizer @@ -915,7 +956,6 @@ def op_forward( ctx.dtype = dtype ctx.input_requires_grad = input_requires_grad ctx.weight_requires_grad = weight_requires_grad - ctx.has_prev_op = prev_op is not None return output @@ -926,12 +966,15 @@ def op_backward( ) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]: # Saved tensors from forward pass - (x_local,) = ctx.saved_tensors + (x_local, w) = ctx.saved_tensors # wgrad fusion accumulate_into_main_grad = self._accumulate_into_main_grad grad_weight = None if ctx.weight_requires_grad and accumulate_into_main_grad: + if hasattr(self.weight, "__fsdp_param__"): + self.weight.main_grad = self.weight.get_main_grad() + if not hasattr(self.weight, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " @@ -946,7 +989,7 @@ def op_backward( grad_input, grad_weight = BasicLinear._functional_backward( grad_output=grad_output, input=x_local, - weight=self.weight, + weight=w, input_requires_grad=ctx.input_requires_grad, weight_requires_grad=ctx.weight_requires_grad, dtype=ctx.dtype, @@ -963,8 +1006,7 @@ def op_backward( ) # Clear input tensor if possible - if ctx.has_prev_op: - clear_tensor_data(x_local) + clear_tensor_data(x_local) if accumulate_into_main_grad: grad_weight = None diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index 5a73ec6c2..a985601e2 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -9,14 +9,17 @@ import torch +import transformer_engine_torch as tex from transformer_engine.pytorch.ops.op import ( BasicOperation, OperationContext, ) -from .._common import ( +from ...utils import ( canonicalize_device, canonicalize_dtype, ) +from ...fp8 import FP8GlobalStateManager +from ...tensor import Quantizer class Bias(BasicOperation): @@ -111,8 +114,8 @@ def reset_parameters(self) -> None: bias = torch.nn.Parameter(bias) self.bias = bias - def pre_forward(self, *args, **kwargs) -> None: - super().pre_forward(*args, **kwargs) + def pre_first_forward(self, *args, **kwargs) -> None: + super().pre_first_forward(*args, **kwargs) if self.bias.device.type == "meta": self.reset_parameters() @@ -120,11 +123,25 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op: Optional[BasicOperation] = None, - next_op: Optional[BasicOperation] = None, + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: x = input_ - b = self.bias.reshape([1] * (x.dim() - 1) + [self.local_size]) + b = self.bias.view([1] * (x.dim() - 1) + [self.local_size]) + + # Check if backward pass is needed + requires_grad = ctx.requires_grad + + # Check if previous op quantizes its output's gradient + grad_input_quantizer = None + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + if with_quantized_compute: + grad_input_quantizer = prev_op_grad_input_quantizer + + if requires_grad: + ctx.with_quantized_compute = with_quantized_compute + ctx.grad_input_quantizer = grad_input_quantizer + return x + b def op_backward( @@ -134,7 +151,11 @@ def op_backward( ) -> tuple[torch.Tensor, tuple[()]]: dy = grad_output if dy.dim() > 1: - db = dy.sum(tuple(range(dy.dim() - 1))) + quantizer = ctx.grad_input_quantizer + if ctx.with_quantized_compute and quantizer is not None: + db, dy = tex.bgrad_quantize(dy, quantizer) + else: + db = dy.sum(tuple(range(dy.dim() - 1))) else: db = dy return dy, (db,) diff --git a/transformer_engine/pytorch/ops/basic/identity.py b/transformer_engine/pytorch/ops/basic/identity.py index d0466be15..3161e77c7 100644 --- a/transformer_engine/pytorch/ops/basic/identity.py +++ b/transformer_engine/pytorch/ops/basic/identity.py @@ -13,6 +13,7 @@ BasicOperation, OperationContext, ) +from ...tensor import Quantizer class Identity(BasicOperation): @@ -22,8 +23,8 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op: Optional[BasicOperation] = None, - next_op: Optional[BasicOperation] = None, + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: return input_ diff --git a/transformer_engine/pytorch/ops/basic/l2normalization.py b/transformer_engine/pytorch/ops/basic/l2normalization.py new file mode 100644 index 000000000..d8196c1bd --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/l2normalization.py @@ -0,0 +1,121 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusable operation for L2 Normalization.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from ...utils import clear_tensor_data +from .._common import maybe_dequantize +from ..op import BasicOperation, OperationContext +from ...jit import ( + l2normalization_fused, + l2normalization_fwd_fused, + l2normalization_backward_fused, + set_jit_fusion_options, + warmup_jit_l2normalization_all_dtypes, +) +from ...tensor import Quantizer + + +class L2Normalization(BasicOperation): + r"""L2 Normalization + + Applies L2 normalization over the last dimension of input tensors. + This is a parameter-free normalization that scales each vector to unit L2 norm. + + .. math:: + y = \frac{x}{\sqrt{\sum_{i} x_i^2 + \varepsilon}} + + This operation is used e.g. for query-key normalization in attention mechanisms. + + Parameters + ---------- + eps : float, default = 1e-6 + A value added to the denominator for numerical stability + seq_length: int, default = None + sequence length of input samples. Needed for JIT Warmup, a technique where jit fused + functions are warmed up before training to ensure same kernels are used for forward + propagation and activation recompute phase. + micro_batch_size: int, default = None + batch size per training step. Needed for JIT Warmup, a technique where jit + fused functions are warmed up before training to ensure same kernels are + used for forward propagation and activation recompute phase. + + """ + + def __init__( + self, + *, + eps: float = 1e-6, + seq_length: Optional[int] = None, + micro_batch_size: Optional[int] = None, + ) -> None: + super().__init__() + self.eps: float = eps + + # JIT warmup for L2Normalization fused operations + if seq_length and micro_batch_size: + if torch.cuda.is_available(): + set_jit_fusion_options() + # For L2Normalization, we don't know the hidden size until forward pass, + # but we can warm up with common sizes. For QK normalization, this will be + # the attention head dimension (hidden_size_per_attention_head), not the full + # model hidden dimension. Common head dimensions are 32, 64, 80, 96, 128, 256. + common_hidden_sizes = [32, 64, 80, 96, 128, 256] + for hidden_size in common_hidden_sizes: + warmup_jit_l2normalization_all_dtypes(hidden_size, seq_length, micro_batch_size) + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + ) -> torch.Tensor: + # Use input directly - torch.compile can handle multi-dimensional tensors + x = maybe_dequantize(input_) + + # Check if backward pass is needed + requires_grad = ctx.requires_grad + + # Compute L2 normalization using fused implementation + # L2 norm: x / sqrt(sum(x^2) + eps) = x * rsqrt(sum(x^2) + eps) + if requires_grad: + # Training: use version that returns both output and intermediate values + y, rsqrt_norm = l2normalization_fwd_fused(x, self.eps) + else: + # Inference: use lightweight version that only returns output + y = l2normalization_fused(x, self.eps) + rsqrt_norm = None # Not needed for inference + + # Save state for backward pass + if requires_grad: + ctx.save_for_backward(x, rsqrt_norm) + + return y + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + x, rsqrt_norm = ctx.saved_tensors + + dy = maybe_dequantize(grad_output) + + # Compute L2 norm backward pass using fused implementation + dx = l2normalization_backward_fused(dy, x, rsqrt_norm, self.eps) + + # Clear saved tensors if possible + clear_tensor_data(x) + clear_tensor_data(rsqrt_norm) + + # No parameters, so empty tuple for param grads + return dx, () diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index c94459bc3..26c39909e 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -19,7 +19,6 @@ if IS_HIP_EXTENSION: from ...triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton from ...fp8 import FP8GlobalStateManager -from ...tensor import QuantizedTensor from ...constants import TE_DType from ...utils import ( canonicalize_device, @@ -28,7 +27,9 @@ devices_match, ) from ..op import BasicOperation, OperationContext -from .._common import maybe_autocast_dtype, reshape +from .._common import maybe_autocast_dtype, maybe_dequantize +from ...export import is_in_onnx_export_mode +from ...tensor import Quantizer class LayerNorm(BasicOperation): @@ -172,8 +173,8 @@ def reset_parameters(self) -> None: self.weight = weight self.bias = bias - def pre_forward(self, *args, **kwargs) -> None: - super().pre_forward(*args, **kwargs) + def pre_first_forward(self, *args, **kwargs) -> None: + super().pre_first_forward(*args, **kwargs) if self.weight.device.type == "meta" or self.bias.device.type == "meta": self.reset_parameters() @@ -181,9 +182,11 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op: Optional[BasicOperation] = None, - next_op: Optional[BasicOperation] = None, + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: + if is_in_onnx_export_mode(): + return self.op_onnx_forward(input_) # Check tensor dims weight = self.weight @@ -197,31 +200,19 @@ def op_forward( # Check input tensors inner_dim = math.prod(weight_dims) - device = weight.device - if device.type != "cuda": - device = canonicalize_device(None) dtype = maybe_autocast_dtype(default_dtype=weight.dtype) - x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) - w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) - b = reshape(self.bias, (inner_dim,), device=device, dtype=dtype) - if isinstance(x, QuantizedTensor): - x = x.dequantize() - if isinstance(w, QuantizedTensor): - w = w.dequantize() - if isinstance(b, QuantizedTensor): - b = b.dequantize() + x = maybe_dequantize(input_.contiguous(), dtype).view((-1, inner_dim)) + w = maybe_dequantize(self.weight, dtype).view((inner_dim,)) + b = maybe_dequantize(self.bias, dtype).view((inner_dim,)) # Check if backward pass is needed requires_grad = ctx.requires_grad # Check if output is quantized output_quantizer = None - if ( - FP8GlobalStateManager.is_fp8_enabled() - and next_op is not None - and next_op.num_quantizers("forward") > 0 - ): - output_quantizer = next_op.get_quantizer("forward", 0) + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + if with_quantized_compute: + output_quantizer = next_op_input_quantizer # Compute layer norm sm_margin = self._sm_margins["forward" if requires_grad else "inference"] @@ -242,12 +233,10 @@ def op_forward( # Save state for backward pass if requires_grad: ctx.save_for_backward(x, means, rstdevs) - ctx.device = device ctx.dtype = dtype - ctx.has_prev_op = prev_op is not None # Reshape output tensor - out = reshape(y, input_dims) + out = y.view(input_dims) return out def op_backward( @@ -264,14 +253,9 @@ def op_backward( inner_dim = math.prod(weight_dims) # Check input tensors - device = ctx.device dtype = ctx.dtype - dy = reshape(grad_output, x.size(), device=device, dtype=dtype) - w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) - if isinstance(w, QuantizedTensor): - w = w.dequantize() - if isinstance(dy, QuantizedTensor): - dy = dy.dequantize() + dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size()) + w = maybe_dequantize(self.weight, dtype).view((inner_dim,)) # Compute layer norm backward pass use_layernorm_triton = bool( int(os.environ.get('NVTE_USE_LAYERNORM_TRITON', '0')) ) and IS_HIP_EXTENSION @@ -287,13 +271,22 @@ def op_backward( ) # Clear saved tensors if possible - if ctx.has_prev_op: - clear_tensor_data(x) + clear_tensor_data(x) clear_tensor_data(means) clear_tensor_data(rstdevs) # Reshape results - grad_input = reshape(dx, grad_output.size()) - grad_weight = reshape(dw, weight_dims) - grad_bias = reshape(db, weight_dims) + grad_input = dx.view(grad_output.size()) + grad_weight = dw.view(weight_dims) + grad_bias = db.view(weight_dims) return grad_input, (grad_weight, grad_bias) + + def op_onnx_forward( + self, + input_: torch.Tensor, + ) -> torch.Tensor: + """Every operand in this function has a defined ONNX translation.""" + weight = self.weight + 1 if self.zero_centered_gamma else self.weight + return torch.nn.functional.layer_norm( + input_, input_.shape[-1:], weight, self.bias, self.eps + ) diff --git a/transformer_engine/pytorch/ops/basic/make_extra_output.py b/transformer_engine/pytorch/ops/basic/make_extra_output.py index 73d08b5c7..81b581ae2 100644 --- a/transformer_engine/pytorch/ops/basic/make_extra_output.py +++ b/transformer_engine/pytorch/ops/basic/make_extra_output.py @@ -14,6 +14,7 @@ BasicOperation, OperationContext, ) +from ...tensor import Quantizer class MakeExtraOutput(BasicOperation): @@ -58,8 +59,8 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - basic_op_prev_ops: list[Optional[BasicOperation]], - basic_op_next_ops: list[Optional[BasicOperation]], + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: return input_, [(input_,)] @@ -77,4 +78,4 @@ def fuser_backward( ]: grad_input = basic_op_grad_extra_outputs[0][0] grad_input += grad_output - return grad_input, [], [()] + return grad_input, [()], [()] diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index 448954fc6..005e9fd8d 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -10,8 +10,9 @@ import torch from ...fp8 import FP8GlobalStateManager -from ...tensor import QuantizedTensor +from .._common import is_quantized_tensor from ..op import BasicOperation, OperationContext +from ...tensor import Quantizer class Quantize(BasicOperation): @@ -49,8 +50,8 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op: Optional[BasicOperation] = None, - next_op: Optional[BasicOperation] = None, + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: # Check if FP8 is enabled @@ -60,7 +61,7 @@ def op_forward( # Quantize if needed out = input_ - if quantize_forward and not isinstance(out, QuantizedTensor): + if quantize_forward and not is_quantized_tensor(out): out = self.get_quantizer("forward", 0)(out) ctx.quantize_backward = quantize_backward @@ -72,6 +73,6 @@ def op_backward( grad_output: torch.Tensor, ) -> tuple[torch.Tensor, tuple[()]]: grad_input = grad_output - if ctx.quantize_backward and not isinstance(grad_input, QuantizedTensor): + if ctx.quantize_backward and not is_quantized_tensor(grad_input): grad_input = self.get_quantizer("backward", 0)(grad_input) return grad_input, () diff --git a/transformer_engine/pytorch/ops/basic/reduce_scatter.py b/transformer_engine/pytorch/ops/basic/reduce_scatter.py index adfd46641..1238b0879 100644 --- a/transformer_engine/pytorch/ops/basic/reduce_scatter.py +++ b/transformer_engine/pytorch/ops/basic/reduce_scatter.py @@ -10,8 +10,9 @@ import torch from ...distributed import gather_along_first_dim -from ...tensor import QuantizedTensor +from .._common import maybe_dequantize from ..op import BasicOperation, OperationContext +from ...tensor import Quantizer class ReduceScatter(BasicOperation): @@ -39,8 +40,8 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op: Optional[BasicOperation] = None, - next_op: Optional[BasicOperation] = None, + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: # Trivial case @@ -59,10 +60,7 @@ def op_forward( output_dims[0] //= self.process_group_size # Check input tensor - x = input_ - if isinstance(x, QuantizedTensor): - x = x.dequantize() - x = x.contiguous() + x = maybe_dequantize(input_.contiguous()) # Perform reduce-scatter y = torch.empty(output_dims, dtype=x.dtype, device=x.device) diff --git a/transformer_engine/pytorch/ops/basic/reshape.py b/transformer_engine/pytorch/ops/basic/reshape.py index 1e9095169..8d8b75ff0 100644 --- a/transformer_engine/pytorch/ops/basic/reshape.py +++ b/transformer_engine/pytorch/ops/basic/reshape.py @@ -14,6 +14,7 @@ BasicOperation, OperationContext, ) +from ...tensor import Quantizer class Reshape(BasicOperation): @@ -37,8 +38,8 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op: Optional[BasicOperation] = None, - next_op: Optional[BasicOperation] = None, + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: ctx.input_shape = input_.size() return input_.reshape(*self._shape) diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index e945d25fc..a0f111b6c 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -22,7 +22,6 @@ te_rmsnorm_fwd_triton ) from ...fp8 import FP8GlobalStateManager -from ...tensor import QuantizedTensor from ...constants import TE_DType from ...utils import ( canonicalize_device, @@ -31,7 +30,9 @@ devices_match, ) from ..op import BasicOperation, OperationContext -from .._common import maybe_autocast_dtype, reshape +from .._common import maybe_autocast_dtype, maybe_dequantize +from ...export import is_in_onnx_export_mode +from ...tensor import Quantizer class RMSNorm(BasicOperation): r"""Root Mean Square Layer Normalization @@ -158,8 +159,8 @@ def reset_parameters(self) -> None: weight = torch.nn.Parameter(weight) self.weight = weight - def pre_forward(self, *args, **kwargs) -> None: - super().pre_forward(*args, **kwargs) + def pre_first_forward(self, *args, **kwargs) -> None: + super().pre_first_forward(*args, **kwargs) if self.weight.device.type == "meta": self.reset_parameters() @@ -167,9 +168,11 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op: Optional[BasicOperation] = None, - next_op: Optional[BasicOperation] = None, + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: + if is_in_onnx_export_mode(): + return self.op_onnx_forward(input_) # Check tensor dims weight = self.weight @@ -183,28 +186,18 @@ def op_forward( # Check input tensors inner_dim = math.prod(weight_dims) - device = weight.device - if device.type != "cuda": - device = canonicalize_device(None) dtype = maybe_autocast_dtype(default_dtype=weight.dtype) - x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) - w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) - if isinstance(x, QuantizedTensor): - x = x.dequantize() - if isinstance(w, QuantizedTensor): - w = w.dequantize() + x = maybe_dequantize(input_.contiguous(), dtype).view((-1, inner_dim)) + w = maybe_dequantize(self.weight, dtype).view((inner_dim,)) # Check if backward pass is needed requires_grad = ctx.requires_grad # Check if output is quantized output_quantizer = None - if ( - FP8GlobalStateManager.is_fp8_enabled() - and next_op is not None - and next_op.num_quantizers("forward") > 0 - ): - output_quantizer = next_op.get_quantizer("forward", 0) + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + if with_quantized_compute: + output_quantizer = next_op_input_quantizer # Compute RMSNorm sm_margin = self._sm_margins["forward" if requires_grad else "inference"] @@ -224,12 +217,10 @@ def op_forward( # Save state for backward pass if requires_grad: ctx.save_for_backward(x, rstdevs) - ctx.device = device ctx.dtype = dtype - ctx.has_prev_op = prev_op is not None # Reshape output tensor - out = reshape(y, input_dims) + out = y.view(input_dims) return out def op_backward( @@ -246,14 +237,9 @@ def op_backward( inner_dim = math.prod(weight_dims) # Check input tensors - device = ctx.device dtype = ctx.dtype - dy = reshape(grad_output, x.size(), device=device, dtype=dtype) - w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) - if isinstance(w, QuantizedTensor): - w = w.dequantize() - if isinstance(dy, QuantizedTensor): - dy = dy.dequantize() + dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size()) + w = maybe_dequantize(self.weight, dtype).view((inner_dim,)) # Compute RMSNorm backward pass rmsnorm_bwd_func = te_rmsnorm_bwd_triton if self.use_rmsnorm_triton else rmsnorm_bwd @@ -268,11 +254,18 @@ def op_backward( ) # Clear saved tensors if possible - if ctx.has_prev_op: - clear_tensor_data(x) + clear_tensor_data(x) clear_tensor_data(rstdevs) # Reshape results - grad_input = reshape(dx, grad_output.size()) - grad_weight = reshape(dw, weight_dims) + grad_input = dx.view(grad_output.size()) + grad_weight = dw.view(weight_dims) return grad_input, (grad_weight,) + + def op_onnx_forward( + self, + input_: torch.Tensor, + ) -> torch.Tensor: + """Every operand in this function has a defined ONNX translation.""" + weight = self.weight + 1 if self.zero_centered_gamma else self.weight + return torch.nn.functional.rms_norm(input_, input_.shape[-1:], weight, self.eps) diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index e95c4d031..29d3c50cd 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -6,6 +6,10 @@ """Compound tensor operation supported by the operation fuser.""" +from .backward_bias_activation import ( + BackwardBiasActivation, + fuse_backward_bias_activation, +) from .backward_linear_add import ( BackwardLinearAdd, fuse_backward_linear_add, diff --git a/transformer_engine/pytorch/ops/fused/backward_bias_activation.py b/transformer_engine/pytorch/ops/fused/backward_bias_activation.py new file mode 100644 index 000000000..f4b7b9ec3 --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/backward_bias_activation.py @@ -0,0 +1,151 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused backward dbias + dact + quantize.""" + +from __future__ import annotations +from typing import Optional + +import torch + +import transformer_engine_torch as tex +from transformer_engine.pytorch.fp8 import Recipe +from transformer_engine.pytorch.ops.basic import Bias +from transformer_engine.pytorch.ops.basic.activation import ( + _ActivationOperation, + GELU, + ReLU, +) +from transformer_engine.pytorch.ops.op import ( + FusedOperation, + FusibleOperation, + OperationContext, +) +from ...utils import clear_tensor_data +from .._common import maybe_dequantize + +_fused_activations = {GELU: tex.dbias_dgelu, ReLU: tex.dbias_drelu} +_fusible_activations = tuple(_fused_activations.keys()) + + +class BackwardBiasActivation(FusedOperation): + """Fused backward dbias + dact + quantize + + Uses the next operation's input quantizer. + + """ + + def __init__(self, *, bias: Bias, activation: _ActivationOperation): + super().__init__((bias, activation)) + self._fused_function = _fused_activations[type(activation)] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + list[tuple[Optional[torch.Tensor], ...]], + list[tuple[()]], + ]: + + # Get basic operation contexts + activation_op_ctx = basic_op_ctxs[0] + bias_op_ctx = basic_op_ctxs[1] + + # Saved tensors from forward pass + (act_input,) = activation_op_ctx.saved_tensors + + # Check activation input tensor + act_input = maybe_dequantize(act_input.contiguous(), activation_op_ctx.dtype) + + # Check grad output tensor + dy = maybe_dequantize(grad_output.contiguous(), act_input.dtype) + + # Get previous op quantizer + if not bias_op_ctx.with_quantized_compute: + raise RuntimeError( + "BackwardBiasActivation requires quantized compute, " + "but Bias context has it disabled" + ) + quantizer = bias_op_ctx.grad_input_quantizer + if quantizer is None: + raise RuntimeError( + "BackwardBiasActivation requires previous op's grad output quantizer, " + "but Bias context has no quantizer" + ) + + # Launch kernel + db, dx = self._fused_function(dy, act_input, quantizer) + + # Clear activation input tensor + clear_tensor_data(act_input) + + return dx, [(), (db,)], [(), ()] + + +def fuse_backward_bias_activation( + ops: list[tuple[FusibleOperation, list[int]]], + recipe: Optional[Recipe], +) -> list[tuple[FusibleOperation, list[int]]]: + """Fused backward dbias + dact + quantize + + Parameters + ---------- + ops: list of tuples + Backward pass operations and the indices of the corresponding + basic operations. + recipe: Recipe, optional + Used quantization recipe + + Returns + ------- + ops: list of tuples + Updated backward pass operations + + """ + + # Check if recipe supports bias activation fusion + if recipe is None or not (recipe.delayed() or recipe.mxfp8()): + return ops + + # Scan through ops, fusing if possible + out = [] + window = [] + while len(ops) >= 3: + out.extend(window) + + # Check if first op is a supported activation + window, ops = ops[:1], ops[1:] + op, _ = window[0] + if not isinstance(op, _fusible_activations): + continue + + # Check if second op is bias + op, _ = ops[0] + if not isinstance(op, Bias): + continue + + # Check if third op has a grad input quantizer + op, _ = ops[1] + if not op.num_quantizers("backward") > 0: + continue + + window.extend(ops[:1]) + ops = ops[1:] + + # Replace window with fused op + op = BackwardBiasActivation( + activation=window[0][0], + bias=window[1][0], + ) + basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] + window = [(op, basic_op_idxs)] + + # Return list of ops + out.extend(window) + out.extend(ops) + return out diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index e295929e9..54ddfaa5c 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -51,12 +51,15 @@ def fuser_backward( linear_op_ctx = basic_op_ctxs[0] # Saved tensors from forward pass - (x_local,) = linear_op_ctx.saved_tensors + (x_local, w) = linear_op_ctx.saved_tensors # wgrad fusion accumulate_into_main_grad = linear_op._accumulate_into_main_grad grad_weight = None if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: + if hasattr(linear_op.weight, "__fsdp_param__"): + linear_op.weight.main_grad = linear_op.weight.get_main_grad() + if not hasattr(linear_op.weight, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " @@ -72,7 +75,7 @@ def fuser_backward( grad_input, grad_weight = BasicLinear._functional_backward( grad_output=grad_output, input=x_local, - weight=linear_op.weight, + weight=w, input_requires_grad=linear_op_ctx.input_requires_grad, weight_requires_grad=linear_op_ctx.weight_requires_grad, dtype=grad_input.dtype, @@ -93,8 +96,7 @@ def fuser_backward( grad_weight = None # Clear input tensor if possible - if linear_op_ctx.has_prev_op: - clear_tensor_data(x_local) + clear_tensor_data(x_local) return grad_input, [(grad_weight,), ()], [(), ()] @@ -107,13 +109,13 @@ def fuse_backward_linear_add( Parameters ---------- ops: list of tuples - Forward pass operations and the indices of the corresponding + Backward pass operations and the indices of the corresponding basic operations. Returns ------- ops: list of tuples - Updated forward pass operations + Updated backward pass operations """ diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 6088b3c0d..5d1223bd8 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -13,11 +13,11 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.ops.basic import BasicLinear, Bias from transformer_engine.pytorch.ops.op import ( - BasicOperation, FusedOperation, FusibleOperation, OperationContext, ) +from ...tensor import Quantizer class ForwardLinearBiasActivation(FusedOperation): @@ -59,8 +59,8 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - basic_op_prev_ops: list[Optional[BasicOperation]], - basic_op_next_ops: list[Optional[BasicOperation]], + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: @@ -70,10 +70,12 @@ def fuser_forward( linear_op_ctx = basic_op_ctxs[idx] if self._op_idxs["bias"] is None: bias_op = None + bias_op_ctx = None bias = None else: idx = self._op_idxs["bias"] bias_op = self.basic_ops[idx] + bias_op_ctx = basic_op_ctxs[idx] bias = bias_op.bias if basic_op_kwargs[idx]: raise ValueError("Bias operation forward does not expect keyword arguments") @@ -82,6 +84,10 @@ def fuser_forward( else: raise NotImplementedError("Activations are not yet supported") + # Check which grads are required + input_requires_grad = linear_op_ctx.requires_grad + weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad + # FP8 metadata with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() input_quantizer = None @@ -92,21 +98,18 @@ def fuser_forward( if with_quantized_compute: input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) - next_op = basic_op_next_ops[-1] - if next_op is not None and next_op.num_quantizers("forward") > 0: - output_quantizer = next_op.get_quantizer("forward", 0) + output_quantizer = next_op_input_quantizer grad_output_quantizer = linear_op.get_quantizer("backward", 0) - prev_op = basic_op_prev_ops[0] - if prev_op is not None and prev_op.num_quantizers("backward") > 0: - grad_input_quantizer = prev_op.get_quantizer("backward", 0) + grad_input_quantizer = prev_op_grad_input_quantizer # Get autocast dtype if needed - dtype = None if torch.is_autocast_enabled(): dtype = torch.get_autocast_dtype("cuda") + else: + dtype = linear_op.weight.dtype # Linear forward - output, x_local, _ = BasicLinear._functional_forward( + output, x_local, w = BasicLinear._functional_forward( input=input_, weight=linear_op.weight, bias=bias, @@ -118,19 +121,23 @@ def fuser_forward( input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, + input_requires_grad=input_requires_grad, + weight_requires_grad=weight_requires_grad, ) # Save state for backward pass - linear_op_ctx.save_for_backward(x_local) + linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.dtype = dtype - linear_op_ctx.input_requires_grad = input_.requires_grad - linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad - linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None + linear_op_ctx.input_requires_grad = input_requires_grad + linear_op_ctx.weight_requires_grad = weight_requires_grad + if bias_op is not None: + bias_op_ctx.with_quantized_compute = with_quantized_compute + bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer() return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 69b0c3ba5..5055bc60a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -13,11 +13,11 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.ops.basic import AddInPlace, BasicLinear, Bias from transformer_engine.pytorch.ops.op import ( - BasicOperation, FusedOperation, FusibleOperation, OperationContext, ) +from transformer_engine.pytorch.tensor import Quantizer class ForwardLinearBiasAdd(FusedOperation): @@ -57,8 +57,8 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - basic_op_prev_ops: list[Optional[BasicOperation]], - basic_op_next_ops: list[Optional[BasicOperation]], + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: @@ -68,14 +68,20 @@ def fuser_forward( linear_op_ctx = basic_op_ctxs[idx] if self._op_idxs["bias"] is None: bias_op = None + bias_op_ctx = None bias = None else: idx = self._op_idxs["bias"] bias_op = self.basic_ops[idx] + bias_op_ctx = basic_op_ctxs[idx] bias = bias_op.bias if basic_op_kwargs[idx]: raise ValueError("Bias operation forward does not expect keyword arguments") + # Check which grads are required + input_requires_grad = linear_op_ctx.requires_grad + weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad + # FP8 metadata with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() input_quantizer = None @@ -87,21 +93,21 @@ def fuser_forward( input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) grad_output_quantizer = linear_op.get_quantizer("backward", 0) - prev_op = basic_op_prev_ops[0] - if prev_op is not None and prev_op.num_quantizers("backward") > 0: - grad_input_quantizer = prev_op.get_quantizer("backward", 0) + grad_input_quantizer = prev_op_grad_input_quantizer # Get autocast dtype if needed - dtype = None if torch.is_autocast_enabled(): dtype = torch.get_autocast_dtype("cuda") + else: + dtype = linear_op.weight.dtype # Linear forward output = basic_op_extra_inputs[self._op_idxs["add"]][0] - output, x_local, _ = BasicLinear._functional_forward( + output, x_local, w = BasicLinear._functional_forward( input=input_, weight=linear_op.weight, bias=bias, + dtype=output.dtype, out=output, accumulate_into_out=True, tensor_parallel_mode=linear_op.tensor_parallel_mode, @@ -111,19 +117,23 @@ def fuser_forward( input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, + input_requires_grad=input_requires_grad, + weight_requires_grad=weight_requires_grad, ) # Save state for backward pass - linear_op_ctx.save_for_backward(x_local) + linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.dtype = dtype - linear_op_ctx.input_requires_grad = input_.requires_grad - linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad - linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None + linear_op_ctx.input_requires_grad = input_requires_grad + linear_op_ctx.weight_requires_grad = weight_requires_grad + if bias_op is not None: + bias_op_ctx.with_quantized_compute = with_quantized_compute + bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer() return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 32748b877..4fbc28482 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -20,10 +20,11 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer +from ...tensor.quantized_tensor import Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data from ..basic import BasicLinear, Bias, ReduceScatter +from .._common import maybe_dequantize, is_quantized_tensor from ..op import FusedOperation, FusibleOperation, OperationContext @@ -279,7 +280,7 @@ def _functional_backward( # Cast grad output tensor dtype if needed dy_local = grad_output if with_quantized_compute: - if not isinstance(dy_local, QuantizedTensorBase): + if not is_quantized_tensor(dy_local): with_columnwise = weight_requires_grad if ( with_columnwise @@ -293,24 +294,18 @@ def _functional_backward( ) dy_local = grad_output_quantizer(dy_local) else: - if isinstance(dy_local, QuantizedTensorBase): - dy_local = dy_local.dequantize(dtype=dtype) - elif dy_local.dtype != dtype: - dy_local = dy_local.to(dtype=dtype) + dy_local = maybe_dequantize(dy_local, dtype) # Cast weight tensor dtype if needed if weight is None: raise ValueError("Weight tensor is required to compute input grad") w = weight if with_quantized_compute: - if not isinstance(w, QuantizedTensorBase): + if not is_quantized_tensor(w): weight_quantizer.set_usage(columnwise=True) w = weight_quantizer(w) else: - if isinstance(w, QuantizedTensorBase): - w = w.dequantize(dtype=dtype) - elif w.dtype != dtype: - w = w.to(dtype=dtype) + w = maybe_dequantize(w, dtype) # Cast input tensor dtype if needed x_local = None @@ -319,14 +314,11 @@ def _functional_backward( raise ValueError("Input tensor is required to compute weight grad") x_local = input if with_quantized_compute: - if not isinstance(x_local, QuantizedTensorBase): + if not is_quantized_tensor(x_local): input_quantizer.set_usage(columnwise=True) x_local = input_quantizer(x_local) else: - if isinstance(x_local, QuantizedTensorBase): - x_local = x_local.dequantize(dtype=dtype) - elif x_local.dtype != dtype: - x_local = x_local.to(dtype=dtype) + x_local = maybe_dequantize(x_local, dtype) # dgrad GEMM dx_local = None @@ -407,23 +399,33 @@ def _functional_backward( # Initialize grad output if tensor_parallel_mode == "row" and isinstance(grad_output_quantizer, MXFP8Quantizer): # UB does not support overlapping grad output - # all-gather with wgrad GEMM. Also, MXFP8 does not - # allow reusing the grad output that was gathered for - # the dgrad GEMM. We work around with blocking - # all-gather for column-scaled MXFP8 data. + # all-gather with wgrad GEMM. Also, we can't + # convert row-scaled MXFP8 to column-scaled, so we + # can't reuse the grad output that was gathered + # for the dgrad GEMM. We work around by explicitly + # overlapping the NCCL operation with the dgrad GEMM. grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - dy, _ = gather_along_first_dim( - grad_output, - tensor_parallel_group, - quantizer=grad_output_quantizer, - ) + # Get the communication stream from the dgrad GEMM and set it as the current torch stream + dgrad_comm_stream = ub_comm_dgrad.get_communication_stream() + with torch.cuda.stream(dgrad_comm_stream): + # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather + # This ensures that we don't start until all communication for the dgrad GEMM is complete + dy, dy_work = gather_along_first_dim( + dy_local, + tensor_parallel_group, + async_op=True, + quantizer=grad_output_quantizer, + ) + # Synchronize with the main stream + dy_work.wait() + if tensor_parallel_mode == "column": dy = dy_local if dy is None: raise RuntimeError( "wgrad GEMM requires grad output tensor, which has not been initialized" ) - if isinstance(dy, QuantizedTensorBase): + if is_quantized_tensor(dy): dy.update_usage(rowwise_usage=False, columnwise_usage=True) # Initialize input tensor @@ -433,7 +435,7 @@ def _functional_backward( raise RuntimeError( "wgrad GEMM requires input tensor, which has not been initialized" ) - if isinstance(x, QuantizedTensorBase): + if is_quantized_tensor(x): x.update_usage(rowwise_usage=False, columnwise_usage=True) # Check grad weight tensor @@ -500,12 +502,15 @@ def fuser_backward( bias_op = self.basic_ops[idx] # Saved tensors from forward pass - (x_local,) = linear_op_ctx.saved_tensors + (x_local, w) = linear_op_ctx.saved_tensors # wgrad fusion accumulate_into_main_grad = linear_op._accumulate_into_main_grad grad_weight = None if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: + if hasattr(linear_op.weight, "__fsdp_param__"): + linear_op.weight.main_grad = linear_op.weight.get_main_grad() + if not hasattr(linear_op.weight, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " @@ -520,7 +525,7 @@ def fuser_backward( retval = UserbuffersBackwardLinear._functional_backward( grad_output=grad_output, input=x_local, - weight=linear_op.weight, + weight=w, weight_requires_grad=linear_op_ctx.weight_requires_grad, bias_requires_grad=(bias_op is not None), dtype=linear_op_ctx.dtype, @@ -542,8 +547,7 @@ def fuser_backward( grad_bias = extra_outputs["grad_bias"] # Clear input tensor if possible - if linear_op_ctx.has_prev_op: - clear_tensor_data(x_local) + clear_tensor_data(x_local) # Return gradients grad_params = [() for _ in range(len(self.basic_ops))] @@ -564,13 +568,13 @@ def fuse_userbuffers_backward_linear( Parameters ---------- ops: list of tuples - Forward pass operations and the indices of the corresponding + Backward pass operations and the indices of the corresponding basic operations. Returns ------- ops: list of tuples - Updated forward pass operations + Updated backward pass operations """ diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 92c2741ff..30d9cdaae 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -20,13 +20,12 @@ get_workspace, _2X_ACC_FPROP, ) -from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer -from ...tensor.float8_tensor import Float8Quantizer +from ...tensor.quantized_tensor import Quantizer +from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase -from ...utils import canonicalize_device, canonicalize_dtype +from .._common import maybe_dequantize, is_quantized_tensor from ..basic import BasicLinear, Bias, ReduceScatter from ..op import ( - BasicOperation, FusedOperation, FusibleOperation, OperationContext, @@ -88,8 +87,8 @@ def _functional_forward( weight: torch.Tensor, *, bias: Optional[torch.Tensor] = None, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device, + dtype: torch.dtype, tensor_parallel_mode: Optional[str] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, tensor_parallel_size: Optional[int] = None, @@ -98,6 +97,8 @@ def _functional_forward( input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, + input_requires_grad: bool = True, + weight_requires_grad: bool = True, ub_comm_name: str, ) -> tuple[torch.Tensor, dict]: """Functional API for forward pass @@ -110,9 +111,9 @@ def _functional_forward( Weight tensor bias: torch.Tensor, optional Bias tensor - device: torch.device, default = default CUDA device + device: torch.device Tensor device - dtype: torch.dtype, default = default dtype + dtype: torch.dtype Tensor datatype tensor_parallel_mode: {`None`, "column", "row"}, default = `None` Mode for tensor parallelism @@ -131,6 +132,12 @@ def _functional_forward( Builder class for quantized weight tensor. output_quantizer: Quantizer, optional Builder class for quantized output tensor. + input_requires_grad: bool, default = `True` + Whether the loss gradient w.r.t. the input tensor is + required in the backward pass. + weight_requires_grad: bool, default = `True` + Whether the loss gradient w.r.t. the weight tensor is + required in the backward pass. ub_comm_name: str Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is used to access the corresponding Userbuffers communicators @@ -141,22 +148,17 @@ def _functional_forward( torch.Tensor Output tensor dict - Extra output tensors. "input" is the input tensor, - possibly cast and reshaped from the provided input tensor. + Extra output tensors. "input" is the input tensor and + "weight" is the weight tensor, both ready for use in the + backward pass. """ # Check device - if device is None: - device = weight.device - device = canonicalize_device(device) if device.type != "cuda": raise ValueError(f"Only CUDA devices are supported (got {device})") # Check datatype - if dtype is None: - dtype = weight.dtype - dtype = canonicalize_dtype(dtype) if dtype not in (torch.float32, torch.float16, torch.bfloat16): raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") @@ -197,9 +199,11 @@ def _functional_forward( x = None if with_ub_all_gather: if input_quantizer is not None: - if not isinstance(x_local, QuantizedTensorBase): - input_quantizer.set_usage(rowwise=True, columnwise=True) - if isinstance(input_quantizer, Float8Quantizer): + if not is_quantized_tensor(x_local): + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + if isinstance( + input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ): input_quantizer.set_usage(columnwise=False) x_local = input_quantizer(x_local) input_quantizer.set_usage(rowwise=True, columnwise=False) @@ -211,26 +215,20 @@ def _functional_forward( ) else: if with_quantized_compute: - if not isinstance(x_local, QuantizedTensorBase): - input_quantizer.set_usage(rowwise=True, columnwise=True) + if not is_quantized_tensor(x_local): + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) x_local = input_quantizer(x_local) else: - if isinstance(x_local, QuantizedTensorBase): - x_local = x_local.dequantize(dtype=dtype) - if x_local.dtype != dtype: - x_local = x_local.to(dtype=dtype) + x_local = maybe_dequantize(x_local, dtype) x = x_local # Initialize weight tensor w = weight - w_is_quantized = isinstance(w, QuantizedTensorBase) - if with_quantized_compute and not w_is_quantized: - weight_quantizer.set_usage(rowwise=True) + if not with_quantized_compute: + w = maybe_dequantize(w, dtype) + elif with_quantized_compute and not is_quantized_tensor(w): + weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) w = weight_quantizer(w) - elif not with_quantized_compute and w_is_quantized: - w = w.dequantize() - if not with_quantized_compute and w.dtype != dtype: - w = w.to(dtype=dtype) # Construct output tensor if needed reduce_scatter_output = None @@ -258,17 +256,21 @@ def _functional_forward( else: y_local = gemm_output - # Detach input tensor if needed - # Note: PyTorch autograd produces esoteric errors if we save - # input tensor as context for backward pass. - if x_local is input: - x_local = x_local.detach() - - # Configure input tensor for backward pass - if with_quantized_compute and isinstance(x_local, QuantizedTensorBase): - if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather): - # FP8 does not support all-gather of transpose data - x_local.update_usage(rowwise_usage=False, columnwise_usage=True) + # Prepare weight tensor for backward pass + if input_requires_grad: + if w is not weight and with_quantized_compute and is_quantized_tensor(w): + w.update_usage(rowwise_usage=False, columnwise_usage=True) + else: + w = None + + # Prepare input tensor for backward pass + if weight_requires_grad: + if with_quantized_compute and is_quantized_tensor(x_local): + if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather): + # FP8 does not support all-gather of transpose data + x_local.update_usage(rowwise_usage=False, columnwise_usage=True) + else: + x_local = None # Return cast tensors extra_outputs = {"input": x_local, "weight": w} @@ -280,8 +282,8 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - basic_op_prev_ops: list[Optional[BasicOperation]], - basic_op_next_ops: list[Optional[BasicOperation]], + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: @@ -290,14 +292,20 @@ def fuser_forward( linear_op = self.basic_ops[idx] linear_op_ctx = basic_op_ctxs[idx] bias_op = None + bias_op_ctx = None bias = None if self._op_idxs["bias"] is not None: idx = self._op_idxs["bias"] bias_op = self.basic_ops[idx] + bias_op_ctx = basic_op_ctxs[idx] bias = bias_op.bias if basic_op_kwargs[idx]: raise ValueError("Bias operation forward does not expect keyword arguments") + # Check which grads are required + input_requires_grad = linear_op_ctx.requires_grad + weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad + # Quantization metadata with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() input_quantizer = None @@ -306,19 +314,20 @@ def fuser_forward( grad_input_quantizer = None if with_quantized_compute: recipe = FP8GlobalStateManager.get_fp8_recipe() - if not recipe.delayed() and not recipe.mxfp8(): - raise RuntimeError("Userbuffers is only supported with FP8 delayed scaling recipe") + if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())): + raise RuntimeError( + f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})" + ) input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) grad_output_quantizer = linear_op.get_quantizer("backward", 0) - prev_op = basic_op_prev_ops[0] - if prev_op is not None and prev_op.num_quantizers("backward") > 0 and recipe.delayed(): - grad_input_quantizer = prev_op.get_quantizer("backward", 0) + grad_input_quantizer = prev_op_grad_input_quantizer # Get autocast dtype if needed - dtype = None if torch.is_autocast_enabled(): dtype = torch.get_autocast_dtype("cuda") + else: + dtype = linear_op.weight.dtype # Userbuffers options if linear_op._userbuffers_options is None: @@ -330,6 +339,7 @@ def fuser_forward( weight=linear_op.weight, bias=bias, dtype=dtype, + device=linear_op.weight.device, tensor_parallel_mode=self.tensor_parallel_mode, tensor_parallel_group=self.tensor_parallel_group, tensor_parallel_size=self.tensor_parallel_size, @@ -338,12 +348,15 @@ def fuser_forward( input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=None, # Not supported + input_requires_grad=input_requires_grad, + weight_requires_grad=weight_requires_grad, ub_comm_name=linear_op._userbuffers_options["comm_name"], ) x_local = extra_outputs["input"] + w = extra_outputs["weight"] # Save state for backward pass - linear_op_ctx.save_for_backward(x_local) + linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer @@ -351,9 +364,11 @@ def fuser_forward( linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.dtype = dtype linear_op_ctx.input_dims = input_.size() - linear_op_ctx.input_requires_grad = input_.requires_grad - linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad - linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None + linear_op_ctx.input_requires_grad = input_requires_grad + linear_op_ctx.weight_requires_grad = weight_requires_grad + if bias_op is not None: + bias_op_ctx.with_quantized_compute = with_quantized_compute + bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer() return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index e87bfc945..7549cda71 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -12,7 +12,7 @@ import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe from transformer_engine.pytorch.ops.op import ( BasicOperation, FusibleOperation, @@ -20,6 +20,7 @@ ) from torch.utils.cpp_extension import IS_HIP_EXTENSION from transformer_engine.pytorch.ops.fused import ( + fuse_backward_bias_activation, fuse_backward_linear_add, fuse_forward_linear_bias_activation, fuse_forward_linear_bias_add, @@ -29,6 +30,11 @@ fuse_userbuffers_backward_linear, fuse_userbuffers_forward_linear, ) +from transformer_engine.pytorch.tensor.quantized_tensor import ( + prepare_for_saving, + restore_from_saved, +) + def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]: """Split tuple at index""" @@ -66,13 +72,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): def forward( func_ctx: Optional[torch.autograd.function.FunctionCtx], input_: torch.Tensor, - forward_ops: list[tuple[FusibleOperation, list[int]]], - backward_ops: list[tuple[FusibleOperation, list[int]]], - basic_ops: list[BasicOperation], + fuser: OperationFuser, basic_op_kwargs: list[dict[str, Any]], is_grad_enabled: bool, - num_params: int, - num_extra_inputs: int, *params_and_extra_inputs: torch.nn.Parameter, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Forward pass @@ -83,20 +85,12 @@ def forward( Context for PyTorch autograd function input_: torch.Tensor Input to first operation in pipeline - forward_ops: list of tuple - Forward pass operations and the indices of the - corresponding basic operations. The order should match - basic_ops. - backward_ops: list of tuple - Backward pass operations and the indices of the - corresponding basic operations. The order should be the - reverse of basic_ops. - basic_ops: list of BasicOperation - Basic operations + fuser: OperationFuser + Container for the pipeline of operations to run basic_op_kwargs: list of dict Keyword arguments to BasicOperation - num_params: int - Number of parameter tensors to include in autograd graph. + is_grad_enabled: bool + Should context be saved for backward *params_and_extra_inputs: torch.Tensor Other tensor inputs to include in autograd graph. Consists of parameter tensors, followed by extra operation inputs. @@ -111,26 +105,25 @@ def forward( """ # Operation autograd contexts - basic_op_ctxs = [OperationContext() for _ in range(len(basic_ops))] + basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)] + + # Mark input tensors as not deletable in backward + for tensor in (input_,) + params_and_extra_inputs: + tensor.do_not_clear = True # Unflatten list of parameters and extra tensor inputs - if len(params_and_extra_inputs) != num_params + num_extra_inputs: - raise ValueError( - f"Expected {num_params + num_extra_inputs} extra tensor arguments " - f"({num_params} parameters, {num_extra_inputs} extra inputs), " - f"but got {len(params_and_extra_inputs)}" - ) - _, extra_inputs = _split_tuple(params_and_extra_inputs, num_params) + extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :] basic_op_extra_inputs = [] - for op in basic_ops: + for op in fuser._basic_ops: xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs) basic_op_extra_inputs.append(xs) # Apply forward ops x = input_ requires_grad = is_grad_enabled and x.requires_grad - extra_outputs = [None for _ in range(len(basic_ops))] - for op, basic_op_idxs in forward_ops: + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + extra_outputs = [None] * fuser._num_basic_ops + for op, basic_op_idxs in fuser._forward_ops: # Check if backward op is required if is_grad_enabled: @@ -140,37 +133,38 @@ def forward( requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs) for idx in basic_op_idxs: basic_op_ctxs[idx].requires_grad = requires_grad - if requires_grad != x.requires_grad: - if requires_grad: - x.requires_grad_() - else: - x = x.detach() # Forward op extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] - prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs] - next_ops = [ - basic_ops[idx + 1] if (idx < len(basic_ops) - 1) else None for idx in basic_op_idxs - ] + prev_op_idx = basic_op_idxs[0] - 1 + prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None + prev_op_grad_input_quantizer = None + if prev_op is not None and with_quantized_compute: + prev_op_grad_input_quantizer = prev_op.get_grad_input_quantizer() + next_op_idx = basic_op_idxs[-1] + 1 + next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None + next_op_input_quantizer = None + if next_op is not None and with_quantized_compute: + next_op_input_quantizer = next_op.get_input_quantizer() + x, fused_op_extra_outputs = op.fuser_forward( [basic_op_ctxs[idx] for idx in basic_op_idxs], x, basic_op_extra_inputs=extra_inputs, - basic_op_prev_ops=prev_ops, - basic_op_next_ops=next_ops, + prev_op_grad_input_quantizer=prev_op_grad_input_quantizer, + next_op_input_quantizer=next_op_input_quantizer, basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs], ) - x.requires_grad_(requires_grad=requires_grad) for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): for y in ys: - y.requires_grad_(requires_grad=requires_grad) + y.requires_grad_(requires_grad) extra_outputs[idx] = ys # Flatten list of extra outputs extra_outputs_flat = [] for idx, ys in enumerate(extra_outputs): ys = list(ys) - num_extra_outputs = basic_ops[idx].num_extra_outputs + num_extra_outputs = fuser._basic_ops[idx].num_extra_outputs if len(ys) != num_extra_outputs: raise RuntimeError( f"Expected op {idx} to generate " @@ -191,19 +185,30 @@ def forward( range_end = len(to_save) ctx.to_save = None ctx._saved_tensors_range = (range_start, range_end) - func_ctx.save_for_backward(*to_save) + + # Save tensors for backward + if with_quantized_compute: + tensors_to_save, tensor_objects = prepare_for_saving(*to_save) + func_ctx.save_for_backward(*tensors_to_save) + func_ctx.tensor_objects = tensor_objects + else: + func_ctx.save_for_backward(*to_save) # Other context - func_ctx.backward_ops = backward_ops - func_ctx.basic_ops = basic_ops + func_ctx.backward_ops = fuser._backward_ops + func_ctx.basic_ops = fuser._basic_ops func_ctx.basic_op_ctxs = basic_op_ctxs - func_ctx.basic_op_num_params = [sum(1 for _ in op.parameters()) for op in basic_ops] - func_ctx.num_extra_inputs = num_extra_inputs + func_ctx.basic_op_num_params = fuser._num_list_basic_op_params + func_ctx.num_extra_inputs = fuser._num_extra_inputs func_ctx.num_extra_outputs = len(extra_outputs_flat) func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() + func_ctx.with_quantized_compute = with_quantized_compute + + x.requires_grad_(requires_grad) if extra_outputs_flat: return x, *extra_outputs_flat + return x @staticmethod @@ -220,9 +225,15 @@ def backward( basic_ops = func_ctx.basic_ops basic_op_ctxs = func_ctx.basic_op_ctxs + # Restore saved tensors + if func_ctx.with_quantized_compute: + saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors) + else: + saved_tensors = func_ctx.saved_tensors + # Unflatten list of saved tensors for ctx in basic_op_ctxs: - ctx.saved_tensors = func_ctx.saved_tensors[slice(*ctx._saved_tensors_range)] + ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)] ctx._saved_tensors_range = None # Unflatten list of extra tensor output grads @@ -297,13 +308,9 @@ def backward( return ( dx, # input_ - None, # forward_ops - None, # backward_ops - None, # basic_ops + None, # fuser None, # basic_op_kwargs None, # is_grad_enabled - None, # num_params - None, # num_extra_inputs *grad_params_flat, *grad_extra_inputs_flat, ) @@ -316,15 +323,19 @@ class OperationFuser: ---------- ops: list of FusibleOperation Pipeline of operations - fuse_ops: bool, default = `True` + fuse_ops: bool Whether to attempt fusing operations + recipe: Recipe, optional + Quantization recipe to use when fusing and executing operations. + Note: certain fusions may depend on what kind of recipe is being used. """ def __init__( self, ops: list[FusibleOperation], - fuse_ops: bool = True, + fuse_ops: bool, + recipe: Optional[Recipe], ) -> None: # Get list of basic operations @@ -346,14 +357,23 @@ def __init__( self._forward_ops = [(op, (idx,)) for idx, op in enumerate(self._basic_ops)] self._backward_ops = list(reversed(self._forward_ops)) + # Flag for checking if this is the first iteration + self._is_first_forward = True + # Fuse ops if needed + self.recipe = recipe if fuse_ops: self.fuse_ops() + # Flatten list of parameters + self._basic_op_params = [param for op in self._basic_ops for param in op.parameters()] + self._num_list_basic_op_params = [sum(1 for _ in op.parameters()) for op in self._basic_ops] + @classmethod def _fuse_forward_ops( cls, ops: list[tuple[FusibleOperation, list[int]]], + recipe: Optional[Recipe], # pylint: disable=unused-argument ) -> list[tuple[FusibleOperation, list[int]]]: """Attempt to fuse operations in forward pass""" if not IS_HIP_EXTENSION: @@ -366,17 +386,19 @@ def _fuse_forward_ops( def _fuse_backward_ops( cls, ops: list[tuple[FusibleOperation, list[int]]], + recipe: Optional[Recipe], ) -> list[tuple[FusibleOperation, list[int]]]: """Attempt to fuse operations in backward pass""" if not IS_HIP_EXTENSION: ops = fuse_userbuffers_backward_linear(ops) ops = fuse_backward_linear_add(ops) + ops = fuse_backward_bias_activation(ops, recipe) return ops def fuse_ops(self) -> None: """Attempt to fuse operations""" - self._forward_ops = self._fuse_forward_ops(self._forward_ops) - self._backward_ops = self._fuse_backward_ops(self._backward_ops) + self._forward_ops = self._fuse_forward_ops(self._forward_ops, self.recipe) + self._backward_ops = self._fuse_backward_ops(self._backward_ops, self.recipe) def __call__( self, @@ -384,17 +406,21 @@ def __call__( *extra_inputs: torch.Tensor, basic_op_kwargs: Optional[list[dict[str, Any]]] = None, ) -> torch.Tensor | tuple[torch.Tensor, ...]: + # Verify extra input count + if len(extra_inputs) != self._num_extra_inputs: + raise ValueError( + f"Expected {self._num_extra_inputs} extra inputs but got {len(extra_inputs)}" + ) # Initialization before forward pass - for op in self._basic_ops: - op.pre_forward() + if self._is_first_forward: + for op in self._basic_ops: + op.pre_first_forward(recipe=self.recipe) + self._is_first_forward = False # Canonicalize op kwargs if basic_op_kwargs is None: - basic_op_kwargs = [{} for _ in range(len(self._basic_ops))] - - # Flatten list of parameters - params = [param for op in self._basic_ops for param in op.parameters()] + basic_op_kwargs = [{}] * self._num_basic_ops # Fuser forward pass is_grad_enabled = torch.is_grad_enabled() @@ -406,14 +432,10 @@ def __call__( args = [None] args += ( input, - self._forward_ops, - self._backward_ops, - self._basic_ops, + self, basic_op_kwargs, is_grad_enabled, - len(params), - self._num_extra_inputs, - *params, + *self._basic_op_params, *extra_inputs, ) return forward_func(*args) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 802f4c25e..8490019e5 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -65,17 +65,27 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): def is_fused_op(self) -> bool: """Whether this op is the fusion of one or more basic ops""" - def pre_forward(self) -> None: + def pre_first_forward( + self, + *, + recipe: Optional[Recipe], + ) -> None: """Preprocessing before forward pass""" + def get_input_quantizer(self) -> Optional[Quantizer]: + """Get builder class for quantized input tensor""" + + def get_grad_input_quantizer(self) -> Optional[Quantizer]: + """Get builder class for quantized input's grad tensor""" + def fuser_forward( self, basic_op_ctxs: list[OperationContext], input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - basic_op_prev_ops: list[Optional[BasicOperation]], - basic_op_next_ops: list[Optional[BasicOperation]], + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: """Forward pass @@ -94,12 +104,10 @@ def fuser_forward( Input tensor basic_op_extra_inputs: list of torch.Tensor Extra tensor inputs to basic operations - basic_op_prev_ops: list of BasicOperation - Basic operations that preceed this operation's basic - operations - basic_op_next_ops: list of BasicOperation - Basic operations that follow this operation's basic - operations + prev_op_grad_input_quantizer: Quantizer, optional + The grad_input_quantizer of the preceeding operation + next_op_input_quantizer: Quantizer, optional + The input_quantizer of the following operation basic_op_kwargs: list of dict Keyword arguments to forward functions of basic operations. @@ -201,17 +209,23 @@ def num_quantizers( """ return 0 + def get_input_quantizer(self) -> Optional[Quantizer]: + if self.num_quantizers("forward") > 0: + return self.get_quantizer("forward", 0) + return None + + def get_grad_input_quantizer(self) -> Optional[Quantizer]: + if self.num_quantizers("backward") > 0: + return self.get_quantizer("backward", 0) + return None + def _reset_quantization_recipe_state( self, *, - recipe: Optional[Recipe] = None, + recipe: Recipe, ) -> None: """Construct state for quantization recipe""" - # Quantization recipe - if recipe is None: - recipe = FP8GlobalStateManager.get_fp8_recipe() - # Quantization recipe state for forward and backward pass self._fp8_metas = {"forward": None, "backward": None} self._quantizers = {"forward": [], "backward": []} @@ -246,14 +260,10 @@ def _reset_quantization_recipe_state( def _update_quantization_recipe_state( self, *, - recipe: Optional[Recipe] = None, + recipe: Recipe, ) -> None: """Make sure quantizer state matches quantization recipe""" - # Quantization recipe - if recipe is None: - recipe = FP8GlobalStateManager.get_fp8_recipe() - # Reset quantization state if needed if self._fp8_metas is None or self._quantizers is None: self._reset_quantization_recipe_state(recipe=recipe) @@ -327,7 +337,7 @@ def get_quantizer( """ if self._quantizers is None: - self._reset_quantization_recipe_state() + self._reset_quantization_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe()) return self._quantizers[mode][index] @torch.no_grad() @@ -378,19 +388,16 @@ def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None: self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale) self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history) - def pre_forward( + def pre_first_forward( self, *, - fp8_enabled: Optional[bool] = None, - fp8_recipe: Optional[Recipe] = None, + recipe: Optional[Recipe], ) -> None: """Preprocessing before forward pass""" # Initialize FP8 metadata if needed - if fp8_enabled is None: - fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() - if fp8_enabled: - self._update_quantization_recipe_state(recipe=fp8_recipe) + if recipe is not None: + self._update_quantization_recipe_state(recipe=recipe) if not FP8GlobalStateManager.fp8_graph_capturing(): if self.num_quantizers("forward"): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( @@ -407,8 +414,8 @@ def op_forward( ctx: OperationContext, input_: torch.Tensor, *, - prev_op: Optional[BasicOperation] = None, - next_op: Optional[BasicOperation] = None, + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], **kwargs: Any, ) -> torch.Tensor: """Forward pass @@ -419,10 +426,10 @@ def op_forward( Context to coordinate between forward and backward passes input_: torch.Tensor Input tensor - prev_op: BasicOperation, optional - Basic operation that preceeds this operation - next_op: BasicOperation, optional - Basic operation that follows this operation + prev_op_grad_input_quantizer: Quantizer, optional + The grad_input_quantizer of the preceeding operation + next_op_input_quantizer: Quantizer, optional + The input_quantizer of the following operation Returns ------- @@ -461,8 +468,8 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - basic_op_prev_ops: list[Optional[BasicOperation]], - basic_op_next_ops: list[Optional[BasicOperation]], + prev_op_grad_input_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, list[tuple[()]]]: if self.num_extra_inputs > 0 or self.num_extra_outputs > 0: @@ -475,8 +482,8 @@ def fuser_forward( output = self.op_forward( basic_op_ctxs[0], input_, - prev_op=basic_op_prev_ops[0], - next_op=basic_op_next_ops[0], + prev_op_grad_input_quantizer=prev_op_grad_input_quantizer, + next_op_input_quantizer=next_op_input_quantizer, **basic_op_kwargs[0], ) return output, [()] @@ -511,7 +518,9 @@ def forward( """Apply operation""" from .fuser import OperationFuser - return OperationFuser([self], fuse_ops=False)( + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None + return OperationFuser([self], fuse_ops=False, recipe=recipe)( input, *extra_inputs, basic_op_kwargs=[kwargs], @@ -621,7 +630,7 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: # Get op's quantizer state, initializing if needed if self._fp8_metas is None or self._fp8_metas[mode] is None: with fp8_autocast(fp8_recipe=state[mode]["recipe"]): - self._reset_quantization_recipe_state() + self._reset_quantization_recipe_state(recipe=state[mode]["recipe"]) fp8_meta = self._fp8_metas[mode] # Load extra items @@ -696,10 +705,16 @@ def __init__( def is_fused_op(self) -> bool: return True - def pre_forward(self) -> None: + def get_input_quantizer(self) -> Optional[Quantizer]: + return self.basic_ops[0].get_input_quantizer() + + def get_grad_input_quantizer(self) -> Optional[Quantizer]: + return self.basic_ops[-1].get_grad_input_quantizer() + + def pre_first_forward(self, *args, **kwargs) -> None: """Preprocessing before forward pass""" for op in self.basic_ops: - op.pre_forward() + op.pre_first_forward(*args, **kwargs) def forward( self, @@ -712,7 +727,9 @@ def forward( basic_op_kwargs = [{} for _ in range(len(self.basic_ops))] from .fuser import OperationFuser - return OperationFuser([self], fuse_ops=False)( + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None + return OperationFuser([self], fuse_ops=False, recipe=recipe)( input, *extra_inputs, basic_op_kwargs=basic_op_kwargs, diff --git a/transformer_engine/pytorch/ops/sequential.py b/transformer_engine/pytorch/ops/sequential.py index 3240bd73d..f18678309 100644 --- a/transformer_engine/pytorch/ops/sequential.py +++ b/transformer_engine/pytorch/ops/sequential.py @@ -10,6 +10,7 @@ import torch +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe from transformer_engine.pytorch.ops.op import FusibleOperation from transformer_engine.pytorch.ops.fuser import OperationFuser @@ -37,6 +38,9 @@ def __init__( self._module_groups: Optional[list[OperationFuser | torch.nn.Module]] self._module_groups = None + # Global state of last iteration + self._last_global_state = None + # Add modules if len(args) == 1 and isinstance(args[0], dict): for key, module in args[0].items(): @@ -143,6 +147,7 @@ def __add__(self, modules: Iterable[torch.nn.Modules]) -> Sequential: def _make_module_groups( cls, modules: Iterable[torch.nn.Module], + recipe: Optional[Recipe], ) -> list[OperationFuser | torch.nn.Module]: """Make list of modules, with fusible operations grouped together""" @@ -157,7 +162,7 @@ def _make_module_groups( groups.append(module) for idx, group in enumerate(groups): if isinstance(group, list): - groups[idx] = OperationFuser(group, fuse_ops=True) + groups[idx] = OperationFuser(group, fuse_ops=True, recipe=recipe) # Check if operations expect extra input or output tensors # Note: If any op has extra inputs or outputs, then the entire @@ -185,9 +190,19 @@ def forward( ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Forward pass""" + # Get current global state + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None + global_state = (with_quantized_compute, type(recipe)) + + # Reset module groups is global state changed + if self._last_global_state != global_state: + self._module_groups = None + self._last_global_state = global_state + # Create module groups if needed if self._module_groups is None: - self._module_groups = self._make_module_groups(self._modules.values()) + self._module_groups = self._make_module_groups(self._modules.values(), recipe) # Forward pass for each module group x = input diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 58f5c5289..209d116ef 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -351,7 +351,7 @@ def forward( if restore_shape is None: restore_shape = inp.shape num_tokens, hidden_size = restore_shape - num_experts = row_id_map.size(0) + num_experts = (row_id_map.size(1) - 1) // 2 with_probs = merging_probs is not None if with_probs: @@ -653,14 +653,20 @@ def forward( fp8_scale_inv = inp._scale_inv fake_dtype = inp.dtype inp = inp._data - output, row_id_map, permuted_probs = triton_permutation.sort_chunks_by_idx( - inp, + + row_id_map = triton_permutation.make_chunk_sort_map( split_sizes, sorted_idxs, + num_tokens, + num_splits, + ) + output, permuted_probs = triton_permutation.sort_chunks_by_map( + inp, + row_id_map, probs, num_tokens, hidden_size, - num_splits, + is_forward=True, ) if fp8: output = Float8Tensor( @@ -702,6 +708,7 @@ def backward( permuted_probs_grad, ctx.num_tokens, ctx.hidden_size, + is_forward=False, ) if fp8: act_grad = Float8Tensor( diff --git a/transformer_engine/pytorch/pyproject.toml b/transformer_engine/pytorch/pyproject.toml new file mode 100755 index 000000000..e5a4549db --- /dev/null +++ b/transformer_engine/pytorch/pyproject.toml @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +[build-system] +requires = ["setuptools>=61.0", "pip", "torch>=2.1"] + +# Use legacy backend to import local packages in setup.py +build-backend = "setuptools.build_meta:__legacy__" + diff --git a/transformer_engine/pytorch/router.py b/transformer_engine/pytorch/router.py new file mode 100644 index 000000000..db5114ae0 --- /dev/null +++ b/transformer_engine/pytorch/router.py @@ -0,0 +1,275 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +""" +Fused functions used in the MoE router +""" +import torch +import transformer_engine_torch as tex + + +class FusedTopkScoreFunction(torch.autograd.Function): + """ + Fused Topk with Score Function router. + Currently, only support softmax and sigmoid. + """ + + @staticmethod + def forward( + ctx, + logits: torch.Tensor, + topk: int, + use_pre_softmax: bool, + num_groups: int, + group_topk: int, + scaling_factor: float, + score_function: str, + expert_bias: torch.Tensor, + ): + # pylint: disable=missing-function-docstring + # Save the shape of the logits + tensor_shape = logits.shape + logits = logits.view(-1, tensor_shape[-1]) + # Get the metadata of the viewed logits + num_tokens = logits.size(0) + num_experts = logits.size(1) + probs, routing_map, intermediate_output = tex.fused_topk_with_score_function_fwd( + logits, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + expert_bias, + ) + # Restore the shape + probs = probs.view(tensor_shape) + ctx.save_for_backward(routing_map, intermediate_output) + ctx.num_tokens = num_tokens + ctx.num_experts = num_experts + ctx.use_pre_softmax = use_pre_softmax + ctx.topk = topk + ctx.scaling_factor = scaling_factor + ctx.score_function = score_function + return probs, routing_map + + @staticmethod + def backward(ctx, grad_probs, _): + # pylint: disable=missing-function-docstring + routing_map, intermediate_output = ctx.saved_tensors + # Save the shape of the grad_probs + tensor_shape = grad_probs.shape + # Adjust the shape of the grad_probs to 2D shape + grad_probs = grad_probs.contiguous().view(-1, tensor_shape[-1]) + grad_logits = tex.fused_topk_with_score_function_bwd( + ctx.num_tokens, + ctx.num_experts, + routing_map, + intermediate_output, + grad_probs, + ctx.topk, + ctx.use_pre_softmax, + ctx.scaling_factor, + ctx.score_function, + ) + # Restore the shape + grad_logits = grad_logits.view(tensor_shape) + return grad_logits, None, None, None, None, None, None, None + + +def fused_topk_with_score_function( + logits: torch.Tensor, + topk: int, + use_pre_softmax: bool, + num_groups: int, + group_topk: int, + scaling_factor: float, + score_function: str, + expert_bias: torch.Tensor, +): + """ + Fused topk with score function router. + Parameters + ---------- + logits: torch.Tensor + topk: int + use_pre_softmax: bool + if enabled, the computation order: softmax -> topk + num_groups: int + used in the group topk + group_topk: int + used in the group topk + scaling_factor: float + score_function: str + currently only support softmax and sigmoid + expert_bias: torch.Tensor + could be used in the sigmoid + + Returns + ------- + probs: torch.Tensor + routing_map: torch.Tensor + """ + if logits.dtype == torch.float64: + raise ValueError("Current TE does not support float64 router type") + return FusedTopkScoreFunction.apply( + logits, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + expert_bias, + ) + + +class FusedComputeScoresForMoEAuxLoss(torch.autograd.Function): + """ + Fused compute scores for MoE aux loss. + """ + + @staticmethod + def forward( + ctx, + logits: torch.Tensor, + topk: int, + score_function: str, + ): + # pylint: disable=missing-function-docstring + # Save the shape of the logits + tensor_shape = logits.shape + logits = logits.view(-1, tensor_shape[-1]) + # Get the metadata of the viewed logits + num_tokens = logits.size(0) + num_experts = logits.size(1) + scores, routing_map, intermediate_output = tex.fused_score_for_moe_aux_loss_fwd( + logits=logits, + topk=topk, + score_function=score_function, + ) + ctx.save_for_backward(intermediate_output) + ctx.topk = topk + ctx.score_function = score_function + ctx.num_tokens = num_tokens + ctx.num_experts = num_experts + return routing_map, scores + + @staticmethod + def backward(ctx, _, grad_scores): + # pylint: disable=missing-function-docstring + intermediate_output = ctx.saved_tensors[0] + # Save the shape of the grad_scores + tensor_shape = grad_scores.shape + # Adjust the shape of the grad_scores to 2D shape + grad_scores = grad_scores.contiguous().view(-1, tensor_shape[-1]) + grad_logits = tex.fused_score_for_moe_aux_loss_bwd( + num_tokens=ctx.num_tokens, + num_experts=ctx.num_experts, + intermediate_output=intermediate_output, + grad_scores=grad_scores, + topk=ctx.topk, + score_function=ctx.score_function, + ) + # Restore the shape + grad_logits = grad_logits.view(tensor_shape) + return grad_logits, None, None + + +def fused_compute_score_for_moe_aux_loss( + logits: torch.Tensor, + topk: int, + score_function: str, +): + """ + Fused compute scores for MoE aux loss, subset of the fused_topk_with_score_function. + Parameters + ---------- + logits: torch.Tensor + topk: int + score_function: str + currently only support softmax and sigmoid + + Returns + ------- + routing_map: torch.Tensor + scores: torch.Tensor + """ + return FusedComputeScoresForMoEAuxLoss.apply(logits, topk, score_function) + + +class FusedAuxLoss(torch.autograd.Function): + """ + Fused MoE aux loss. + """ + + @staticmethod + def forward( + ctx, + probs: torch.Tensor, + tokens_per_expert: torch.Tensor, + total_num_tokens: int, + num_experts: int, + topk: int, + coeff: float, + ): + # pylint: disable=missing-function-docstring + num_rows = probs.size(0) + num_cols = probs.size(1) + aux_loss, Const_buf = tex.fused_moe_aux_loss_fwd( + probs=probs, + tokens_per_expert=tokens_per_expert, + total_num_tokens=total_num_tokens, + num_experts=num_experts, + num_rows=num_rows, + num_cols=num_cols, + topk=topk, + coeff=coeff, + ) + ctx.save_for_backward(Const_buf, tokens_per_expert) + ctx.num_rows = num_rows + ctx.num_cols = num_cols + return aux_loss + + @staticmethod + def backward(ctx, grad_aux_loss): + # pylint: disable=missing-function-docstring + Const_buf, tokens_per_expert = ctx.saved_tensors + grad_probs = tex.fused_moe_aux_loss_bwd( + Const_buf=Const_buf, + tokens_per_expert=tokens_per_expert, + num_rows=ctx.num_rows, + num_cols=ctx.num_cols, + grad_aux_loss=grad_aux_loss, + ) + return grad_probs, None, None, None, None, None + + +def fused_moe_aux_loss( + probs: torch.Tensor, + tokens_per_expert: torch.Tensor, + total_num_tokens: int, + num_experts: int, + topk: int, + coeff: float, +): + """ + Fused MoE aux loss. + Parameters + ---------- + probs: torch.Tensor + tokens_per_expert: torch.Tensor + the number of tokens per expert + total_num_tokens: int + the total number of tokens, involved in the aux loss calculation + num_experts: int + topk: int + coeff: float + the coefficient of the aux loss + + Returns + ------- + aux_loss: torch.scalar + """ + return FusedAuxLoss.apply(probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff) diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index 2b99f7f95..0dbf9fdb5 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -34,7 +34,7 @@ from build_tools.utils import ( rocm_build, copy_common_headers, copy_hipify_tools, clear_hipify_tools_copy ) from build_tools.te_version import te_version -from build_tools.pytorch import setup_pytorch_extension +from build_tools.pytorch import setup_pytorch_extension, install_requirements, test_requirements os.environ["NVTE_PROJECT_BUILDING"] = "1" @@ -59,8 +59,8 @@ description="Transformer acceleration library - Torch Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, - install_requires=["einops"] if rocm_build() else ["torch"], - tests_require=[] if rocm_build() else ["numpy", "torchvision"], + install_requires=["einops"] if rocm_build() else install_requirements(), + tests_require=[] if rocm_build() else test_requirements(), ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): shutil.rmtree(common_headers_dir) diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index 42a618171..882650ffb 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -11,6 +11,7 @@ import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType +from transformer_engine_torch import Float8BlockScaleTensorFormat from ..quantized_tensor import QuantizedTensorBase @@ -37,10 +38,10 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): _rowwise_scale_inv: Optional[torch.Tensor] _columnwise_scale_inv: Optional[torch.Tensor] _is_2D_scaled: bool + _data_format: Float8BlockScaleTensorFormat def __new__( cls, - *args, rowwise_data: Optional[torch.Tensor], rowwise_scale_inv: Optional[torch.Tensor], columnwise_data: Optional[torch.Tensor], @@ -48,9 +49,14 @@ def __new__( fp8_dtype: TE_DType, quantizer: Quantizer, is_2D_scaled: bool, + data_format: Float8BlockScaleTensorFormat, + *args, **kwargs, ): - instance = super().__new__(cls, *args, **kwargs) + if cls is Float8BlockwiseQTensorBase: + instance = object.__new__(cls) + else: + instance = super().__new__(cls, *args, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data instance._quantizer = quantizer @@ -58,6 +64,7 @@ def __new__( instance._rowwise_scale_inv = rowwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv instance._is_2D_scaled = is_2D_scaled + instance._data_format = data_format return instance @@ -82,8 +89,13 @@ def get_metadata(self) -> Dict[str, Any]: "fp8_dtype": self._fp8_dtype, "quantizer": self._quantizer, "is_2D_scaled": self._is_2D_scaled, + "data_format": self._data_format, } + def _is_gemm_ready_format(self) -> bool: + """Whether data is in GEMM_READY format""" + return self._data_format == Float8BlockScaleTensorFormat.GEMM_READY + def prepare_for_saving( self, ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: @@ -136,34 +148,69 @@ def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch q_K = q.shape[-1] for i in range(len(q.shape) - 1): q_M *= q.shape[i] + inner_q_dimension_tiled = True + if self._is_gemm_ready_format(): + scales_tiled_dim, scales_untiled_dim = scale_inv.shape + inner_scale_dimension_tiled = False + scales_are_compact = False + else: + scales_untiled_dim, scales_tiled_dim = scale_inv.shape + inner_scale_dimension_tiled = True + scales_are_compact = True else: assert self._columnwise_data is not None, "No data to dequantize" q = self._columnwise_data scale_inv = self._columnwise_scale_inv - transpose_output = True - if len(q.shape) >= 1: - q_M = q.shape[0] - for i in range(1, len(q.shape)): - q_K *= q.shape[i] + scales_tiled_dim, scales_untiled_dim = scale_inv.shape + inner_scale_dimension_tiled = False + if self._is_gemm_ready_format(): + inner_q_dimension_tiled = True + transpose_output = True + if len(q.shape) >= 1: + q_M = q.shape[0] + for i in range(1, len(q.shape)): + q_K *= q.shape[i] + scales_are_compact = False + else: + inner_q_dimension_tiled = False + transpose_output = False + if len(q.shape) >= 1: + q_K = q.shape[-1] + for i in range(len(q.shape) - 1): + q_M *= q.shape[i] + scales_are_compact = True orig_shape = q.shape q = q.reshape(q_M, q_K) - k_tiles, scale_m = scale_inv.shape - if q_K % block_len != 0: - k_pad_amount = (block_len - (q_K % block_len)) % block_len - q = torch.nn.functional.pad( - q, (0, k_pad_amount, 0, 0), mode="constant", value=0 - ).contiguous() - _, padded_K = q.shape - q_tiled = q.reshape(q_M, k_tiles, block_len) - if scale_m > q_M: - # scale_m is 4 element aligned. + if inner_q_dimension_tiled: + if q_K % block_len != 0: + k_pad_amount = (block_len - (q_K % block_len)) % block_len + q = torch.nn.functional.pad( + q, (0, k_pad_amount, 0, 0), mode="constant", value=0 + ).contiguous() + padded_M, padded_K = q.shape + q_tiled = q.reshape(q_M, scales_tiled_dim, block_len) + else: + if q_M % block_len != 0: + m_pad_amount = (block_len - (q_M % block_len)) % block_len + q = torch.nn.functional.pad( + q, (0, 0, 0, m_pad_amount), mode="constant", value=0 + ).contiguous() + padded_M, padded_K = q.shape + q_tiled = q.reshape(scales_tiled_dim, block_len, q_K) + if not scales_are_compact and scales_untiled_dim > q_M: + # untiled scale dimension is 4 element aligned. scale_inv = scale_inv[:, :q_M].contiguous() - dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, k_tiles, 1) + if scales_are_compact and inner_scale_dimension_tiled: + dq_scale = scale_inv.contiguous().reshape(q_M, scales_tiled_dim, 1) + elif scales_are_compact and not inner_scale_dimension_tiled: + dq_scale = scale_inv.contiguous().reshape(scales_tiled_dim, 1, q_K) + else: + dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, scales_tiled_dim, 1) torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype] result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale - if padded_K != q_K: - result = result.reshape(q_M, padded_K)[:, :q_K] + if padded_M != q_M or padded_K != q_K: + result = result.reshape(padded_M, padded_K)[:q_M, :q_K] result = result.to(dtype) if len(orig_shape) == 0: result = result.reshape([]) @@ -182,6 +229,12 @@ def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: if not self._is_2D_scaled: return self._dequantize_vectorwise(dtype=dtype) + if not self._is_gemm_ready_format(): + raise NotImplementedError( + "Dequantize is only supported with GEMM_READY data format, " + f"but found _data_format={self._data_format}" + ) + def format_scale_as_logical_shape(q_K, scales, block_len): # The GEMM for 2D blocks required padding in the scales. derived_scale_k_shape = math.ceil(q_K / block_len) @@ -247,6 +300,8 @@ def size(self, *args, **kwargs): if self._rowwise_data is not None: return self._rowwise_data.size(*args, **kwargs) dims = list(self._columnwise_data.size(*args, **kwargs)) + if not self._is_gemm_ready_format(): # compact format + return torch.Size(dims) reordered = [] for i in range(1, len(dims)): reordered.append(dims[i]) @@ -285,6 +340,13 @@ def _create_columnwise(self): w = min(self._columnwise_scale_inv.shape[1], columnwise_scale_inv.shape[1]) self._columnwise_scale_inv[0:h, 0:w].copy_(columnwise_scale_inv[0:h, 0:w]) + def _transpose_columnwise_data(self): + """Plainly transpose the columnwise data and scale inv.""" + if self._columnwise_data is not None: + self._columnwise_data = tex.fp8_transpose( + self._columnwise_data, self._fp8_dtype, out=None + ) + def __repr__(self): if self._rowwise_data is not None: data = self.dequantize() diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index 4124511cd..c0dc6e651 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -143,6 +143,23 @@ def size(self, *args, **kwargs): size = self._transpose.size(*args, **kwargs) return torch.Size([size[-1], math.prod(size[:-1])]) + def view(self, shape: torch.Size): + # pylint: disable=missing-function-docstring + out_data = self._data.view(shape) + out_transpose = None if self._transpose_invalid else self._transpose + if out_transpose is not None: + out_transpose_shape = out_transpose.size() + if out_transpose_shape[0] != shape[-1] or out_transpose_shape[1:] != shape[:-1]: + out_transpose = None + + return Float8TensorBase( + data=out_data, + fp8_scale_inv=self._scale_inv, + fp8_dtype=self._fp8_dtype, + data_transpose=out_transpose, + quantizer=self._quantizer, + ) + def __repr__(self): return ( "Float8TensorBase(" diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index 150d65a51..211a20be8 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -6,6 +6,8 @@ from __future__ import annotations from typing import Optional, Dict, Any, Tuple +from collections.abc import Iterable +import math import torch import os @@ -72,16 +74,19 @@ class MXFP8TensorBase(QuantizedTensorBase): def __new__( cls, - *args, rowwise_data: Optional[torch.Tensor], - rowwise_scale_inv: torch.Tensor, + rowwise_scale_inv: Optional[torch.Tensor], columnwise_data: Optional[torch.Tensor], - columnwise_scale_inv: torch.Tensor, + columnwise_scale_inv: Optional[torch.Tensor], fp8_dtype: TE_DType, - quantizer: Optional[Quantizer] = None, + quantizer: Optional[Quantizer], + *args, **kwargs, ): - instance = super().__new__(cls, *args, **kwargs) + if cls is MXFP8TensorBase: + instance = object.__new__(cls) + else: + instance = super().__new__(cls, *args, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data instance._quantizer = quantizer @@ -151,6 +156,51 @@ def size(self, *args, **kwargs): return self._rowwise_data.size(*args, **kwargs) return self._columnwise_data.size(*args, **kwargs) + def view(self, shape: torch.Size): + # pylint: disable=missing-function-docstring + + # Return input tensor if view not needed + cur_shape = self.size() + if shape is None or shape == cur_shape: + return self + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(cur_shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != cur_shape[-1]: + raise RuntimeError( + "MXFP8Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(cur_shape)} to {tuple(shape)})" + ) + + # Construct new tensor + cur_rowwise_data = self._rowwise_data + cur_columnwise_data = self._columnwise_data + new_rowwise_data = None + new_columnwise_data = None + if cur_rowwise_data is not None: + new_rowwise_data = cur_rowwise_data.view(*shape) + if cur_columnwise_data is not None: + new_columnwise_data = cur_columnwise_data.view(*shape) + + return MXFP8TensorBase( + rowwise_data=new_rowwise_data, + rowwise_scale_inv=self._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=self._columnwise_scale_inv, + fp8_dtype=self._fp8_dtype, + quantizer=self._quantizer, + ) + def __repr__(self): data_rowwise = self.dequantize() diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ce4137c66..bac715949 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -4,13 +4,15 @@ """Tensor class with FP8 data quantized with NxN tiles""" from __future__ import annotations -from typing import Optional, Tuple, Iterable +from typing import Optional, Tuple, Iterable, Union import math import torch import transformer_engine_torch as tex - from transformer_engine_torch import DType as TE_DType +from transformer_engine_torch import Float8BlockScaleTensorFormat + +from transformer_engine.common.recipe import Float8BlockScaling, Recipe from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from ..utils import devices_match, round_up_to_nearest_multiple @@ -32,6 +34,8 @@ class Float8BlockQuantizer(Quantizer): amax_epsilon: float force_pow_2_scales: bool block_scaling_dim: int + # Whether to produce tensors that will be used in all-gather + all_gather_usage: bool def __init__( self, @@ -42,6 +46,7 @@ def __init__( amax_epsilon: float = 0.0, force_pow_2_scales: bool = True, block_scaling_dim: int = 2, + all_gather_usage: bool = False, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) self.dtype = fp8_dtype @@ -49,6 +54,7 @@ def __init__( self.force_pow_2_scales = force_pow_2_scales self.amax_epsilon = amax_epsilon self.block_scaling_dim = block_scaling_dim + self.all_gather_usage = all_gather_usage def update_quantized( self, @@ -125,22 +131,36 @@ def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, M *= shape[i] if len(shape) > 0: K = shape[-1] + # 2D 128x128 quantization block scaling + # CuBLAS requries 128x128 scaling factor to be padded + # currently rowwise and columnwise format option doesn't apply to 2D scaling if self.block_scaling_dim == 2: if columnwise: outer = math.ceil(K / self.block_len) inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4) return (outer, inner) + # rowwise outer = math.ceil(M / self.block_len) inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4) return (outer, inner) + # 1D 1x128 quantization block scaling + # CuBLAS requries 1x128 scaling factor to be padded and transposed assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported" if columnwise: + columnwise_compact = self.all_gather_usage outer = math.ceil(M / self.block_len) - inner = round_up_to_nearest_multiple(K, 4) + inner = round_up_to_nearest_multiple(K, 4) if not columnwise_compact else K + # GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS + # for COMPACT case, since we apply 1x128 scaling here without transposing columnwise data, scaling factor is also [outer, inner] + # so no need to swap inner outer here return (outer, inner) + # rowwise + rowwise_compact = self.all_gather_usage outer = math.ceil(K / self.block_len) - inner = round_up_to_nearest_multiple(M, 4) - return (outer, inner) + inner = round_up_to_nearest_multiple(M, 4) if not rowwise_compact else M + # GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS need + # for COMPACT case, since we apply 128x1 scaling, scaling block applies to inner dim, so we need to swap outer and inner here + return (outer, inner) if not rowwise_compact else (inner, outer) def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: """Calculate the shape of a tensor after columnwise permutation. @@ -162,15 +182,25 @@ def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: """ if len(shape) == 0: return tuple() + # currently columnwise format option only applies to 1D quantizer + # for 2D scaling, columnwise format should always be GEMM_READY_DATA_AND_SCALES + # since currently 2D scaling only applies to module weights + if self.block_scaling_dim == 1 and self.all_gather_usage: + return shape colwise_shape = [shape[-1]] for i in range(len(shape) - 1): colwise_shape.append(shape[i]) return tuple(colwise_shape) - # TODO(kwyss): With FP8 gather support, we need to implement a - # shape/layout/swizzle check to know whether FP8 gather works - # cleanly by stacking data without aliasing tiles and whether - # the scales also stack on the proper dimensions. + def is_quantizable(self, inp: torch.Tensor) -> bool: + """Returns whether or not given inp can be quantized""" + if inp.ndim < 2: + return False + if inp.shape[-1] % self.block_len != 0: + return False + if math.prod(inp.shape[:-1]) % self.block_len != 0: + return False + return True def make_empty( self, @@ -184,6 +214,12 @@ def make_empty( if device is None: device = torch.device("cuda") + data_format = ( + tex.Float8BlockScaleTensorFormat.COMPACT + if self.all_gather_usage + else tex.Float8BlockScaleTensorFormat.GEMM_READY + ) + # Allocate FP8 data data = None scale_inv = None @@ -221,6 +257,7 @@ def make_empty( columnwise_scale_inv=columnwise_scale_inv, quantizer=self, is_2D_scaled=self.block_scaling_dim == 2, + data_format=data_format, requires_grad=requires_grad, ) @@ -229,6 +266,9 @@ def calibrate(self, tensor: torch.Tensor) -> None: # where state from an estimator influences distribution parameters. pass + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return Float8BlockScaling + class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. @@ -255,11 +295,43 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): holds configuration about quantization and dequantization modes. """ + # NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensorBase with positional args, + # which significantly reduces the Pybind11 overhead when calling the constructor from C++. + def __new__( + cls, + *args, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: Optional[torch.Tensor], + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: Optional[torch.Tensor], + fp8_dtype: TE_DType, + quantizer: Quantizer, + is_2D_scaled: bool, + data_format: tex.Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat.GEMM_READY, + **kwargs, + ): + instance = super().__new__( + cls, + rowwise_data, + rowwise_scale_inv, + columnwise_data, + columnwise_scale_inv, + fp8_dtype, + quantizer, + is_2D_scaled, + data_format, + *args, + **kwargs, + ) + + return instance + def __repr__(self, *, tensor_contents=None): return ( f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," f" is_2D_scaled={self._is_2D_scaled}," - f" data={self.dequantize(dtype=self.dtype)})" + f" data={self.dequantize(dtype=self.dtype)})," + f" data_format={self._data_format}" ) def _get_quantizer(self) -> Quantizer: @@ -392,6 +464,7 @@ def _make_in_reduce_ex( dtype: torch.dtype, quantizer: Quantizer, is_2D_scaled: bool, + data_format: tex.Float8BlockScaleTensorFormat, ) -> Float8BlockwiseQTensor: """Build Float8BlockwiseQTensor, for use in __reduce__ @@ -409,6 +482,7 @@ def _make_in_reduce_ex( dtype=dtype, quantizer=quantizer, is_2D_scaled=is_2D_scaled, + data_format=data_format, ) def __reduce_ex__(self, protocol: int) -> tuple: @@ -425,6 +499,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self.dtype, self._quantizer, self._is_2D_scaled, + self._data_format, ), ) @@ -450,6 +525,7 @@ def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): dst._fp8_dtype = src._fp8_dtype dst._rowwise_scale_inv = src._rowwise_scale_inv dst._columnwise_scale_inv = src._columnwise_scale_inv + dst._data_format = src._data_format # Check that tensor dimensions match if ( @@ -497,6 +573,13 @@ def forward( ) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring + # Check for invalid configurations + if not tensor._is_gemm_ready_format(): + raise NotImplementedError( + "View is only supported with GEMM_READY data format, " + f"but found data_format={tensor._data_format}" + ) + # Return input tensor if shape is not provided ctx.shape = tensor.shape if shape is None: @@ -565,6 +648,14 @@ def backward( # pylint: disable=missing-function-docstring if isinstance(grad, Float8BlockwiseQTensor): + + # Check for invalid configurations + if not grad._is_gemm_ready_format(): + raise NotImplementedError( + "View is only supported with GEMM_READY data format, " + f"but found data_format={grad._data_format}" + ) + new_data = ( grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None ) @@ -604,6 +695,13 @@ def forward( ) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring + # Check for invalid configurations + if not tensor._is_gemm_ready_format(): + raise NotImplementedError( + "Reshape is only supported with GEMM_READY data format, " + f"but found data_format={tensor._data_format}" + ) + # Return input tensor if shape is not provided ctx.shape = tensor.shape if shape is None: @@ -671,6 +769,14 @@ def backward( # pylint: disable=missing-function-docstring if isinstance(grad, Float8BlockwiseQTensor): + + # Check for invalid configurations + if not grad._is_gemm_ready_format(): + raise NotImplementedError( + "Reshape is only supported with GEMM_READY data format, " + f"but found data_format={grad._data_format}" + ) + new_rowwise_data = None new_columnwise_data = None if grad._rowwise_data is not None: diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index cce37dde8..742096dca 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -7,14 +7,15 @@ """Tensor class with FP8 data""" from __future__ import annotations import os -from typing import Optional, Tuple, Iterable +from typing import Optional, Tuple, Iterable, Union import warnings from torch.utils.cpp_extension import IS_HIP_EXTENSION import torch import transformer_engine_torch as tex - from transformer_engine_torch import DType as TE_DType + +from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe from ..utils import canonicalize_process_group, devices_match from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc @@ -177,6 +178,24 @@ def create_tensor_from_data( quantizer=self, ) + def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: + """Function using primitives with ONNX defined translations.""" + # Q inputs are currently constrained to FP32 due to a similar limitation in ORT + # custom ops, so cast the input if needed. + if tensor.dtype != torch.float32: + tensor = tensor.to(torch.float32) + data = torch.ops.tex.fp8_quantize(tensor, self.scale.item()) + return self.create_tensor_from_data(data, fake_dtype=torch.float32) + + def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor: + """Function using primitives with ONNX defined translations.""" + out = torch.ops.tex.fp8_dequantize(tensor._data, self.scale.item()) + out = out.to(tensor.dtype) + return out + + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return DelayedScaling + class Float8CurrentScalingQuantizer(Quantizer): """Builder class for FP8 tensors with per-tensor current scaling @@ -340,10 +359,25 @@ def create_tensor_from_data( quantizer=self, ) + def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: + """Function using primitives with ONNX defined translations.""" + raise NotImplementedError( + "Float8CurrentScalingQuantizer does not support ONNX quantization yet." + ) + + def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor: + """Function using primitives with ONNX defined translations.""" + raise NotImplementedError( + "Float8CurrentScalingQuantizer does not support ONNX dequantization yet." + ) + def _canonicalized_amax_reduction_group(self) -> dist_group_type: """Get process group for amax reduction""" return canonicalize_process_group(self.amax_reduction_group) + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return Float8CurrentScaling + class Float8Tensor(Float8TensorBase, QuantizedTensor): """Experimental tensor class with FP8 data diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 920b7d6b0..b3504b175 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -7,7 +7,7 @@ from collections.abc import Iterable import math import os -from typing import Optional, Tuple +from typing import Optional, Tuple, Union from torch.utils.cpp_extension import IS_HIP_EXTENSION import torch @@ -15,8 +15,9 @@ from ..triton_kernels.cast import te_quantize_triton import transformer_engine_torch as tex - from transformer_engine_torch import DType as TE_DType + +from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe from ..constants import MXFP8_BLOCK_SCALING_SIZE from ..utils import devices_match, round_up_to_nearest_multiple @@ -145,6 +146,37 @@ def calibrate(self, tensor: torch.Tensor) -> None: # TODO(ksivamani): No calibration needed for mxfp8? pass + def create_tensor_from_data( + self, + data: torch.Tensor, + scale_inv: torch.Tensor, + fake_dtype: torch.dtype, + fp8_dtype: TE_DType = tex.DType.kFloat8E4M3, + ) -> MXFP8Tensor: + """Create a new MXFP8Tensor from data and scale_inv.""" + return MXFP8Tensor( + shape=data.shape, + dtype=fake_dtype, + rowwise_data=data, + rowwise_scale_inv=scale_inv, + columnwise_data=None, + columnwise_scale_inv=None, + fp8_dtype=fp8_dtype, + quantizer=self, + ) + + def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: + if tensor.dtype != torch.float32: + tensor = tensor.to(dtype=torch.float32) + data, scale_inv = torch.ops.tex.mxfp8_quantize(tensor) + return self.create_tensor_from_data(data, scale_inv, fake_dtype=torch.float32) + + def onnx_dequantize(self, tensor: Union[MXFP8TensorBase, MXFP8Tensor]) -> torch.Tensor: + return torch.ops.tex.mxfp8_dequantize(tensor._rowwise_data, tensor._rowwise_scale_inv) + + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return MXFP8BlockScaling + class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): """Experimental tensor class with FP8 data @@ -171,6 +203,32 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): """ + # NOTE: We reorder the *args so that we can instantiate a MXFP8TensorBase with positional args, + # which significantly reduces the Pybind11 overhead when calling the constructor from C++. + def __new__( + cls, + *args, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: Optional[torch.Tensor], + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: Optional[torch.Tensor], + fp8_dtype: TE_DType, + quantizer: Optional[Quantizer], + **kwargs, + ): + instance = super().__new__( + cls, + rowwise_data, + rowwise_scale_inv, + columnwise_data, + columnwise_scale_inv, + fp8_dtype, + quantizer, + *args, + **kwargs, + ) + return instance + def __repr__(self, *, tensor_contents=None): return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})" @@ -308,6 +366,7 @@ def _make_in_reduce_ex( fp8_dtype: TE_DType, dtype: torch.dtype, shape: torch.shape, + quantizer: Optional[Quantizer] = None, ) -> MXFP8Tensor: """Build MXFP8Tensor, for use in __reduce__ @@ -323,6 +382,7 @@ def _make_in_reduce_ex( columnwise_scale_inv=columnwise_scale_inv, dtype=dtype, shape=shape, + quantizer=quantizer, ) def __reduce_ex__(self, protocol: int) -> tuple: @@ -337,6 +397,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self._fp8_dtype, self.dtype, self.shape, + self._quantizer, ), ) @@ -390,6 +451,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Quantize to FP8 assert self._quantizer is not None, "Can't quantize without a quantizer" + self._quantizer.internal = False self.data = self._quantizer.quantize(tensor) if self.requires_grad != tensor.requires_grad: self.requires_grad_(requires_grad=tensor.requires_grad) @@ -442,8 +504,7 @@ def forward( if tensor._rowwise_data is not None: new_rowwise_data = tensor._rowwise_data.view(*shape) if tensor._columnwise_data is not None: - columnwise_shape = [shape[-1]] + list(shape[:-1]) - new_columnwise_data = tensor._columnwise_data.view(columnwise_shape) + new_columnwise_data = tensor._columnwise_data.view(*shape) return MXFP8Tensor( shape, tensor.dtype, @@ -467,7 +528,7 @@ def backward( grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None ) if grad._columnwise_data is not None: - new_columnwise_data = grad._columnwise_data.view(ctx.shape[-1], -1) + new_columnwise_data = grad._columnwise_data.view(*ctx.shape) else: new_columnwise_data = None dgrad = MXFP8Tensor( @@ -528,8 +589,7 @@ def forward( if tensor._rowwise_data is not None: new_rowwise_data = tensor._rowwise_data.reshape(*shape) if tensor._columnwise_data is not None: - columnwise_shape = [shape[-1]] + list(shape[:-1]) - new_columnwise_data = tensor._columnwise_data.view(columnwise_shape) + new_columnwise_data = tensor._columnwise_data.view(*shape) return MXFP8Tensor( shape, @@ -555,8 +615,7 @@ def backward( if grad._rowwise_data is not None: new_rowwise_data = grad._rowwise_data.view(*ctx.shape) if grad._columnwise_data is not None: - columnwise_shape = [ctx.shape[-1]] + list(ctx.shape[:-1]) - new_columnwise_data = grad._columnwise_data.view(columnwise_shape) + new_columnwise_data = grad._columnwise_data.view(*ctx.shape) dgrad = MXFP8Tensor( ctx.shape, grad.dtype, diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index e521d4279..3a6eb7290 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -12,11 +12,13 @@ from typing import Optional, Tuple, Iterable, Any, Dict, Union import abc import copy +import warnings import torch from torch.utils._pytree import tree_map import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe class QuantizedTensorBase: @@ -35,6 +37,8 @@ class QuantizedTensorBase: XTensor should only implement the functionality needed to behave like regular torch.Tensor (liek __torch_dispatch__).""" + _quantizer: Optional[Quantizer] + def update_usage( self, rowwise_usage: Optional[bool] = None, @@ -73,6 +77,14 @@ def restore_from_saved( f"{self.__class__.__name__} class does not implement restore_from_saved function" ) + def update_quantizer(self, quantizer: Quantizer): + """Update quantizer for the tensor""" + if self._quantizer is None: + raise RuntimeError("To be updated, quantizer must be set") + if self._quantizer is not quantizer: + warnings.warn("Quantizer is being updated, this may affect model behavior") + self._quantizer = quantizer + def prepare_for_saving( *tensors: Union[torch.Tensor, QuantizedTensorBase], @@ -242,6 +254,16 @@ def copy(self) -> Quantizer: """Create shallow copy""" return copy.copy(self) + def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: + """Symbolic function for ONNX export""" + + def onnx_dequantize(self, tensor) -> torch.Tensor: + """Symbolic function for ONNX export""" + + @abc.abstractmethod + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + """Returns recipe class that is compatible with this quantizer""" + class _QuantizeFunc(torch.autograd.Function): """Cast to FP8 from other dtype""" diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 59bf44938..23f56da5d 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -193,7 +193,7 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo quantizer.update_quantized(master_weight.view(1, -1), shard_model_weight_fp8) if len(amaxes) > 0: - dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=amaxes[0].device) + dummy_overflow_buf = torch.zeros(1, dtype=torch.int, device=amaxes[0].device) # Reduce amaxes. packed_amaxes = torch.empty(len(amaxes), dtype=torch.float32, device=amaxes[0].device) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 1f5e6a3ee..b9d59f496 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -33,6 +33,7 @@ dist_group_type, ) from transformer_engine.pytorch.distributed import get_distributed_world_size +from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.module.base import TransformerEngineBaseModule @@ -179,12 +180,12 @@ class TransformerLayer(torch.nn.Module): The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. - attn_input_format: {'sbhd', 'bshd'}, default = 'sbhd' + attn_input_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd' This controls whether the dimensions of the - intermediate hidden states is 'batch first' ('bshd') or - 'sequence first' ('sbhd'). `s` stands for the sequence - length, `b` batch size, `h` the number of heads, `d` - head size. Note that these formats are very closely + intermediate hidden states is 'sequence first' ('sbhd'), 'batch first' ('bshd'), + or 'token first' ('thd'). `s` stands for the sequence length, `b` batch size, + `t` the total number of tokens, `h` the number of heads, `d` head size. + Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules. name: str, default = `None` @@ -235,6 +236,14 @@ class TransformerLayer(torch.nn.Module): parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument `fuse_wgrad_accumulation`. + use_qk_norm: bool, default = 'False' + if set to `True`, L2 normalization is applied to query and key tensors + after RoPE (if applicable) but before attention computation. + This follows the Llama4 approach for QK normalization to improve + training stability and model performance. + qk_norm_eps: float, default = 1e-6 + epsilon value for L2 normalization of query and key tensors. + Only used when `use_qk_norm` is True. """ def __init__( @@ -284,6 +293,8 @@ def __init__( device: Union[torch.device, str] = "cuda", attn_input_format: str = "sbhd", name: str = None, + use_qk_norm: bool = False, + qk_norm_eps: float = 1e-6, ) -> None: super().__init__() @@ -373,6 +384,8 @@ def __init__( "ub_overlap_rs": ub_overlap_rs, "ub_overlap_rs_dgrad": ub_overlap_rs_dgrad, "qkv_format": self.attn_input_format, + "seq_length": seq_length, + "micro_batch_size": micro_batch_size, } self.self_attention = MultiheadAttention( @@ -384,6 +397,8 @@ def __init__( return_bias=not self.parallel_attention_mlp, normalization=normalization, device=device, + use_qk_norm=use_qk_norm, + qk_norm_eps=qk_norm_eps, name=name + ".self_attention" if name is not None else None, ) @@ -398,6 +413,8 @@ def __init__( return_bias=True, normalization=normalization, device=device, + use_qk_norm=use_qk_norm, + qk_norm_eps=qk_norm_eps, name=name + ".inter_attention" if name is not None else None, ) @@ -552,6 +569,8 @@ def forward( alibi_slopes: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None, + cu_seqlens_q_padded: Optional[torch.Tensor] = None, + cu_seqlens_kv_padded: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_kv: Optional[int] = None, fast_zero_fill: bool = True, @@ -568,88 +587,99 @@ def forward( Parameters ---------- hidden_states : torch.Tensor - Input tensor. + Input tensor. attention_mask : Optional[torch.Tensor], default = `None` - Boolean tensor used to mask out self-attention softmax input. It should be - in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable - to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for "`arbitrary`" - mask. It should be `None` for causal masks and "`no_mask`" type. - A `True` value means the corresponding position is masked out and - a `False` means that position is allowed to participate in attention. + Boolean tensor used to mask out self-attention softmax input. It should be + in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable + to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for "`arbitrary`" + mask. It should be `None` for causal masks and "`no_mask`" type. + A `True` value means the corresponding position is masked out and + a `False` means that position is allowed to participate in attention. self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', - 'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'}, - default = `causal` - Type of attention mask passed into softmax operation for encoder. - By default, causal masks are aligned to the top left corner of - the softmax matrix. When "`bottom_right`" is specified in the mask type, - causal masks are aligned to the bottom right corner. + 'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'}, + default = `causal` + Type of attention mask passed into softmax operation for encoder. + By default, causal masks are aligned to the top left corner of + the softmax matrix. When "`bottom_right`" is specified in the mask type, + causal masks are aligned to the bottom right corner. window_size: Optional[Tuple[int, int]], default = `None` - Sliding window size for local attention in encoder. + Sliding window size for local attention in encoder. encoder_output : Optional[torch.Tensor], default = `None` - Output of the encoder block to be fed into the decoder block if using - `layer_type="decoder"`. + Output of the encoder block to be fed into the decoder block if using + `layer_type="decoder"`. enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], - default = `None`. Boolean tensors used to mask out inter-attention softmax input if - using `layer_type="decoder"`. It should be a tuple of two masks in - [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks. - It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] - for "`arbitrary`" mask. It should be `None` for causal masks and "`no_mask`". - A `True` value means the corresponding position is masked out and a `False` - means that position is allowed to participate in attention. + default = `None`. Boolean tensors used to mask out inter-attention softmax input if + using `layer_type="decoder"`. It should be a tuple of two masks in + [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks. + It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] + for "`arbitrary`" mask. It should be `None` for causal masks and "`no_mask`". + A `True` value means the corresponding position is masked out and a `False` + means that position is allowed to participate in attention. enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, - default = `None` - Type of attention mask passed into softmax operation for decoder. + default = `None` + Type of attention mask passed into softmax operation for decoder. enc_dec_window_size: Optional[Tuple[int, int]], default = `None` - Sliding window size for local attention in decoder. + Sliding window size for local attention in decoder. is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - * it also allows skipping gradient accumulation during the - first microbatch (since it is the first gradient being - produced) + During training using either gradient accumulation or + pipeline parallelism a minibatch of data is further split + into microbatches. Between the microbatches of the same minibatch + the model weights are not updated. Setting this parameter indicates + whether the current microbatch is the first in a minibatch or not. + When set, this parameter enables additional optimizations: + + * during FP8 training, it allows caching of the FP8 versions of + the weights + * it also allows skipping gradient accumulation during the + first microbatch (since it is the first gradient being + produced) checkpoint_core_attention: bool, default = `False` - If true, forward activations for core attention are recomputed - during the backward pass in order to save memory that would - otherwise be occupied to store the forward activations until - backprop. + If true, forward activations for core attention are recomputed + during the backward pass in order to save memory that would + otherwise be occupied to store the forward activations until + backprop. rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None` - Embeddings for query and key tensors for applying rotary position - embedding. By default no input embedding is applied. + Embeddings for query and key tensors for applying rotary position + embedding. By default no input embedding is applied. core_attention_bias_type: str, default = `no_bias` - Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`} + Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`} core_attention_bias: Optional[torch.Tensor], default = `None` - Bias tensor for Q * K.T + Bias tensor for Q * K.T alibi_slopes: Optional[torch.Tensor], default = `None` - ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. - It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) - to the attention score of query i and key j. + ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. + It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) + to the attention score of query i and key j. cu_seqlens_q: Optional[torch.Tensor], default = `None` - Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, - with shape [batch_size + 1] and dtype torch.int32. + Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, + with shape [batch_size + 1] and dtype torch.int32. + Used by encoders, or decoders' self-attention. cu_seqlens_kv: Optional[torch.Tensor], default = `None` - Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` - and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` + and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + Used by decoders' cross-attention. + cu_seqlens_q_padded: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`, + with shape [batch_size + 1] and dtype torch.int32. Set to `cu_seqlens_q` if None. + Used by encoders, or decoders' self-attention. + cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (with offset) in a batch for `key_layer` + and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + Set to `cu_seqlens_kv` if None. Used by decoders' cross-attention. max_seqlen_q: Optional[int], default = `None` - Maximum sequence length in `query_layer`. - Calculated from `cu_seqlens_q` if not provided. + Maximum sequence length in `query_layer`. + Calculated from `cu_seqlens_q_padded` if not provided. max_seqlen_kv: Optional[int], default = `None` - Maximum sequence length in `key_layer` and `value_layer`. - Calculated from `cu_seqlens_kv` if not provided. + Maximum sequence length in `key_layer` and `value_layer`. + Calculated from `cu_seqlens_kv_padded` if not provided. fast_zero_fill: bool, default = `True` - Whether to set output tensors to 0 or not before use. + Whether to set output tensors to 0 or not before use. inference_params: InferenceParams, default = None - Inference parameters that are passed to the main model in order - to efficiently calculate and store the context during inference. + Inference parameters that are passed to the main model in order + to efficiently calculate and store the context during inference. pad_between_seqs: Optional[bool], default = `None` If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. - If true, there are padding tokens between individual sequences in a packed batch. + If true, there are padding tokens between individual sequences in a packed batch, + i.e. qkv_format = 'thd'. """ if self_attn_mask_type is None: @@ -678,7 +708,9 @@ def forward( if ( "padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary" ) and attention_mask is not None: - assert attention_mask.dtype == torch.bool, "Attention mask must be a boolean tensor" + assert all( + attention_mask[i].dtype == torch.bool for i in range(len(attention_mask)) + ), "Attention mask must be a boolean tensor or a list/tuple of two boolean tensors" if ( "padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary" ) and enc_dec_attn_mask is not None: @@ -707,9 +739,11 @@ def forward( core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_kv=cu_seqlens_q, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_q_padded, max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, + max_seqlen_kv=max_seqlen_q, fast_zero_fill=fast_zero_fill, pad_between_seqs=pad_between_seqs, ) @@ -733,12 +767,21 @@ def forward( attn_mask_type=enc_dec_attn_mask_type, window_size=enc_dec_window_size, encoder_output=encoder_output, + inference_params=inference_params, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, + rotary_pos_emb=rotary_pos_emb, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, fast_zero_fill=fast_zero_fill, + pad_between_seqs=pad_between_seqs, ) if self.apply_residual_connection_post_layernorm: attention_output, attention_bias, residual = inter_attention_outputs @@ -772,7 +815,12 @@ def forward( return output def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None): - if drop_path is None and bias is not None and bias.numel() != 0: + if ( + drop_path is None + and bias is not None + and bias.numel() != 0 + and not is_in_onnx_export_mode() + ): if self.bias_dropout_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index f015c8871..20e0b737d 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -100,6 +100,7 @@ def cross_entropy_kernel( ignore_idx, n_cols, n_non_ignore, + reduce_loss: tl.constexpr, label_smoothing: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): @@ -179,7 +180,13 @@ def cross_entropy_kernel( if label_smoothing > 0: # scale X beforehand to avoid overflow scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) - X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) + # Scale gradients based on reduction mode + # For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore + # For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here + if reduce_loss: + X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) + else: + X_block = tl.exp(X_block - m) / d - eps tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols) # We need tl.debug_barrier() to ensure the new result of X_ptr is written @@ -207,7 +214,11 @@ def cross_entropy_kernel( if y >= vocab_start_idx: if y < vocab_end_idx: X_y = tl.load(X_ptr + y - vocab_start_idx) - X_y += -(1 - label_smoothing) / (n_non_ignore) + # Apply the same conditional scaling logic for the target token + if reduce_loss: + X_y += -(1 - label_smoothing) / (n_non_ignore) + else: + X_y += -(1 - label_smoothing) tl.store(X_ptr + y - vocab_start_idx, X_y) tl.store(loss_ptr, loss) @@ -324,6 +335,7 @@ def cross_entropy_forward( ignore_idx=ignore_idx, n_cols=V, n_non_ignore=n_rows, + reduce_loss=reduce_loss, label_smoothing=label_smoothing, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS, diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index ebf8dd551..9ce01362f 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -10,6 +10,72 @@ import triton import triton.language as tl +from triton.language import core +from triton.language.standard import _log2 + + +# The following three argsort related kernels are adapted from +# the issue https://github.com/triton-lang/triton/issues/3698 + + +@triton.jit +def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr): + n_outer: tl.constexpr = x.numel >> n_dims + shape: tl.constexpr = [n_outer * (2**i), 2, 2 ** (n_dims - i - 1)] + y = tl.reshape(x, shape) + z = tl.reshape(indices, shape) + + mask = tl.arange(0, 2)[None, :, None] + + l_value = tl.reshape(tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape), x.shape).to( + x.dtype + ) + r_value = tl.reshape(tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape), x.shape).to( + x.dtype + ) + + l_indice = tl.reshape(tl.broadcast_to(tl.sum(z * (1 - mask), 1)[:, None, :], shape), x.shape) + r_indice = tl.reshape(tl.broadcast_to(tl.sum(z * mask, 1)[:, None, :], shape), x.shape) + + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + + il_value = l_value.to(idtype, bitcast=True) + ir_value = r_value.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + + flag1 = tl.where(((l_value > r_value) ^ flip) != 0, il_value ^ ir_value, tl.zeros_like(ix)) + ret = ix ^ flag1 + flag2 = tl.where(((l_value > r_value) ^ flip) != 0, l_indice ^ r_indice, tl.zeros_like(ix)) + ind = indices ^ flag2 + + return ret.to(x.dtype, bitcast=True), ind + + +@triton.jit +def _bitonic_merge(x, indices, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr): + n_outer: tl.constexpr = x.numel >> n_dims + tl.static_assert(stage <= n_dims) + """ + order_type 0 == ascending + order_type 1 == descending + order_type 2 == alternating + """ + if order == 2: + shape: tl.constexpr = [n_outer * (2 ** (n_dims - 1 - stage)), 2, 2**stage] + flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape) + else: + flip = tl.full(x.shape, value=order, dtype=tl.int32) + for i in tl.static_range(stage): + x, indices = _compare_and_swap(x, indices, flip, i + (n_dims - stage), n_dims) + return x, indices + + +@triton.jit +def _argsort(x, indices, n_dims: tl.constexpr): + for i in tl.static_range(1, n_dims + 1): + x, indices = _bitonic_merge(x, indices, i, 2 if i < n_dims else 1, n_dims) + return x, indices + @triton.jit def _row_id_map_pass_1_kernel( @@ -22,6 +88,8 @@ def _row_id_map_pass_1_kernel( # strides stride_routing_map_token, stride_routing_map_expert, + stride_row_id_map_token, + stride_row_id_map_expert, # metas BLOCK_SIZE: tl.constexpr, ): @@ -32,10 +100,10 @@ def _row_id_map_pass_1_kernel( routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token, mask=(offset < num_tokens), other=0, - ).to(tl.int64) + ).to(tl.int32) row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask tl.store( - row_id_map_ptr + pid_m * num_tokens + offset, + row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token, row_id_within_token_block, mask=offset < num_tokens, ) @@ -50,6 +118,9 @@ def _row_id_map_pass_2_kernel( workspace_ptr, # sizes num_tokens, + # strides + stride_row_id_map_token, + stride_row_id_map_expert, # metas WORKSPACE_LOAD_WIDTH: tl.constexpr, BLOCK_SIZE: tl.constexpr, @@ -59,7 +130,9 @@ def _row_id_map_pass_2_kernel( chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) row_id_within_token_block = tl.load( - row_id_map_ptr + pid_m * num_tokens + offset, mask=(offset < num_tokens), other=0 + row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token, + mask=(offset < num_tokens), + other=0, ) workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH) @@ -70,23 +143,102 @@ def _row_id_map_pass_2_kernel( row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1, ) tl.store( - row_id_map_ptr + pid_m * num_tokens + offset, + row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token, row_id, mask=(offset < num_tokens), ) +@triton.jit +def _row_id_map_pass_3_kernel( + # pointers + row_id_map_ptr, + # sizes + num_experts: tl.constexpr, + # strides + stride_row_id_map_token, + stride_row_id_map_expert, + # metas + LOAD_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + n_dims: tl.constexpr = _log2(LOAD_SIZE) + off = tl.arange(0, LOAD_SIZE) + row_id_map = tl.load( + row_id_map_ptr + pid * stride_row_id_map_token + stride_row_id_map_expert * off, + mask=off < num_experts, + other=-1, + ) + n_routed = tl.sum(tl.where(row_id_map != -1, 1, 0)) + indices = off + sorted_map, indices = _argsort(row_id_map, indices, n_dims=n_dims) + tl.store( + row_id_map_ptr + pid * stride_row_id_map_token + off * stride_row_id_map_expert, + sorted_map, + mask=off < n_routed, + ) + tl.store( + row_id_map_ptr + + pid * stride_row_id_map_token + + (num_experts + off) * stride_row_id_map_expert, + indices, + mask=off < n_routed, + ) + tl.store( + row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert, + n_routed, + ) + + def make_row_id_map( routing_map: torch.Tensor, num_tokens: int, num_experts: int, ): - # pylint: disable=missing-function-docstring - row_id_map = torch.empty((num_experts, num_tokens), dtype=torch.int64, device="cuda") - block_size = 256 + """ + Prepare the row_id_map for the permutation. + + Parameters + ---------- + routing_map: torch.Tensor + Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates + which experts are routed to which tokens. The values in it: 1 means the token is routed to + this expert and 0 means not. + num_tokens: int + Number of tokens in the input tensor. + num_experts: int + Number of experts in the input tensor. + + Returns + ------- + row_id_map: torch.Tensor + The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`. + For each token, the last item is the number of experts that are routed (n_routed). + The first n_routed items are the destination row indices in the permuted tokens. + The [num_experts, num_experts + n_routed) items are the indices of the experts corresponding + to the first n_routed row indices above. + """ + row_id_map = torch.empty((num_tokens, num_experts * 2 + 1), dtype=torch.int32, device="cuda") + block_size = 1024 grid = (num_experts, triton.cdiv(num_tokens, block_size)) - workspace_tensor = torch.empty(grid, dtype=torch.int64, device="cuda") - # block cumsum + workspace_tensor = torch.empty(grid, dtype=torch.int32, device="cuda") + + # supposing num_tokens == 5, num_experts == 3, block_size == 3 + # and we have a routing_map like this: + # [[1, 1, 0], + # [1, 0, 1], + # [0, 0, 1], + # [1, 1, 0], + # [0, 0, 0]] + + # pass 1: block cumsum + # for each expert, compute the cumsum of every block_size tokens + # the row_id_map will be like this after pass 1 (r means useless values): + # [[1, 1, 0, r, r, r, r], + # [2, 0, 1, r, r, r, r], + # [0, 0, 2, r, r, r, r], + # [1, 1, 0, r, r, r, r], + # [0, 0, 0, r, r, r, r]] _row_id_map_pass_1_kernel[grid]( routing_map, row_id_map, @@ -94,16 +246,44 @@ def make_row_id_map( num_tokens, routing_map.stride(0), routing_map.stride(1), + row_id_map.stride(0), + row_id_map.stride(1), block_size, ) - # cumsum all and process the mask + + # pass 2: cumsum all and process the mask + # process the block cumsum into the global cumsum and then into the dst row indices + # the row_id_map will be like this after pass 2 (r means useless value): + # [[ 0, 3, -1, r, r, r, r], + # [ 1, -1, 5, r, r, r, r], + # [-1, -1, 6, r, r, r, r], + # [ 2, 4, -1, r, r, r, r], + # [-1, -1, -1, r, r, r, r]] _row_id_map_pass_2_kernel[grid]( row_id_map, workspace_tensor, num_tokens, + row_id_map.stride(0), + row_id_map.stride(1), triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)), block_size, ) + + # pass 3: make the row_id_map from the sparse structure to the dense structure + # the row_id_map will be like this after pass 3 (r means useless value): + # [[3, 0, r, 1, 0, r, 2], + # [5, 1, r, 2, 0, r, 2], + # [6, r, r, 2, r, r, 1], + # [4, 2, r, 1, 0, r, 2], + # [r, r, r, r, r, r, 0]] + grid = (num_tokens,) + _row_id_map_pass_3_kernel[grid]( + row_id_map, + num_experts, + row_id_map.stride(0), + row_id_map.stride(1), + triton.next_power_of_2(num_experts), + ) return row_id_map @@ -118,11 +298,12 @@ def _permute_kernel( permuted_probs_ptr, permuted_scale_ptr, # sizes - num_tokens, - num_experts, - hidden_size, + num_experts: tl.constexpr, + hidden_size: tl.constexpr, scale_hidden_dim, # strides + stride_row_id_map_token, + stride_row_id_map_expert, stride_input_token, stride_input_hidden, stride_output_token, @@ -139,35 +320,50 @@ def _permute_kernel( PERMUTE_SCALE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - pid = tl.program_id(0) - cur_pos = 0 - while cur_pos < hidden_size: - cur_off = cur_pos + tl.arange(0, BLOCK_SIZE) - mask = cur_off < hidden_size - input_off = pid * stride_input_token + cur_off * stride_input_hidden - inp = tl.load(input_ptr + input_off, mask=mask) + pid_t = tl.program_id(0) + pid_h = tl.program_id(1) + cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = cur_off < hidden_size + input_off = pid_t * stride_input_token + cur_off * stride_input_hidden + inp = tl.load(input_ptr + input_off, mask=mask) + if PERMUTE_SCALE: + mask_scale = cur_off < scale_hidden_dim + scale_off = pid_t * stride_scale_token + cur_off * stride_scale_hidden + scale = tl.load(scale_ptr + scale_off, mask=mask_scale) + n_routed = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + num_experts * 2 * stride_row_id_map_expert + ) + for idx in tl.range(n_routed): + dst_row = tl.load( + row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert + ) + output_off = dst_row * stride_output_token + cur_off * stride_output_hidden if PERMUTE_SCALE: - mask_scale = cur_off < scale_hidden_dim - scale_off = pid * stride_scale_token + cur_off * stride_scale_hidden - scale = tl.load(scale_ptr + scale_off, mask=mask_scale) - for expert_idx in range(num_experts): - dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) - if dst_row != -1: - output_off = dst_row * stride_output_token + cur_off * stride_output_hidden + permuted_scale_off = ( + dst_row * stride_permuted_scale_token + cur_off * stride_permuted_scale_hidden + ) + tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale) + if PERMUTE_PROBS: + expert_idx = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert + prob = tl.load(probs_ptr + prob_off) + if pid_h == 0: + permuted_prob_off = dst_row * stride_permuted_probs_token + tl.store(permuted_probs_ptr + permuted_prob_off, prob) + if prob == 0.0: + # for routing_map padding + # dst_row != -1 and prob == 0.0 means that this slot is padded + tl.store(output_ptr + output_off, 0, mask=mask) + else: tl.store(output_ptr + output_off, inp, mask=mask) - if PERMUTE_SCALE: - permuted_scale_off = ( - dst_row * stride_permuted_scale_token - + cur_off * stride_permuted_scale_hidden - ) - tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale) - if PERMUTE_PROBS: - if cur_pos == 0: - prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert - prob = tl.load(probs_ptr + prob_off) - permuted_prob_off = dst_row * stride_permuted_probs_token - tl.store(permuted_probs_ptr + permuted_prob_off, prob) - cur_pos += BLOCK_SIZE + else: + tl.store(output_ptr + output_off, inp, mask=mask) try: @@ -178,6 +374,8 @@ def _permute_kernel( triton.Config({"BLOCK_SIZE": 256}), triton.Config({"BLOCK_SIZE": 512}), triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + triton.Config({"BLOCK_SIZE": 4096}), ], key=["hidden_size"], )(_permute_kernel) @@ -196,7 +394,30 @@ def permute_with_mask_map( hidden_size: int, scale_hidden_dim: int, ): - # pylint: disable=missing-function-docstring + """ + Permute the input tensor based on the row_id_map. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + row_id_map: torch.Tensor + The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`. + probs: torch.Tensor + The probabilities of the input tensor. If it is not None, it will be permuted. + scale: torch.Tensor + The scale of the input tensor. If it is not None, it will be permuted. + num_tokens: int + Number of tokens in the input tensor. + num_experts: int + Number of experts in the input tensor. + num_out_tokens: int + Number of tokens in the permuted tensor. + hidden_size: int + Hidden size of the input tensor. + scale_hidden_dim: int + Hidden size of the scale tensor. + """ output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") if probs is not None: permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda") @@ -209,8 +430,8 @@ def permute_with_mask_map( ) else: permuted_scale = None - - grid = (num_tokens,) + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"])) _permute_kernel[grid]( inp, output, @@ -219,10 +440,11 @@ def permute_with_mask_map( scale, permuted_probs, permuted_scale, - num_tokens, num_experts, hidden_size, scale_hidden_dim, + row_id_map.stride(0), + row_id_map.stride(1), inp.stride(0), inp.stride(1), output.stride(0), @@ -250,10 +472,11 @@ def _unpermute_kernel( permuted_probs_ptr, unpermuted_probs_ptr, # sizes - num_tokens, - num_experts, - hidden_size, + num_experts: tl.constexpr, + hidden_size: tl.constexpr, # strides + stride_row_id_map_token, + stride_row_id_map_expert, stride_input_token, stride_input_hidden, stride_output_token, @@ -264,6 +487,7 @@ def _unpermute_kernel( stride_unpermuted_probs_token, stride_unpermuted_probs_expert, # metas + PROBS_LOAD_WIDTH: tl.constexpr, WITH_MERGING_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr, BLOCK_SIZE: tl.constexpr, @@ -271,41 +495,63 @@ def _unpermute_kernel( data_type = input_ptr.dtype.element_ty compute_type = tl.float32 - pid = tl.program_id(0) - current_start = 0 - while current_start < hidden_size: - current_offset = current_start + tl.arange(0, BLOCK_SIZE) - mask = current_offset < hidden_size - accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type) - for expert_idx in range(num_experts): - src_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) - if src_row != -1: - input_off = src_row * stride_input_token + current_offset * stride_input_hidden - inp = tl.load(input_ptr + input_off, mask=mask) - inp = inp.to(compute_type) - if WITH_MERGING_PROBS: - merging_prob_off = ( - pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert - ) - merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) - inp *= merging_prob - accumulator += inp - if PERMUTE_PROBS: - if current_start == 0: - unpermuted_prob_off = ( - pid * stride_unpermuted_probs_token - + expert_idx * stride_unpermuted_probs_expert - ) - if src_row != -1: - permuted_prob_off = src_row * stride_permuted_probs_token - prob = tl.load(permuted_probs_ptr + permuted_prob_off) - tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) - else: - tl.store(unpermuted_probs_ptr + unpermuted_prob_off, 0.0) - accumulator = accumulator.to(data_type) - output_off = pid * stride_output_token + current_offset * stride_output_hidden - tl.store(output_ptr + output_off, accumulator, mask=mask) - current_start += BLOCK_SIZE + pid_t = tl.program_id(0) + pid_h = tl.program_id(1) + current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + if PERMUTE_PROBS: + # write 0.0 to probs_grad that are not routed + if pid_h == 0: + map_load_off = tl.arange(0, PROBS_LOAD_WIDTH) + unpermuted_prob_off = ( + pid_t * stride_unpermuted_probs_token + + stride_unpermuted_probs_expert * map_load_off + ) + tl.store( + unpermuted_probs_ptr + unpermuted_prob_off, 0.0, mask=map_load_off < num_experts + ) + accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type) + n_routed = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + num_experts * 2 * stride_row_id_map_expert + ) + for idx in tl.range(n_routed): + src_row = tl.load( + row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert + ) + input_off = src_row * stride_input_token + current_offset * stride_input_hidden + inp = tl.load(input_ptr + input_off, mask=mask) + inp = inp.to(compute_type) + if WITH_MERGING_PROBS: + expert_idx = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + merging_prob_off = ( + pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert + ) + merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) + inp *= merging_prob + accumulator += inp + if PERMUTE_PROBS: + if pid_h == 0: + expert_idx = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + unpermuted_prob_off = ( + pid_t * stride_unpermuted_probs_token + + expert_idx * stride_unpermuted_probs_expert + ) + permuted_prob_off = src_row * stride_permuted_probs_token + prob = tl.load(permuted_probs_ptr + permuted_prob_off) + tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) + accumulator = accumulator.to(data_type) + output_off = pid_t * stride_output_token + current_offset * stride_output_hidden + tl.store(output_ptr + output_off, accumulator, mask=mask) try: @@ -316,6 +562,8 @@ def _unpermute_kernel( triton.Config({"BLOCK_SIZE": 256}), triton.Config({"BLOCK_SIZE": 512}), triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + triton.Config({"BLOCK_SIZE": 4096}), ], key=["hidden_size"], )(_unpermute_kernel) @@ -332,7 +580,27 @@ def unpermute_with_mask_map( num_experts: int, hidden_size: int, ): - # pylint: disable=missing-function-docstring + """ + Unpermute the input tensor based on the row_id_map. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_out_tokens, hidden_size]`. + row_id_map: torch.Tensor + The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`. + merging_probs: torch.Tensor + The merging probabilities of the input tensor. If it is not None, it will be used as weights + to reduce the unpermuted tokens. + permuted_probs: torch.Tensor + The permuted probabilities of the input tensor. If it is not None, it will be unpermuted. + num_tokens: int + Number of tokens in the permuted tensor. + num_experts: int + Number of experts in the permuted tensor. + hidden_size: int + Hidden size of the permuted tensor. + """ output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") if permuted_probs is not None: unpermuted_probs = torch.empty( @@ -340,7 +608,8 @@ def unpermute_with_mask_map( ) else: unpermuted_probs = None - grid = (num_tokens,) + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"])) _unpermute_kernel[grid]( inp, output, @@ -348,9 +617,10 @@ def unpermute_with_mask_map( merging_probs, permuted_probs, unpermuted_probs, - num_tokens, num_experts, hidden_size, + row_id_map.stride(0), + row_id_map.stride(1), inp.stride(0), inp.stride(1), output.stride(0), @@ -360,6 +630,7 @@ def unpermute_with_mask_map( permuted_probs.stride(0) if permuted_probs is not None else None, unpermuted_probs.stride(0) if unpermuted_probs is not None else None, unpermuted_probs.stride(1) if unpermuted_probs is not None else None, + PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), WITH_MERGING_PROBS=merging_probs is not None, PERMUTE_PROBS=permuted_probs is not None, ) @@ -376,10 +647,11 @@ def _unpermute_bwd_with_merging_probs_kernel( merging_probs_grad_ptr, row_id_map_ptr, # sizes - num_tokens, - num_experts, - hidden_size, + num_experts: tl.constexpr, + hidden_size: tl.constexpr, # strides + stride_row_id_map_token, + stride_row_id_map_expert, stride_fwd_output_grad_token, stride_fwd_output_grad_hidden, stride_fwd_input_grad_token, @@ -391,56 +663,63 @@ def _unpermute_bwd_with_merging_probs_kernel( stride_merging_probs_grad_token, stride_merging_probs_grad_expert, # metas + PROBS_LOAD_WIDTH: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): data_type = fwd_output_grad_ptr.dtype.element_ty compute_type = tl.float32 pid = tl.program_id(0) - for expert_idx in range(num_experts): - dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) - if dst_row != -1: - prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type) - current_start = 0 - while current_start < hidden_size: - current_offset = current_start + tl.arange(0, BLOCK_SIZE) - mask = current_offset < hidden_size - input_off = ( - pid * stride_fwd_output_grad_token - + current_offset * stride_fwd_output_grad_hidden - ) - inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) - inp = inp.to(compute_type) - merging_prob_off = ( - pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert - ) - merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) - output = inp * merging_prob - output = output.to(data_type) - output_off = ( - dst_row * stride_fwd_input_grad_token - + current_offset * stride_fwd_input_grad_hidden - ) - tl.store(fwd_input_grad_ptr + output_off, output, mask=mask) - - fwd_input_off = ( - dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden - ) - fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask) - prob_grad_accum += fwd_input.to(compute_type) * inp - current_start += BLOCK_SIZE - probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty) - probs_grad_off = ( - pid * stride_merging_probs_grad_token - + expert_idx * stride_merging_probs_grad_expert + map_load_off = tl.arange(0, PROBS_LOAD_WIDTH) + token_probs_grad_off = ( + pid * stride_merging_probs_grad_token + stride_merging_probs_grad_expert * map_load_off + ) + tl.store(merging_probs_grad_ptr + token_probs_grad_off, 0.0, mask=map_load_off < num_experts) + n_routed = tl.load( + row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert + ) + for idx in tl.range(n_routed): + dst_row = tl.load( + row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert + ) + expert_idx = tl.load( + row_id_map_ptr + + pid * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type) + current_start = 0 + while current_start < hidden_size: + current_offset = current_start + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + input_off = ( + pid * stride_fwd_output_grad_token + current_offset * stride_fwd_output_grad_hidden ) - tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad) - else: - probs_grad_off = ( - pid * stride_merging_probs_grad_token - + expert_idx * stride_merging_probs_grad_expert + inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) + inp = inp.to(compute_type) + merging_prob_off = ( + pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert + ) + merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) + output = inp * merging_prob + output = output.to(data_type) + output_off = ( + dst_row * stride_fwd_input_grad_token + + current_offset * stride_fwd_input_grad_hidden ) - tl.store(merging_probs_grad_ptr + probs_grad_off, 0.0) + tl.store(fwd_input_grad_ptr + output_off, output, mask=mask) + + fwd_input_off = ( + dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden + ) + fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask) + prob_grad_accum += fwd_input.to(compute_type) * inp + current_start += BLOCK_SIZE + probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty) + probs_grad_off = ( + pid * stride_merging_probs_grad_token + expert_idx * stride_merging_probs_grad_expert + ) + tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad) try: @@ -451,6 +730,8 @@ def _unpermute_bwd_with_merging_probs_kernel( triton.Config({"BLOCK_SIZE": 256}), triton.Config({"BLOCK_SIZE": 512}), triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + triton.Config({"BLOCK_SIZE": 4096}), ], key=["hidden_size"], )(_unpermute_bwd_with_merging_probs_kernel) @@ -468,7 +749,28 @@ def unpermute_with_mask_map_bwd_with_merging_probs( num_out_tokens: int, hidden_size: int, ): - # pylint: disable=missing-function-docstring + """ + Unpermute backward pass kernel with merging probs. + + Parameters + ---------- + fwd_output_grad: torch.Tensor + The gradient of the output tensor of shape `[num_tokens, hidden_size]`. + row_id_map: torch.Tensor + The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`. + fwd_input: torch.Tensor + The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`. + merging_probs: torch.Tensor + The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`. + num_tokens: int + Number of tokens in the permuted tensor. + num_experts: int + Number of experts in the permuted tensor. + num_out_tokens: int + Number of tokens in the output tensor. + hidden_size: int + Hidden size of the output tensor. + """ act_grad = torch.empty( (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" ) @@ -483,9 +785,10 @@ def unpermute_with_mask_map_bwd_with_merging_probs( merging_probs, merging_probs_grad, row_id_map, - num_tokens, num_experts, hidden_size, + row_id_map.stride(0), + row_id_map.stride(1), fwd_output_grad.stride(0), fwd_output_grad.stride(1), act_grad.stride(0), @@ -496,34 +799,21 @@ def unpermute_with_mask_map_bwd_with_merging_probs( merging_probs.stride(1), merging_probs_grad.stride(0), merging_probs_grad.stride(1), + PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), ) return act_grad, merging_probs_grad @triton.jit -def _sort_chunks_by_idxs_kernel( +def _make_chunk_sort_map_kernel( # pointers - input_ptr, split_sizes_ptr, sorted_indices_ptr, - output_ptr, dst_rows_ptr, - probs_ptr, - permuted_probs_ptr, # sizes - num_splits, - hidden_size, - # strides - stride_input_token, - stride_input_hidden, - stride_output_token, - stride_output_hidden, - stride_probs_token, - stride_permuted_probs_token, + num_splits: tl.constexpr, # metas - PERMUTE_PROBS: tl.constexpr, IDX_LOAD_WIDTH: tl.constexpr, - BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) @@ -533,104 +823,58 @@ def _sort_chunks_by_idxs_kernel( ) # get chunk idx of the current token in the input tensor - input_chunk_idx = -1 - in_chunk_offset = tl.zeros([], dtype=tl.int64) - acc_chunk_sizes = tl.zeros([], dtype=tl.int64) - cursor = 0 - while cursor < num_splits: - cur_chunk_size = tl.load(split_sizes_ptr + cursor).to(tl.int64) - acc_chunk_sizes += cur_chunk_size - if input_chunk_idx == -1 and acc_chunk_sizes > pid: - input_chunk_idx = cursor - in_chunk_offset = pid - (acc_chunk_sizes - cur_chunk_size) - cursor += 1 + input_split_sizes = tl.load( + split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0 + ).to(tl.int32) + input_split_sizes_cumsum = tl.cumsum(input_split_sizes) + input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0) + input_chunk_idx = tl.sum(input_split_sizes_mask) + input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask) + in_chunk_offset = pid - input_split_sizes_presum # get chunk idx of the current token in the output tensor - output_chunk_idx = 0 - cursor = 0 - while cursor < num_splits: - cur_input_idx = tl.load(sorted_indices_ptr + cursor) - if cur_input_idx == input_chunk_idx: - output_chunk_idx = cursor - cursor += 1 + output_chunk_mask = tl.where(sorted_indices == input_chunk_idx, 1, 0) + output_chunk_idx = tl.argmax(output_chunk_mask, axis=-1) # make row_id_map output_split_sizes = tl.load( split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits - ).to(tl.int64) + ).to(tl.int32) output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0) dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset tl.store(dst_rows_ptr + pid, dst_row) - current_start = 0 - while current_start < hidden_size: - current_offset = current_start + tl.arange(0, BLOCK_SIZE) - mask = current_offset < hidden_size - input_offsets = pid * stride_input_token + current_offset * stride_input_hidden - output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden - inp = tl.load(input_ptr + input_offsets, mask=mask) - tl.store(output_ptr + output_offsets, inp, mask=mask) - current_start += BLOCK_SIZE - if PERMUTE_PROBS: - prob_off = pid * stride_probs_token - prob = tl.load(probs_ptr + prob_off) - permuted_prob_off = dst_row * stride_permuted_probs_token - tl.store(permuted_probs_ptr + permuted_prob_off, prob) - - -try: - _sort_chunks_by_idxs_kernel = triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - ], - key=["hidden_size"], - )(_sort_chunks_by_idxs_kernel) -except RuntimeError: - pass - - -def sort_chunks_by_idx( - inp: torch.Tensor, +def make_chunk_sort_map( split_sizes: torch.Tensor, sorted_indices: torch.Tensor, - probs: torch.Tensor, num_tokens: int, - hidden_size: int, num_splits: int, ): - # pylint: disable=missing-function-docstring - row_id_map = torch.empty((num_tokens,), dtype=torch.int64, device="cuda") - output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") - if probs is not None: - permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda") - else: - permuted_probs = None + """ + Make a row_id_map for chunk sort. + + Parameters + ---------- + split_sizes: torch.Tensor + The sizes of the chunks of shape `[num_splits,]`. + sorted_indices: torch.Tensor + The indices of the sorted chunks of shape `[num_splits,]`. + num_tokens: int + Number of tokens in the input tensor. + num_splits: int + Number of splits of split_sizes and sorted_indices. + """ + row_id_map = torch.empty((num_tokens,), dtype=torch.int32, device="cuda") grid = (num_tokens,) - _sort_chunks_by_idxs_kernel[grid]( - inp, + _make_chunk_sort_map_kernel[grid]( split_sizes, sorted_indices, - output, row_id_map, - probs, - permuted_probs, num_splits, - hidden_size, - inp.stride(0), - inp.stride(1), - output.stride(0), - output.stride(1), - probs.stride(0) if probs is not None else None, - permuted_probs.stride(0) if permuted_probs is not None else None, - PERMUTE_PROBS=probs is not None, IDX_LOAD_WIDTH=triton.next_power_of_2(num_splits), ) - return output, row_id_map, permuted_probs + return row_id_map @triton.jit @@ -642,7 +886,7 @@ def _sort_chunks_by_map_kernel( probs_ptr, permuted_probs_ptr, # sizes - hidden_size, + hidden_size: tl.constexpr, # strides stride_input_token, stride_input_hidden, @@ -653,23 +897,28 @@ def _sort_chunks_by_map_kernel( # metas PERMUTE_PROBS: tl.constexpr, BLOCK_SIZE: tl.constexpr, + FORWARD: tl.constexpr, ): - pid = tl.program_id(0) - dst_row = tl.load(row_id_map_ptr + pid) - current_start = 0 - while current_start < hidden_size: - current_offset = current_start + tl.arange(0, BLOCK_SIZE) - mask = current_offset < hidden_size - input_offsets = dst_row * stride_input_token + current_offset * stride_input_hidden - output_offsets = pid * stride_output_token + current_offset * stride_output_hidden - inp = tl.load(input_ptr + input_offsets, mask=mask) - tl.store(output_ptr + output_offsets, inp, mask=mask) - current_start += BLOCK_SIZE + pid_t = tl.program_id(0) + pid_h = tl.program_id(1) + if FORWARD: + src_row = pid_t + dst_row = tl.load(row_id_map_ptr + pid_t) + else: + src_row = tl.load(row_id_map_ptr + pid_t) + dst_row = pid_t + current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + input_offsets = src_row * stride_input_token + current_offset * stride_input_hidden + output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden + inp = tl.load(input_ptr + input_offsets, mask=mask) + tl.store(output_ptr + output_offsets, inp, mask=mask) if PERMUTE_PROBS: - prob_off = dst_row * stride_probs_token - prob = tl.load(probs_ptr + prob_off) - permuted_prob_off = pid * stride_permuted_probs_token - tl.store(permuted_probs_ptr + permuted_prob_off, prob) + if pid_h == 0: + prob_off = src_row * stride_probs_token + prob = tl.load(probs_ptr + prob_off) + permuted_prob_off = dst_row * stride_permuted_probs_token + tl.store(permuted_probs_ptr + permuted_prob_off, prob) try: @@ -680,6 +929,8 @@ def _sort_chunks_by_map_kernel( triton.Config({"BLOCK_SIZE": 256}), triton.Config({"BLOCK_SIZE": 512}), triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + triton.Config({"BLOCK_SIZE": 4096}), ], key=["hidden_size"], )(_sort_chunks_by_map_kernel) @@ -693,14 +944,33 @@ def sort_chunks_by_map( probs: torch.Tensor, num_tokens: int, hidden_size: int, + is_forward: bool, ): - # pylint: disable=missing-function-docstring + """ + Sort chunks with row_id_map. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`. + row_id_map: torch.Tensor + The token to expert mapping tensor of shape `[num_tokens,]`. + probs: torch.Tensor + The probabilities of the input tensor. If it is not None, it will be permuted. + num_tokens: int + Number of tokens in the input tensor. + hidden_size: int + Hidden size of the input tensor. + is_forward: bool + Whether the sort is for forward or backward. + """ output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") if probs is not None: permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda") else: permuted_probs = None - grid = (num_tokens,) + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"])) _sort_chunks_by_map_kernel[grid]( inp, output, @@ -715,5 +985,6 @@ def sort_chunks_by_map( probs.stride(0) if probs is not None else None, permuted_probs.stride(0) if permuted_probs is not None else None, PERMUTE_PROBS=probs is not None, + FORWARD=is_forward, ) return output, permuted_probs diff --git a/transformer_engine/pytorch/triton_kernels/cast.py b/transformer_engine/pytorch/triton_kernels/cast.py index 4c7033132..3ba81118f 100644 --- a/transformer_engine/pytorch/triton_kernels/cast.py +++ b/transformer_engine/pytorch/triton_kernels/cast.py @@ -56,6 +56,9 @@ def te_quantize_triton( Quantizes the input tensor using a specified quantizer, with an option to utilize Triton-based `cast_transpose` for performance. """ + from ..tensor.float8_tensor import Float8CurrentScalingQuantizer + if isinstance(quantizer, Float8CurrentScalingQuantizer): + return tex.quantize(tensor, quantizer, output, noop_flag) input_tensor = tensor.contiguous() fake_tensor_type = input_tensor.dtype if not fake_tensor_type.is_floating_point: diff --git a/transformer_engine/pytorch/triton_kernels/layernorm.py b/transformer_engine/pytorch/triton_kernels/layernorm.py index 265093b73..3baa64697 100644 --- a/transformer_engine/pytorch/triton_kernels/layernorm.py +++ b/transformer_engine/pytorch/triton_kernels/layernorm.py @@ -7,7 +7,7 @@ import torch -from ..tensor.float8_tensor import Float8Quantizer +from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..constants import TE_DType from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.quantized_tensor import Quantizer @@ -489,8 +489,9 @@ def te_layernorm_fwd_triton(input: torch.Tensor, # To update the amax ptr directly with atomic max APPLY_ATOMIC = M < 512 - # MXFP8 is handled regularly, hence quantizer of Float8Quantizer is considered FP8 + # MXFP8/fp8_current_scaling requires unfused quantization. IS_FP8 = isinstance(quantizer, Float8Quantizer) + IS_FP8_CURRENT_SCALING = isinstance(quantizer, Float8CurrentScalingQuantizer) amax_temp = torch.empty((M,), dtype=torch.float32, device=device) if IS_FP8 else None @@ -551,7 +552,7 @@ def te_layernorm_fwd_triton(input: torch.Tensor, ) # For MXFP8, we do regular layernorm and then quantize it separately - if IS_MXFP8: + if IS_MXFP8 or IS_FP8_CURRENT_SCALING: ln_out = te_quantize_triton(ln_out, quantizer) # Reduce and find amax if "not APPLY_ATOMIC" is True. diff --git a/transformer_engine/pytorch/triton_kernels/norm_common.py b/transformer_engine/pytorch/triton_kernels/norm_common.py index 7c553592f..87cc4714d 100644 --- a/transformer_engine/pytorch/triton_kernels/norm_common.py +++ b/transformer_engine/pytorch/triton_kernels/norm_common.py @@ -4,7 +4,7 @@ import os import torch import triton -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8CurrentScalingQuantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .common import te_dtype_to_torch_dtype @@ -59,19 +59,24 @@ def make_ln_out(ln_out, quantizer=None, input_shape=None, out_dtype=torch.float3 if ln_out is None: # TODO(micky774): Remove MXFP8Quantizer check when kernels - # properly support MXFP8 as a fused operation - if quantizer is None or isinstance(quantizer, MXFP8Quantizer): + # properly support MXFP8/float8_current_scaling as a fused operation + if quantizer is None or isinstance(quantizer, MXFP8Quantizer) or isinstance(quantizer, Float8CurrentScalingQuantizer): return torch.empty(input_shape, dtype=out_dtype, device='cuda') return quantizer.make_empty(input_shape, dtype=out_dtype) + # TODO: revisit the logic here, whether we should create dequantized/higher precision based on quantizer or quantized tensor type + # TODO(micky774): Remove when kernels properly support MXFP8 as a fused operation if isinstance(ln_out, MXFP8Tensor): return ln_out.dequantize(dtype=out_dtype).to("cuda") - # TODO(micky774): Remove when kernels properly support MXFP8 as a fused operation if isinstance(quantizer, MXFP8Quantizer): return torch.empty(input_shape, dtype=out_dtype, device='cuda') + # TODO: remove when triton kernels support fp8 current scaling + if isinstance(quantizer, Float8CurrentScalingQuantizer): + return torch.empty(input_shape, dtype=out_dtype, device='cuda') + if isinstance(ln_out, Float8Tensor): if ln_out.dtype == out_dtype: return ln_out diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index b62d61ced..c48a2a9b2 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -6,7 +6,7 @@ import triton.language as tl from itertools import product from .norm_common import num_programs, block_size, use_blocked, make_ln_out -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.triton_kernels.common import ( te_dtype_to_torch_dtype, @@ -389,6 +389,7 @@ def te_rmsnorm_fwd_triton( f"but {weight.shape[0]=} while {input.shape[1]=}" ) IS_FP8 = isinstance(quantizer, Float8Quantizer) + IS_FP8_CURRENT_SCALING = isinstance(quantizer, Float8CurrentScalingQuantizer) IS_MXFP8 = isinstance(quantizer, MXFP8Quantizer) BLOCK_SIZE = block_size(input) USE_BLOCKED = use_blocked(input) @@ -460,7 +461,7 @@ def te_rmsnorm_fwd_triton( FP8_MAX, MAKE_TRANSPOSE, ) - if IS_MXFP8: + if IS_MXFP8 or IS_FP8_CURRENT_SCALING: out = quantizer.quantize(out, out=ln_out) return out, None, rsigma diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 5cc6a11e2..9d0d71fdc 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -9,7 +9,7 @@ import functools import math import os -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union import numpy as np import torch from torch.utils.cpp_extension import IS_HIP_EXTENSION @@ -40,8 +40,16 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: Must be used carefully. """ + for t in tensors: if t is not None: + # Workaround for double buffering in cpu offload + if hasattr(t, "do_not_clear"): + continue + if hasattr(t, "get_data_tensors"): + if any(hasattr(tensor, "do_not_clear") for tensor in t.get_data_tensors()): + continue + if hasattr(t, "clear"): t.clear() else: @@ -469,6 +477,7 @@ def is_bf16_compatible() -> None: return torch.cuda.get_device_capability()[0] >= 8 +@functools.lru_cache(maxsize=None) def is_non_tn_fp8_gemm_supported() -> bool: """Checks whether the device supports non-TN layouts for FP8 GEMMs. @@ -648,3 +657,111 @@ def torch_get_autocast_gpu_dtype() -> torch.dtype: gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type="cuda") else: gpu_autocast_ctx = torch.cuda.amp.autocast + + +_torch_dtype_to_np_typestr_dict = { + torch.float16: " 0 else 0, False), + "version": 3, + } + + def torch_dtype_to_np_typestr(self): + """Convert PyTorch dtype to numpy typestr.""" + ret = _torch_dtype_to_np_typestr_dict.get(self.dtype) + assert ret is not None, f"Unsupported dtype: {self.dtype}" + return ret + + +def make_weak_ref(x): + """ + This function is to make a weak reference to the input so that the memory can be released. + """ + + def convert_to_torch_tensor(tensor: Union[_WeakRefTensor, torch.Tensor]) -> torch.Tensor: + """ + This function is to convert the `_WeakRefTensor` to torch.Tensor. + """ + if isinstance(tensor, torch.Tensor): + return tensor + + old_ptr = tensor.data_ptr() + new_tensor = torch.as_tensor(tensor).view(tensor.dtype) + new_ptr = new_tensor.data_ptr() + if old_ptr != new_ptr: + raise RuntimeError("Data pointer mismatch after converting to torch.Tensor") + return new_tensor + + if isinstance(x, torch.Tensor): + return ( + convert_to_torch_tensor(_WeakRefTensor(x.data_ptr(), x.dtype, x.shape)) + if x.is_cuda + else x + ) + if isinstance(x, tuple): + return tuple(make_weak_ref(i) for i in x) + if isinstance(x, list): + return [make_weak_ref(i) for i in x] + if isinstance(x, dict): + return {k: make_weak_ref(v) for k, v in x.items()} + if isinstance(x, (int, float, bool)): + return x + if x is None: + return None + raise TypeError(f"Invalid type {type(x)} to make weak ref")