diff --git a/.gitignore b/.gitignore index 44de0a19e..874eed018 100644 --- a/.gitignore +++ b/.gitignore @@ -52,7 +52,5 @@ compile_commands.json **/profiler_outputs/ **/times.csv tensor_dumps/ -aiter/ transformer_engine/build_info.txt transformer_engine/common/util/hip_nvml.* -transformer_engine/aiter/ diff --git a/3rdparty/aiter b/3rdparty/aiter index a2ca1b460..1b00a0e8a 160000 --- a/3rdparty/aiter +++ b/3rdparty/aiter @@ -1 +1 @@ -Subproject commit a2ca1b460f097a309ee5a128c7454b1c419dc331 +Subproject commit 1b00a0e8a54be0411490a69a5d7042abd33a56d9 diff --git a/README.rst b/README.rst index 49e19bd7e..09f204f68 100644 --- a/README.rst +++ b/README.rst @@ -450,7 +450,7 @@ Installation ============ System Requirements -^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^ * **Hardware:** Blackwell, Hopper, Grace Hopper/Blackwell, Ada, Ampere @@ -468,10 +468,10 @@ System Requirements * **Notes:** FP8 features require Compute Capability 8.9+ (Ada/Hopper/Blackwell) Installation Methods -^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^ Docker (Recommended) -^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^ The quickest way to get started with Transformer Engine is by using Docker images on `NVIDIA GPU Cloud (NGC) Catalog `_. @@ -496,7 +496,7 @@ Where 25.04 (corresponding to April 2025 release) is the container version. * NGC PyTorch 23.08+ containers include FlashAttention-2 pip Installation -^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^ **Prerequisites for pip installation:** @@ -534,7 +534,7 @@ Source Installation `See the installation guide `_ Environment Variables -^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^ These environment variables can be set before installation to customize the build process: * **CUDA_PATH**: Path to CUDA installation @@ -545,7 +545,7 @@ These environment variables can be set before installation to customize the buil * **NVTE_BUILD_THREADS_PER_JOB**: Control threads per build job Compiling with FlashAttention -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Transformer Engine supports both FlashAttention-2 and FlashAttention-3 in PyTorch for improved performance. FlashAttention-3 was added in release v1.11 and is prioritized over FlashAttention-2 when both are present in the environment. You can verify which FlashAttention version is being used by setting these environment variables: @@ -557,8 +557,9 @@ You can verify which FlashAttention version is being used by setting these envir It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug `_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue. .. troubleshooting-begin-marker-do-not-remove + Troubleshooting -^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^ **Common Issues and Solutions:** @@ -692,7 +693,7 @@ Papers Videos ====== -* `Stable and Scalable FP8 Deep Learning Training on Blackwell | GTC 2025 `_ +* `Stable and Scalable FP8 Deep Learning Training on Blackwell | GTC 2025 `__ * `Blackwell Numerics for AI | GTC 2025 `_ * `Building LLMs: Accelerating Pretraining of Foundational Models With FP8 Precision | GTC 2025 `_ * `From FP8 LLM Training to Inference: Language AI at Scale | GTC 2025 `_ diff --git a/build_tools/jax.py b/build_tools/jax.py index ae8e696c8..4e587b965 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -21,7 +21,12 @@ def xla_path() -> str: Throws FileNotFoundError if XLA source is not found.""" try: - from jax.extend import ffi + import jax + from packaging import version + if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax import ffi + else: + from jax.extend import ffi except ImportError: if os.getenv("XLA_HOME"): xla_home = Path(os.getenv("XLA_HOME")) diff --git a/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 b/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 index 2b78544df..cf5dbb3bc 100644 --- a/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 +++ b/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 @@ -2,22 +2,32 @@ # # See LICENSE for license information. -# This Dockerfile is used to build TransformerEngine wheels for ROCm on x86_64 architecture. -# It is based on the manylinux_2_28_x86_64 based image with ROCm installed. -ARG BASE_IMAGE=quay.io/pypa/manylinux_2_28_x86_64:non_existent_rocm_tag +# This Dockerfile is used to build TransformerEngine wheels for ROCm on x86_64 architecture +# on top of the manylinux_2_28_x86_64 base image. + +# Build args: +# BASE_IMAGE - Base manylinux image to use. Default: quay.io/pypa/manylinux_2_28_x86_64 +# ROCM_REPO_URL - ROCm repository URL. Default: https://repo.radeon.com/rocm/rhel8/latest/main/ +# GPU_TARGETS - Semicolon separated list of target GPU architectures. Default: "gfx942;gfx950" +# TARGET_BRANCH - Target branch for TransformerEngine. Default: none (use git default) +# GPU_TARGETS and TARGET_BRANCH can be overriden when start a container with NVTE_ROCM_ARCH and TARGET_BRANCH environment variables. + +# Set base image +ARG BASE_IMAGE=quay.io/pypa/manylinux_2_28_x86_64 FROM $BASE_IMAGE -# Setup the build_system repo -RUN echo -e "[build_system]\nname=ROCm\nbaseurl=https://repo.almalinux.org/build_system/8/x86_64/\nenabled=1\ngpgcheck=0" >/etc/yum.repos.d/build_system.repo +ARG ROCM_REPO_URL=https://repo.radeon.com/rocm/rhel8/latest/main/ -# Add and enable repos -RUN dnf update -y || true -RUN dnf install -y epel-release elrepo-release -RUN dnf config-manager --set-enabled build_system powertools extras epel elrepo +# Set up ROCm repo +RUN echo -e "[rocm]\nname=ROCm\nbaseurl=${ROCM_REPO_URL}\nenabled=1\ngpgcheck=0" > /etc/yum.repos.d/rocm.repo + +# Setup packages +RUN dnf install -y --disablerepo=epel rocm-dev hipblaslt hipblaslt-devel hipcub hipcub-devel +RUN dnf group install -y "Development Tools" && dnf install -y git cmake llvm-toolset gcc-toolset-12 + +#Uncomment the next line for ROCm 6.4 cmake workaround: remove newer incomnpatible cmake preinstalled on base image +#RUN rm /usr/local/bin/cmake || true -# Setup dev packages -RUN dnf group install -y "Development Tools" && \ - dnf install -y git cmake llvm-toolset hipblaslt hipblaslt-devel gcc-toolset-12 RUN dnf clean all RUN rm -rf /var/cache/dnf/* diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 5320b8a39..5d37ae1d9 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -30,11 +30,17 @@ fi ROCM_BUILD=`${PYBINDIR}python -c "import build_tools.utils as u; print(int(u.rocm_build()))"` +if [ "$LOCAL_TREE_BUILD" != "1" ]; then + if [ "$ROCM_BUILD" = "1" ]; then + git pull + fi + git checkout $TARGET_BRANCH + git submodule update --init --recursive +fi + if [ "$ROCM_BUILD" = "1" ]; then - git pull + ${PYBINDIR}pip install setuptools wheel fi -git checkout $TARGET_BRANCH -git submodule update --init --recursive if $BUILD_METAPACKAGE ; then cd /TransformerEngine @@ -50,10 +56,10 @@ if $BUILD_COMMON ; then WHL_BASE="transformer_engine-${VERSION}" if [ "$ROCM_BUILD" = "1" ]; then TE_CUDA_VERS="rocm" - ${PYBINDIR}pip install ninja dataclasses - if [ -n "$PYBINDIR" ]; then - PATH="$PYBINDIR:$PATH" #hipify expects python in PATH - fi + #dataclasses, psutil are needed for AITER + ${PYBINDIR}pip install ninja dataclasses psutil + #hipify expects python in PATH, also ninja may be installed to python bindir + test -n "$PYBINDIR" && PATH="$PYBINDIR:$PATH" || true else TE_CUDA_VERS="cu12" PYBINDIR=/opt/python/cp38-cp38/bin/ diff --git a/ci/core.sh b/ci/core.sh index 0953d7bde..35b4000e9 100755 --- a/ci/core.sh +++ b/ci/core.sh @@ -31,14 +31,14 @@ fi check_test_filter "nongemm" if [ $? -eq 0 ]; then echo ===== Run non GEMM tests ===== - ctest --test-dir build -j"$n_parallel_jobs" -V --output-on-failure -E "OperatorTest/GEMMTestSuite" + ctest --test-dir build -j"$n_parallel_jobs" -V --output-on-failure -E "GEMMTestSuite" test $? -eq 0 || test_run_error "non-GEMM" fi check_test_filter "gemm" if [ $? -eq 0 ]; then echo ===== Run GEMM tests ===== - ctest --test-dir build -j"$n_parallel_jobs" -V --output-on-failure -R "OperatorTest/GEMMTestSuite" + ctest --test-dir build -j"$n_parallel_jobs" -V --output-on-failure -R "GEMMTestSuite" test $? -eq 0 || test_run_error "GEMM" fi diff --git a/ci/jax.sh b/ci/jax.sh index 80c61ce9b..cc080916c 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -66,33 +66,23 @@ run_test_config() { run_test_config_mgpu() { echo ==== Run mGPU with Fused attention backend: $_fus_attn ==== - _JAX_DISABLE_JIT_FLAG=${JAX_DISABLE_JIT:-0} _ver=$(pip show jaxlib | grep Version) case "$_ver" in *0.4.35*) - # Workaround for distributed tests hang with JIT enabled - JAX_DISABLE_JIT=1 run 3 test_distributed_fused_attn.py -k 'not (test_context_parallel_allgather_attn[BALANCED or test_context_parallel_ring_attn)' - _JAX_DISABLE_JIT_FLAG=1 - - # Run tests that fail with JIT disabled - #run_lbl "allgather_balanced" 3 test_distributed_fused_attn.py -k 'test_context_parallel_allgather_attn[BALANCED' - + # Workaround for distributed tests hang with xla_flag + XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn' + # Test ring attention with xla_flag --xla_experimental_ignore_channel_id only - # TODO: remove this flag after jax/xla update - XLA_FLAGS="--xla_experimental_ignore_channel_id" run_lbl "parallel_ring" 3 test_distributed_fused_attn.py -k test_context_parallel_ring_attn - ;; - *0.6.*) - # Workaround for distributed tests hang with JIT enabled - JAX_DISABLE_JIT=1 run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_allgather_attn[BALANCED' - _JAX_DISABLE_JIT_FLAG=1 + XLA_FLAGS="--xla_experimental_ignore_channel_id" run_lbl "parallel_ring" 3 test_distributed_fused_attn.py -k test_context_parallel_ring_attn ;; *) - run 3 test_distributed_fused_attn.py + # Workaround for distributed tests hang with xla_flag + XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py ;; esac run_default_fa 3 test_distributed_layernorm.py - JAX_DISABLE_JIT=$_JAX_DISABLE_JIT_FLAG run_default_fa 3 test_distributed_layernorm_mlp.py + XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run_default_fa 3 test_distributed_layernorm_mlp.py run_default_fa 3 test_distributed_softmax.py run_default_fa 3 test_sanity_import.py diff --git a/ci/pytorch.sh b/ci/pytorch.sh index e4f8380f5..93b9ded7f 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -12,7 +12,7 @@ TEST_DIR=${TE_PATH}tests/pytorch #: ${TEST_WORKERS:=4} install_prerequisites() { - pip install 'numpy>=1.22.4,<2.0' pandas + pip install 'numpy>=1.22.4' pandas rc=$? if [ $rc -ne 0 ]; then script_error "Failed to install test prerequisites" @@ -58,7 +58,7 @@ run_test_config(){ run_default_fa 1 test_deferred_init.py run_default_fa 1 test_float8tensor.py run_default_fa 1 test_float8_current_scaling_exact.py - run_default_fa 1 test_cpu_offloading.py + test $_fus_attn = auto -o $_fus_attn = ck -o $_fus_attn = aotriton && NVTE_FLASH_ATTN=0 run 1 test_cpu_offloading.py run_default_fa 1 test_fused_rope.py run_default_fa 1 test_fusible_ops.py run_default_fa 3 test_gemm_autotune.py diff --git a/docs/api/c/cast_transpose_noop.rst b/docs/api/c/cast_transpose_noop.rst new file mode 100644 index 000000000..ae80c5d2d --- /dev/null +++ b/docs/api/c/cast_transpose_noop.rst @@ -0,0 +1,9 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +cast_transpose_noop.h +===================== + +.. doxygenfile:: cast_transpose_noop.h diff --git a/docs/api/c/cudnn.rst b/docs/api/c/cudnn.rst new file mode 100644 index 000000000..5d93c4d6e --- /dev/null +++ b/docs/api/c/cudnn.rst @@ -0,0 +1,9 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +cudnn.h +======= + +.. doxygenfile:: cudnn.h diff --git a/docs/api/c/index.rst b/docs/api/c/index.rst index 7bc864dcc..0499f52f0 100644 --- a/docs/api/c/index.rst +++ b/docs/api/c/index.rst @@ -14,10 +14,13 @@ directly from C/C++, without Python. transformer_engine.h activation.h + cast_transpose_noop.h cast.h + cudnn.h fused_attn.h fused_rope.h gemm.h + multi_tensor.h normalization.h padding.h permutation.h diff --git a/docs/api/c/multi_tensor.rst b/docs/api/c/multi_tensor.rst new file mode 100644 index 000000000..8ba2d274c --- /dev/null +++ b/docs/api/c/multi_tensor.rst @@ -0,0 +1,9 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +multi_tensor.h +============== + +.. doxygenfile:: multi_tensor.h diff --git a/docs/api/common.rst b/docs/api/common.rst index 95d4b50f3..541118985 100644 --- a/docs/api/common.rst +++ b/docs/api/common.rst @@ -11,3 +11,7 @@ Common API .. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None) .. autoapiclass:: transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=Format.E4M3) + +.. autoapiclass:: transformer_engine.common.recipe.Float8CurrentScaling(fp8_format=Format.HYBRID) + +.. autoapiclass:: transformer_engine.common.recipe.Float8BlockScaling(fp8_format=Format.E4M3) diff --git a/docs/debug.rst b/docs/debug.rst new file mode 100644 index 000000000..d33568ea3 --- /dev/null +++ b/docs/debug.rst @@ -0,0 +1,14 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. +Precision debug tools +============================================== + +.. toctree:: + :caption: Precision debug tools + + debug/1_getting_started.rst + debug/2_config_file_structure.rst + debug/api + debug/4_distributed.rst \ No newline at end of file diff --git a/docs/debug/1_getting_started.rst b/docs/debug/1_getting_started.rst new file mode 100644 index 000000000..bc2b95057 --- /dev/null +++ b/docs/debug/1_getting_started.rst @@ -0,0 +1,241 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Getting started +============== + +.. note:: + + Precision debug tools with `Nvidia-DL-Framework-Inspect `_ for Transformer Engine are currently supported only for PyTorch. + +Transformer Engine provides a set of precision debug tools which allow you to easily: + +- log the statistics for each of the tensors in every matrix multiply (GEMM) operation, +- run selected GEMMs in higher precision, +- run current scaling - with one scaling factor per tensor - for particular GEMMs, +- test new precisions and integrate them with FP8 training, +- ... and many more. + +There are 4 things one needs to do to use Transformer Engine debug features: + +1. Create a configuration YAML file to configure the desired features. +2. Import, and initialize the `Nvidia-DL-Framework-Inspect `_ tool, which is installed as the dependency of the Transformer Engine. +3. One can pass ``name="..."`` when creating TE layers to easier identify layer names. If this is not provided, names will be inferred automatically. +4. Invoke ``debug_api.step()`` at the end of one forward-backward pass. + +To start debugging, one needs to create a configuration YAML file. This file lists the features to be used in particular layers. There are 2 kinds of features: + +- provided by the Transformer Engine - for example, DisableFP8GEMM or LogTensorStats - they are listed in the :doc:`debug features API <3_api_features>` section +- defined by the user. For details on how to create a custom feature - please read the :doc:`calls to Nvidia-DL-Framework-Inspect <3_api_te_calls>` section. + +.. figure:: ./img/introduction.svg + :align: center + + Fig 1: Example of Nvidia-DL-Framework-Inspect affecting training script with 3 TE Linear Layers. + ``config.yaml`` contains the specification of the features used for each Linear layer. Some feature classes are provided by TE, + one - ``UserProvidedPrecision`` - is a custom feature implemented by the user. Nvidia-DL-Framework-Inspect inserts features into the layers according to the config. + +Example training script +---------------------- + +Let's look at a simple example of training a Transformer layer using Transformer Engine with FP8 precision. This example demonstrates how to set up the layer, define an optimizer, and perform a few training iterations using synthetic data. + +.. code-block:: python + + # train.py + + from transformer_engine.pytorch import TransformerLayer + import torch + import torch.nn as nn + import torch.optim as optim + import transformer_engine.pytorch as te + + hidden_size = 512 + num_attention_heads = 8 + + transformer_layer = TransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=hidden_size, + num_attention_heads=num_attention_heads + ).cuda() + + dummy_input = torch.randn(10, 32, hidden_size).cuda() + criterion = nn.MSELoss() + optimizer = optim.Adam(transformer_layer.parameters(), lr=1e-4) + dummy_target = torch.randn(10, 32, hidden_size).cuda() + + for epoch in range(5): + transformer_layer.train() + optimizer.zero_grad() + with te.fp8_autocast(enabled=True): + output = transformer_layer(dummy_input) + loss = criterion(output, dummy_target) + loss.backward() + optimizer.step() + +We will demonstrate two debug features on the code above: + +1. Disabling FP8 precision for specific GEMM operations, such as the FC1 and FC2 forward propagation GEMM. +2. Logging statistics for other GEMM operations, such as gradient statistics for data gradient GEMM within the LayerNormLinear sub-layer of the TransformerLayer. + +Config file +---------- + +We need to prepare the configuration YAML file, as below + +.. code-block:: yaml + + # config.yaml + + fc1_fprop_to_fp8: + enabled: True + layers: + layer_types: [fc1, fc2] # contains fc1 or fc2 in name + transformer_engine: + DisableFP8GEMM: + enabled: True + gemms: [fprop] + + log_tensor_stats: + enabled: True + layers: + layer_types: [layernorm_linear] # contains layernorm_linear in name + transformer_engine: + LogTensorStats: + enabled: True + stats: [max, min, mean, std, l1_norm] + tensors: [activation] + freq: 1 + start_step: 2 + end_step: 5 + +Further explanation on how to create config files is in the :doc:`next part of the documentation <2_config_file_structure>`. + +Adjusting Python file +-------------------- + +.. code-block:: python + + # (...) + + import nvdlfw_inspect.api as debug_api + debug_api.initialize( + config_file="./config.yaml", + feature_dirs=["/path/to/transformer_engine/debug/features"], + log_dir="./log", + default_logging_enabled=True) + + # initialization of the TransformerLayer with the name + transformer_layer = TransformerLayer( + name="transformer_layer", + # ...) + + # (...) + for epoch in range(5): + # forward and backward pass + # ... + debug_api.step() + +In the modified code above, the following changes were made: + +1. Added an import for ``nvdlfw_inspect.api``. +2. Initialized the Nvidia-DL-Framework-Inspect by calling ``debug_api.initialize()`` with appropriate configuration, specifying the path to the config file, feature directories, and log directory. +3. Added ``debug_api.step()`` after each of the forward-backward pass. + +Inspecting the logs +------------------ + +Let's look at the files with the logs. Two files will be created: + +1. debug logs. +2. statistics logs. + +Let's look inside them! + +In the main log file, you can find detailed information about the transformer layer's GEMMs behavior. You can see that ``fc1`` and ``fc2`` fprop GEMMs are run in high precision, as intended. + +.. code-block:: text + + # log/nvdlfw_inspect_logs/nvdlfw_inspect_globalrank-0.log + + INFO - Default logging to file enabled at ./log + INFO - Reading config from ./config.yaml. + INFO - Loaded configs for dict_keys(['fc1_fprop_to_fp8', 'log_tensor_stats']). + INFO - transformer_layer.self_attention.layernorm_qkv: Tensor: activation, gemm fprop - FP8 quantization + INFO - transformer_layer.self_attention.layernorm_qkv: Tensor: activation, gemm wgrad - FP8 quantization + INFO - transformer_layer.self_attention.layernorm_qkv: Tensor: weight, gemm fprop - FP8 quantization + INFO - transformer_layer.self_attention.layernorm_qkv: Tensor: weight, gemm dgrad - FP8 quantization + INFO - transformer_layer.self_attention.layernorm_qkv: Tensor: gradient, gemm dgrad - FP8 quantization + INFO - transformer_layer.self_attention.layernorm_qkv: Tensor: gradient, gemm wgrad - FP8 quantization + INFO - transformer_layer.self_attention.proj: Tensor: activation, gemm fprop - FP8 quantization + INFO - transformer_layer.self_attention.proj: Tensor: activation, gemm wgrad - FP8 quantization + INFO - transformer_layer.self_attention.proj: Tensor: weight, gemm fprop - FP8 quantization + INFO - transformer_layer.self_attention.proj: Tensor: weight, gemm dgrad - FP8 quantization + INFO - transformer_layer.self_attention.proj: Tensor: gradient, gemm dgrad - FP8 quantization + INFO - transformer_layer.self_attention.proj: Tensor: gradient, gemm wgrad - FP8 quantization + INFO - transformer_layer.layernorm_mlp.fc1: Tensor: activation, gemm fprop - High precision + INFO - transformer_layer.layernorm_mlp.fc1: Tensor: activation, gemm wgrad - FP8 quantization + INFO - transformer_layer.layernorm_mlp.fc1: Tensor: weight, gemm fprop - High precision + INFO - transformer_layer.layernorm_mlp.fc1: Tensor: weight, gemm dgrad - FP8 quantization + INFO - transformer_layer.layernorm_mlp.fc1: Tensor: gradient, gemm dgrad - FP8 quantization + INFO - transformer_layer.layernorm_mlp.fc1: Tensor: gradient, gemm wgrad - FP8 quantization + INFO - transformer_layer.layernorm_mlp.fc2: Tensor: activation, gemm fprop - High precision + INFO - transformer_layer.layernorm_mlp.fc2: Tensor: activation, gemm wgrad - FP8 quantization + INFO - transformer_layer.layernorm_mlp.fc2: Tensor: weight, gemm fprop - High precision + INFO - transformer_layer.layernorm_mlp.fc2: Tensor: weight, gemm dgrad - FP8 quantization + INFO - transformer_layer.layernorm_mlp.fc2: Tensor: gradient, gemm dgrad - FP8 quantization + INFO - transformer_layer.layernorm_mlp.fc2: Tensor: gradient, gemm wgrad - FP8 quantization + INFO - transformer_layer.self_attention.layernorm_qkv: Feature=LogTensorStats, API=look_at_tensor_before_process: activation + .... + +The second log file (``nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log``) contains statistics for tensors we requested in ``config.yaml``. + +.. code-block:: text + + # log/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log + + INFO - transformer_layer.self_attention.layernorm_qkv_activation_max iteration=000002 value=4.3188 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_min iteration=000002 value=-4.3386 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_mean iteration=000002 value=0.0000 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_std iteration=000002 value=0.9998 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_l1_norm iteration=000002 value=130799.6953 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_max iteration=000003 value=4.3184 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_min iteration=000003 value=-4.3381 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_mean iteration=000003 value=0.0000 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_std iteration=000003 value=0.9997 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_l1_norm iteration=000003 value=130788.1016 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_max iteration=000004 value=4.3181 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_min iteration=000004 value=-4.3377 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_mean iteration=000004 value=0.0000 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_std iteration=000004 value=0.9996 + INFO - transformer_layer.self_attention.layernorm_qkv_activation_l1_norm iteration=000004 value=130776.7969 + +Logging using TensorBoard +------------------------ + +Precision debug tools support logging using `TensorBoard `_. To enable it, one needs to pass the argument ``tb_writer`` to the ``debug_api.initialize()``. Let's modify ``train.py`` file. + +.. code-block:: python + + # (...) + + from torch.utils.tensorboard import SummaryWriter + tb_writer = SummaryWriter('./tensorboard_dir/run1') + + # add tb_writer to the Debug API initialization + debug_api.initialize( + config_file="./config.yaml", + feature_dirs=["/path/to/transformer_engine/debug/features"], + log_dir="./log", + tb_writer=tb_writer) + + # (...) + +Let's run training and open TensorBoard by ``tensorboard --logdir=./tensorboard_dir/run1``: + +.. figure:: ./img/tensorboard.png + :align: center + + Fig 2: TensorBoard with plotted stats. \ No newline at end of file diff --git a/docs/debug/2_config_file_structure.rst b/docs/debug/2_config_file_structure.rst new file mode 100644 index 000000000..f1069b0c8 --- /dev/null +++ b/docs/debug/2_config_file_structure.rst @@ -0,0 +1,241 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Config File Structure +==================== + +To enable debug features, create a configuration YAML file to specify the desired behavior, such as determining which GEMMs (General Matrix Multiply operations) should run in higher precision rather than FP8 and defining which statistics to log. +Below, we outline how to structure the configuration YAML file. + +General Format +------------- + +A config file can have one or more sections, each containing settings for specific layers and features: + +.. code-block:: yaml + + section_name_1: + enabled: ... + layers: + # Specify layers here... + transformer_engine: + Feature1Name: + enabled: ... + # Feature details... + Feature2Name: + enabled: ... + # Feature details... + + section_name_2: + enabled: ... + layers: + # Specify layers here... + Feature1Name: # If feature has no namespace, then it is in the default namespace. + enabled: ... + # Feature details... + + section_name_3: + enabled: ... + layers: + # Specify layers here... + transformer_engine: + Feature1Name: + enabled: ... + # Feature details... + Feature2Name: + enabled: ... + # Feature details... + +Sections may have any name and must contain: + +1. An ``enabled`` field that specifies whether the features in that section will be active. +2. A ``layers`` field specifying which layers the section applies to. Each layer can belong to only one section. +3. Additional fields describing features for those layers. + +Layer Specification +------------------ + +Debug layers can be identified by a ``name`` parameter: + +.. code-block:: python + + linear = transformer_engine.debug.pytorch.Linear(in_features, out_features, name="linear1") + +This name is used in the config file to identify the layer. To specify the ``layers`` field, you can use one of the following methods: + +1. ``layer_name_regex_pattern``: Use a regular expression to match layer names. This expression must adhere to the Python ``re`` module syntax. +2. ``layer_types``: Provide a list of strings, where a layer will be selected if any string matches part of its name. + +Examples: + +.. code-block:: yaml + + # Example 1: Using regular expression to select layers + my_section: + enabled: ... + layers: + layer_name_regex_pattern: 'self_attn.*' + transformer_engine: + (...) + + # Example 2: Using layer type to select layers + another_section: + enabled: ... + layers: + layer_types: ['fc1', 'layernorm_linear'] + transformer_engine: + (...) + +Names in Transformer Layers +-------------------------- + +There are three ways to assign a name to a layer in the Transformer Engine: + +- Initialize the layer with the ``name=...`` argument. +- Use ``debug_api.infer_and_assign_layer_names(model)``, which assigns names based on class names. +- Rely on the default names assigned during module initialization, such as ``Layer_n``, where ``n`` represents the layer number. + +The ``TransformerLayer`` in Transformer Engine is a composition of multiple sub-layers. We can modify some of these layers using precision debug tools, particularly those that contain exactly one linear layer. To see the names of all such layers, we can inspect log files. For instance, a ``TransformerLayer`` named ``transformer_layer`` might consist of: + +- ``transformer_layer.self_attn.layernorm_linear_qkv`` / ``transformer_layer.self_attn.linear_qkv`` / ``transformer_layer.self_attn.layernorm_linear_q`` / ``transformer_layer.self_attn.linear_q`` / ``transformer_layer.self_attn.linear_kv``, +- ``transformer_layer.self_attn.proj``, +- ``transformer_layer.inter_attn.*`` for ``layer_type="decoder"``, +- ``transformer_layer.layernorm_mlp.fc1``, +- ``transformer_layer.layernorm_mlp.fc2``, + +depending on the configuration. Some layers, like ``LayerNormLinear``, are fusions of two layers: ``LayerNorm`` and ``Linear``. When referring to such layers in precision debug tools, only the ``Linear`` part is affected. + +Below is an example ``TransformerLayer`` with four linear layers that can be influenced by the precision debug tools. + +.. figure:: ./img/names.svg + :align: center + :width: 80% + + Fig 1: Names of layers in an example configuration of TransformerLayer. The most nested blocks represent the most basic layers, each containing one linear layer. Layers that do not contain linear layers, such as ``DotProductAttention``, are omitted. + +**Configuration File Example** + +.. code-block:: yaml + + # Disables wgrad in all 4 GEMMs + section1: + enabled: True + layers: + layer_types: [transformer_layer] + transformer_engine: + DisableFP8GEMM: + enabled: True + gemms: [wgrad] + + # Disables all GEMMs in layernorm_mlp layer + section2: + enabled: True + layers: + layer_types: [layernorm_mlp] + transformer_engine: + DisableFP8Layer: + enabled: True + + # Logs wgrad stats in fc1 + section3: + enabled: True + layers: + layer_types: [fc1] + transformer_engine: + LogTensorStats: + enabled: True + stats: [min] + tensors: [wgrad] + freq: 1 + start_step: 0 + end_step: 50 + + +Structured Configuration for GEMMs and Tensors +--------------------------------------------- + +Sometimes a feature is parameterized by a list of tensors or by a list of GEMMs. +There are multiple ways of describing this parameterization. + +We can pass lists, as below. + +.. code-block:: yaml + + Feature: + enabled: ... + gemms: [gemm1, gemm2] + tensors: [tensor1, tensor2] + ... + +We can use struct for tensors. + +.. code-block:: yaml + + Feature: + gemms: [gemm1, gemm2] + tensors_struct: + - tensor: tensor1 + feature_param1: value + - tensor: tensor2 + feature_param1: value + gemm_feature_param1: value + +Similarly, we can use struct for GEMMs. + +.. code-block:: yaml + + Feature: + enabled: ... + tensors: [tensor1, tensor2] + gemms_struct: + - gemm: gemm1 + feature_param1: value + - gemm: gemm2 + feature_param1: value + gemm_feature_param1: value + +We can use both structs for tensors and GEMMs. The tensors_struct should be nested inside gemms_struct. + +.. code-block:: yaml + + Feature: + enabled: ... + gemms_struct: + - gemm: gemm1 + tensors: [tensor1, tensor2] + tensor_feature_param1: value + gemm_feature_param1: value + - gemm: gemm2 + tensors_struct: + - tensor: tensor1 + tensor_feature_param1: value + - tensor: tensor2 + tensor_feature_param2: value + gemm_feature_param1: value + +Enabling or Disabling Sections and Features +------------------------------------------ + +Debug features can be enabled or disabled with the ``enabled`` keyword: + +.. code-block:: yaml + + section1: + enabled: True + layers: + layer_types: [self_attention] + transformer_engine: + LogTensorStats: + enabled: False # Disables the LogTensorStats feature + stats: [max, min, mean, std, l1_norm] + + section2: + enabled: False # Disables entire section2 + transformer_engine: + LogFp8TensorStats: + enabled: True # Does not enable the LogFp8TensorStats feature, because section2 is disabled + stats: [underflows, overflows] + +By organizing your ``config.yaml`` properly, you can easily manage debugging features, ensuring a more streamlined and customizable debugging experience. diff --git a/docs/debug/3_api_debug_setup.rst b/docs/debug/3_api_debug_setup.rst new file mode 100644 index 000000000..bda8f096d --- /dev/null +++ b/docs/debug/3_api_debug_setup.rst @@ -0,0 +1,87 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Setup +===== + +Precision debug tools for the Transformer Engine use `Nvidia-DL-Framework-Inspect `_ package from NVIDIA. +Please refer to the Nvidia-DL-Framework-Inspect `documentation `_ for more details. +Below, we outline the steps for debug initialization. + +initialize() +----------- + +Must be called once on every rank in the global context to initialize Nvidia-DL-Framework-Inspect. + +**Parameters** + +- **config_file** (*str*, default=""): Path to the configuration YAML file containing features to enable and layer names. If one wants to run without the configuration file, pass ``""``. +- **feature_dirs** (*List[str] | str*): List of directories containing features to load and register. One needs to pass ``[/path/to/transformerengine/transformer_engine/debug/features]`` to use TE features. +- **logger** (*Union[BaseLogger, None]*, default=None): Logger for logging tensor statistics. Should adhere to ``BaseLogger`` from the `Nvidia-DL-Framework-Inspect `_ package. +- **log_dir** (*str*, default= "."): Directory path to hold ``debug_logs`` and ``debug_statistics_logs``. +- **tb_writer** (*TensorBoardWriter*, default=None): TensorBoard writer for logging. +- **default_logging_enabled** (*bool*, default=False): Enable default logging to the file. + +.. code-block:: python + + import nvdlfw_inspect.api as debug_api + + debug_api.initialize( + config_file="./config.yaml", + feature_dirs=["/path/to/transformer_engine/debug/features"], + log_dir="./log_dir") + +set_tensor_reduction_group() +-------------------------- + +Needed only for logging tensor stats. In multi-GPU training, activation and gradient tensors are distributed across multiple nodes. This method lets you specify the group for the reduction of stats; see the `reduction group section <./4_distributed.rst#reduction-groups>`_ for more details. + +If the tensor reduction group is not specified, then statistics are reduced across all nodes in the run. + +**Parameters** + +- **group** (torch.distributed.ProcessGroup): The process group across which tensors will be reduced to get stats. + + +.. code-block:: python + + import nvdlfw_inspect.api as debug_api + + # initialization + # (...) + + pipeline_parallel_group = initialize_pipeline_parallel_group() + + debug_api.set_tensor_reduction_group(pipeline_parallel_group) + + # training + # (...) + # activation/gradient tensor statistics are reduced along pipeline_parallel_group + +set_weight_tensor_tp_group_reduce() +--------------------------------- + +By default, weight tensor statistics are reduced within the tensor parallel group. This function allows you to disable that behavior; for more details, see `reduction group section <./4_distributed.rst#reduction-groups>`_. + +This method is not provided by the ``debug_api``, but by the ``transformer_engine.debug``. + +**Parameters** + +- **enabled** (*bool*, default=True): A boolean flag to enable or disable the reduction of weight tensor statistics within the tensor parallel group. + + +.. code-block:: python + + import nvdlfw_inspect.api as debug_api + from transformer_engine.debug import set_weight_tensor_tp_group_reduce + + # initialization + # (...) + + set_weight_tensor_tp_group_reduce(False) + + # training + # (...) + # weight tensor statistics are not reduced diff --git a/docs/debug/3_api_features.rst b/docs/debug/3_api_features.rst new file mode 100644 index 000000000..b31c437b2 --- /dev/null +++ b/docs/debug/3_api_features.rst @@ -0,0 +1,14 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Debug features +========== + +.. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats +.. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats +.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM +.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer +.. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling +.. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant diff --git a/docs/debug/3_api_te_calls.rst b/docs/debug/3_api_te_calls.rst new file mode 100644 index 000000000..eb66c8ff2 --- /dev/null +++ b/docs/debug/3_api_te_calls.rst @@ -0,0 +1,45 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Calls to Nvidia-DL-Framework-Inspect +==================================== +Let's look deeper into how Nvidia-DL-Framework-Inspect with Transformer Engine work together. TransformerEngine layers have some hook calls inside each of the GEMMs. Users can define feature classes or use feature classes provided with TE. File ``config.yaml`` describes which hooks need to be used for which layers. Nvidia-DL-Framework-Inspect combines 3 things: TE training, feature classes and ``config.yaml`` and takes care of inserting hooks in the correct places. This process is illustrated in the image below. + +.. figure:: ./img/api_calls1.svg + :align: center + + Fig 1: Example of Nvidia-DL-Framework-Inspect affecting training script with 1 Linear Layer. For tensors mentioned in ``config.yaml``, behavior of ``modify_tensor_enabled()`` and ``modify_tensor()`` calls are substituted with definitions from the feature class. Other calls return default values - in fact they do nothing. + +In this page, all calls from TransformerEngine to the Nvidia-DL-Framework-Inspect for each GEMM are listed. The order of these calls is illustrated in the image below. + +.. figure:: ./img/api_calls2.svg + :align: center + + Fig 2: The calls to Nvidia-DL-Framework-Inspect done for Transformer Engine. There are 2 types of calls: GEMM calls and routing calls. + + +There are 2 categories of API calls, each is used for different purposes: + +- GEMM calls - invoked during every GEMM, used to process or quantize tensors and collect information about them, +- Routing calls - invoked at the beginning of every forward pass - they indicate whether a feature is going to use `modify_tensor()`, etc. + +If all routing calls for the layer return `False`, then the layer is invoked in an optimized version with Transformer Engine fusions. +If any of the routing calls return `True`, layers are run without the fusions. This is necessary because otherwise some tensors cannot be accessed +if fusions happen. An important remark is that if no feature is used for the layer, then it should perform as fast as the layer without initializing `debug_api`. + + +.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor + +.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor + +.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize + +.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor_enabled + +.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.fp8_gemm_enabled + +.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_enabled + +.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize_enabled diff --git a/docs/debug/4_distributed.rst b/docs/debug/4_distributed.rst new file mode 100644 index 000000000..6f69f2712 --- /dev/null +++ b/docs/debug/4_distributed.rst @@ -0,0 +1,91 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Distributed training +=================== + +Nvidia-Pytorch-Inspect with Transformer Engine supports multi-GPU training. This guide describes how to run it and how the supported features work in the distributed setting. + +To use precision debug tools in multi-GPU training, one needs to: + +1. Run ``debug_api.initialize(...)`` and provide the same configuration YAML file on every node. +2. If one wants to log stats, one may want to invoke ``debug_api.set_tensor_reduction_group`` with a proper reduction group. + +Behavior of the features +----------------------- + +In a distributed setting, **DisableFP8GEMM** and **DisableFP8Layer** function similarly to the single-GPU case, with no notable differences. + +**PerTensorScaling** and **FakeQuant** calculate FP8 scaling factors independently on each node, meaning the number of GPUs may affect results. This differs from the delayed scaling FP8 recipe behavior, in which scaling factors are synchronized. + +.. figure:: ./img/scaling_factors.svg + :align: center + + Fig 1: For **PerTensorScaling** and **FakeQuant** tensor scaling factors are computed separately for each of the tensor shards. This is not the case for delayed scaling FP8 scaling factors, which are synchronized. + +Logging-related features are more complex and will be discussed further in the next sections. + +Reduction groups +-------------- + +In setups with tensor, data, or pipeline parallelism, some tensors are distributed across multiple GPUs, requiring a reduction operation to compute statistics for these tensors. + +The weight tensor is always split among the tensor parallel group, and debug tools automatically reduce statistics within this group by default. To disable this automatic reduction, use: + +.. code-block:: python + + transformer_engine.debug.set_weight_tensor_tp_group_reduce(False) + +In cases of data parallelism, Transformer Engine modules lack the process group needed for reduction. To manually specify the group, use: + +.. code-block:: python + + debug_api.set_tensor_reduction_group(group) + +This command ensures statistics are reduced across the defined group. Activation statistics are logged after the forward pass (immediately after exiting autocast), while gradient (dgrad and wgrad) statistics are logged following the backward pass. + +Below, we illustrate configurations for a 4-node setup with tensor parallelism size 2 and data parallelism size 2, showcasing different reduction configurations. + +.. figure:: ./img/reduction1.svg + :align: center + + Fig 2: There is a single tensor reduction group composed of all nodes. As a result, each node logs the same statistics for the tensors, as they are fully reduced across all nodes. + +.. figure:: ./img/reduction2.svg + :align: center + + Fig 3: Every node is set with a tensor reduction group consisting of itself. Every node prints the same statistics for weights (which are still synchronized within TP groups), but the statistics of activations and gradients are not synchronized. + +.. figure:: ./img/reduction3.svg + :align: center + + Fig 4: Weight synchronization is disabled by ``set_weight_tensor_tp_group_reduce(False)``, so every node logs stats for its shard of the weight. + + +Microbatching +----------- + +Let's dive into how statistics collection works with microbatching. By microbatching, we mean invoking multiple ``forward()`` calls for each ``debug_api.step()``. The behavior is as follows: + +- For weight tensors, the stats remain the same for each microbatch because the weight does not change. +- For other tensors, the stats are accumulated. + +Logging to files and TensorBoard +------------------------------ + +In a single-node setup with ``default_logging_enabled=True``, all logs are saved by default to ``log_dir/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log``. In multi-GPU training, each node writes its reduced statistics to its unique file, named ``log_dir/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-i.log`` for rank i. Because these logs contain reduced statistics, the logged values are identical for all nodes within a reduction group. + +If certain nodes are given a TensorBoard writer, only those nodes will log to TensorBoard. This is useful in scenarios involving pipeline, data, and tensor parallelism, such as with two transformer layers and settings TP_SIZE = 2, DP_SIZE = 2, and PP_SIZE = 2. To log all stats to TensorBoard, you should pass a TensorBoard writer to one process in each pipeline parallel group. + +.. figure:: ./img/pipeline_logging.svg + :align: center + + Fig 5: Example with pipeline parallelism, where a ``tb_writer`` is assigned to one node within each pipeline parallel group, setting these as tensor reduction groups. + +Alternatively, setting the tensor reduction group to None will yield unreduced statistics for wgrad and dgrad tensors on each node, allowing for post-processing. For weight statistics without reduction in the TP parallel group, use: + +.. code-block:: python + + transformer_engine.debug.set_weight_tensor_tp_group_reduce(False) \ No newline at end of file diff --git a/docs/debug/api.rst b/docs/debug/api.rst new file mode 100644 index 000000000..ac593d353 --- /dev/null +++ b/docs/debug/api.rst @@ -0,0 +1,13 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. +API +============ + +.. toctree:: + :caption: Precision debug tools API + + 3_api_debug_setup.rst + 3_api_features.rst + 3_api_te_calls.rst \ No newline at end of file diff --git a/docs/debug/img/api_calls1.svg b/docs/debug/img/api_calls1.svg new file mode 100644 index 000000000..098f384b2 --- /dev/null +++ b/docs/debug/img/api_calls1.svg @@ -0,0 +1 @@ +te.LinearLinear1Nvidia-DLFramework-Inspectconfig.yamlSection1:enabled: Truelayer_names: [Linear1]UserProvidedPrecision:enabled: Truegemms_struct:-gemm: frop-tensors: [activation, output]-gemm: dgrad-tensors: [weight]FeatureclassesUserProvidedPrecisionFPROPWGRADDGRADmodify_tensor_enabledDefaultmodify_tensormodify_tensor_enabledmodify_tensor \ No newline at end of file diff --git a/docs/debug/img/api_calls2.svg b/docs/debug/img/api_calls2.svg new file mode 100644 index 000000000..5df72fc2e --- /dev/null +++ b/docs/debug/img/api_calls2.svg @@ -0,0 +1 @@ +Tensor Ainspect_tensorfp8 castmodify_tensorinspect_tensor_postquantizeGEMMinspect_tensormodify_tensorinspect_tensor_enabledinspect_tensor_postquantize_enabledfp8_gemm_enabledmodify_tensor_enabledTensor Binspect_tensorfp8 castmodify_tensorinspect_tensor_postquantizeRouting callsGEMM calls \ No newline at end of file diff --git a/docs/debug/img/fake_quant.svg b/docs/debug/img/fake_quant.svg new file mode 100644 index 000000000..3ba6973d5 --- /dev/null +++ b/docs/debug/img/fake_quant.svg @@ -0,0 +1 @@ +FP8 GEMMBF16weightBF16inputFP8inputFP8weightBF16activationBF16 GEMMBF16weightBF16inputBF16activationBF16 Inputfake quantizedto FP8FP8inputBF16 Inputfake quantizedto FP8 \ No newline at end of file diff --git a/docs/debug/img/introduction.svg b/docs/debug/img/introduction.svg new file mode 100644 index 000000000..0eae8e820 --- /dev/null +++ b/docs/debug/img/introduction.svg @@ -0,0 +1 @@ +te.LinearLinear1Nvidia-DLFramework-InspectDisableFp8LayerLogTensorStatsconfig.yamlte.LinearLinear2DisableFp8LayerLogTensorStatsSection1:enabled: Truelayer_names: [Linear1, Linear2]DisableFp8Layer:enabled: TrueSection2:enabled: Truelayer_names: [Linear2]LogTensorStats:enabled: TrueotherparamsSection3:enabled: Truelayer_names: [Linear3]UserProvidedPrecision:enabled: Truete.LinearLinear3FeatureclassesDisableFp8LayerUserProvidedPrecisionUserProvidedPrecisionProvidedby the Transformer EngineUser candefinecustomfeatureclasses \ No newline at end of file diff --git a/docs/debug/img/names.svg b/docs/debug/img/names.svg new file mode 100644 index 000000000..3990939e7 --- /dev/null +++ b/docs/debug/img/names.svg @@ -0,0 +1 @@ +Transformer Layer with name transformer_layertransformer_layer.self_attntransformer_layer.self_attn.projtransformer_layer.self_attn.layernorm_linear_qkvtransformer_layer.layernorm_mlptransformer_layer.layernorm_mlp.fc1transformer_layer.layernorm_mlp.fc21 Linear1 Linear1 Linear1 Linear \ No newline at end of file diff --git a/docs/debug/img/pipeline_logging.svg b/docs/debug/img/pipeline_logging.svg new file mode 100644 index 000000000..b87254315 --- /dev/null +++ b/docs/debug/img/pipeline_logging.svg @@ -0,0 +1 @@ +Node 1Node 2Node 3Node 4Node 5Node 6Node 7Node 8TensorBoard logstb_writertb_writertensor reduction group 1=pipeline parallel group 1tensor reduction group 2=pipeline parallel group 2 \ No newline at end of file diff --git a/docs/debug/img/reduction1.svg b/docs/debug/img/reduction1.svg new file mode 100644 index 000000000..184799d53 --- /dev/null +++ b/docs/debug/img/reduction1.svg @@ -0,0 +1 @@ +Node 1Node 2Node 3Node 4TP group 1activation/gradient tensorsweight tensorsTensor reduction groupTP group 2StatsStatsStatsStats \ No newline at end of file diff --git a/docs/debug/img/reduction2.svg b/docs/debug/img/reduction2.svg new file mode 100644 index 000000000..36f94611e --- /dev/null +++ b/docs/debug/img/reduction2.svg @@ -0,0 +1 @@ +TP group 1activation/gradient tensorsweight tensorsTensor reduction groupTP group 2StatsStatsStatsStatsTensor reduction groupTensor reduction groupTensor reduction groupNode 1Node 2Node 3Node 4 \ No newline at end of file diff --git a/docs/debug/img/reduction3.svg b/docs/debug/img/reduction3.svg new file mode 100644 index 000000000..601fb8502 --- /dev/null +++ b/docs/debug/img/reduction3.svg @@ -0,0 +1 @@ +TP group 1activation/gradient tensorsweight tensorsTensor reduction groupTP group 2StatsStatsStatsStatsNode 1Node 2Node 3Node 4 \ No newline at end of file diff --git a/docs/debug/img/scaling_factors.svg b/docs/debug/img/scaling_factors.svg new file mode 100644 index 000000000..b70b51e66 --- /dev/null +++ b/docs/debug/img/scaling_factors.svg @@ -0,0 +1 @@ +One Scaling FactorScaling Factor No. 1Scaling Factor No. 2Node 1Node 2NodeOne Scaling FactorOne Scaling FactorNode 1Node 2PerTensorScalingandFakeQuantFP8 Delayed Scaling \ No newline at end of file diff --git a/docs/debug/img/tensorboard.png b/docs/debug/img/tensorboard.png new file mode 100644 index 000000000..481dbd2eb Binary files /dev/null and b/docs/debug/img/tensorboard.png differ diff --git a/docs/index.rst b/docs/index.rst index cd9ce41cf..bbdb4fea6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -52,4 +52,5 @@ Transformer Engine documentation :caption: Advanced api/c/index + debug examples/attention/attention.ipynb diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh new file mode 100644 index 000000000..9339777f4 --- /dev/null +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -0,0 +1,26 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + + + +: ${TE_PATH:=/opt/transformerengine} +: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} +: ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/} + +# Config with the dummy feature which prevents nvinspect from being disabled. +# Nvinspect will be disabled if no feature is active. +: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} + +FAIL=0 + +pip install pytest==8.2.1 +pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 +NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 + +# standard numerics tests with initialized debug +NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 + +exit $FAIL diff --git a/qa/L0_pytorch_lint/test.sh b/qa/L0_pytorch_lint/test.sh index 81d7822d7..e2c50c445 100644 --- a/qa/L0_pytorch_lint/test.sh +++ b/qa/L0_pytorch_lint/test.sh @@ -20,5 +20,5 @@ if [ -z "${CPP_ONLY}" ] then cd $TE_PATH echo "Checking Python files" - python3 -m pylint --recursive=y transformer_engine/common transformer_engine/pytorch + python3 -m pylint --recursive=y transformer_engine/common transformer_engine/pytorch transformer_engine/debug fi diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 79f3c8fb9..ea5236502 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -44,6 +44,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entro NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 4319e96c7..09ef661c4 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -20,6 +20,7 @@ FAILED_CASES="" : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" + pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" @@ -30,6 +31,19 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_use python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" + +# debug tests + + +# Config with the dummy feature which prevents nvinspect from being disabled. +# Nvinspect will be disabled if no feature is active. +: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} +: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} + +pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" +# standard numerics tests with initialized debug +NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" + if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" exit 1 diff --git a/setup.py b/setup.py index 41893644c..b7b234ba3 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,6 @@ from setuptools.command.build_ext import build_ext as BuildExtension -from setuptools.command.develop import develop as _develop os.environ["NVTE_PROJECT_BUILDING"] = "1" @@ -48,26 +47,6 @@ if not rocm_build(): archs = cuda_archs() -# A custom develop command only used for ROCm builds -class develop(_develop): - def run(self): - super().run() - if ( - int(os.getenv("NVTE_FUSED_ATTN_CK", "1")) and - int(os.getenv("NVTE_FUSED_ATTN", "1")) - ): - # Ensure that the AITER ASM kernels are properly available at runtime - # by creating a symlink to them. This is only necessary for editable - # mode since our C++ code assumes the AITER ASM kernel paths relative - # to trasnformer_engine.so, which is different in editable installs. - project_dir = Path(__file__).parent - asm_src_dir = project_dir / 'transformer_engine' / 'aiter' - # Must be synced with - # TransformerEngine/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp - asm_target_dir = project_dir / 'aiter' - if asm_src_dir.is_dir() and not asm_target_dir.is_dir(): - asm_target_dir.symlink_to(asm_src_dir) - class TimedBdist(bdist_wheel): """Helper class to measure build time""" @@ -89,7 +68,7 @@ def setup_common_extension() -> CMakeExtension: cmake_flags.append(f"-DCK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT={os.getenv('NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT', 3)}") if os.getenv("NVTE_CK_FUSED_ATTN_PATH"): ck_path = Path(os.getenv("NVTE_CK_FUSED_ATTN_PATH")) - cmake_flags.append(f"-DCK_FUSED_ATTN_PATH={ck_path}") + cmake_flags.append(f"-DAITER_MHA_PATH={ck_path}") if int(os.getenv("NVTE_FUSED_ATTN_AOTRITON", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0: cmake_flags.append("-DUSE_FUSED_ATTN_AOTRITON=OFF") if int(os.getenv("NVTE_FUSED_ATTN_CK", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0: @@ -173,7 +152,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ) # Blackwell is not supported as of Triton 3.2.0, need custom internal build # install_reqs.append("triton") - test_reqs.extend(["numpy", "torchvision"]) + test_reqs.extend(["numpy", "torchvision", "transformers"]) if "jax" in frameworks: if rocm_build(): from build_tools.jax import jax_install_requires @@ -192,7 +171,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: with open("README.rst", encoding="utf-8") as f: long_description = f.read() - cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} # Settings for building top level empty package for dependency management. if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))): assert bool( @@ -200,6 +178,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ), "NVTE_RELEASE_BUILD env must be set for metapackage build." te_cuda_vers = "rocm" if rocm_build() else "cu12" ext_modules = [] + cmdclass = {} package_data = {} include_package_data = False setup_requires = [] @@ -211,8 +190,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: else: setup_requires, install_requires, test_requires = setup_requirements() ext_modules = [setup_common_extension()] - if rocm_build(): - cmdclass["develop"] = develop + cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} package_data = {"": ["VERSION.txt"]} include_package_data = True extras_require = {"test": test_requires} @@ -255,7 +233,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: long_description=long_description, long_description_content_type="text/x-rst", ext_modules=ext_modules, - cmdclass=cmdclass, + cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, python_requires=">=3.8, <3.13", classifiers=[ "Programming Language :: Python :: 3.8", diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index da8a37ba8..4ab5fd237 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -64,8 +64,7 @@ else() project(transformer_engine_tests LANGUAGES HIP CXX) # Ask hcc to generate device code during compilation so we can use # host linker to link. - set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -fno-gpu-rdc -Wno-defaulted-function-deleted") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${HIP_HCC_FLAGS}") + set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -fno-gpu-rdc -Wno-defaulted-function-deleted -Wno-unused-result") endif() add_subdirectory(../../3rdparty/googletest ${PROJECT_BINARY_DIR}/googletest) diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index 33a9b8629..6855c9487 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -76,12 +76,12 @@ void scale_block(const ProcessingMethod processing_method, continue; } amax = std::max(amax, std::abs(elt)); -#else +#else // #ifdef __HIP_PLATFORM_AMD__ if (std::isinf(elt) || std::isnan(elt)) { continue; } amax = fmaxf(amax, fabsf(elt)); -#endif +#endif // #ifdef __HIP_PLATFORM_AMD__ } } @@ -312,6 +312,23 @@ void performTest_x1(const ProcessingMethod processing_method, block_size_cols, scales_stride); + +#ifdef __HIP_PLATFORM_AMD__ + if (processing_method != ProcessingMethod::CAST_ONLY) { + std::vector> mismatch_idx; + compare_e8m0_scaling_factors("scales", output_c, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, 0.01, rowwise, mismatch_idx); + + if (mismatch_idx.size()) { + adjust_ref(mismatch_idx, ref_output_c.get(), unpadded_blocks_Y, unpadded_blocks_X, rows, cols, otype); + } + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol); + } + else +#endif // #ifdef __HIP_PLATFORM_AMD__ + { auto [atol, rtol] = getTolerances(otype); compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol); @@ -321,6 +338,7 @@ void performTest_x1(const ProcessingMethod processing_method, compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + } if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { auto [atol_dbias, rtol_dbias] = getTolerances(itype); @@ -454,7 +472,29 @@ void performTest_x2(const ProcessingMethod processing_method, block_size_cols, scales_stride_rowwise, scales_stride_colwise); +#ifdef __HIP_PLATFORM_AMD__ + if (processing_method != ProcessingMethod::CAST_ONLY) { + std::vector> mismatch_idx_r; + compare_e8m0_scaling_factors("scales_rowwise", output, ref_scales_rowwise.get(), + unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise, 0.01, true, mismatch_idx_r); + + if (mismatch_idx_r.size()) { + adjust_ref(mismatch_idx_r, ref_output_c_rowwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, rows, cols, otype); + } + std::vector> mismatch_idx_c; + compare_e8m0_scaling_factors("scales_colwise", output, ref_scales_colwise.get(), + unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, scales_stride_colwise, 0.01, false, mismatch_idx_c); + + if (mismatch_idx_c.size()) { + adjust_ref(mismatch_idx_c, ref_output_c_colwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, rows, cols, otype); + } + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol); + compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol); + } else +#endif // #ifdef __HIP_PLATFORM_AMD__ + { auto [atol, rtol] = getTolerances(otype); compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol); compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol); @@ -464,6 +504,7 @@ void performTest_x2(const ProcessingMethod processing_method, compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), ref_scales_colwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, scales_stride_colwise); + } if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { auto [atol_dbias, rtol_dbias] = getTolerances(itype); @@ -563,7 +604,7 @@ TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { if (getDeviceComputeCapability() < blackwellComputeCapability) { GTEST_SKIP(); } -#endif +#endif // #ifdef __HIP_PLATFORM_AMD__ using namespace transformer_engine; using namespace test; diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index f93c8c9e0..96663e752 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -262,9 +262,24 @@ void performTest_x1(const size_t rows, block_size_rows, block_size_cols, scales_stride); +#ifdef __HIP_PLATFORM_AMD__ + std::vector> mismatch_idx; + if (rowwise) { + compare_e8m0_scaling_factors("rowwise scales", output, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, 0.01, true, mismatch_idx); + } else { + compare_e8m0_scaling_factors("colwise scales", output, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, 0.01, false, mismatch_idx); + } + if (mismatch_idx.size()) { + adjust_ref(mismatch_idx, ref_output.get(), unpadded_blocks_Y, unpadded_blocks_X, rows, cols, otype); + } auto [atol, rtol] = getTolerances(otype); compareResults("output", output, ref_output.get(), rowwise, atol, rtol); +#else // #ifdef __HIP_PLATFORM_AMD__ + auto [atol, rtol] = getTolerances(otype); + compareResults("output", output, ref_output.get(), rowwise, atol, rtol); const uint8_t * const gpu_scales_ptr = rowwise ? output.rowwise_cpu_scale_inv_ptr() @@ -276,6 +291,7 @@ void performTest_x1(const size_t rows, compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), unpadded_blocks_Y, unpadded_blocks_X, scales_stride); } +#endif // #ifdef __HIP_PLATFORM_AMD__ } /** @@ -361,17 +377,41 @@ void performTest_x2(const size_t rows, block_size_cols, scales_stride_rowwise, scales_stride_colwise); +#ifdef __HIP_PLATFORM_AMD__ + std::vector> mismatch_idx_r; + compare_e8m0_scaling_factors("scales_rowwise", output, + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise, 0.01, true, mismatch_idx_r); + + if (mismatch_idx_r.size()) { + adjust_ref(mismatch_idx_r, ref_output_colwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, rows, cols, otype); + } + + std::vector> mismatch_idx_c; + compare_e8m0_scaling_factors("scales_colwise", output, + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise, 0.01, false, mismatch_idx_c); + + if (mismatch_idx_c.size()) { + adjust_ref(mismatch_idx_c, ref_output_rowwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, rows, cols, otype); + } auto [atol, rtol] = getTolerances(otype); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol); +#else // #ifdef __HIP_PLATFORM_AMD__ + auto [atol, rtol] = getTolerances(otype); + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); + compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol); compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise); compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), ref_scales_colwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, scales_stride_colwise); +#endif // #ifdef __HIP_PLATFORM_AMD__ } std::vector> matrix_sizes = { @@ -382,7 +422,7 @@ std::vector> matrix_sizes = { {256, 256}, {993, 512}, {768, 1024}, - {65536, 128}, + {65504, 128}, {16384, 1632}, }; @@ -418,12 +458,12 @@ class CastMXFP8_GatedActTestSuite : public ::testing::TestWithParam TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) { #ifdef __HIP_PLATFORM_AMD__ omp_set_num_threads(std::min(128, omp_get_max_threads())); // Using threads = # of vcpus causes occasional errors. -#else +#else // #ifdef __HIP_PLATFORM_AMD__ // Skip tests for pre-Blackwell architectures if (getDeviceComputeCapability() < blackwellComputeCapability) { GTEST_SKIP(); } -#endif +#endif // #ifdef __HIP_PLATFORM_AMD__ using namespace transformer_engine; diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 7d0597ef7..1ef3f7ee3 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -3,17 +3,15 @@ * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ +#include +#include +#include +#include +#include +#include +#include #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include "../test_common.h" using namespace transformer_engine; @@ -30,29 +28,17 @@ std::vector> test_case_sizes = { {29, 29, 17389}, //primes }; +std::vector> test_case_sizes_mxfp8 = { + {2304, 768, 4096}, +}; + // A, B, Bias, Gelu, D // Bias type choose as bf16 in use_fp8, D_type otherwise // Gelu type the same as Bias_Type -// {DType::kFloat32, DType::kFloat32, DType::kFloat32, DType::kFloat32, DType::kFloat32}, -// {DType::kFloat16, DType::kFloat16, DType::kFloat16, DType::kFloat16, DType::kFloat16}, -// {DType::kBFloat16, DType::kBFloat16, DType::kBFloat16, DType::kBFloat16, DType::kBFloat16}, -// {DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat32}, -// {DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat16}, -// {DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kBFloat16}, -// {DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat8E4M3}, -// {DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat8E5M2}, -// {DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kBFloat16, DType::kBFloat16, DType::kFloat32}, -// {DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kBFloat16, DType::kBFloat16, DType::kFloat16}, -// {DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kBFloat16, DType::kBFloat16, DType::kBFloat16}, -// {DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kBFloat16, DType::kBFloat16, DType::kFloat8E4M3}, -// {DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kBFloat16, DType::kBFloat16, DType::kFloat8E5M2}, -// {DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat32}, -// {DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat16}, -// {DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kBFloat16}, -// {DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat8E4M3}, -// {DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat8E5M2}, -} // namespace +using fp32=float; +using fp8=fp8e4m3; +using bf8=fp8e5m2; using Layout = std::pair;// {transa, transb} static const Layout kNN{false,false}; @@ -61,10 +47,9 @@ static const Layout kNT{false,true }; static const std::vector kLayouts = { kNN, kTN, kNT }; -// , -class GEMMTestSuite - : public ::testing::TestWithParam< - std::tuple, bool, bool, Layout, NVTEScalingMode>> {}; +using TShape = std::vector; +} // namespace + float ref_gelu(float x){ float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); @@ -81,12 +66,14 @@ void compute_ref( const float d_scale, size_t m, size_t k, size_t n, D_Type* ref_d_data, - float* ref_d_amax, + float* ref_d_amax_ptr, Gelu_Type* ref_gelu_data, bool transa, bool transb){ - *ref_d_amax = 0; + float ref_d_amax = 0; + + #pragma omp parallel for schedule(static) collapse(2) reduction(max: ref_d_amax) proc_bind(spread) for(size_t ii = 0; ii < m; ii++){ for(size_t jj = 0; jj < n; jj++){ float val = 0; @@ -106,41 +93,45 @@ void compute_ref( // update ref_d_amax if in fp8 DType dtype = TypeInfo::dtype; if(isFp8Type(dtype)){ - *ref_d_amax = std::max(*ref_d_amax, std::fabs(val)); + ref_d_amax = std::max(ref_d_amax, std::fabs(val)); } } } + if (ref_d_amax_ptr) + { + *ref_d_amax_ptr = ref_d_amax; + } } template void compute_mxfp8_ref( const A_Type* a_data, const B_Type* b_data, - const NVTEShape& a_scale_inv_shape, const fp8e8m0* a_scale_inv_data, - const NVTEShape& b_scale_inv_shape, const fp8e8m0* b_scale_inv_data, const Bias_Type* bias_data, //bias is of dim m const float d_scale, size_t m, size_t k, size_t n, D_Type* ref_d_data, - float* ref_d_amax, + float* ref_d_amax_ptr, Gelu_Type* ref_gelu_data, bool transa, bool transb){ - *ref_d_amax = 0; + float ref_d_amax = 0; + + #pragma omp parallel for schedule(static) collapse(2) reduction(max: ref_d_amax) proc_bind(spread) for(size_t ii = 0; ii < m; ii++){ for(size_t jj = 0; jj < n; jj++){ float val = 0; for(size_t kk = 0; kk < k; kk++){ - float a_val = a_data[ii*k + kk]; - float b_val = b_data[kk + jj*k]; - float a_scale_inv_val = - (float)std::pow(2, a_scale_inv_data[ii * a_scale_inv_shape.data[1] + kk / 32] - 127); - float b_scale_inv_val = - (float)std::pow(2, b_scale_inv_data[kk / 32 + jj * b_scale_inv_shape.data[1]] - 127); - val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; + size_t a_idx = transa ? (ii*k + kk) : (kk*m + ii); + size_t b_idx = transb ? (kk*n + jj) : (jj*k + kk); + float a_scale_inv_val = (float)std::pow(2, + a_scale_inv_data[transa ? a_idx/32 : (kk/32 * m + ii)] - 127); + float b_scale_inv_val = (float)std::pow(2, + b_scale_inv_data[transb ? (kk/32 * n + jj) : b_idx/32] - 127); + val += a_scale_inv_val * (float)a_data[a_idx] * b_scale_inv_val * (float)b_data[b_idx]; } if(bias_data){ val += (float)bias_data[ii]; @@ -153,10 +144,14 @@ void compute_mxfp8_ref( // update ref_d_amax if in fp8 DType dtype = TypeInfo::dtype; if(isFp8Type(dtype)){ - *ref_d_amax = std::max(*ref_d_amax, std::fabs(val)); + ref_d_amax = std::max(ref_d_amax, std::fabs(val)); } } } + if (ref_d_amax_ptr) + { + *ref_d_amax_ptr = ref_d_amax; + } } template @@ -172,6 +167,36 @@ void cpu_rowwise_to_columnwise( } } +std::pair getTestTolerances(const DType type, bool use_fp8, bool use_mxfp8) { + auto [atol, rtol] = getTolerances(type); + + //relax for certain prime number gemm + if (type == DType::kFloat32) { + atol = 1e-5; + } + // relax for certain FP8 gemm with hipblaslt + if (use_mxfp8) { + atol = 5e-4; + /*During hipifying std::max is converted to ::max + to w/a HIP bug with using std:: in device functions. + W/o explicitlit , compiler uses non-templated int method variant from HIP headers + TODO: remove when switch to new hipify version after fixing HIP bug */ + rtol = std::max(rtol, 1e-3); + } + else if (use_fp8) { + atol = 1e-3; + //TODO: remove (see comment above) + rtol = std::max(rtol, 1e-2); + } + else if (type == DType::kBFloat16) { + //relax for certain prime number TN gemm + rtol = 5e-2; + } + else if (type == DType::kFloat32) { + rtol = 1e-5; + } + return {atol, rtol}; +} struct TestParams { size_t m; @@ -258,8 +283,13 @@ void performTest(const TestParams& params) { if (params.use_gelu && dtype == DType::kBFloat16) { GTEST_SKIP() << "BF16 GEMM with GELU is not supported in current config"; } - if (has_fp8 && params.use_bias && dtype == DType::kFloat32) { - GTEST_SKIP() << "FP8 GEMM with bias and FP32 output is not supported in current config"; + if constexpr ((std::is_same::value || std::is_same::value) && + std::is_same::value) + { + //GEMM with bias and fp32 output is not supported with bf8 A/B + if (params.use_bias) { + GTEST_SKIP() << "FP8 GEMM with bias is not supported in current config"; + } } } if (prop.major == 9 && prop.minor == 4) //gfx942 specific hipblasLt limitations @@ -273,49 +303,39 @@ void performTest(const TestParams& params) { } #endif - // pytorch tensor storage is row-major while cublas/hipblaslt is column-major - Tensor A; - if (params.transa){ - A = Tensor("A", std::vector{ params.m, params.k }, atype, true, false, params.scaling_mode); - }else { - // hipblaslt path need fp8-gemm with TN layout - A = Tensor("A", std::vector{ params.k, params.m }, atype, true, isFp8Type(atype), params.scaling_mode); - } - Tensor B; - if (params.transb){ - //hipblaslt path need fp8-gemm with TN layout - B = Tensor("B", std::vector{ params.k, params.n }, btype, true, isFp8Type(btype), params.scaling_mode); - }else { - B = Tensor("B", std::vector{ params.n, params.k }, btype, true, false, params.scaling_mode); - } - Tensor D("D", std::vector{ params.n, params.m }, dtype); + // FP8 GEMM path needs columnwise data for A/B tensor with non TN layout + const bool a_colwise = !params.transa && isFp8Type(atype); + const bool b_colwise = params.transb && isFp8Type(btype); + Tensor A("A", params.transa ? TShape{ params.m, params.k } : TShape{ params.k, params.m }, + atype, (!a_colwise || !use_mxfp8), a_colwise, params.scaling_mode); + Tensor B("B", params.transb ? TShape{ params.k, params.n } : TShape{ params.n, params.k }, + btype, (!b_colwise || !use_mxfp8), b_colwise, params.scaling_mode); + + Tensor D("D", TShape{ params.n, params.m }, dtype); Tensor bias; if(params.use_bias){ - bias = Tensor("bias", std::vector{params.m}, bias_type); + bias = Tensor("bias", TShape{params.m}, bias_type); } Tensor pre_gelu_out; if(params.use_gelu){ - pre_gelu_out = Tensor("pre_gelu_out", std::vector{ params.n, params.m }, gelu_type); + pre_gelu_out = Tensor("pre_gelu_out", TShape{ params.n, params.m }, gelu_type); } //initialize the data and scale inv of A, B + //fillUniform does not initialize columnwise data if rowwise data exist fillUniform(&A); - if (isFp8Type(atype) && !params.transa && !use_mxfp8) { + if (a_colwise && !use_mxfp8) { // A must be of shape k, m - cpu_rowwise_to_columnwise( - params.k, params.m, - A.rowwise_cpu_dptr(), - A.columnwise_cpu_dptr()); + cpu_rowwise_to_columnwise(params.k, params.m, + A.rowwise_cpu_dptr(), A.columnwise_cpu_dptr()); // sync the columnwise data on GPU as well A.from_cpu(); } fillUniform(&B); - if (isFp8Type(btype) && params.transb && !use_mxfp8) { - // B must be of shape k, m - cpu_rowwise_to_columnwise( - params.k, params.n, - B.rowwise_cpu_dptr(), - B.columnwise_cpu_dptr()); + if (b_colwise && !use_mxfp8) { + // B must be of shape k, n + cpu_rowwise_to_columnwise(params.k, params.n, + B.rowwise_cpu_dptr(), B.columnwise_cpu_dptr()); // sync the columnwise data on GPU as well B.from_cpu(); } @@ -335,7 +355,7 @@ void performTest(const TestParams& params) { workspace_size = 67108864; } #endif - Tensor Workspace("Workspace", std::vector{ workspace_size }, DType::kByte); + Tensor Workspace("Workspace", TShape{ workspace_size }, DType::kByte); //perform the gemm in GPU nvte_cublas_gemm(A.data(), @@ -370,28 +390,23 @@ void performTest(const TestParams& params) { const A_Type *a_data; const B_Type *b_data; const fp8e8m0 *a_scale_inv_data, *b_scale_inv_data; - NVTEShape a_scale_inv_shape, b_scale_inv_shape; if (params.transa) { a_data = A.rowwise_cpu_dptr(); a_scale_inv_data = A.rowwise_cpu_scale_inv_ptr(); - a_scale_inv_shape = A.rowwise_scale_inv_shape(); } else { a_data = A.columnwise_cpu_dptr(); a_scale_inv_data = A.columnwise_cpu_scale_inv_ptr(); - a_scale_inv_shape = A.columnwise_scale_inv_shape(); } if (params.transb) { b_data = B.columnwise_cpu_dptr(); b_scale_inv_data = B.columnwise_cpu_scale_inv_ptr(); - b_scale_inv_shape = B.columnwise_scale_inv_shape(); } else { b_data = B.rowwise_cpu_dptr(); b_scale_inv_data = B.rowwise_cpu_scale_inv_ptr(); - b_scale_inv_shape = B.rowwise_scale_inv_shape(); } compute_mxfp8_ref( - a_data, b_data, a_scale_inv_shape, a_scale_inv_data, b_scale_inv_shape, b_scale_inv_data, + a_data, b_data, a_scale_inv_data, b_scale_inv_data, params.use_bias ? bias.rowwise_cpu_dptr() : nullptr, D.scale(), params.m, params.k, params.n, ref_D.get(), &ref_amax_d, params.use_gelu ? ref_pre_gelu_out.get() : nullptr, @@ -416,49 +431,91 @@ void performTest(const TestParams& params) { compareResults("D_amax", D.amax(), ref_amax_d, atol_amax, rtol_amax); } - auto [atol, rtol] = getTolerances(dtype); - //relax for certain prime number gemm - if (dtype == DType::kFloat32) { - atol = 1e-5; - } -#ifdef __HIP_PLATFORM_AMD__ - // relax for certain FP8 gemm with hipblaslt - if (use_mxfp8) { - atol = 5e-4; - /*During hipifying std::max is converted to ::max - to w/a HIP bug with using std:: in device functions. - W/o explicitlit , compiler uses non-templated int method variant from HIP headers - TODO: remove when switch to new hipify version after fixing HIP bug */ - rtol = std::max(rtol, 1e-3); - } - else if (has_fp8) { - atol = 1e-3; - //TODO: remove (see comment above) - rtol = std::max(rtol, 5e-3); - } - else if (dtype == DType::kBFloat16) { - //relax for certain prime number TN gemm - rtol = 5e-2; - } - else if (dtype == DType::kFloat32) { - rtol = 1e-5; - } -#endif + auto [atol, rtol] = getTestTolerances(dtype, has_fp8, use_mxfp8); compareResults("D", D, ref_D.get(), true, atol, rtol); if(params.use_gelu){ - auto [atol, rtol] = getTolerances(gelu_type); - //relax for certain prime number gemm - if (dtype == DType::kFloat32) { - atol = 1e-5; - } + auto [atol, rtol] = getTestTolerances(gelu_type, false, false); compareResults("gelu", pre_gelu_out, ref_pre_gelu_out.get(), true, atol, rtol); } } -using fp32=float; -using fp8=fp8e4m3; -using bf8=fp8e5m2; +#ifdef __HIP_PLATFORM_AMD__ +template +void performDqTest(const TestParams ¶ms) { + DType atype = TypeInfo::dtype; + DType btype = TypeInfo::dtype; + DType dtype = TypeInfo::dtype; + + GTEST_ASSERT_TRUE(isFp8Type(atype) && isFp8Type(btype)) << "FP8/BF8 input datatype is expected"; + GTEST_ASSERT_FALSE(isFp8Type(dtype)) << "Non FP8/BF8 output datatype is expected"; + + if (params.m % 32 != 0 || params.n % 32 != 0 || params.k % 32 != 0) { + GTEST_SKIP() << "MXFP8 requires M, N, K to be multiples of 32"; + } + + cudaDeviceProp prop; + (void)cudaGetDeviceProperties(&prop, 0); + + bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5); + if (!mxfp8_supported) { + GTEST_SKIP() << "MXFP8 is not supported in current config"; + } + + DType ref_type = dtype; + TShape a_shape = params.transa ? TShape{params.m, params.k} : TShape{params.k, params.m}; + TShape b_shape = params.transb ? TShape{params.k, params.n} : TShape{params.n, params.k}; + + Tensor A_src("A", a_shape, ref_type); + Tensor B_src("B", b_shape, ref_type); + //initialize A, B + fillUniform(&A_src); + fillUniform(&B_src); + + // FP8 GEMM path needs columnwise data for A/B tensor with non TN layout + Tensor A_fp8("A_fp8", a_shape, atype, params.transa, !params.transa, + NVTEScalingMode::NVTE_MXFP8_1D_SCALING); + Tensor B_fp8("B_fp8", b_shape, btype, !params.transb, params.transb, + NVTEScalingMode::NVTE_MXFP8_1D_SCALING); + nvte_quantize(A_src.data(), A_fp8.data(), 0); + nvte_quantize(B_src.data(), B_fp8.data(), 0); + + Tensor A_ref("A_ref", a_shape, ref_type); + Tensor B_ref("B_ref", b_shape, ref_type); + nvte_dequantize(A_fp8.data(), A_ref.data(), 0); + nvte_dequantize(B_fp8.data(), B_ref.data(), 0); + + Tensor bias; + Tensor pre_gelu_out; + + size_t workspace_size = 67108864; + Tensor Workspace("Workspace", TShape{workspace_size}, DType::kByte); + + //perform FP8 gemm and copy the output results from GPU memory to CPU memory + Tensor D("D", TShape{params.n, params.m}, dtype); + nvte_cublas_gemm(A_fp8.data(), B_fp8.data(), D.data(), bias.data(), pre_gelu_out.data(), + params.transa, params.transb, false, Workspace.data(), false, false, + prop.multiProcessorCount, 0); + D.to_cpu(); + + + //perform non-FP8 gemm and copy the output results from GPU memory to CPU memory + Tensor D_ref("D", TShape{params.n, params.m}, dtype); + nvte_cublas_gemm(A_ref.data(), B_ref.data(), D_ref.data(), bias.data(), pre_gelu_out.data(), + params.transa, params.transb, false, Workspace.data(), false, false, + prop.multiProcessorCount, 0); + D_ref.to_cpu(); + + // check if error message happens in running + (void)cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + //compare results + auto [atol, rtol] = getTestTolerances(dtype, true, true); + compareResults("D", D, D_ref.rowwise_cpu_dptr(), true, atol, rtol); +} +#endif // __HIP_PLATFORM_AMD__ #define MAKE_TEST_PARAMS(P_) \ TestParams P_ = {.m = std::get<0>(std::get<0>(GetParam())), \ @@ -472,10 +529,13 @@ using bf8=fp8e5m2; ? NVTEScalingMode::NVTE_MXFP8_1D_SCALING \ : NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING} +// , use_bias, use_gelu, Layout, fp8_scalinig +class GEMMTestSuite + : public ::testing::TestWithParam< + std::tuple, bool, bool, Layout, NVTEScalingMode>> {}; + #define MAKE_GEMM_TEST(NAME_, A_, B_, BIAS_, GELU_, D_) \ TEST_P(GEMMTestSuite, NAME_) { \ - using namespace transformer_engine; \ - using namespace test; \ MAKE_TEST_PARAMS(test_params); \ using A_Type = A_; \ using B_Type = B_; \ @@ -523,24 +583,51 @@ MAKE_GEMM_TEST(Testbf8xfp8xbf16xbf16xbf8, bf8, fp8, bf16, bf16, bf8); MAKE_GEMM_TEST(Testfp8xfp8xfp16xfp16xfp8, fp8, fp8, fp16, fp16, fp8); -INSTANTIATE_TEST_SUITE_P( - OperatorTest, - GEMMTestSuite, - ::testing::Combine( - ::testing::ValuesIn(test_case_sizes), - ::testing::Values(false, true), //use bias - ::testing::Values(false, true), //use_gelu - ::testing::ValuesIn(kLayouts), //transa,transb - ::testing::Values(false, true)), //use mxfp8 - [](const testing::TestParamInfo& info) { - auto TN = [](bool v){ return v ? "T" : "N"; }; - const auto layout = std::get<3>(info.param); - std::string name = std::to_string(std::get<0>(std::get<0>(info.param))) + "X" + - std::to_string(std::get<1>(std::get<0>(info.param))) + "X" + - std::to_string(std::get<2>(std::get<0>(info.param))) + "X" + - std::to_string(std::get<1>(info.param)) + "X" + - std::to_string(std::get<2>(info.param)) + "X" + - TN(layout.first) + TN(layout.second) + "X" + - (std::get<4>(info.param) ? "M" : "S"); - return name; - }); +static inline auto TN(const Layout& layout) { + static const char* map[2][2] = {{"NN", "NT"}, {"TN", "TT"}}; + return std::string(map[layout.first][layout.second]); +} + +static inline auto MKN(const std::tuple& shape) { + return std::to_string(std::get<0>(shape)) + "x" + std::to_string(std::get<1>(shape)) + "x" + + std::to_string(std::get<2>(shape)); +} + +INSTANTIATE_TEST_SUITE_P(OperatorTest, GEMMTestSuite, + ::testing::Combine(::testing::ValuesIn(test_case_sizes), + ::testing::Values(false, true), //use bias + ::testing::Values(false, true), //use_gelu + ::testing::ValuesIn(kLayouts), //transa,transb + ::testing::Values(false, true)), //use mxfp8 + [](const testing::TestParamInfo& info) { + return MKN(std::get<0>(info.param)) + "x" + + std::to_string(std::get<1>(info.param)) + "x" + + std::to_string(std::get<2>(info.param)) + "x" + + TN(std::get<3>(info.param)) + "x" + + (std::get<4>(info.param) ? "M" : "S"); + }); + +#ifdef __HIP_PLATFORM_AMD__ +class DqGEMMTestSuite: public GEMMTestSuite {}; + +#define MAKE_DQ_GEMM_TEST(NAME_, A_, B_, D_) \ + TEST_P(DqGEMMTestSuite, NAME_) { \ + MAKE_TEST_PARAMS(test_params); \ + using A_Type = A_; \ + using B_Type = B_; \ + using D_Type = D_; \ + performDqTest(test_params); \ + } + +MAKE_DQ_GEMM_TEST(Testfp8xfp8xfp16, fp8, fp8, fp16) + +INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite, + ::testing::Combine(::testing::ValuesIn(test_case_sizes_mxfp8), + ::testing::Values(false), // bias - unused + ::testing::Values(false), // gelu - unused + ::testing::ValuesIn(kLayouts), //transa,transb + ::testing::Values(true)), //use mxfp8 + [](const testing::TestParamInfo& info) { + return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param)); + }); +#endif // __HIP_PLATFORM_AMD__ diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 32eb1d63a..d3dd6e95f 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -454,6 +454,9 @@ void Tensor::set_scale_inv(float scale_inv) { columnwise_cpu_scale_inv_ptr()[0] = scale_inv; } else { std::uniform_int_distribution dis(0, 127); + if (rowwise_) { + from_cpu(); //Need it because scale_inv_ptr getting does to_cpu() + } auto *scale_inv_ptr = columnwise_cpu_scale_inv_ptr(); for (size_t i = 0; i < num_scales; i++) { scale_inv_ptr[i] = dis(gen_); @@ -711,6 +714,74 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, } } +#ifdef __HIP_PLATFORM_AMD__ +void compare_e8m0_scaling_factors(const std::string &name, Tensor &output, const uint8_t *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + double tol, bool rowwise, std::vector> &mismatch_idx) { + const uint8_t *const test = rowwise ? output.rowwise_cpu_scale_inv_ptr() + : output.columnwise_cpu_scale_inv_ptr(); + + const float scale_tol = std::max(1.f, row_blocks * col_blocks * tol); + + for (int i = 0; i < row_blocks; i++) { + for (int j = 0; j < col_blocks; j++) { + const int idx = i * stride + j; + if (test[idx] != ref[idx]) { + int t_scale = static_cast(test[idx]); + int r_scale = static_cast(ref[idx]); + if (std::abs(t_scale - r_scale) == 1) { + mismatch_idx.emplace_back(i, j, r_scale-t_scale); + } else { + GTEST_FAIL() << "Error in " << name << std::endl + << "Mismatch: " << t_scale << " vs " + << r_scale << " at index " << idx; + } + } + } + } + const size_t scale_mismatches = mismatch_idx.size(); + + ASSERT_FALSE(scale_mismatches > scale_tol) + << "Error in " << name << std::endl << std::setprecision(4) + << "Total scale mismatches: " << scale_mismatches << " (" << 100.*(double)scale_mismatches/(double)(row_blocks*col_blocks) + << "%) Exceeds tolerance of " << scale_tol << " (" << 100.*tol << "%) mismatches"; + + if (scale_mismatches) { + std::cout << "\x1b[33mWARNING:\x1b[0m " << scale_mismatches + << " scale mismatches were found. This does not imply an accuracy issue." << std::endl; + } +} + +void adjust_ref(std::vector> mismatch_idx, void *ref, const size_t row_blocks, + const size_t col_blocks, const size_t rows, const size_t cols, DType otype) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY( otype, T, + T *ref_data = reinterpret_cast(ref); + double scale_val; + const size_t col_blocks_size = cols / col_blocks; + const size_t row_blocks_size = rows / row_blocks; + for (const auto &[i, j, scale_diff] : mismatch_idx) { + if (scale_diff == 1) { + scale_val = 2.; + } else if (scale_diff == -1) { + scale_val = .5; + } else { // Shouldn't ever reach this + GTEST_FAIL() << "Error in adjust_ref, |scale_diff| > 1"; + } + size_t ii_min = i * row_blocks_size; + const size_t ii_max = std::min(ii_min + row_blocks_size, rows); + for (; ii_min < ii_max; ii_min++) { + size_t jj_min = j * col_blocks_size; + const size_t jj_max = std::min(jj_min + col_blocks_size, cols); + for (; jj_min < jj_max; jj_min++) { + const size_t data_idx = ii_min * cols + jj_min; + ref_data[data_idx] = static_cast(static_cast(ref_data[data_idx]) * scale_val); + } + } + } + ); // NOLINT(*) +} +#endif // #ifdef __HIP_PLATFORM_AMD__ + std::pair getTolerances(const DType type) { switch(type) { case DType::kFloat32: @@ -737,7 +808,21 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { for (int i = 0; i < size; i++) { data[i] = static_cast(dis(*gen)); } + gen->discard(size); #else + // Check how many RNG calls are required to generate one uniform random value + int rng_calls_per_val = 0; + { + std::mt19937 gen1 = *gen, gen2 = *gen; + std::uniform_real_distribution<> dis(-2.0, 1.0); + const float _ = dis(gen1); + while (gen2 != gen1) { + auto _ = gen2(); + ++rng_calls_per_val; + } + } + + // Generate uniform random values in parallel #pragma omp parallel proc_bind(spread) { std::mt19937 gen_local = *gen; @@ -746,15 +831,15 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { const int chunk_size = (size + threads_num - 1) / threads_num; const int idx_min = chunk_size * thread_ID; const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast(size)); - gen_local.discard(idx_min); + gen_local.discard(idx_min * rng_calls_per_val); std::uniform_real_distribution<> dis(-2.0, 1.0); for (int i = idx_min; i < idx_max; ++i) { data[i] = static_cast(dis(gen_local)); } } + gen->discard(size * rng_calls_per_val); #endif - gen->discard(size); } void fillUniform(Tensor *t) { diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 7ac2b75a6..6b9514d38 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -19,6 +19,7 @@ #else #include #include "amd_detail/hip_float8.h" +#include #endif #include @@ -461,6 +462,14 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const size_t row_blocks, const size_t col_blocks, const size_t stride); void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, const size_t N); +#ifdef USE_ROCM +void compare_e8m0_scaling_factors(const std::string &name, Tensor &output, const uint8_t *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + double tol, bool rowwise, std::vector> &mismatch_idx); + +void adjust_ref(std::vector> mismatch_idx, void *ref, const size_t row_blocks, + const size_t col_blocks, const size_t rows, const size_t cols, DType otype); +#endif std::array get_scale_tensor_dims(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols); diff --git a/tests/pytorch/debug/conftest.py b/tests/pytorch/debug/conftest.py new file mode 100644 index 000000000..20edc6aab --- /dev/null +++ b/tests/pytorch/debug/conftest.py @@ -0,0 +1,27 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--feature_dirs", nargs="+", action="store", default="", help="List of feature directories" + ) + parser.addoption( + "--configs_dir", + action="store", + default="", + type=str, + help="Path to the directory with configs.", + ) + + +@pytest.fixture +def feature_dirs(request): + return request.config.getoption("--feature_dirs") + + +@pytest.fixture +def configs_dir(request): + return request.config.getoption("--configs_dir") diff --git a/tests/pytorch/debug/run_distributed.py b/tests/pytorch/debug/run_distributed.py new file mode 100644 index 000000000..640fdf9c5 --- /dev/null +++ b/tests/pytorch/debug/run_distributed.py @@ -0,0 +1,647 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import tempfile +import functools +import os +import itertools +import random +import argparse +import re + +import torch +import torch.distributed as dist +import transformer_engine +import transformer_engine_torch as tex +import nvdlfw_inspect.api as debug_api +from transformer_engine.debug import set_weight_tensor_tp_group_reduce + + +from test_numerics import ( + _emulate_linear, + _init_debug, + disable_fp8_gemms_create_config, + DISABLE_FP8_LAYER_CONFIG, + _cmp, + IN_SIZE, + OUT_SIZE, + _init_model, + SEED, + SEQ_LEN, + BATCH_SIZE, + FP8_RECIPE, + fake_quant_fp8_create_config, + _get_current_scale, + _prepare_per_tensor_scaling_config, + AMAX_HISTORY_LEN, + set_scaling_factors, + set_current_scaling_factors, +) + +WORLD_RANK, WORLD_SIZE = None, None +NCCL_WORLD = None +FEATURE_DIRS = None +all_boolean = [True, False] +TEST_NR = 0 + + +def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None): + if tp_size is None: + tp_size = WORLD_SIZE + tp_rank = WORLD_RANK + torch.manual_seed(weight_seed) + weight = torch.randn((OUT_SIZE, IN_SIZE)).cuda() + torch.manual_seed(data_seed) + in_split_size = IN_SIZE // tp_size + out_split_size = OUT_SIZE // tp_size + x = torch.randn((SEQ_LEN * BATCH_SIZE, IN_SIZE), requires_grad=True).cuda() + if parallel_mode == "row": + x = x[:, tp_rank * in_split_size : (tp_rank + 1) * in_split_size] + x.retain_grad() + + with torch.no_grad(): + if parallel_mode == "column": + weight = weight[tp_rank * out_split_size : (tp_rank + 1) * out_split_size, :] + else: + weight = weight[:, tp_rank * in_split_size : (tp_rank + 1) * in_split_size] + + return x, weight.contiguous() + + +def _init_model(weight, parallel_mode=None, tp_group=None, name="linear"): + model = transformer_engine.pytorch.Linear( + IN_SIZE, + OUT_SIZE, + name=name, + parallel_mode=parallel_mode, + tp_group=(tp_group or NCCL_WORLD if parallel_mode else None), + ) + with torch.no_grad(): + model.weight.copy_(weight) + return model + + +class AllGather(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor, dim, group=None): + if group is None: + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + world_size = torch.distributed.get_world_size(group=group) + rank = torch.distributed.get_rank(group=group) + dist.barrier() + + # Create a list to gather tensors from all processes + y_list = [torch.zeros_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(y_list, tensor, group=group) + + # Save the world size and rank for backward computation + ctx.world_size = world_size + ctx.rank = rank + ctx.dim = dim + + # Concatenate the gathered tensors along the feature dimension + y_full = torch.cat(y_list, dim=dim) + + return y_full + + @staticmethod + def backward(ctx, grad_output): + # Split the gradient output and return the portion corresponding to this rank + grad_input = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)[ctx.rank] + return grad_input, None, None + + +def _run_forward_backward(x, model, parallel_mode=None, group=None): + with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + y = model(x) + + y.requires_grad_(True) + y.retain_grad() + if parallel_mode == "column": + y = AllGather.apply(y, -1, group) + y.requires_grad_(True) + y.retain_grad() + l = y.sum() + l.backward() + elif parallel_mode == "row": + l = y.sum() + l.backward() + debug_api.step() + return y + + +def _emulate_linear_distributed(*args, parallel_mode=None, **kwargs): + assert parallel_mode in ["column", "row"] + + def split(gradient): + split_size = OUT_SIZE // WORLD_SIZE + gradient = gradient[:, WORLD_RANK * split_size : (WORLD_RANK + 1) * split_size] + return gradient + + activation_sync = None + gradient_sync = None + if parallel_mode == "column": + activation_sync = lambda x: AllGather.apply(x, -1) + gradient_sync = split + else: + activation_sync = ( + lambda activation: dist.all_reduce(activation, op=dist.ReduceOp.SUM) or activation + ) + + output = _emulate_linear( + *args, activation_sync=activation_sync, gradient_sync=gradient_sync, **kwargs + ) + + if parallel_mode == "column": + dist.all_reduce(output["dgrad"], op=dist.ReduceOp.SUM) + + return output + + +def check_debug_log(msg): + with open(f"log/debug_logs/debug_log_globalrank-{WORLD_RANK}.log", "r") as f: + for line in f.readlines(): + if msg in line: + return True + return False + + +def run_debug_test(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank = dist.get_rank() + temp_file_name = None + temp_logdir_name = None + + if rank == 0: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file: + temp_file_name = temp_file.name + temp_dir_obj = tempfile.TemporaryDirectory() + temp_logdir_name = temp_dir_obj.name + + # Store the TemporaryDirectory object to prevent it from being deleted + wrapper.temp_dir_obj = temp_dir_obj + + temp_file_name_list = [temp_file_name] + temp_logdir_name_list = [temp_logdir_name] + + # Broadcast the temporary file and directory names to all processes + dist.broadcast_object_list(temp_file_name_list, src=0) + dist.broadcast_object_list(temp_logdir_name_list, src=0) + + temp_file_name = temp_file_name_list[0] + temp_logdir_name = temp_logdir_name_list[0] + + dist.barrier() + + config_file = open(temp_file_name, mode="r+", buffering=1) + + try: + kwargs["config_file"] = config_file + kwargs["log_dir"] = temp_logdir_name + + if rank == 0: + global TEST_NR + print(f"Running test {TEST_NR} {func.__name__} with args = {args}.") + TEST_NR += 1 + + func(*args, **kwargs) + finally: + if rank == 0 and temp_file_name is not None: + os.unlink(temp_file_name) + + debug_api.end_debug() + + if rank == 0 and hasattr(wrapper, "temp_dir_obj"): + wrapper.temp_dir_obj.cleanup() + + return wrapper + + +CONFIG_LOG_TEST_DISTRIBUTED = """log_distributed: + layers: + layer_types: [linear] + enabled: + True + transformer_engine: + LogTensorStats: + enabled: True + tensors: [activation, gradient, weight, output, wgrad, dgrad] + stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range] + start_step : 0 + end_step: 1 + LogFp8TensorStats: + enabled: True + tensors: [activation, gradient, weight] + stats: [underflows%] + start_step : 0 + end_step: 1 +""" + + +def _prepare_config_test_log_distributed(config_file): + if WORLD_RANK != 0: + return + config_file.write(CONFIG_LOG_TEST_DISTRIBUTED) + config_file.flush() + + +def _compute_dynamic_range(tensor): + tensor_abs = tensor.abs() + tensor_abs = tensor_abs[tensor_abs != 0] + if tensor_abs.any(): + amin = tensor_abs.min().float() + else: + amin = torch.tensor(1, device=tensor.device).to(torch.float) + amax = tensor_abs.max().float() + if not amax.all(): + amax = torch.tensor(1, device=tensor.device).to(torch.float) + dynamic_range = torch.log2(amax) - torch.log2(amin) + return dynamic_range + + +@run_debug_test +def test_log_distributed(parallel_mode, gather_weight, **kwargs): + _prepare_config_test_log_distributed(kwargs["config_file"]) + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS) + set_weight_tensor_tp_group_reduce(gather_weight) + if WORLD_SIZE % 2 != 0: + return # skip + TP_SIZE = WORLD_SIZE // 2 + DP_SIZE = 2 + TP_RANK = WORLD_RANK % TP_SIZE + DP_RANK = (WORLD_RANK - TP_RANK) // TP_SIZE + + debug_api.set_tensor_reduction_group(NCCL_WORLD) + + x, weight = _get_tensors( + parallel_mode, + weight_seed=TP_RANK * 1234, + data_seed=DP_RANK * 1234, + tp_size=TP_SIZE, + tp_rank=TP_RANK, + ) + + tp_group_ranks = [i for i in range(DP_RANK * TP_SIZE, (DP_RANK + 1) * TP_SIZE)] + tp_group = dist.new_group(ranks=tp_group_ranks) + + dp_group_ranks = [i for i in range(TP_RANK, WORLD_SIZE, TP_SIZE)] + dp_group = dist.new_group(ranks=dp_group_ranks) + + model = _init_model(weight, parallel_mode=parallel_mode, tp_group=tp_group) + output = _run_forward_backward(x, model, parallel_mode=parallel_mode, group=tp_group) + + gathered_activation = AllGather.apply(x.contiguous(), 0) + gathered_weight = AllGather.apply(weight.contiguous(), 0, tp_group) + gathered_gradient = AllGather.apply(output.grad.contiguous(), 0, dp_group) + if parallel_mode == "row": + gathered_gradient = AllGather.apply(gathered_gradient, 0, tp_group) + + log_file = kwargs["log_dir"] + "/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log" + + dist.barrier() + if WORLD_RANK != 0: + return # stats are gathered on node 0 + with open(log_file) as f: + content = f.read() + + def get_stat(tensor, stat): + regex = r".*_{tensor}_{stat}\s+.*iteration=(\d+)\s+.*value=([-+]?\d*\.?\d+)".format( + tensor=tensor, stat=stat + ) + for line in content.splitlines(): + match = re.search(regex, line) + if match: + value = float(match.group(2)) + return value + + rf = lambda x: round(float(x), 4) + stats = [] + tensors = { + "activation": gathered_activation, + "weight": gathered_weight if gather_weight else weight, + "gradient": gathered_gradient, + } + stats = { + "min": torch.min, + "max": torch.max, + "mean": torch.mean, + "std": torch.std, + "l1_norm": lambda x: torch.norm(x, p=1), + "l2_norm": lambda x: torch.norm(x, p=2), + "cur_amax": lambda x: x.abs().max(), + "dynamic_range": _compute_dynamic_range, + } + for stat_key in stats.keys(): + for tensor_key in tensors.keys(): + torch.testing.assert_close( + get_stat(tensor_key, stat_key), + rf(stats[stat_key](tensors[tensor_key])), + atol=0.0001, + rtol=0.0001, + ) + set_weight_tensor_tp_group_reduce(True) # reset + + +@run_debug_test +def test_log_expert_parallel(**kwargs): + """ + This test tests the scenario, when one of the node of data parallel does not invoke the debug layer. + It naturally occurs in the expert parallelism, when one expert doesn't get input on one node, + but gets it on other nodes. If there were all_gather inside forward(), this would result in deadlock. + """ + _prepare_config_test_log_distributed(kwargs["config_file"]) + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS) + debug_api.set_tensor_reduction_group(NCCL_WORLD) + x, weight = _get_tensors( + "row", weight_seed=WORLD_RANK * 1234, data_seed=WORLD_RANK * 1234, tp_size=1, tp_rank=0 + ) # data parallel + model = _init_model(weight, parallel_mode=None, name="linear1") + model1 = _init_model(weight, parallel_mode=None, name="linear2") + with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + y1 = model(x) + y2 = model1(x) + y = y1 + y2 + y.sum().backward() + debug_api.step() + with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + y = model(x) + if WORLD_RANK != 0: + y = y + model1(x) + + y.sum().backward() + + +@run_debug_test +def test_disable_fp8_gemms(fprop_fp8, dgrad_fp8, wgrad_fp8, parallel_mode, **kwargs): + disable_fp8_gemms_create_config(fprop_fp8, dgrad_fp8, wgrad_fp8, kwargs["config_file"]) + fp8_kwargs = { + "fprop_fp8": fprop_fp8, + "dgrad_fp8": dgrad_fp8, + "wgrad_fp8": wgrad_fp8, + } + + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS) + x, weight = _get_tensors(parallel_mode) + model = _init_model(weight, parallel_mode=parallel_mode) + y = _run_forward_backward(x, model, parallel_mode=parallel_mode) + + output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()} + + x.grad.zero_() + ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode, **fp8_kwargs) + _cmp(ground_truth, output) + + +@run_debug_test +def test_disable_fp8_layer(parallel_mode, **kwargs): + if WORLD_RANK == 0: + kwargs["config_file"].write(DISABLE_FP8_LAYER_CONFIG) + kwargs["config_file"].flush() + dist.barrier() + + x, weight = _get_tensors(parallel_mode) + + ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode) + x.grad.zero_() + + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS) + + model = _init_model(weight, parallel_mode) + y = _run_forward_backward(x, model, parallel_mode) + + output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()} + _cmp(ground_truth, output) + + +@run_debug_test +def test_per_tensor_scaling( + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + parallel_mode, + **kwargs, +): + input_kwargs = { + "fprop_inp": fprop_inp, + "fprop_weight": fprop_weight, + "dgrad_weight": dgrad_weight, + "dgrad_grad": dgrad_grad, + "wgrad_input": wgrad_input, + "wgrad_grad": wgrad_grad, + } + fp8_kwargs = { + "fprop_fp8": True, + "dgrad_fp8": True, + "wgrad_fp8": True, + } + """ + Runs a test to validate per-tensor (current) scaling in FP8 computations. + The function performs warm-up iterations to populate the amax buffer of the model and compute scaling factors based on delayed scaling. + Subsequently, weights and inputs are switched to ensure their current scaling factors differ from those based on delayed scaling; + similarly, the loss is multiplied by a large factor to alter the gradient's magnitude, + creating a discrepancy between the original (delayed) and per-tensor (current) scaling factors. + Finally, a linear pass is emulated, and the results are compared.” + """ + _prepare_per_tensor_scaling_config( + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + kwargs["config_file"], + ) + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS) + + warmup_input, warmup_weight = _get_tensors(parallel_mode=parallel_mode) + model = _init_model(warmup_weight, parallel_mode=parallel_mode) + + # Warmup run to setup amax and scaling factors. + for _ in range(AMAX_HISTORY_LEN): + _run_forward_backward(warmup_input, model, parallel_mode=parallel_mode) + + x, weight = _get_tensors( + parallel_mode=parallel_mode, weight_seed=WORLD_RANK * 2137, data_seed=WORLD_RANK * 2137 + ) + model.weight.data = weight.data + x.retain_grad() + + # delayed scaling factor + # need to be collected before forward pass with test data, + # because this forward pass changes scaling factors + set_scaling_factors(model, input_kwargs, fp8_kwargs) + + LOSS_MULTIPLIER = 100 + + with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + y = model(x) + model.zero_grad() + if parallel_mode == "column": + y = AllGather.apply(y, -1) + y.retain_grad() + + ( + LOSS_MULTIPLIER * y.sum() + ).backward() # Loss multiplication to change gradient's order of magintude + + output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()} + # per tensor - current - scaling factors + # need to be collected after forward pass with test data, + # because gradient(y.grad) cannot be accessed before forward, + # but it needs to be collected. + + set_current_scaling_factors(x, weight, y, input_kwargs, fp8_kwargs) + ground_truth = _emulate_linear_distributed( + x, weight, parallel_mode=parallel_mode, loss_multiplier=LOSS_MULTIPLIER, **fp8_kwargs + ) + + _cmp(ground_truth, output) + + +@run_debug_test +def test_fake_quant_fp8( + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + parallel_mode, + **kwargs, +): + + fp8_kwargs = { + "fprop_input_fake_quant": fprop_inp, + "fprop_weight_fake_quant": fprop_weight, + "dgrad_gradient_fake_quant": dgrad_grad, + "dgrad_weight_fake_quant": dgrad_weight, + "wgrad_gradient_fake_quant": wgrad_grad, + "wgrad_input_fake_quant": wgrad_input, + "fprop_fp8": not (fprop_inp or fprop_weight), + "dgrad_fp8": not (dgrad_weight or dgrad_grad), + "wgrad_fp8": not (wgrad_grad or wgrad_input), + } + if WORLD_RANK == 0: + fake_quant_fp8_create_config( + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + kwargs["config_file"], + ) + dist.barrier() + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS) + + x, weight = _get_tensors(parallel_mode) + model = _init_model(weight, parallel_mode) + y = _run_forward_backward(x, model, parallel_mode) + + output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()} + fp8_kwargs["fprop_input_scale"] = ( + _get_current_scale(x, fprop_inp) if not fp8_kwargs["fprop_fp8"] else None + ) + fp8_kwargs["fprop_weight_scale"] = ( + _get_current_scale(weight, fprop_weight) if not fp8_kwargs["fprop_fp8"] else None + ) + fp8_kwargs["dgrad_gradient_scale"] = ( + _get_current_scale(y.grad, dgrad_grad) if not fp8_kwargs["dgrad_fp8"] else None + ) + fp8_kwargs["dgrad_weight_scale"] = ( + _get_current_scale(weight, dgrad_weight) if not fp8_kwargs["dgrad_fp8"] else None + ) + fp8_kwargs["wgrad_gradient_scale"] = ( + _get_current_scale(y.grad, wgrad_grad) if not fp8_kwargs["wgrad_fp8"] else None + ) + fp8_kwargs["wgrad_input_scale"] = ( + _get_current_scale(x, wgrad_input) if not fp8_kwargs["wgrad_fp8"] else None + ) + ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode, **fp8_kwargs) + _cmp(ground_truth, output) + + +def _init_distributed(): + global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, FP8 + + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + + assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node + assert LOCAL_SIZE <= torch.cuda.device_count() + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + } + dist_init_kwargs["init_method"] = "env://" + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + torch.cuda.set_device(LOCAL_RANK) + dist.init_process_group(**dist_init_kwargs) + + NCCL_WORLD = dist.new_group(backend="nccl") + + WORLD_SIZE = dist.get_world_size() + + +def _run_test_with_combinations( + test_function, values_list, num_repeat, extra_args, sample_size=None +): + combinations = itertools.product(values_list, repeat=num_repeat) + total_combinations = itertools.product(combinations, extra_args) + + if sample_size is not None: + total_combinations = random.sample(list(total_combinations), sample_size) + + for comb, arg in total_combinations: + test_function(*comb, arg) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--feature_dirs", type=str) + args = parser.parse_args() + FEATURE_DIRS = args.feature_dirs + random.seed(SEED) + _init_distributed() + + test_log_expert_parallel() + for parallel_mode in ["column", "row"]: + for gather_weight in [True, False]: + test_log_distributed(parallel_mode, gather_weight) + + for parallel_mode in ["row", "column"]: + test_disable_fp8_layer(parallel_mode) + + # test_disable_fp8_gemms + _run_test_with_combinations( + test_disable_fp8_gemms, all_boolean, num_repeat=3, extra_args=["column", "row"] + ) + + # test_fake_quant_fp8 + dtype_options = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None] + _run_test_with_combinations( + test_fake_quant_fp8, + dtype_options, + num_repeat=6, + extra_args=["column", "row"], + sample_size=20, + ) + + _run_test_with_combinations( + test_per_tensor_scaling, + all_boolean, + num_repeat=6, + extra_args=["column"], + sample_size=20, + ) diff --git a/tests/pytorch/debug/test_api_features.py b/tests/pytorch/debug/test_api_features.py new file mode 100644 index 000000000..f9cd234ba --- /dev/null +++ b/tests/pytorch/debug/test_api_features.py @@ -0,0 +1,398 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer + +import nvdlfw_inspect.api as debug_api + +try: + import transformer_engine + import transformer_engine_torch as tex +except (ImportError, ModuleNotFoundError): + print("Could not find TransformerEngine package.") + exit(1) + + +def test_transformer_engine_no_config(feature_dirs): + debug_api.initialize("", feature_dirs=feature_dirs) + try: + + tensor = torch.rand(24, 2046).cuda() + + # FP8 enabled - true by the default + assert debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="fprop", iteration=0 + ) + + # modify_tensor_enabled - False by default + assert not debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0 + ) + + # inspect_tensor_enabled - False by default + assert not debug_api.transformer_engine.inspect_tensor_enabled( + "decoder.1.attn.qkv", tensor_name="activation", iteration=0 + ) + + # inspect_tensor_postquantize - False by default + assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( + "decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0 + ) + + finally: + debug_api.end_debug() + + +def test_disable_fp8_gemm(configs_dir, feature_dirs): + try: + debug_api.initialize(configs_dir + "disable_fp8_gemms.yaml", feature_dirs=feature_dirs) + + assert debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="fprop", iteration=0 + ) + assert not debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="dgrad", iteration=0 + ) + assert not debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="wgrad", iteration=0 + ) + + # caching + assert debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="fprop", iteration=0 + ) + assert not debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="dgrad", iteration=0 + ) + assert not debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="wgrad", iteration=0 + ) + + finally: + debug_api.end_debug() + + +def test_disable_fp8_layer(configs_dir, feature_dirs): + try: + debug_api.initialize(configs_dir + "disable_fp8_layer.yaml", feature_dirs=feature_dirs) + + assert debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.mlp.fc1", gemm="fprop", iteration=0 + ) + assert debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.mlp.fc1", gemm="wgrad", iteration=0 + ) + assert debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.mlp.fc1", gemm="dgrad", iteration=0 + ) + assert not debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="fprop", iteration=0 + ) + assert not debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="wgrad", iteration=0 + ) + assert not debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.attn.qkv", gemm="dgrad", iteration=0 + ) + + finally: + debug_api.end_debug() + + +def test_per_tensor_scaling(configs_dir, feature_dirs): + try: + + debug_api.initialize(configs_dir + "per_tensor_scaling.yaml", feature_dirs=feature_dirs) + + tensor = torch.rand(24, 2046).cuda() + + # check modify_tensor_enabled + assert debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0 + ) + assert debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc1", gemm="fprop", tensor_name="weight", iteration=0 + ) + assert debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0 + ) + assert not debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="weight", iteration=0 + ) + assert not debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc1", gemm="wgrad", tensor_name="gradient", iteration=0 + ) + assert not debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc1", gemm="wgrad", tensor_name="activation", iteration=0 + ) + + # check modify_tensor + + default_quantizer1 = Float8Quantizer( + scale=torch.tensor([1]).cuda(), + amax=torch.tensor([0]).cuda(), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + default_quantizer2 = Float8Quantizer( + scale=torch.tensor([1]).cuda(), + amax=torch.tensor([0]).cuda(), + fp8_dtype=tex.DType.kFloat8E5M2, + ) + + output1 = debug_api.transformer_engine.modify_tensor( + layer_name="decoder.1.mlp.fc1", + gemm="fprop", + tensor_name="activation", + default_quantizer=default_quantizer1, + iteration=0, + tensor=tensor, + ) + assert type(output1) == Float8Tensor + assert output1._fp8_dtype == tex.DType.kFloat8E4M3 + + output2 = debug_api.transformer_engine.modify_tensor( + "decoder.1.mlp.fc1", + gemm="dgrad", + tensor=tensor, + tensor_name="gradient", + default_quantizer=default_quantizer2, + iteration=0, + ) + assert type(output2) == Float8Tensor + assert output2._fp8_dtype == tex.DType.kFloat8E5M2 + + assert not debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc1", + gemm="wgrad", + tensor_name="gradient", + iteration=0, + ) + + assert not debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc4", + gemm="fprop", + tensor_name="activation", + iteration=0, + ) + finally: + debug_api.end_debug() + + +def test_fake_quant(configs_dir, feature_dirs): + try: + debug_api.initialize( + configs_dir + "fake_quantization_config.yaml", feature_dirs=feature_dirs + ) + + tensor = torch.rand(24, 2046).cuda() + + # modify_tensor_enabled + assert debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0 + ) + + assert debug_api.transformer_engine.modify_tensor_enabled( + "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0 + ) + + # modify_tensor + debug_api.transformer_engine.modify_tensor( + "decoder.1.mlp.fc1", + gemm="fprop", + tensor=tensor, + tensor_name="activation", + iteration=0, + default_quantizer=None, + ) + + debug_api.transformer_engine.modify_tensor( + "decoder.1.mlp.fc1", + gemm="dgrad", + tensor=tensor, + tensor_name="gradient", + iteration=0, + default_quantizer=None, + ) + + assert debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.fc2", gemm="wgrad", iteration=0 + ) + # caching + assert debug_api.transformer_engine.fp8_gemm_enabled( + "decoder.1.fc2", gemm="wgrad", iteration=0 + ) + finally: + debug_api.end_debug() + + +def test_statistics_collection(configs_dir, feature_dirs): + try: + debug_api.initialize( + config_file=configs_dir + "stats_collection_test_config.yaml", + feature_dirs=feature_dirs, + default_logging_enabled=False, + ) + + tensor = torch.randn((100, 100, 5)).cuda() + tensor_fp8 = Float8Tensor( + data=tensor.to(torch.uint8).cuda(), + fp8_scale_inv=torch.full([1], 1.0).cuda(), + fp8_dtype=tex.DType.kFloat8E4M3, + shape=tensor.shape, + dtype=torch.float32, + ) + + def log(): + from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS + + return STATS_BUFFERS.log_stats() + + def assert_empty(): + stats = log() + assert len(stats) == 0 + + # TE tensor stats -- + debug_api.transformer_engine.inspect_tensor( + "decoder.1.mlp.fc1", + tensor=tensor, + tensor_name="activation", + iteration=200, + tp_group=None, + ) + stats = log() + assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max() + assert not debug_api.transformer_engine.inspect_tensor_enabled( + "decoder.1.mlp.fc1", tensor_name="activation", iteration=201 + ) + assert not debug_api.transformer_engine.inspect_tensor_enabled( + "decoder.2.mlp.fc1", tensor_name="activation", iteration=200 + ) + assert not debug_api.transformer_engine.inspect_tensor_enabled( + "decoder.1.mlp.fc1", tensor_name="gradient", iteration=200 + ) + + expected_underflows = (tensor_fp8._data == 0).sum() * 100 / (100 * 100 * 5) + expected_overflows = (tensor_fp8._data == 126).sum() * 100 / (100 * 100 * 5) + + # TE FP8 tensor stats -- + assert debug_api.transformer_engine.inspect_tensor_postquantize_enabled( + "decoder.1.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200 + ) + debug_api.transformer_engine.inspect_tensor_postquantize( + "decoder.1.mlp.fc1", + tensor=tensor_fp8, + tensor_name="gradient", + iteration=200, + rowwise=True, + tp_group=None, + ) + stats = log() + torch.testing.assert_close( + stats[("decoder.1.mlp.fc1", "gradient", "underflows%", 200)], expected_underflows + ) + + assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( + "decoder.1.mlp.fc1", tensor_name="activation", gemm="fprop", iteration=201 + ) + assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( + "decoder.2.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200 + ) + + # Second config in same yaml + tensor = torch.rand((100, 100, 5)) + debug_api.transformer_engine.inspect_tensor( + "decoder.6.mlp.fc1", + tensor=tensor, + tensor_name="activation", + iteration=200, + tp_group=None, + ) + stats = log() + stats_names = [x[3] for x in stats.keys()] + all(s in stats_names for s in ["cur_amax", "dynamic_range", "mean", "std", "l1_norm"]) + assert stats[("decoder.6.mlp.fc1", "activation", "mean", 200)] == tensor.mean() + + debug_api.transformer_engine.inspect_tensor( + "decoder.7.mlp.fc1", + tensor=tensor, + tensor_name="weight", + iteration=200, + tp_group=None, + ) + stats = log() + stats_names = [x[3] for x in stats.keys()] + all(s in stats_names for s in ["mean", "std", "l1_norm", "min", "max"]) + assert stats[("decoder.7.mlp.fc1", "weight", "max", 200)] == tensor.max() + + assert not debug_api.transformer_engine.inspect_tensor_enabled( + "decoder.7.mlp.fc1", tensor_name="weight", iteration=201 + ) + assert_empty() + + finally: + debug_api.end_debug() + + +def test_statistics_multi_run(configs_dir, feature_dirs): + try: + debug_api.initialize( + config_file=configs_dir + "stats_collection_test_config.yaml", + feature_dirs=feature_dirs, + default_logging_enabled=False, + ) + + def feed(tensor, tensor_fp8): + debug_api.transformer_engine.inspect_tensor( + "decoder.5.mlp.fc1", + tensor=tensor, + tensor_name="activation", + iteration=1, + tp_group=None, + ) + debug_api.transformer_engine.inspect_tensor_postquantize( + "decoder.5.mlp.fc1", + tensor=tensor_fp8, + tensor_name="activation", + iteration=1, + rowwise=True, + tp_group=None, + ) + + def log_stats(): + from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS + + return STATS_BUFFERS.log_stats() + + def fp8_tensor(t): + return Float8Tensor( + data=t.to(torch.uint8).cuda(), + fp8_scale_inv=torch.ones([1]).cuda(), + fp8_dtype=tex.DType.kFloat8E4M3, + shape=t.shape, + dtype=torch.float32, + ) + + shape = [1024, 1024] + tensors = [torch.randn(shape) for _ in range(2)] + tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)] + + feed(tensors[0], tensors_fp8[0]) + feed(tensors[1], tensors_fp8[1]) + stats1 = log_stats() + + tensor2 = torch.cat((tensors[0], tensors[1])).cuda() + fp8tensor2 = fp8_tensor(tensor2) + feed(tensor2, fp8tensor2) + stats2 = log_stats() + + assert len(stats1.keys()) > 0 + for k in stats1.keys(): + torch.testing.assert_close(stats1[k], stats2[k]) + finally: + debug_api.end_debug() + + +if __name__ == "__main__": + pass diff --git a/tests/pytorch/debug/test_config.py b/tests/pytorch/debug/test_config.py new file mode 100644 index 000000000..71715a686 --- /dev/null +++ b/tests/pytorch/debug/test_config.py @@ -0,0 +1,151 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +import pathlib, os + +from nvdlfw_inspect.config_manager import ConfigManager + +import nvdlfw_inspect.api as debug_api + +try: + import transformer_engine + from transformer_engine.debug.features.api import TEConfigAPIMapper +except (ImportError, ModuleNotFoundError): + print("Could not find TransformerEngine debug module.") + exit(1) + + +def test_transformer_engine_config_parsing(feature_dirs): + debug_api.initialize( + config_file=pathlib.Path(__file__).resolve().parent + / "test_configs/tensor_manipulation_transformer_engine.yaml", + feature_dirs=feature_dirs, + log_dir="./log", + ) + + cfg_fc1 = ConfigManager.get_config_for_layer("decoder.1.mlp.fc1")["transformer_engine"] + cfg_fc2 = ConfigManager.get_config_for_layer("decoder.1.mlp.fc2")["transformer_engine"] + assert cfg_fc1 and cfg_fc2 + + gemm_parsing = True + tensor_parsing = True + + # Per tensor scaling set for dgrad, filter based on gemm + ret, _ = TEConfigAPIMapper().parse_config_and_api( + cfg_fc1["PerTensorScaling"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="wgrad", + tensor_name="activation", + ) + assert not ret + + # per tensor scaling set for gradient, filter based on tensor name + ret, _ = TEConfigAPIMapper().parse_config_and_api( + cfg_fc1["PerTensorScaling"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="dgrad", + tensor_name="activation", + ) + assert not ret + + ret, parsed_cfg_fc1 = TEConfigAPIMapper().parse_config_and_api( + cfg_fc1["PerTensorScaling"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="dgrad", + tensor_name="gradient", + ) + assert ret + assert parsed_cfg_fc1 == {"gemm": "dgrad", "tensor": "gradient"} + + # Test tensor struct + ret, parsed_cfg_fc1_act = TEConfigAPIMapper().parse_config_and_api( + cfg_fc1["FakeQuant"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="fprop", + tensor_name="activation", + ) + ret, parsed_cfg_fc1_wei = TEConfigAPIMapper().parse_config_and_api( + cfg_fc1["FakeQuant"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="fprop", + tensor_name="weight", + ) + assert ret + assert parsed_cfg_fc1_act == { + "gemm": "fprop", + "tensor": "activation", + "quant_format": "FP8E4M3", + } + assert parsed_cfg_fc1_wei == { + "gemm": "fprop", + "tensor": "weight", + "quant_format": "FP8E4M3", + } + + # Test gemms struct + ret, parsed_cfg_fc2_grad = TEConfigAPIMapper().parse_config_and_api( + cfg_fc2["FakeQuant"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="dgrad", + tensor_name="gradient", + ) + assert ret + assert parsed_cfg_fc2_grad == {"gemm": "dgrad", "tensor": "gradient", "quant_format": "FP8E5M2"} + ret, parsed_cfg_fc2_wei = TEConfigAPIMapper().parse_config_and_api( + cfg_fc2["FakeQuant"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="dgrad", + tensor_name="weight", + ) + assert ret + assert parsed_cfg_fc2_wei == {"gemm": "dgrad", "tensor": "weight", "quant_format": "FP8E5M2"} + + # Test gemm + tensor struct + ret, parsed_cfg_fc2_fprop_act = TEConfigAPIMapper().parse_config_and_api( + cfg_fc2["PerTensorScaling"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="fprop", + tensor_name="activation", + ) + assert ret + assert parsed_cfg_fc2_fprop_act == {"gemm": "fprop", "tensor": "activation"} + + ret, parsed_cfg_fc2_fprop_wei = TEConfigAPIMapper().parse_config_and_api( + cfg_fc2["PerTensorScaling"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="fprop", + tensor_name="weight", + ) + assert ret + assert parsed_cfg_fc2_fprop_wei == {"gemm": "fprop", "tensor": "weight"} + + ret, parsed_cfg_fc2_wgrad_act = TEConfigAPIMapper().parse_config_and_api( + cfg_fc2["PerTensorScaling"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="wgrad", + tensor_name="activation", + ) + assert ret + assert parsed_cfg_fc2_wgrad_act == {"gemm": "wgrad", "tensor": "activation"} + + ret, parsed_cfg_fc2_wgrad_grad = TEConfigAPIMapper().parse_config_and_api( + cfg_fc2["PerTensorScaling"], + gemm_parsing=gemm_parsing, + tensor_parsing=tensor_parsing, + gemm="wgrad", + tensor_name="gradient", + ) + assert ret + assert parsed_cfg_fc2_wgrad_grad == {"gemm": "wgrad", "tensor": "gradient"} + + ConfigManager.reset() diff --git a/tests/pytorch/debug/test_configs/disable_fp8_gemms.yaml b/tests/pytorch/debug/test_configs/disable_fp8_gemms.yaml new file mode 100644 index 000000000..b832f26d8 --- /dev/null +++ b/tests/pytorch/debug/test_configs/disable_fp8_gemms.yaml @@ -0,0 +1,8 @@ +test_disable_fp8_gemm_1: + enabled: True + layers: + layer_types: [qkv, fc2] + transformer_engine: + DisableFP8GEMM: + enabled: True + gemms: [dgrad, wgrad] \ No newline at end of file diff --git a/tests/pytorch/debug/test_configs/disable_fp8_layer.yaml b/tests/pytorch/debug/test_configs/disable_fp8_layer.yaml new file mode 100644 index 000000000..39bfc7a25 --- /dev/null +++ b/tests/pytorch/debug/test_configs/disable_fp8_layer.yaml @@ -0,0 +1,7 @@ +test_disable_fp8_layer: + enabled: True + layers: + layer_types: [qkv] + transformer_engine: + DisableFP8Layer: + enabled: True \ No newline at end of file diff --git a/tests/pytorch/debug/test_configs/dummy_feature.yaml b/tests/pytorch/debug/test_configs/dummy_feature.yaml new file mode 100644 index 000000000..540e3ac42 --- /dev/null +++ b/tests/pytorch/debug/test_configs/dummy_feature.yaml @@ -0,0 +1,9 @@ +deummy_feature_everywhere: + enabled: True + layers: + layer_name_regex_pattern: .* + transformer_engine: + TestDummyFeature: + enabled: True + tensors: [weight, activation, gradient, output, wgrad, dgrad] + gemms: [wgrad, dgrad, fprop] \ No newline at end of file diff --git a/tests/pytorch/debug/test_configs/fake_quantization_config.yaml b/tests/pytorch/debug/test_configs/fake_quantization_config.yaml new file mode 100644 index 000000000..62feace6d --- /dev/null +++ b/tests/pytorch/debug/test_configs/fake_quantization_config.yaml @@ -0,0 +1,14 @@ +test_fake_quant_fp8: + enabled: True + layers: + layer_numbers: [1] + layer_types: [fc1, fc2] + transformer_engine: + FakeQuant: + enabled: True + gemms: [fprop, dgrad] + tensors_struct: + - tensor: activation + quant_format: FP8E4M3 + - tensor: gradient + quant_format: FP8E5M2 \ No newline at end of file diff --git a/tests/pytorch/debug/test_configs/per_tensor_scaling.yaml b/tests/pytorch/debug/test_configs/per_tensor_scaling.yaml new file mode 100644 index 000000000..c17f2f7d2 --- /dev/null +++ b/tests/pytorch/debug/test_configs/per_tensor_scaling.yaml @@ -0,0 +1,19 @@ +test_per_tensor_scaling: + enabled: True + layers: + layer_numbers: [1] + layer_types: [fc1, fc2] + transformer_engine: + DisableFP8GEMM: + enabled: True + gemms: [wgrad] + PerTensorScaling: + enabled: True + gemms_struct: + - gemm: fprop + tensors_struct: + - tensor: activation + - tensor: weight + - gemm: dgrad + tensors_struct: + - tensor: gradient \ No newline at end of file diff --git a/tests/pytorch/debug/test_configs/stats_collection_test_config.yaml b/tests/pytorch/debug/test_configs/stats_collection_test_config.yaml new file mode 100644 index 000000000..8f01b2d62 --- /dev/null +++ b/tests/pytorch/debug/test_configs/stats_collection_test_config.yaml @@ -0,0 +1,59 @@ +stat_collection_test_1: + enabled: True + layers: + layer_numbers: [1, 3] + LogTensorStats: + enabled: True + stats: [mean, std, l1_norm, l2_norm] + tensors: [activation] + freq: 1 + start_step: 100 + end_step: 500 + transformer_engine: + LogTensorStats: + enabled: True + stats: [cur_amax, dynamic_range] + tensors: [activation] + freq: 2 + start_step: 100 + end_step: 500 + LogFp8TensorStats: + enabled: True + stats: [underflows%] + tensors: [gradient] + freq: 5 + start_step: 100 + end_step: 500 + +stat_collection_test_2: + enabled: True + layers: + layer_numbers: [6, 7] + transformer_engine: + LogTensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [cur_amax, dynamic_range, mean, std, l1_norm] + freq: 2 + start_step: 100 + end_step: 500 + - tensor: weight + stats: [mean, std, l1_norm, min, max] + freq: 5 + start_step: 100 + end_step: 500 + +stat_collection_test_4: + enabled: True + layers: + layer_numbers: [5] + transformer_engine: + LogTensorStats: + enabled: True + tensors: [activation] + stats: [cur_amax, dynamic_range, mean, std, l1_norm] + LogFp8TensorStats: + enabled: True + stats: [underflows%] + tensors: [activation] \ No newline at end of file diff --git a/tests/pytorch/debug/test_configs/tensor_manipulation_transformer_engine.yaml b/tests/pytorch/debug/test_configs/tensor_manipulation_transformer_engine.yaml new file mode 100644 index 000000000..e86486366 --- /dev/null +++ b/tests/pytorch/debug/test_configs/tensor_manipulation_transformer_engine.yaml @@ -0,0 +1,45 @@ +# This config is used when FP8 training is ON + +transformer_engine_fc1_manipulation: + enabled: True + layers: + layer_name_regex_pattern: .*(fc1) # Select layers if they end in fc1 + transformer_engine: # namespace + DisableFP8GEMM: # Disable FP8 GEMM. FProp run in high precision + enabled: True + gemms: [fprop] + PerTensorScaling: # Scale DGrad gradients using per tensor current scaling and run FP8 GEMM + enabled: True + gemms: [dgrad] + tensors: [gradient] + FakeQuant: # Disable FP8 GEMM for Wgrad. Fake quantize activations to Wgrad and run high precision GEMM + enabled: True + gemms: [fprop] + tensors_struct: + - tensor: activation + quant_format: FP8E4M3 + - tensor: weight + quant_format: FP8E4M3 + +transformer_engine_fc2_manipulation: + enabled: True + layers: + layer_name_regex_pattern: .*(fc2) # Select layers if they end in fc2 + transformer_engine: # namespace + PerTensorScaling: # Scale WGrad and Fprop inputs using per tensor current scaling and run FP8 GEMM + enabled: True + gemms_struct: + - gemm: fprop + tensors_struct: + - tensor: activation + - tensor: weight + - gemm: wgrad + tensors_struct: + - tensor: activation + - tensor: gradient + FakeQuant: # Disable FP8 GEMM for DGrad. Fake quantize weights and gradients to DGrad and run high precision GEMM + enabled: True + gemms_struct: + - gemm: dgrad + tensors: [weight, gradient] + quant_format: FP8E5M2 \ No newline at end of file diff --git a/tests/pytorch/debug/test_distributed.py b/tests/pytorch/debug/test_distributed.py new file mode 100644 index 000000000..7c072a054 --- /dev/null +++ b/tests/pytorch/debug/test_distributed.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import subprocess +from pathlib import Path + +import pytest +import torch + +""" + Distributed numerics tests + + These tests test the numerical corectness of the TransformerEngine layers. + Tests are parametrized by the layer and fp8 precision. + One test consists of running multiple configurations from file run_numerics.py + Such design is due to the fact the initialization of one test is long + - 2 processes need to start and load torch and TE. Multiple configurations + are run in one test - this reduces the initialization overhead. + +""" + + +if torch.cuda.device_count() < 2: + pytest.skip("Distributed training needs at least 2 GPUs.") + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS: int = min(4, torch.cuda.device_count()) +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + + +def test_debug_distributed(feature_dirs): + test_path = TEST_ROOT / "run_distributed.py" + test_cmd = LAUNCH_CMD + [str(test_path), f"--feature_dirs={feature_dirs[0]}"] + + result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) + if result.returncode != 0: + raise AssertionError(result.stderr.decode()) diff --git a/tests/pytorch/debug/test_numerics.py b/tests/pytorch/debug/test_numerics.py new file mode 100644 index 000000000..55c3ab9b7 --- /dev/null +++ b/tests/pytorch/debug/test_numerics.py @@ -0,0 +1,718 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import functools +import itertools +import os +import random +import tempfile +from string import Template + +import pytest +import torch + +import nvdlfw_inspect.api as debug_api +import transformer_engine.debug +import transformer_engine.pytorch as tepytorch +import transformer_engine_torch as tex +from transformer_engine.common.recipe import DelayedScaling, Format +from transformer_engine.pytorch.fp8 import _default_sf_compute +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.module.base import ( + _2X_ACC_DGRAD, + _2X_ACC_FPROP, + _2X_ACC_WGRAD, +) + +all_boolean = [True, False] +FP8_FORMAT = Format.HYBRID +AMAX_HISTORY_LEN = 16 +FP8_RECIPE = DelayedScaling( + fp8_format=FP8_FORMAT, amax_history_len=AMAX_HISTORY_LEN, amax_compute_algo="max" +) +SEED = 1234 +IN_SIZE = 128 +OUT_SIZE = 64 +BATCH_SIZE = 16 +SEQ_LEN = 128 +LOSS_FN = torch.nn.functional.cross_entropy + + +def _cast_to_fp8(tensor, scale, dtype): + tensor = tensor.contiguous() + if type(scale) == torch.Tensor: + amax = scale.abs().max().float() + quantizer = Float8Quantizer(scale, amax, dtype) + else: + quantizer = Float8CurrentScalingQuantizer(scale, device=tensor.device) + + return quantizer(tensor) + + +def _get_current_scale(tensor, fp8_dtype): + if fp8_dtype == tex.DType.kFloat8E4M3: + fp8_max = Format.E4M3.value.max_fwd + else: + fp8_max = Format.E5M2.value.max_fwd + + amax = tensor.abs().max().float() + one = torch.ones(1, device=tensor.device) + + return _default_sf_compute(amax, one, fp8_max, 0).detach() + + +def _fake_cast(tensor, fp8_dtype, scale): + scale = scale or _get_current_scale(tensor, fp8_dtype) + fp8_tensor = _cast_to_fp8(tensor, scale, fp8_dtype) + + return fp8_tensor.dequantize() + + +def _fp8_gemm_kernel(tensor1, scale1, dtype1, tensor2, scale2, dtype2, use_split_accumulator): + fp8_tensor1 = _cast_to_fp8(tensor1, scale1, dtype1) + fp8_tensor2 = _cast_to_fp8(tensor2, scale2, dtype2) + + out, *_ = tepytorch.cpp_extensions.general_gemm( + fp8_tensor1, + fp8_tensor2, + tepytorch.module.base.get_workspace(), + torch.float32, + use_split_accumulator=use_split_accumulator, + ) + out.requires_grad = True + return out.T + + +def _emulate_linear( + input: torch.Tensor, + weight: torch.Tensor, + fprop_fp8: bool = False, + fprop_input_fake_quant: tex.DType = None, + fprop_input_scale: torch.Tensor = None, + fprop_weight_fake_quant: tex.DType = None, + fprop_weight_scale: torch.Tensor = None, + dgrad_fp8: bool = False, + dgrad_gradient_fake_quant: tex.DType = None, + dgrad_gradient_scale: torch.Tensor = None, + dgrad_weight_fake_quant: tex.DType = None, + dgrad_weight_scale: torch.Tensor = None, + wgrad_fp8: bool = False, + wgrad_gradient_fake_quant: tex.DType = None, + wgrad_gradient_scale: torch.Tensor = None, + wgrad_input_fake_quant: tex.DType = None, + wgrad_input_scale: torch.Tensor = None, + loss_multiplier: float = 1.0, + activation_sync=None, + gradient_sync=None, +): + _scalar = lambda x: torch.Tensor([x]).cuda() if type(x) in [float, torch.Tensor] else x + if fprop_fp8: + activation = _fp8_gemm_kernel( + input, + _scalar(fprop_input_scale or 1.0), + tex.DType.kFloat8E4M3, + weight, + _scalar(fprop_weight_scale or 1.0), + tex.DType.kFloat8E4M3, + _2X_ACC_FPROP, + ) + activation = activation.clone().detach().contiguous().requires_grad_(True) + else: + fprop_input = ( + _fake_cast(input, fprop_input_fake_quant, _scalar(fprop_input_scale)) + if fprop_input_fake_quant is not None + else input + ) + fprop_weight = ( + _fake_cast(weight, fprop_weight_fake_quant, _scalar(fprop_weight_scale)) + if fprop_weight_fake_quant is not None + else weight + ) + + activation = (fprop_input @ fprop_weight.T).contiguous() + + if activation_sync: + activation = activation_sync(activation) + + activation.retain_grad() + + (loss_multiplier * activation.sum()).backward(retain_graph=True) + gradient = activation.grad.clone() + + if gradient_sync: + gradient = gradient_sync(gradient) + + if dgrad_fp8: + dgrad = _fp8_gemm_kernel( + weight.T, + _scalar(dgrad_weight_scale or 1.0), + tex.DType.kFloat8E4M3, + gradient, + _scalar(dgrad_gradient_scale or 1.0), + tex.DType.kFloat8E5M2, + _2X_ACC_DGRAD, + ).T + else: + dgrad_gradient = ( + _fake_cast(gradient, dgrad_gradient_fake_quant, _scalar(dgrad_gradient_scale)) + if dgrad_gradient_fake_quant is not None + else gradient + ) + + dgrad_weight = ( + _fake_cast(weight, dgrad_weight_fake_quant, _scalar(dgrad_weight_scale)) + if dgrad_weight_fake_quant is not None + else weight + ) + dgrad = dgrad_gradient @ dgrad_weight + + if wgrad_fp8: + wgrad = _fp8_gemm_kernel( + input.T, + _scalar(wgrad_input_scale or 1.0), + tex.DType.kFloat8E4M3, + gradient.T, + _scalar(wgrad_gradient_scale or 1.0), + tex.DType.kFloat8E5M2, + _2X_ACC_WGRAD, + ).T + else: + wgrad_gradient = ( + _fake_cast(gradient, wgrad_gradient_fake_quant, _scalar(wgrad_gradient_scale)) + if wgrad_gradient_fake_quant is not None + else gradient + ) + wgrad_input = ( + _fake_cast(input, wgrad_input_fake_quant, _scalar(wgrad_input_scale)) + if wgrad_input_fake_quant is not None + else input + ) + wgrad_input = wgrad_input.contiguous() + wgrad_gradient = wgrad_gradient.contiguous() + wgrad, *_ = tepytorch.cpp_extensions.general_gemm( + wgrad_input, + wgrad_gradient, + tepytorch.module.base.get_workspace(), + torch.float32, + layout="NT", + grad=True, + use_split_accumulator=_2X_ACC_WGRAD, + ) + + return {"activation": activation, "wgrad": wgrad, "dgrad": dgrad} + + +def _init_debug(config_name, log_dir, feature_dirs): + debug_api.initialize( + config_file=config_name, + feature_dirs=feature_dirs, + log_dir=log_dir, + default_logging_enabled=True, + ) + + +def create_config_file(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file: + with tempfile.TemporaryDirectory() as temp_dir: + try: + kwargs["config_file"] = temp_file + kwargs["log_dir"] = temp_dir + result = func(*args, **kwargs) + finally: + temp_file_name = temp_file.name + debug_api.end_debug() + os.unlink(temp_file_name) + return result + + return wrapper + + +def _cmp(ground_truth, output): + torch.testing.assert_close(ground_truth["activation"], output["activation"]) + torch.testing.assert_close(ground_truth["wgrad"], output["wgrad"]) + torch.testing.assert_close(ground_truth["dgrad"], output["dgrad"]) + + +def _init_model(weight): + model = transformer_engine.pytorch.Linear(IN_SIZE, OUT_SIZE, name="linear") + with torch.no_grad(): + model.weight.copy_(weight.contiguous()) + return model + + +def _run_forward_backward(x, model, loss_scale=1.0, is_first_microbatch=None): + with tepytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + y = model(x, is_first_microbatch=is_first_microbatch) + (y.sum() * loss_scale).backward() + debug_api.step() + return y + + +def _get_tensors(): + torch.manual_seed(SEED) + x = torch.randn((SEQ_LEN * BATCH_SIZE, IN_SIZE), requires_grad=True).cuda() + x.retain_grad() + weight = torch.randn((OUT_SIZE, IN_SIZE)).cuda() + return x, weight + + +DISABLE_FP8_CONFIG = Template( + """disable_fp8_config: + enabled: True + layers: + layer_types: [linear] + transformer_engine: + DisableFP8GEMM: + enabled: True + gemms: [$gemms] +""" +) + + +@pytest.mark.parametrize("fprop_fp8", all_boolean) +@pytest.mark.parametrize("dgrad_fp8", all_boolean) +@pytest.mark.parametrize("wgrad_fp8", all_boolean) +def test_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8): + run_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8) + + +def disable_fp8_gemms_create_config(fprop_fp8, dgrad_fp8, wgrad_fp8, config_file): + gemms = "" + if not fprop_fp8: + gemms += "fprop," + if not dgrad_fp8: + gemms += "dgrad," + if not wgrad_fp8: + gemms += "wgrad," + if len(gemms) > 0: + gemms = gemms[:-1] # remove last ',' + config_file.write(DISABLE_FP8_CONFIG.safe_substitute(gemms=gemms)) + config_file.flush() + + +@create_config_file +def run_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8, **kwargs): + disable_fp8_gemms_create_config(fprop_fp8, dgrad_fp8, wgrad_fp8, kwargs["config_file"]) + fp8_kwargs = { + "fprop_fp8": fprop_fp8, + "dgrad_fp8": dgrad_fp8, + "wgrad_fp8": wgrad_fp8, + } + + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs) + x, weight = _get_tensors() + model = _init_model(weight) + y = _run_forward_backward(x, model) + + output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()} + + x.grad.zero_() + ground_truth = _emulate_linear(x, weight, **fp8_kwargs) + _cmp(ground_truth, output) + + +def test_disable_fp8_layer(feature_dirs): + run_disable_fp8_layer(feature_dirs) + + +DISABLE_FP8_LAYER_CONFIG = """disable_fp8_config: + enabled: True + layers: + layer_types: [linear] + transformer_engine: + DisableFP8Layer: + enabled: True +""" + + +@create_config_file +def run_disable_fp8_layer(feature_dirs, **kwargs): + kwargs["config_file"].write(DISABLE_FP8_LAYER_CONFIG) + kwargs["config_file"].flush() + + x, weight = _get_tensors() + + ground_truth = _emulate_linear(x, weight) + x.grad.zero_() + + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs) + + model = _init_model(weight) + y = _run_forward_backward(x, model) + + output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()} + _cmp(ground_truth, output) + + +random.seed(1234) + +all_combinations = list(itertools.product(all_boolean, repeat=6)) +subset_combinations = random.sample(all_combinations, 20) + + +@pytest.mark.parametrize( + "fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad", + subset_combinations, +) +def test_per_tensor_scaling( + feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad +): + if not any([fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad]): + pytest.skip("Skipping test because all parameters are False") + run_per_tensor_scaling( + feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad + ) + + +PER_TENSOR_SCALING_CONFIG = Template( + """per_tensor_scaling_config: + enabled: True + layers: + layer_types: [linear] + transformer_engine: + PerTensorScaling: + enabled: True + gemms_struct: +$gemms +""" +) + + +def _prepare_per_tensor_scaling_config( + fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad, config_file +): + gemms = "" + title = lambda x: f" - gemm: {x}\n tensors: [" + + def add_tensor(if_add, gemm_name): + nonlocal gemms + if if_add: + gemms += gemm_name + "," + + if fprop_inp or fprop_weight: + gemms += title("fprop") + add_tensor(fprop_inp, "activation") + add_tensor(fprop_weight, "weight") + gemms = gemms[:-1] + "]\n" + if dgrad_weight or dgrad_grad: + gemms += title("dgrad") + add_tensor(dgrad_weight, "weight") + add_tensor(dgrad_grad, "gradient") + gemms = gemms[:-1] + "]\n" + if wgrad_input or wgrad_grad: + gemms += title("wgrad") + add_tensor(wgrad_input, "activation") + add_tensor(wgrad_grad, "gradient") + gemms = gemms[:-1] + "]\n" + config_file.write(PER_TENSOR_SCALING_CONFIG.safe_substitute(gemms=gemms)) + config_file.flush() + + +def set_scaling_factors(model, input_kwargs, fp8_kwargs): + # Copy fp8 scaling factors into fp8_kwargs dict if respective flag in input_kwargs is set. + if not input_kwargs["fprop_inp"]: + fp8_kwargs["fprop_input_scale"] = model.fp8_meta["scaling_fwd"].scale[0].clone() + if not input_kwargs["fprop_weight"]: + fp8_kwargs["fprop_weight_scale"] = model.fp8_meta["scaling_fwd"].scale[1].clone() + if not input_kwargs["dgrad_grad"]: + fp8_kwargs["dgrad_gradient_scale"] = model.fp8_meta["scaling_bwd"].scale[0].clone() + if not input_kwargs["dgrad_weight"]: + fp8_kwargs["dgrad_weight_scale"] = model.fp8_meta["scaling_fwd"].scale[1].clone() + if not input_kwargs["wgrad_grad"]: + fp8_kwargs["wgrad_gradient_scale"] = model.fp8_meta["scaling_bwd"].scale[0].clone() + if not input_kwargs["wgrad_input"]: + fp8_kwargs["wgrad_input_scale"] = model.fp8_meta["scaling_fwd"].scale[0].clone() + + +def set_current_scaling_factors(x, weight, y, input_kwargs, fp8_kwargs): + # Compute per tensor scaling factor if respective flag in input_kwargs is set. + if input_kwargs["fprop_inp"]: + fp8_kwargs["fprop_input_scale"] = tex.DType.kFloat8E4M3 + if input_kwargs["fprop_weight"]: + fp8_kwargs["fprop_weight_scale"] = tex.DType.kFloat8E4M3 + if input_kwargs["dgrad_grad"]: + fp8_kwargs["dgrad_gradient_scale"] = tex.DType.kFloat8E5M2 + if input_kwargs["dgrad_weight"]: + fp8_kwargs["dgrad_weight_scale"] = tex.DType.kFloat8E4M3 + if input_kwargs["wgrad_grad"]: + fp8_kwargs["wgrad_gradient_scale"] = tex.DType.kFloat8E5M2 + if input_kwargs["wgrad_input"]: + fp8_kwargs["wgrad_input_scale"] = tex.DType.kFloat8E4M3 + + +@create_config_file +def run_per_tensor_scaling( + feature_dirs, + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + **kwargs, +): + input_kwargs = { + "fprop_inp": fprop_inp, + "fprop_weight": fprop_weight, + "dgrad_weight": dgrad_weight, + "dgrad_grad": dgrad_grad, + "wgrad_input": wgrad_input, + "wgrad_grad": wgrad_grad, + } + fp8_kwargs = { + "fprop_fp8": True, + "dgrad_fp8": True, + "wgrad_fp8": True, + } + """ + Runs a test to validate per-tensor (current) scaling in FP8 computations. + The function performs warm-up iterations to populate the amax buffer of the model and compute scaling factors based on delayed scaling. + Subsequently, weights and inputs are switched to ensure their current scaling factors differ from those based on delayed scaling; + similarly, the loss is multiplied by a large factor to alter the gradient's magnitude, + creating a discrepancy between the original (delayed) and per-tensor (current) scaling factors. + Finally, a linear pass is emulated, and the results are compared.” + """ + _prepare_per_tensor_scaling_config( + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + kwargs["config_file"], + ) + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs) + + warmup_input, warmup_weight = _get_tensors() + model = _init_model(warmup_weight) + + # Warmup run to setup amax and scaling factors. + for _ in range(AMAX_HISTORY_LEN): + _run_forward_backward(warmup_input, model) + + x = torch.randn_like(warmup_input, requires_grad=True).cuda() + weight = torch.randn_like(warmup_weight, requires_grad=True).cuda() + model.weight.data = weight.data + x.retain_grad() + + # delayed scaling factor + # need to be collected before forward pass with test data, + # because this forward pass changes scaling factors + set_scaling_factors(model, input_kwargs, fp8_kwargs) + + LOSS_MULTIPLIER = 100 + + with tepytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + y = model(x, is_first_microbatch=True) + model.zero_grad() + y.retain_grad() + ( + LOSS_MULTIPLIER * y.sum() + ).backward() # Loss multiplication to change gradient's order of magintude + + output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()} + + # per tensor - current - scaling factors + # need to be collected after forward pass with test data, + # because gradient(y.grad) cannot be accessed before forward, + # but it needs to be collected. + set_current_scaling_factors(x, weight, y, input_kwargs, fp8_kwargs) + + ground_truth = _emulate_linear(x, weight, loss_multiplier=LOSS_MULTIPLIER, **fp8_kwargs) + _cmp(ground_truth, output) + + +@pytest.mark.parametrize( + "fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad", + subset_combinations, +) +def test_microbatching_per_tensor_scaling( + feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad +): + if not any([fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad]): + pytest.skip("Skipping test because all parameters are False") + + @create_config_file + def run_microbatching_test( + feature_dirs, + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + **kwargs, + ): + # Prepare the configuration file + _prepare_per_tensor_scaling_config( + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + kwargs["config_file"], + ) + + # Initialize debug + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs) + + # Get data + x_full, weight = _get_tensors() + microbatch_size = x_full.size(0) // 2 + x_mb1 = x_full[:microbatch_size, ...].clone().detach().requires_grad_(True) + x_mb2 = x_full[microbatch_size:, ...].clone().detach().requires_grad_(True) + + def init_and_warmup(): + model = _init_model(weight) + _run_forward_backward(x_mb1, model, loss_scale=0.5) + _run_forward_backward(x_mb2, model, loss_scale=0.5) + return model + + # Run without is_first_microbatch + + model = init_and_warmup() # running next 2 iters does not change amaxes and scaling factors + y_mb1 = _run_forward_backward(x_mb1, model, loss_scale=0.5) + y_mb2 = _run_forward_backward(x_mb2, model, loss_scale=0.5) + + # Collect outputs + output1 = { + "activation": torch.cat([y_mb1.clone(), y_mb2.clone()], dim=0), + "wgrad": model.weight.grad.clone(), + "dgrad": torch.cat([x_mb1.grad.clone(), x_mb2.grad.clone()], dim=0), + } + + # Run with is_first_microbatch + model = init_and_warmup() # running next 2 iters does not change amaxes and scaling factors + y_mb1 = _run_forward_backward(x_mb1, model, loss_scale=0.5, is_first_microbatch=True) + y_mb2 = _run_forward_backward(x_mb2, model, loss_scale=0.5, is_first_microbatch=False) + + # Collect outputs + output2 = { + "activation": torch.cat([y_mb1.clone(), y_mb2.clone()], dim=0), + "wgrad": model.weight.grad.clone(), + "dgrad": torch.cat([x_mb1.grad.clone(), x_mb2.grad.clone()], dim=0), + } + + # Compare outputs + torch.testing.assert_close(output1["activation"], output2["activation"], atol=1.0, rtol=0.5) + torch.testing.assert_close(output1["dgrad"], output2["dgrad"], atol=1.0, rtol=0.5) + torch.testing.assert_close(output1["wgrad"], output2["wgrad"], atol=1.0, rtol=0.5) + + # Run the test + run_microbatching_test( + feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad + ) + + +all_combinations = list( + itertools.product([tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None], repeat=6) +) +subset_combinations = random.sample(all_combinations, 10) + + +@pytest.mark.parametrize( + "fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad", + subset_combinations, +) +def test_fake_quant_fp8( + feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad +): + run_fake_quant_fp8( + feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad + ) + + +FAKE_QUANT_CONFIG = Template( + """fake_quant_config: + enabled: True + layers: + layer_types: [linear] + transformer_engine: + FakeQuant: + enabled: True + gemms_struct: +$gemms +""" +) + + +def fake_quant_fp8_create_config( + fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad, config_file +): + format_to_str = {tex.DType.kFloat8E4M3: "FP8E4M3", tex.DType.kFloat8E5M2: "FP8E5M2"} + gemms = "" + + def _add_tensor(quant_format, tensor): + nonlocal gemms + if quant_format: + gemms += " " * 8 + "- tensor: " + tensor + "\n" + gemms += " " * 8 + " quant_format: " + format_to_str[quant_format] + "\n" + + title = lambda x: f" - gemm: {x}\n tensors_struct:\n" + if fprop_inp or fprop_weight: + gemms += title("fprop") + _add_tensor(fprop_inp, "activation") + _add_tensor(fprop_weight, "weight") + gemms = gemms[:-1] + "\n" + if dgrad_weight or dgrad_grad: + gemms += title("dgrad") + _add_tensor(dgrad_weight, "weight") + _add_tensor(dgrad_grad, "gradient") + gemms = gemms[:-1] + "\n" + if wgrad_input or wgrad_grad: + gemms += title("wgrad") + _add_tensor(wgrad_input, "activation") + _add_tensor(wgrad_grad, "gradient") + gemms = gemms[:-1] + "\n" + config = FAKE_QUANT_CONFIG.safe_substitute(gemms=gemms) + config_file.write(config) + config_file.flush() + + +@create_config_file +def run_fake_quant_fp8( + feature_dirs, + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + **kwargs, +): + fp8_kwargs = { + "fprop_input_fake_quant": fprop_inp, + "fprop_weight_fake_quant": fprop_weight, + "dgrad_gradient_fake_quant": dgrad_grad, + "dgrad_weight_fake_quant": dgrad_weight, + "wgrad_gradient_fake_quant": wgrad_grad, + "wgrad_input_fake_quant": wgrad_input, + "fprop_fp8": not (fprop_inp or fprop_weight), + "dgrad_fp8": not (dgrad_weight or dgrad_grad), + "wgrad_fp8": not (wgrad_grad or wgrad_input), + } + fake_quant_fp8_create_config( + fprop_inp, + fprop_weight, + dgrad_weight, + dgrad_grad, + wgrad_input, + wgrad_grad, + kwargs["config_file"], + ) + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs) + + x, weight = _get_tensors() + model = _init_model(weight) + y = _run_forward_backward(x, model) + + output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()} + ground_truth = _emulate_linear(x, weight, **fp8_kwargs) + _cmp(ground_truth, output) diff --git a/tests/pytorch/debug/test_sanity.py b/tests/pytorch/debug/test_sanity.py new file mode 100644 index 000000000..6b0883b14 --- /dev/null +++ b/tests/pytorch/debug/test_sanity.py @@ -0,0 +1,107 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import functools +import itertools +import os +import random +import tempfile +from string import Template + +import pytest +import torch + +import nvdlfw_inspect.api as debug_api +import transformer_engine.debug +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.common.recipe import DelayedScaling, Format +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.fp8 import _default_sf_compute +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer + +from test_numerics import create_config_file + +B, S, H, D = 64, 64, 64, 64 + +model_keys = ["linear", "layernorm_linear", "layernorm_mlp", "mha_attention", "transformer_layer"] + +configs = { + "": "", + "log": """log: + layers: + layer_types: [linear] + enabled: + True + transformer_engine: + LogTensorStats: + enabled: True + tensors: [activation, gradient, weight, output, wgrad, dgrad] + stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range] + start_step : 0 + end_step: 1 + LogFp8TensorStats: + enabled: True + tensors: [activation, gradient, weight] + stats: [underflows, overflows] + start_step : 0 + end_step: 1 +""", + "fake_quant": """ +fake_quant_config: + enabled: True + layers: + layer_types: [linear] + transformer_engine: + FakeQuant: + enabled: True + gemms: [fprop, dgrad, wgrad] + quant_format: FP8E5M2 +""", +} + + +def _get_model(model_key): + if model_key == "linear": + return te.Linear(D, D) + if model_key == "layernorm_linear": + return te.LayerNormLinear(D, D) + if model_key == "layernorm_mlp": + return te.LayerNormMLP(D, D, D) + if model_key == "mha_attention": + return te.MultiheadAttention(D, H) + if model_key == "transformer_layer": + return te.TransformerLayer(D, D, H) + + +def _run_forward_backward(model, fp8): + for _ in range(3): + inp = torch.randn((S, B, H)).cuda() + with te.fp8_autocast(enabled=fp8): + out = model(inp) + out.sum().backward() + debug_api.step() + + +@create_config_file +def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir): + try: + if config != "": + config_file.write(config) + config_file.flush() + config_file_name = config_file.name if config != "" else "" + debug_api.initialize(feature_dirs=feature_dirs, config_file=config_file_name) + model = _get_model(model_key) + _run_forward_backward(model, fp8) + except Exception as error: + raise error + finally: + debug_api.end_debug() + + +@pytest.mark.parametrize("model_key", model_keys) +@pytest.mark.parametrize("fp8", [False, True]) +@pytest.mark.parametrize("config_key", configs.keys()) +def test_sanity_debug(model_key, fp8, config_key, feature_dirs): + _run_test(model_key, fp8, configs[config_key], feature_dirs) diff --git a/tests/pytorch/debug/utils.py b/tests/pytorch/debug/utils.py new file mode 100644 index 000000000..f03ee56b5 --- /dev/null +++ b/tests/pytorch/debug/utils.py @@ -0,0 +1,22 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os + +LOG_FILE = os.path.join("nvdlfw_inspect_logs", "nvdlfw_inspect_globalrank-0.log") + + +def reset_debug_log(): + if os.path.isfile(LOG_FILE): + # delete all content + with open(LOG_FILE, "w") as f: + pass + + +def check_debug_log(msg): + with open(LOG_FILE, "r") as f: + for line in f.readlines(): + if msg in line: + return True + return False diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index ac72960c4..b7af78832 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -40,6 +40,18 @@ LOSS_FN = nn.MSELoss() QUANTIZATION = None +if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False): + # The numerics of all the layers should work the same, + # when debug=True. I fed them with dummy feature + # to prevent switching off debug, which can happen if + # no feature is active. + import nvdlfw_inspect.api as debug_api + + debug_api.initialize( + os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"], + feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"], + ) + # Disable TF32 torch.backends.cuda.matmul.allow_tf32 = False @@ -195,7 +207,7 @@ def _get_tolerances(dtype): if dtype == torch.bfloat16: return {"rtol": 1.6e-2, "atol": 1e-5} if dtype == torch.float32: - return {"rtol": 1.3e-6, "atol": 4e-5} + return {"rtol": 1e-4, "atol": 1e-4} raise ValueError(f"Unsupported dtype ({dtype})") diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index ab4b7634b..816df12f6 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -29,7 +29,7 @@ # Flash attention saves some internal tensor for the backward pass # that cannot be offloaded to CPU. -assert os.getenv("NVTE_FLASH_ATTN") == "0" +assert os.getenv("NVTE_FLASH_ATTN", "1") == "0" # Offloading is supported for attention only for fused and flash attention backends, # so the use of bfloat16 is required. diff --git a/tests/pytorch/test_gemm_autotune.py b/tests/pytorch/test_gemm_autotune.py index 562581364..1b54e8464 100644 --- a/tests/pytorch/test_gemm_autotune.py +++ b/tests/pytorch/test_gemm_autotune.py @@ -34,7 +34,7 @@ def analyse_storage(fname): next(reader) head = reader.fieldnames assert ("m" in head and "algo_id" in head and "ws_min" in head and "ws_max" in head - and "aidx" in head), "Invalid CSV format" + ), "Invalid CSV format" return head def read_storage(fname): diff --git a/tests/pytorch/test_hf_integration.py b/tests/pytorch/test_hf_integration.py new file mode 100644 index 000000000..0b2468510 --- /dev/null +++ b/tests/pytorch/test_hf_integration.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_utils import PreTrainedModel + +from transformer_engine.pytorch.transformer import TransformerLayer +from transformer_engine.pytorch.utils import is_bf16_compatible + + +class SimpleTEModel(PreTrainedModel): + config_class = PretrainedConfig + + def __init__(self, config: PretrainedConfig): + super().__init__(config) + self.my_layer = TransformerLayer( + hidden_size=320, + num_attention_heads=16, + ffn_hidden_size=1024, + layer_number=None, + ) + + def forward(self, hidden_states, attention_mask): + return self.my_layer(hidden_states, attention_mask) + + +def test_save_hf_model(tmp_path): + model = SimpleTEModel(PretrainedConfig()) + model.save_pretrained(tmp_path / "simple_te_model") + + +@pytest.mark.xfail(reason="This test is failing until huggingface/transformers#38155 is merged.") +def test_save_and_load_hf_model(tmp_path): + model = SimpleTEModel(PretrainedConfig()) + model.save_pretrained(tmp_path / "simple_te_model") + del model + model = SimpleTEModel.from_pretrained(tmp_path / "simple_te_model") + assert model is not None diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 3b56796cc..6d9a4412e 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -120,6 +120,20 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq mask_types = ["causal", "no_mask"] +NVTE_TEST_NVINSPECT_ENABLED = os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False) + +if NVTE_TEST_NVINSPECT_ENABLED: + # The numerics of all the layers should work the same, + # when debug=True. I fed them with dummy feature + # to prevent switching off debug, which can happen if + # no feature is active. + import nvdlfw_inspect.api as debug_api + + debug_api.initialize( + os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"], + feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"], + ) + fp8_recipes = [ recipe.MXFP8BlockScaling(), recipe.DelayedScaling(), @@ -621,6 +635,8 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") if recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) @@ -741,6 +757,8 @@ def test_gpt_full_activation_recompute( use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ) if fp8 and recipe.float8_current_scaling() and use_cast_transpose_triton: pytest.skip("Float8 Current Scaling unsupported for full recompute.") + if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") if recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) @@ -1957,6 +1975,8 @@ def test_grouped_linear_accuracy( pytest.skip(reason_for_no_fp8) if fp8 and recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) @@ -2155,6 +2175,8 @@ def test_padding_grouped_linear_accuracy( pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") if recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) @@ -2276,6 +2298,8 @@ def test_gpt_cuda_graph(dtype, bs, model): if use_fa: pytest.skip(f"ROCm flash attention does not support cuda graph with {dtype}") + if NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("Cuda Graphs are not supported in debug mode.") config = model_configs[model] sigma = 0.023 @@ -2373,6 +2397,8 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe): pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") if recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 02ff9367a..8d379be7c 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -8,22 +8,32 @@ import pytest import torch +import warnings import transformer_engine.common.recipe import transformer_engine.pytorch as te +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer import transformer_engine_torch as tex from transformer_engine.pytorch.fp8 import ( FP8GlobalStateManager, _amax_and_scale_update, - get_default_fp8_recipe, + fp8_model_init, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.utils import is_fp8_fnuz +from transformer_engine.pytorch import Linear +from transformer_engine.pytorch.distributed import fp8_autocast +from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling import transformer_engine_torch as tex # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) # FP8 per tensor delayed scaling @@ -370,3 +380,96 @@ def setup_fp8_meta(): ) torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale) + + @pytest.mark.parametrize( + "model_init_recipe", + [ + pytest.param( + MXFP8BlockScaling(), + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + ), + pytest.param( + Float8BlockScaling(), + marks=pytest.mark.skipif( + not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling + ), + ), + ], + ) + def test_check_for_weight_tensor_and_recipe_correspondence(self, model_init_recipe): + with fp8_model_init(enabled=True, recipe=model_init_recipe): + linear = Linear(32, 32).cuda() + + x = torch.randn(32, 32, device="cuda") + with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling()): + with pytest.raises(RuntimeError) as excinfo: + _ = linear(x) + assert "Recipe mismatch for " in str(excinfo.value) + + @pytest.mark.parametrize( + "target_recipe_class, expected_quantizer_type, available_flag, reason", + [ + pytest.param( + MXFP8BlockScaling, + MXFP8Quantizer, + mxfp8_available, + reason_for_no_mxfp8, + id="DelayedScaling->MXFP8BlockScaling", + ), + pytest.param( + Float8BlockScaling, + Float8BlockQuantizer, + fp8_block_scaling_available, + reason_for_no_fp8_block_scaling, + id="DelayedScaling->Float8BlockScaling", + ), + ], + ) + def test_dynamic_recipe_update( + self, target_recipe_class, expected_quantizer_type, available_flag, reason + ): + if not available_flag: + pytest.skip(reason) + + in_features = 32 + out_features = 32 + batch_size = 32 + linear = Linear(in_features, out_features).cuda() + initial_recipe = DelayedScaling() + + # Run initial iterations with DelayedScaling + for _ in range(3): + x = torch.randn(batch_size, in_features, device="cuda") + with fp8_autocast(enabled=True, fp8_recipe=initial_recipe): + y = linear(x) + loss = y.mean() + loss.backward() + + for quantizer in linear.quantizers["scaling_fwd"]: + assert isinstance(quantizer, Float8Quantizer) + + # Change recipe + target_recipe = target_recipe_class() + + # Run subsequent iterations with the target recipe + for i in range(3): + x = torch.randn(batch_size, in_features, device="cuda") + if i == 0: + # Expect a warning on the first iteration with the new recipe + with pytest.warns(UserWarning, match="Recipe type changed"): + with fp8_autocast(enabled=True, fp8_recipe=target_recipe): + y = linear(x) + for quantizer in linear.quantizers["scaling_fwd"]: + assert isinstance(quantizer, expected_quantizer_type) + else: + # No warning expected on subsequent iterations + with warnings.catch_warnings(): + warnings.simplefilter("error") # Raise error if unexpected warning occurs + with fp8_autocast(enabled=True, fp8_recipe=target_recipe): + y = linear(x) + loss = y.mean() + loss.backward() + + # Final check + for quantizer in linear.quantizers["scaling_fwd"]: + assert isinstance(quantizer, expected_quantizer_type) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a9e2e056e..9a4187378 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -351,18 +351,7 @@ else() endif() if(USE_FUSED_ATTN_CK) - if(NOT DEFINED CK_FUSED_ATTN_PATH) - set(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT ${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT} CACHE STRING "ck float to bf16 conversion rounding") - add_subdirectory(ck_fused_attn ${CMAKE_CURRENT_BINARY_DIR}/ck_fused_attn) - else() - # Use CK built during initial TE building/installation - # When only need rebuild TE library itself - unset(CK_FUSED_ATTN_LIB CACHE) - find_library(CK_FUSED_ATTN_LIB NAMES ck_fused_attn PATHS ${CK_FUSED_ATTN_PATH}/lib REQUIRED NO_DEFAULT_PATH) - add_library( ck_fused_attn STATIC IMPORTED ) - set_target_properties( ck_fused_attn PROPERTIES IMPORTED_LOCATION ${CK_FUSED_ATTN_LIB} ) - target_include_directories(ck_fused_attn INTERFACE ${CK_FUSED_ATTN_PATH}/include) - endif() + add_subdirectory(ck_fused_attn ${CMAKE_CURRENT_BINARY_DIR}/ck_fused_attn) endif() find_package(hip) @@ -498,8 +487,16 @@ install(TARGETS transformer_engine DESTINATION .) set_target_properties(transformer_engine PROPERTIES INSTALL_RPATH "$ORIGIN/lib;$ORIGIN/transformer_engine/lib") if (USE_ROCM) + if("$ENV{ROCM_PATH}" STREQUAL "") + set(ROCM_PATH "/opt/rocm") + else() + set(ROCM_PATH "$ENV{ROCM_PATH}") + endif() + file(READ "${ROCM_PATH}/.info/version" ROCM_VER) + string(STRIP "${ROCM_VER}" ROCM_VER) + string(REGEX MATCH "^[0-9]+\\.[0-9]+" ROCM_VER "${ROCM_VER}") file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/build_info.txt" - "ROCM_VERSION: ${hip_VERSION_MAJOR}.${hip_VERSION_MINOR}\n" + "ROCM_VERSION: ${ROCM_VER}\n" "GPU_TARGETS: ${CMAKE_HIP_ARCHITECTURES}\n" ) install(FILES "${CMAKE_CURRENT_BINARY_DIR}/build_info.txt" DESTINATION "transformer_engine/") diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 871723a0e..49395fa23 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -113,9 +113,10 @@ def _get_shared_object_file(library: str) -> Path: # Case 1: Typical user workflow: Both locations are the same, return any result. if te_install_dir == site_packages_dir: - assert ( - so_path_in_install_dir is not None - ), f"Could not find shared object file for Transformer Engine {library} lib." + if so_path_in_install_dir is None: + raise FileNotFoundError( + f"Could not find shared object file for Transformer Engine {library} lib." + ) return so_path_in_install_dir # Case 2: ERR! Both locations are different but returned a valid result. @@ -123,13 +124,12 @@ def _get_shared_object_file(library: str) -> Path: # editable builds. In case developers are executing inside a TE directory via # an inplace build, and then move to a regular build, the local shared object # file will be incorrectly picked up without the following logic. - if so_path_in_install_dir is not None and so_path_in_default_dir is not None: - raise RuntimeError( - f"Found multiple shared object files: {so_path_in_install_dir} and" - f" {so_path_in_default_dir}. Remove local shared objects installed" - f" here {so_path_in_install_dir} or change the working directory to" - "execute from outside TE." - ) + assert so_path_in_install_dir is None or so_path_in_default_dir is None, ( + f"Found multiple shared object files: {so_path_in_install_dir} and" + f" {so_path_in_default_dir}. Remove local shared objects installed" + f" here {so_path_in_install_dir} or change the working directory to" + "execute from outside TE." + ) # Case 3: Typical dev workflow: Editable install if so_path_in_install_dir is not None: @@ -139,7 +139,9 @@ def _get_shared_object_file(library: str) -> Path: if so_path_in_default_dir is not None: return so_path_in_default_dir - raise RuntimeError(f"Could not find shared object file for Transformer Engine {library} lib.") + raise FileNotFoundError( + f"Could not find shared object file for Transformer Engine {library} lib." + ) @functools.lru_cache(maxsize=None) @@ -207,6 +209,7 @@ def load_framework_extension(framework: str): @functools.lru_cache(maxsize=None) def _get_sys_extension(): system = platform.system() + if system == "Linux": extension = "so" elif system == "Darwin": diff --git a/transformer_engine/common/ck_fused_attn/CMakeLists.txt b/transformer_engine/common/ck_fused_attn/CMakeLists.txt index 2a2afa328..c44a930e6 100644 --- a/transformer_engine/common/ck_fused_attn/CMakeLists.txt +++ b/transformer_engine/common/ck_fused_attn/CMakeLists.txt @@ -1,20 +1,15 @@ # Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT -#TODO: compile to a shared library -cmake_minimum_required(VERSION 3.28) -set(CMAKE_CXX_STANDARD 20) -#TODO: remove after figuring out how to install clang-scan-deps -set(CMAKE_CXX_SCAN_FOR_MODULES OFF) +cmake_minimum_required(VERSION 3.21) +set(CMAKE_CXX_STANDARD 17) project(ck_fused_attn LANGUAGES HIP CXX) -# remove files that should be regenerated -file(REMOVE_RECURSE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp ${CMAKE_CURRENT_BINARY_DIR}/gen_src/blob_list.txt) -# create gen_src and gen_src/tmp directories if needed -file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp) +set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared lib install prefix in TE") set(__AITER_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/aiter") +set(__AITER_TEST_DIR "${__AITER_SOURCE_DIR}/op_tests/cpp/mha") set(__CK_SOURCE_DIR "${__AITER_SOURCE_DIR}/3rdparty/composable_kernel") # so far, there are only gfx942 and gfx950 v3 kernels @@ -37,82 +32,22 @@ message(STATUS "AITER V3_ASM_ARCHS: ${V3_ASM_ARCHS}") list(JOIN V3_ASM_ARCHS ";" V3_ASM_ARCHS_STR) set(ENV{GPU_ARCHS} "${V3_ASM_ARCHS_STR}") -# generate v2 (CK) kernels -# fwd kernels list -execute_process( - COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py - --api fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_blob_list.txt --receipt 600 -) -execute_process( - COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py - --api fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_splitkv_blob_list.txt --receipt 600 -) -execute_process( - COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py - --api batch_prefill --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_batch_prefill_blob_list.txt --receipt 600 -) - -# bwd kernels list -execute_process( - COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py - --api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/bwd_blob_list.txt --receipt 600 -) - -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS) -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_splitkv_blob_list.txt FMHA_FWD_SPLITKV_GEN_BLOBS) -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_batch_prefill_blob_list.txt FMHA_FWD_BATCH_PREFILL_GEN_BLOBS) -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) - -# generate the actual fwd kernel cpp files -execute_process( - COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py - --api fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600 -) - -execute_process( - COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py - --api fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600 -) - -execute_process( - COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py - --api batch_prefill --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600 -) - -# generate the aiter fwd interface cpp file -execute_process( - COMMAND python3 ${__AITER_SOURCE_DIR}/csrc/cpp_itfs/mha_fwd_generate.py - --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 5 -) - -# generate the actual bwd kernel cpp files -execute_process( - COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py - --api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600 -) - -# generate the aiter bwd interface cpp file -execute_process( - COMMAND python3 ${__AITER_SOURCE_DIR}/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py - --filter *@*_ndeterministic@*_nbias*_dropout*_ndeterministic* --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp -) - -execute_process( - COMMAND python3 ${__AITER_SOURCE_DIR}/csrc/cpp_itfs/mha_bwd_generate.py - --receipt 3 --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp -) - -# generate fwd/bwd v3 kernels for each requested rocm arch -foreach(CK_TARGET_ARCH IN LISTS V3_ASM_ARCHS) - execute_process( - COMMAND python3 ${__AITER_SOURCE_DIR}/hsa/${CK_TARGET_ARCH}/fmha_v3_fwd/codegen.py - --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp - ) +if(NOT DEFINED AITER_MHA_PATH) + # delete the existing aiter/jit/build dir for a clean build + file(REMOVE_RECURSE "${__AITER_SOURCE_DIR}/aiter/jit/build") + # compile the libmha_fwd.so and libmha_bwd.so + set(ENV{AITER_LOG_MORE} 1) + # fp32 to bf16 cvt env still required for MI300X + set(ENV{CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT} ${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT}) execute_process( - COMMAND python3 ${__AITER_SOURCE_DIR}/hsa/${CK_TARGET_ARCH}/fmha_v3_bwd/codegen.py - --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp + COMMAND python3 ${__AITER_TEST_DIR}/compile.py ) -endforeach() + # libmha_fwd.so and libmha_bwd.so will be under 3rdparty/aiter/op_tests/cpp/mha + set(__AITER_MHA_PATH ${__AITER_TEST_DIR}) +else() + # use pre-built libmha_fwd.so libmha_bwd.so + set(__AITER_MHA_PATH ${AITER_MHA_PATH}) +endif() set(ck_fused_attn_SOURCES) list(APPEND ck_fused_attn_SOURCES @@ -120,75 +55,18 @@ list(APPEND ck_fused_attn_SOURCES src/ck_fused_attn_bwd.cpp src/ck_fused_attn_utils.cpp) -foreach(blob ${FMHA_FWD_GEN_BLOBS}) - file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob}) - file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT) -endforeach() -list(APPEND ck_fused_attn_SOURCES ${FMHA_FWD_GEN_BLOBS}) - -foreach(blob ${FMHA_FWD_SPLITKV_GEN_BLOBS}) - file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob}) - file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT) -endforeach() -list(APPEND ck_fused_attn_SOURCES ${FMHA_FWD_SPLITKV_GEN_BLOBS}) - -foreach(blob ${FMHA_FWD_BATCH_PREFILL_GEN_BLOBS}) - file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob}) - file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT) -endforeach() -list(APPEND ck_fused_attn_SOURCES ${FMHA_FWD_BATCH_PREFILL_GEN_BLOBS}) - -foreach(blob ${FMHA_BWD_GEN_BLOBS}) - file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob}) - file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT) -endforeach() -list(APPEND ck_fused_attn_SOURCES ${FMHA_BWD_GEN_BLOBS}) - -# add generated cpp files into ck_fused_attn_sources -set(MHA_BWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/mha_bwd.cpp") -set(MHA_FWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/mha_fwd.cpp") - -file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${MHA_BWD_SRC}) -file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${MHA_BWD_SRC} ONLY_IF_DIFFERENT) - -file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${MHA_FWD_SRC}) -file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${MHA_FWD_SRC} ONLY_IF_DIFFERENT) - -list(APPEND ck_fused_attn_SOURCES ${MHA_BWD_SRC} ${MHA_FWD_SRC}) - -foreach(CK_TARGET_ARCH IN LISTS V3_ASM_ARCHS) - set(ASM_MHA_FWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/asm_fmha_fwd_v3_${CK_TARGET_ARCH}.cpp") - set(ASM_MHA_BWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/asm_fmha_bwd_v3_${CK_TARGET_ARCH}.cpp") - - file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${ASM_MHA_BWD_SRC}) - file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${ASM_MHA_BWD_SRC} ONLY_IF_DIFFERENT) - - file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${ASM_MHA_FWD_SRC}) - file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${ASM_MHA_FWD_SRC} ONLY_IF_DIFFERENT) - list(APPEND ck_fused_attn_SOURCES ${ASM_MHA_BWD_SRC} ${ASM_MHA_FWD_SRC}) -endforeach() - -# remove all previously generated temporary files -file(REMOVE_RECURSE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp) - message(STATUS "Found the following fused attention files:") foreach(file ${ck_fused_attn_SOURCES}) message(STATUS " ${file}") endforeach() -add_library(ck_fused_attn STATIC ${ck_fused_attn_SOURCES}) +add_library(ck_fused_attn SHARED ${ck_fused_attn_SOURCES}) set(CK_FUSED_ATTN_COMPILE_OPTIONS) list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS - -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -DCK_TILE_FMHA_FWD_SPLITKV_API=1-DCK_TILE_FMHA_FWD_APPENDKV_API=0 - -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT} - -fgpu-flush-denormals-to-zero -ftemplate-backtrace-limit=0 -fPIC - -Wno-undefined-func-template -Wno-float-equal -Wno-gnu-line-marker -Wunused-variable -Wuninitialized - "SHELL:-mllvm -enable-post-misched=0" "SHELL:-mllvm -amdgpu-early-inline-all=true" - "SHELL:-mllvm -amdgpu-function-calls=false" "SHELL:-mllvm -amdgpu-coerce-illegal-types=1" - "SHELL:-mllvm --amdgpu-kernarg-preload-count=16") + -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT}) -foreach(CK_TARGET_ARCH IN LISTS CMAKE_HIP_ARCHITECTURES) - list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS --offload-arch=${CK_TARGET_ARCH}) +foreach(ARCH IN LISTS V3_ASM_ARCHS) + list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS --offload-arch=${ARCH}) endforeach() set(CK_INCLUDE_DIR "${__CK_SOURCE_DIR}/include") @@ -216,18 +94,22 @@ target_include_directories(ck_fused_attn PRIVATE ${CK_INCLUDE_DIR} ${__CK_SOURCE target_include_directories(ck_fused_attn PRIVATE ${AITER_INCLUDE_DIR}) find_package(hip) -list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64) +list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64 ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so) target_link_libraries(ck_fused_attn PUBLIC ${ck_fused_attn_LINKER_LIBS}) target_compile_options(ck_fused_attn PRIVATE ${CK_FUSED_ATTN_COMPILE_OPTIONS}) +set_target_properties(ck_fused_attn PROPERTIES INSTALL_RPATH "$ORIGIN") +install(FILES ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) +install(TARGETS ck_fused_attn DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) # copy v3 kernels to destination foreach(ARCH IN LISTS V3_ASM_ARCHS) install(DIRECTORY ${__AITER_SOURCE_DIR}/hsa/${ARCH}/fmha_v3_fwd - DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine/aiter/${ARCH}/ + DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib/aiter/${ARCH}/ PATTERN "codegen.py" EXCLUDE) install(DIRECTORY ${__AITER_SOURCE_DIR}/hsa/${ARCH}/fmha_v3_bwd - DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine/aiter/${ARCH}/ + DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib/aiter/${ARCH}/ PATTERN "codegen.py" EXCLUDE) endforeach() + diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 840db7b86..2b717ace0 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -920,8 +920,8 @@ hipError_t ck_attn_varlen_bwd( cu_seqlen_q_ptr,//cu_seqlen_q cu_seqlen_kv_ptr,//cu_seqlen_kv nullptr, /* seqlen_k_ptr */ - 0, //seqlen_q, unused in group mode - 0, //seqlen_kv, unused in group mode + max_seqlen_q, //seqlen_q, unused in group mode + max_seqlen_k, //seqlen_kv, unused in group mode batch, max_seqlen_q, max_seqlen_k, diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index 2829175ab..c87a3db6c 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -209,9 +209,13 @@ hipError_t ck_attn_fwd( nullptr,//rand_val_ptr lse_ptr, o_ptr, - nullptr,//cu_seqlen_q - nullptr,//cu_seqlen_kv - nullptr, /* seqlen_k_ptr */ + nullptr, //cu_seqlen_q + nullptr, //cu_seqlen_kv + nullptr, //seqstart_q_ptr + nullptr, //seqstart_k_ptr + nullptr, //seqlen_k_ptr + nullptr, //seqstart_padded_q_ptr + nullptr, //seqstart_padded_k_ptr max_seqlen_q, max_seqlen_k, batch, @@ -308,6 +312,7 @@ hipError_t ck_attn_varlen_fwd( ck_tile::index_t nhead_k = hg; ck_tile::index_t hdim_v = d_v; ck_tile::index_t max_seqlen_q = s_q; + ck_tile::index_t max_seqlen_kv = s_kv; float scale_s = scaling_factor; float scale_p = 1.f; @@ -379,11 +384,15 @@ hipError_t ck_attn_varlen_fwd( nullptr,//rand_val_ptr lse_thd_ptr, o_ptr, - cu_seqlen_q_ptr,//cu_seqlen_q - cu_seqlen_kv_ptr,//cu_seqlen_kv - nullptr, /* seqlen_k_ptr */ - 0, //seqlen_q, unused in group mode - 0, //seqlen_kv, unused in group mode + nullptr, //cu_seqlen_q + nullptr, //cu_seqlen_kv + cu_seqlen_q_ptr, //seqstart_q_ptr + cu_seqlen_kv_ptr, //seqstart_k_ptr + nullptr, //seqlen_k_ptr + nullptr, //seqstart_padded_q_ptr + nullptr, //seqstart_padded_k_ptr + max_seqlen_q, //seqlen_q, unused in group mode + max_seqlen_kv, //seqlen_kv, unused in group mode batch, max_seqlen_q, hdim_q, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 72696fbd9..b38249f5b 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -557,6 +557,7 @@ void fused_attn_ck_fwd_impl( nvte_log_ck_config = true; } bool nvte_ck_uses_fwd_v3 = getenv("NVTE_CK_USES_FWD_V3", 0); + bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD; // extract the qkv and o storage bytes to allocate buffer for padding removing diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 36cbcd330..07b256972 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -94,7 +94,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla A.scaling_mode == B.scaling_mode || (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) || (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D), - "Inputs A and B to GEMM need to have compatible scaling modes!"); + "Inputs A and B to GEMM need to have compatible scaling modes, but got A.scaling_mode = " + + to_string(A.scaling_mode) + ", B.scaling_mode = " + to_string(B.scaling_mode)); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); GemmParam ret; diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index dcba674e4..9de4cfad7 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -750,8 +750,8 @@ protected: std::getline(is, type_b, csv_sep); std::getline(is, type_d, csv_sep); std::getline(is, bias_type, csv_sep); - is >> cfg.lda >> c >> cfg.ldb >> c >> cfg.ldd >> c >> cfg.scaling_mode >> c; std::getline(is, aux_type, csv_sep); + is >> cfg.lda >> c >> cfg.ldb >> c >> cfg.ldd >> c >> cfg.scaling_mode >> c; std::getline(is, epi, csv_sep); std::getline(is, comp, csv_sep); std::getline(is, scale, csv_sep); @@ -1089,7 +1089,7 @@ void hipblaslt_gemm(const Tensor *inputA, // Note: gelu fusion is available for certain config from rocm 7.0 // amax(D) either (next op is high precision). #if HIPBLASLT_VERSION_MAJOR > 0 || HIPBLASLT_VERSION_MINOR >= 15 - hipblasLtMatmulMatrixScale_t scaling_mode; + hipblasLtMatmulMatrixScale_t scaling_mode = (hipblasLtMatmulMatrixScale_t)0; #else constexpr int scaling_mode = 0; #endif diff --git a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h index 678ffe919..649b5ced5 100644 --- a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h +++ b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h @@ -17,23 +17,21 @@ extern "C" { #endif -/*! \brief Transposes the input, providing the option to immediately exit the kernel - * based on the value of the 'noop' tensor. +/*! \brief Transposes the input. * - * \param[in] input Input tensor. - * \param[in] noop Noop tensor. + * \param[in] input Input tensor to be cast. + * \param[in] noop If this single element tensor has non-zero value, kernel will exit immediately. * \param[in,out] output Output tensor. * \param[in] stream CUDA stream used for the operation. */ void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, cudaStream_t stream); -/*! \brief Casts and transposes the input, providing the option to immediately exit the kernel - * based on the value of the 'noop' tensor. +/*! \brief Casts and transposes the input. * - * \param[in] input Input tensor. - * \param[in] noop Noop tensor. - * \param[in,out] output Output tensor. + * \param[in] input Input tensor to be cast. + * \param[in] noop If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] output Output quantized tensor. * \param[in] stream CUDA stream used for the operation. */ void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 3400eaaeb..f63ee636d 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -634,6 +634,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #ifndef __HIP_PLATFORM_AMD__ /*! \brief Update the RNG state with the seed and calculated offset. + * + * \warning This API is **experimental** and subject to change. * * \param[in] rng_state_dst RNG state to store seed and offset. * \param[in] seed Seed for RNG state. @@ -666,6 +668,8 @@ void nvte_populate_rng_state_async(void *rng_state_dst, const void *const seed, #endif /*! \brief Get KV format for a given QKV layout. + * + * \warning This API is **experimental** and subject to change. * * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. * \param[in] workspace Workspace tensor. @@ -675,48 +679,187 @@ void nvte_populate_rng_state_async(void *rng_state_dst, const void *const seed, uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len, cudaStream_t stream); +/*! \brief Set the seed and offset for RNG state. + * + * \warning This API is **experimental** and subject to change. + * + * \param[out] rng_state_ptr A size 2 array storing the RNG's seed and offset respectively. + * \param[in] captured Whether a CUDA graph is being captured. + * \param[in] seed_ptr Seed pointer. + * \param[in] seed_val Seed value. + * \param[in] offset_ptr Offset pointer. + * \param[in] offset_val Offset value. + * \param[in] offset_intragraph Intragraph offset in RNG states. For use with CUDA Graphs. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t *seed_ptr, uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val, uint32_t offset_intragraph, cudaStream_t stream); +/*! \brief Copy keys and values into the KV cache. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] new_k Key tensor. + * \param[in] new_v Value tensor. + * \param[out] k_cache Key cache. + * \param[out] v_cache Value cache. + * \param[in] page_table Page table for K cache, [batch_size, max_pages_per_seq]. + * \param[in] cu_new_lens Cumulative sequence lengths. + * \param[in] cu_cached_lens Cached cumulative sequence lengths. + * \param[in] qkv_format QKV format, e.g. sbhd. + * \param[in] b Batch size. + * \param[in] max_ctx_len Maximum context length. + * \param[in] max_seq_len Maximum sequence length. + * \param[in] max_pages_per_seq Maximum number of pages per sequence. + * \param[in] is_non_paged Whether the cache is paged or not. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cache, NVTETensor v_cache, NVTETensor page_table, NVTETensor cu_new_lens, NVTETensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b, int max_ctx_len, int max_seq_len, int max_pages_per_seq, int is_non_paged, cudaStream_t stream); +/*! \brief Extract the first half (half_idx=0) or second half (half_idx=1) of a THD tensor. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] tensor Input tensor. + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[out] half Output tensor. + * \param[in] half_idx Whether to read first or second half of input tensor. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu_seqlens, NVTETensor half, int half_idx, cudaStream_t stream); +/*! \brief Correct the second half of the softmax LSE (LogSumExp) for context parallelism. + * + * \warning This API is **experimental** and subject to change. + * + * \param[out] lse Output tensor. + * \param[in] lse_per_step Input tensor. + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[in] lse_packed Whether or not lse_per_step is packed. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step, const NVTETensor &cu_seqlens, int lse_packed, cudaStream_t stream); +/*! \brief Read the second half of the softmax LSE (LogSumExp) for context parallelism. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] lse Input tensor. + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[out] half_lse Output tensor. + * \param[in] lse_packed Whether or the softmax LSE is in packed format. + * \param[in] second_half_lse_seqlen Sequence length. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens, NVTETensor half_lse, int lse_packed, int second_half_lse_seqlen, cudaStream_t stream); +/*! \brief Correct the THD format output of context parallelism in forward pass. + * + * \warning This API is **experimental** and subject to change. + * + * \param[out] out Output tensor. + * \param[in] out_per_step THD format output of context parallelism in forward pass. + * \param[in] lse Softmax LSE. + * \param[in] lse_per_step Softmax LSE per step. + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[in] only_second_half Whether or not to correct only second half. + * \param[in] lse_packed Whether or the softmax LSE is in packed format. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step, const NVTETensor &lse, const NVTETensor &lse_per_step, const NVTETensor &cu_seqlens, int only_second_half, int lse_packed, cudaStream_t stream); +/*! \brief Correct the THD format output of context parallelism in forward pass. + * + * \warning This API is **experimental** and subject to change. + * + * \param[out] grad Output tensor. + * \param[in] grad_per_step THD format gradient of context parallelism. + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[in] first_half One of ("add", "copy", "none") correction op for first half. + * \param[in] second_half One of ("add", "copy", "none") correction op for second half. + Must be different from first_half. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step, const NVTETensor &cu_seqlens, const char *first_half, const char *second_half, cudaStream_t stream); +/*! \brief Generate partitioned indices for inputs in THD format. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[out] output Output tensor. + * \param[in] total_tokens Total number of tokens. + * \param[in] world_size Total number of devices for context parallelism. + * \param[in] rank Device ID for current device. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output, int total_tokens, int world_size, int rank, cudaStream_t stream); +/*! \brief Convert tensor from THD to BSHD format. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] tensor Input tensor. + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[out] new_tensor Output tensor. + * \param[in] b Batch size. + * \param[in] max_seq_len Maximum sequence length. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor, int b, int max_seq_len, cudaStream_t stream); +/*! \brief Convert tensor from BSHD to THD format. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] tensor Input tensor. + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[out] new_tensor Output tensor. + * \param[in] b Batch size. + * \param[in] max_seq_len Maximum sequence length. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor, int t, cudaStream_t stream); +/*! \brief Prepare QKV tensor for Flash Attention forward kernel. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] qkvi Input tensor. + * \param[out] qkv Output tensor. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream); +/*! \brief Prepare QKV tensor for Flash Attention backward kernel. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] q Input query tensor. + * \param[in] k Input key tensor. + * \param[in] v Input value tensor. + * \param[out] qkv Output tensor. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv, cudaStream_t stream); diff --git a/transformer_engine/common/include/transformer_engine/multi_tensor.h b/transformer_engine/common/include/transformer_engine/multi_tensor.h index e78b31d77..c21fd2627 100644 --- a/transformer_engine/common/include/transformer_engine/multi_tensor.h +++ b/transformer_engine/common/include/transformer_engine/multi_tensor.h @@ -17,6 +17,25 @@ extern "C" { #endif +/*! \brief Computes L2 norm for a list of tensors. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] output Scratch space. Required size grows with number of inputs. + * \param[in] output_per_tensor Fixed size auxilliary scratch space. + * \param[out] ret L2 norm of all inputs. + * \param[out] ret_per_tensor L2 norm for each tensor. + * \param[in] per_tensor Whether to calculate per tensor or cumulative norm. + * \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret, @@ -24,6 +43,28 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen int max_chunks_per_tensor, const int device_id, cudaStream_t stream); +/*! \brief Computes L2 norm for a list of tensors after unscaling. + * + * Unscaling is only done for computing the L2 norm. The tensors themselves are not updated. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] output Scratch space. Required size grows with number of inputs. + * \param[in] output_per_tensor Fixed size auxilliary scratch space. + * \param[out] ret L2 norm of all inputs. + * \param[out] ret_per_tensor L2 norm for each tensor. + * \param[in] inv_scale Scalar for the unscaling operation. + * \param[in] per_tensor Whether to calculate per tensor or cumulative norm. + * \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor output, @@ -32,6 +73,27 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, int per_tensor, int max_chunks_per_tensor, const int device_id, cudaStream_t stream); +/*! \brief Compute and apply gradient update to parameters for Adam optimizer. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] lr Learning rate. + * \param[in] beta1 Coefficient for first moment of gradient. + * \param[in] beta2 Coefficient for second moment of gradient. + * \param[in] epsilon Term added to the denominator for numerical stability. + * \param[in] step Iteration counter. + * \param[in] mode Whether to use AdamW (L2 penalty applied to params). + * \param[in] bias_correction Whether to apply correction factor for moment estimates. + * \param[in] weight_decay L2 penalty for weight decay. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2, @@ -39,12 +101,57 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso const int bias_correction, const float weight_decay, const int device_id, cudaStream_t stream); +/*! \brief Compute and apply gradient update to parameters for Adam optimizer + * where the master parameters only store the remainder bits. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] lr Learning rate. + * \param[in] beta1 Coefficient for first moment of gradient. + * \param[in] beta2 Coefficient for second moment of gradient. + * \param[in] epsilon Term added to the denominator for numerical stability. + * \param[in] step Iteration counter. + * \param[in] mode Whether to use AdamW (L2 penalty applied to params). + * \param[in] bias_correction Whether to apply correction factor for moment estimates. + * \param[in] weight_decay L2 penalty for weight decay. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_adam_param_remainder_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, const int device_id, cudaStream_t stream); +/*! \brief Compute and apply gradient update to parameters for Adam optimizer + * when model parameters are in Float8 precision. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] lr Learning rate. + * \param[in] beta1 Coefficient for first moment of gradient. + * \param[in] beta2 Coefficient for second moment of gradient. + * \param[in] epsilon Term added to the denominator for numerical stability. + * \param[in] step Iteration counter. + * \param[in] mode Whether to use AdamW (L2 penalty applied to params). + * \param[in] bias_correction Whether to apply correction factor for moment estimates. + * \param[in] weight_decay L2 penalty for weight decay. + * \param[in] fp8_dtype FP8 data type for model parameters. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, const float lr, @@ -53,28 +160,125 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, const float weight_decay, const NVTEDType fp8_dtype, const int device_id, cudaStream_t stream); +/*! \brief Compute and apply gradient update to parameters for Adam optimizer + * with CUDA graph support and LR scheduling. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] lr Learning rate. + * \param[in] beta1 Coefficient for first moment of gradient. + * \param[in] beta2 Coefficient for second moment of gradient. + * \param[in] epsilon Term added to the denominator for numerical stability. + * \param[in] step Iteration counter. + * \param[in] mode Whether to use AdamW (L2 penalty applied to params). + * \param[in] bias_correction Whether to apply correction factor for moment estimates. + * \param[in] weight_decay L2 penalty for weight decay. + * \param[in] inv_scale Scalar for the unscaling operation. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_adam_capturable_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const float epsilon, NVTETensor step, const int mode, const int bias_correction, const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream); +/*! \brief Compute and apply gradient update to parameters for Adam optimizer + * with CUDA graph support, LR scheduling, and FP32 master weights. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] lr Learning rate. + * \param[in] beta1 Coefficient for first moment of gradient. + * \param[in] beta2 Coefficient for second moment of gradient. + * \param[in] epsilon Term added to the denominator for numerical stability. + * \param[in] step Iteration counter. + * \param[in] mode Whether to use AdamW (L2 penalty applied to params). + * \param[in] bias_correction Whether to apply correction factor for moment estimates. + * \param[in] weight_decay L2 penalty for weight decay. + * \param[in] inv_scale Scalar for the unscaling operation. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_adam_capturable_master_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const float epsilon, NVTETensor step, const int mode, const int bias_correction, const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream); +/*! \brief Compute and apply gradient update to parameters for SGD optimizer. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] wd Weight decay (L2 penalty). + * \param[in] momentum Momentum factor. + * \param[in] dampening Dampening factor. + * \param[in] lr Learning rate. + * \param[in] nesterov Whether or not to enable nesterov momentum. + * \param[in] first_run Whether momentum buffers have been initialized. + * \param[in] wd_after_momentum Whether to applied weight decay after momentum update. + * \param[in] scale Scalar for the scaling operation. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, float wd, float momentum, float dampening, float lr, int nesterov, int first_run, int wd_after_momentum, float scale, const int device_id, cudaStream_t stream); +/*! \brief Check overflow and scale a list of tensors. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] scale Scalar for the scaling operation. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, float scale, const int device_id, cudaStream_t stream); +/*! \brief Check overflow and scale a list of tensors. + * + * \warning This API is **experimental** and subject to change. + * \warning Argument device_id is deprecated and will be removed in a future release. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] max_fp8 Maximum representible value in underlying FP8 format. + * \param[in] force_pow_2_scales Ensure scaling factors are a power of 2. + * \param[in] epsilon Term added to the denominator for numerical stability. + * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. + * \param[in] stream CUDA stream used for this operation. + */ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon, diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 3be7d5004..442e11216 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -42,8 +42,6 @@ namespace transformer_engine { namespace normalization { #ifndef __HIP_PLATFORM_AMD__ -bool& use_zero_centered_gamma_in_weight_dtype(); - cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) { return training ? cudnn_frontend::NormFwdPhase_t::TRAINING : cudnn_frontend::NormFwdPhase_t::INFERENCE; @@ -53,13 +51,17 @@ cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) { TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma, - bool is_tuned, NVTEScalingMode mode, bool training) { - // TODO: Add scaling_mode to general_key is needed - uint64_t general_key = static_cast(itype) | (static_cast(otype) << 3) | - (static_cast(ctype) << 6) | (static_cast(wtype) << 9) | - (uint32_t(NormType) << 12) | (uint32_t(NormStage)) << 14 | - (uint32_t(NormBackend) << 16) | (uint32_t(zero_centered_gamma) << 18) | - (uint32_t(mode) << 19) | (uint32_t(training) << 22); + bool is_tuned, NVTEScalingMode mode, bool training, + bool gamma_in_weight_dtype) { + static_assert(NVTE_INVALID_SCALING < 1024, + "This function assumes at most 10 bits used in the scaling mode."); + static_assert(kNVTENumTypes < 32, "This function assumes at most 5 bits used in the NVTEDType"); + uint64_t general_key = static_cast(itype) | (static_cast(otype) << 5) | + (static_cast(ctype) << 10) | + (static_cast(wtype) << 15) | (uint64_t(NormType) << 20) | + (uint64_t(NormStage)) << 22 | (uint64_t(NormBackend) << 24) | + (uint64_t(zero_centered_gamma) << 26) | (uint64_t(mode) << 27) | + (uint64_t(training) << 37) | (uint64_t(gamma_in_weight_dtype) << 38); return std::make_tuple(general_key, batch_size, hidden_size, is_tuned); } @@ -502,11 +504,12 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan( NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, const size_t batch_size, const size_t hidden_size, const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned, - const NVTEScalingMode mode, const bool training) { + const NVTEScalingMode mode, const bool training, const bool gamma_in_weight_dtype) { const DType ctype = DType::kFloat32; bool is_tuned = is_aligned && (batch_size % 4 == 0); - auto key = get_key(NormBackend, NormType, NormStage, wtype, itype, otype, ctype, batch_size, - hidden_size, zero_centered_gamma, is_tuned, mode, training); + auto key = + get_key(NormBackend, NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, + zero_centered_gamma, is_tuned, mode, training, gamma_in_weight_dtype); auto it = normalizationPlanMap.find(key); if (it != normalizationPlanMap.end()) { @@ -578,6 +581,7 @@ void nvte_enable_cudnn_norm_bwd(bool enable) { transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable; } +// Only for testing, not thread-safe void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable) { NVTE_API_CALL(nvte_enable_zero_centered_gamma_in_weight_dtype); transformer_engine::normalization::_zero_centered_gamma_in_weight_dtype() = enable; diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 241a3b77b..d1fe6868e 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -196,7 +196,7 @@ TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma, bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, - bool training = true); + bool training = true, bool gamma_in_weight_dtype = false); template class TeNormalizationRegistry { @@ -350,7 +350,8 @@ class NormalizationPlanRegistry { NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, const size_t batch_size, const size_t hidden_size, const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned, - const NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, const bool training = true); + const NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, const bool training = true, + const bool gamma_in_weight_dtype = false); private: NormalizationPlanRegistry() {} @@ -471,6 +472,8 @@ void rocm_norm_mxfp8_quantize(LaunchParams &launch_params) } #endif +bool& use_zero_centered_gamma_in_weight_dtype(); + } // namespace normalization } // namespace transformer_engine diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index f660ca5b7..9b689ec88 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -17,6 +17,7 @@ #include "../../common.h" #include "../common.h" +#include "transformer_engine/transformer_engine.h" namespace transformer_engine { @@ -66,13 +67,18 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size bool is_aligned = true; #ifndef __HIP_PLATFORM_AMD__ bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); +#endif + bool gamma_in_weight_dtype = false; +#ifndef __HIP_PLATFORM_AMD__ if (cudnn_backend) { // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; - } else -#endif + gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); + } else +#endif { + norm_backend = NVTE_Norm_Backend::Te; is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, beta.data.dptr, mu->data.dptr, rsigma->data.dptr); @@ -88,7 +94,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size z->data.dtype, // otype x.data.shape[0], // batch_size x.data.shape[1], // hidden_size - multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training); + multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training, + gamma_in_weight_dtype); if (workspace->data.shape.empty()) { workspace->data.shape = plan->getWorkspaceShape(); @@ -155,12 +162,14 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te NVTE_Norm_Backend norm_backend; bool is_aligned = true; -#ifndef __HIP_PLATFORM_AMD__ + bool gamma_in_weight_dtype = false; + #ifndef __HIP_PLATFORM_AMD__ if (use_cudnn_norm_bwd()) { // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; - } else -#endif + gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); + } else +#endif { norm_backend = NVTE_Norm_Backend::Te; is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr, @@ -173,7 +182,8 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te gamma.data.dtype, // otype x.data.shape[0], // batch_size x.data.shape[1], // hidden_size - multiprocessorCount, zero_centered_gamma, is_aligned); + multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true, + gamma_in_weight_dtype); if (workspace->data.shape.empty()) { workspace->data.shape = plan->getWorkspaceShape(); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index eabed2bd5..4eb5f7496 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -15,6 +15,7 @@ #include "../../common.h" #include "../common.h" #include "transformer_engine/normalization.h" +#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transpose.h" namespace transformer_engine { @@ -57,12 +58,14 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens bool training = is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; + bool gamma_in_weight_dtype = false; #ifndef __HIP_PLATFORM_AMD__ if (cudnn_backend) { // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; + gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); } else -#endif +#endif { norm_backend = NVTE_Norm_Backend::Te; is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, rsigma->data.dptr); @@ -75,7 +78,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens z->data.dtype, // otype x.data.shape[0], // batch_size x.data.shape[1], // hidden_size - multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training); + multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training, + gamma_in_weight_dtype); if (workspace->data.shape.empty()) { workspace->data.shape = plan->getWorkspaceShape(); @@ -133,12 +137,14 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const NVTE_Norm_Backend norm_backend; bool is_aligned = true; + bool gamma_in_weight_dtype = false; #ifndef __HIP_PLATFORM_AMD__ if (use_cudnn_norm_bwd()) { // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; - } else -#endif + gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); + } else +#endif { norm_backend = NVTE_Norm_Backend::Te; is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr, @@ -151,7 +157,8 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const gamma.data.dtype, // otype x.data.shape[0], // batch_size x.data.shape[1], // hidden_size - multiprocessorCount, zero_centered_gamma, is_aligned); + multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true, + gamma_in_weight_dtype); if (workspace->data.shape.empty()) { workspace->data.shape = plan->getWorkspaceShape(); diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 9426d1621..466c2e605 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -196,6 +196,7 @@ def __post_init__(self) -> None: def __repr__(self) -> str: return ( + f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " f"amax_history_len={self.amax_history_len}, " @@ -208,42 +209,12 @@ def __repr__(self) -> str: class Float8CurrentScaling(Recipe): """ Use the per-tensor current scaling factor strategy. + Parameters ---------- fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID Controls the FP8 data format used during forward and backward pass. - fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} - used for quantization of input tensor x - fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} - used for quantization of weight tensor w - fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} - used for quantization of gradient tensor dY - fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False - used for calculating output y in forward pass - fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True - use for calculating dgrad in backward pass - fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True - use for calculating dgrad in backward pass - fp8_dpa: bool, default = `False` - Whether to enable FP8 dot product attention (DPA). When the model is placed in an - `fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the - inputs from higher precision to FP8, performs attention in FP8, and casts tensors - back to higher precision as outputs. FP8 DPA currently is only supported in the - `FusedAttention` backend. - fp8_mha: bool, default = `False` - Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting - operations mentioned above at the DPA boundaries. Currently only standard MHA modules - i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When - `fp8_mha = False, fp8_dpa = True`, a typical MHA module works as - `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. - When `fp8_mha = True, fp8_dpa = True`, it becomes - `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. - - Notes - ----- - * `fp8_dpa` and `fp8_mha` are Beta features, and their API and functionality are - subject to change in future Transformer Engine releases. """ fp8_format: Format = Format.HYBRID @@ -258,9 +229,13 @@ class Float8CurrentScaling(Recipe): def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert ( + not self.fp8_dpa and not self.fp8_mha + ), "FP8 attention is not supported for Float8CurrentScaling." def __repr__(self) -> str: return ( + f"recipe_type={self.__class__.__name__}, " f"format={str(self.fp8_format).split('.')[1]}, " f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, " @@ -307,7 +282,11 @@ def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." def __repr__(self) -> str: - return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]}," + return ( + f"recipe_type={self.__class__.__name__}, " + f"margin={self.margin}, " + f"format={str(self.fp8_format).split('.')[1]}" + ) @dataclass() @@ -329,32 +308,12 @@ class Float8BlockScaling(Recipe): NOTE: To relax the default constraint that scales be powers of 2, set env variable NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it for the recipe defaults. - export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 - Or initialize the Recipe with non-default QParams in code for increased control. Parameters ---------- fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. - fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} - used for quantization of input tensor x - fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} - used for quantization of weight tensor w - fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0} - used for quantization of gradient tensor dY - x_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) - qblock scaling for x. - w_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) - qblock scaling for w. - grad_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional) - qblock scaling for grad. - fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False - used for calculating output y in forward pass - fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True - use for calculating dgrad in backward pass - fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True - use for calculating dgrad in backward pass """ use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" @@ -388,9 +347,13 @@ def __post_init__(self) -> None: assert self.fp8_gemm_fprop.use_split_accumulator, "Split accumulator required for fprop." assert self.fp8_gemm_dgrad.use_split_accumulator, "Split accumulator required for dgrad." assert self.fp8_gemm_wgrad.use_split_accumulator, "Split accumulator required for wgrad." + assert ( + not self.fp8_dpa and not self.fp8_mha + ), "FP8 attention is not supported for Float8BlockScaling." def __repr__(self) -> str: return ( + f"recipe_type={self.__class__.__name__}, " f"format={str(self.fp8_format).split('.')[1]}, " f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, " diff --git a/transformer_engine/debug/features/api.py b/transformer_engine/debug/features/api.py index 887043c42..13ab6040d 100644 --- a/transformer_engine/debug/features/api.py +++ b/transformer_engine/debug/features/api.py @@ -12,7 +12,7 @@ import torch from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS -from transformer_engine.pytorch.tensor import all_tensor_types +from transformer_engine.pytorch.tensor import get_all_tensor_types from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor @@ -424,7 +424,7 @@ def output_assertions_hook(self, api_name, ret, **kwargs): if api_name in ["inspect_tensor", "inspect_tensor_postquantize"]: assert ret is None if api_name == "modify_tensor": - assert type(ret) in all_tensor_types + assert type(ret) in get_all_tensor_types() if ( type(ret) == torch.Tensor # pylint: disable=unidiomatic-typecheck and "dtype" in kwargs @@ -438,4 +438,4 @@ def step(self): def end_debug(self): """This function is called by the nvidia-dlframework-inspect after every debug_api.end_debug()""" - TEDebugState.reset() + TEDebugState._reset() diff --git a/transformer_engine/debug/features/fake_quant.py b/transformer_engine/debug/features/fake_quant.py index bab4b4dcf..4a5b6c34a 100644 --- a/transformer_engine/debug/features/fake_quant.py +++ b/transformer_engine/debug/features/fake_quant.py @@ -49,7 +49,7 @@ def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None): fp8_dtype = tex.DType.kFloat8E5M2 amax = tensor.abs().max().float() one = torch.ones(1, device=tensor.device) - scale = _default_sf_compute(amax, one, fp8_max) + scale = _default_sf_compute(amax, one, fp8_max, 0) quantizer = Float8Quantizer(scale, amax, fp8_dtype) else: diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index 4ca2a8ed3..e5c84a9bd 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -120,7 +120,6 @@ def inspect_tensor_postquantize( if not rowwise: return # tensor was already seen rowwise in the other gemm - tensor = tensor._data options = ( config.get("start_step", None), config.get("end_step", None), diff --git a/transformer_engine/debug/features/per_tensor_scaling.py b/transformer_engine/debug/features/per_tensor_scaling.py index eabb6304a..7b4de0a18 100644 --- a/transformer_engine/debug/features/per_tensor_scaling.py +++ b/transformer_engine/debug/features/per_tensor_scaling.py @@ -15,6 +15,7 @@ from transformer_engine.pytorch.tensor import Quantizer from transformer_engine.pytorch.tensor.float8_tensor import ( Float8Tensor, + Float8Quantizer, Float8CurrentScalingQuantizer, ) from transformer_engine.debug.features.api import TEConfigAPIMapper @@ -39,7 +40,7 @@ def per_tensor_cast( }, "[NVTORCH INSPECT ERROR] Only 2 FP8 types: E4M3 and E5M2 are supported in TE." tensor = tensor.contiguous() - quantizer = Float8CurrentScalingQuantizer(fp8_dtype) + quantizer = Float8CurrentScalingQuantizer(fp8_dtype, device=tensor.device) if out is not None: quantizer.update_quantized(tensor, out) @@ -81,7 +82,6 @@ class PerTensorScaling(TEConfigAPIMapper): transformer_engine: PerTensorScaling: enabled: True - margin: 1 gemms: [dgrad] tensors: [weight, activation] """ @@ -118,7 +118,7 @@ def modify_tensor( if key not in ["gemm", "tensor"]: raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') - assert isinstance(default_quantizer, Float8CurrentScalingQuantizer), ( + assert isinstance(default_quantizer, Float8Quantizer), ( f"[NVTORCH INSPECT ERROR] Feature={self.__class__.__name__}, API=process_tensor: " "Per-tensor current scaling can be used only within `DelayedScaling` recipe autocast." f" {layer_name}" diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 84a740161..d111e4890 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -96,7 +96,10 @@ def _get(buffers, stat_name): "max": (torch.max, lambda buffers: max(_get(buffers, "max"))), "sum": (torch.sum, lambda buffers: sum(_get(buffers, "sum"))), "mean": (torch.mean, lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel"))), - "numel": (lambda x: x.numel(), lambda buffers: sum(_get(buffers, "numel"))), + "numel": ( + lambda x: x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel(), + lambda buffers: sum(_get(buffers, "numel")), + ), "l1_norm": (lambda x: torch.norm(x, p=1), lambda buffers: sum(_get(buffers, "l1_norm"))), "l2_norm_square": ( lambda x: torch.sum(x**2), @@ -137,7 +140,7 @@ def _get(buffers, stat_name): - min(_get(buffers, "dynamic_range_bottom")), ), "underflows%": ( - lambda x: (x == 0).sum() / x.numel() * 100, + lambda x: (x.get_data_tensors()[0] == 0).sum() / x.get_data_tensors()[0].numel() * 100, lambda buffers: 100 * sum(_get(buffers, "underflows_num")) / sum(_get(buffers, "numel")), ), } diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 4a7a156a0..4d61757e1 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -14,10 +14,11 @@ import transformer_engine_torch as tex - +from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch.tensor.quantized_tensor import ( QuantizedTensor, Quantizer, + QuantizedTensorBase, prepare_for_saving, restore_from_saved, ) @@ -299,8 +300,9 @@ def quantize( iteration=self.iteration, dtype=dtype, ) - if columnwise_gemm_tensor.dtype != dtype: - raise ValueError("Dtype does not match the output of the modify_tensor call") + if dtype is not None: + if columnwise_gemm_tensor.dtype != dtype: + raise ValueError("Dtype does not match the output of the modify_tensor call") if self.rowwise_tensor_plan == API_CALL_MODIFY: rowwise_gemm_tensor = debug_api.transformer_engine.modify_tensor( layer_name=self.layer_name, @@ -311,8 +313,9 @@ def quantize( iteration=self.iteration, dtype=dtype, ) - if rowwise_gemm_tensor.dtype != dtype: - raise ValueError("Dtype does not match the output of the modify_tensor call") + if dtype is not None: + if rowwise_gemm_tensor.dtype != dtype: + raise ValueError("Dtype does not match the output of the modify_tensor call") # 3. If some tensors still are not defined we use high precision tensor. if self.rowwise_tensor_plan == HIGH_PRECISION: @@ -332,6 +335,7 @@ def quantize( quantizer=self, layer_name=self.layer_name, tensor_name=self.tensor_name, + original_tensor=tensor, ) def process_gemm_output(self, tensor: torch.Tensor): @@ -455,8 +459,12 @@ def any_feature_enabled(self) -> bool: return True return False + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + """Probably not needed for debug quantizer""" + return None + -class DebugQuantizedTensor: +class DebugQuantizedTensor(QuantizedTensorBase): """ Class containing quantized tensors after debug. Depending on configuration it can contain one or two different objects. These objects can be accessed by the method @@ -470,6 +478,7 @@ def __init__( quantizer, layer_name=None, tensor_name=None, + original_tensor=None, ): self.rowwise_gemm_tensor = rowwise_gemm_tensor @@ -477,6 +486,7 @@ def __init__( self.quantizer = quantizer self._layer_name = layer_name self._tensor_name = tensor_name + self._original_tensor = original_tensor def prepare_for_saving(self): """ " Prepare for saving method override""" @@ -524,5 +534,5 @@ def size(self): """Size of the tensor.""" return self.rowwise_gemm_tensor.size() - def update_usage(self, rowwise_usage: bool, columnwise_usage: bool): + def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None): """Update usage of the tensor.""" diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index b52d1003f..2b8d332f4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -3487,7 +3487,64 @@ def attn_forward_func_with_cp( use_flash_attn_3=False, ) -> torch.Tensor: """ - Attention implementation with context parallelism. + Attention implementation with context parallelism (CP). CP partitions tensors along the sequence + dimension, and by reducing the memory and computational pressure on each GPU, it enables long-context + LLMs in a distributed fashion. Transformer Engine's PyTorch CP implementation currently utilizes + the DualChunkSwap strategy to ensure load balancing across CP ranks. It is applied to all `attn_mask_type`s + and all `qkv_format`s, and it requires sequence lengths to be, or are padded to be, divisible by + (cp_size * 2). It also requires tokens to be re-ordered before entering this function. + + For qkv_format = {'bshd', 'sbhd'}, the token re-ordering is illustrated as below, for an example + use case of s = 12, attn_mask_type = 'causal', and cp_size = 2. seq_pos indicates each token's position + in their corresponding sequence. + + GPU0 | GPU1 GPU0 | GPU1 + seq_pos | 0 1 2 3 4 5 | 6 7 8 9 10 11 seq_pos | 0 1 2 9 10 11 | 3 4 5 6 7 8 + ---------------------------|----------------- ---------------------------|------------------ + 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + U 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 9 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 1, 1, + 0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 -> 0 10 | 1, 1, 1, 1, 1, 0,| 1, 1, 1, 1, 1, 1, + 5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1, + ---------------------------|----------------- ---------------------------|------------------ + 6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 3 | 1, 1, 1, 0, 0, 0,| 1, 0, 0, 0, 0, 0, + G 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 4 | 1, 1, 1, 0, 0, 0,| 1, 1, 0, 0, 0, 0, + P 8 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 0, 0, 0, P 5 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 0, 0, 0, + U 9 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 0, 0, U 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0, + 1 10 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 0, 1 7 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 0, + 11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1, 8 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 1, + + For qkv_format = 'thd', multiple sequences may be packed into the batch, and they may be of different + lengths. DualChunkSwap divides each sequence into (cp_size * 2) chunks and distributes 2 chunks of + every sequence onto a CP rank. The token matrix transformation is shown as follows, for an example of + batch_size = 2, seq_ids = [0, 1], seq_lens = [8, 4], t = 12, attn_mask_type = 'padding_causal', and + cp_size = 2. + + GPU0 | GPU1 GPU0 | GPU1 + seq_id | 0 0 0 0 0 0 | 0 0 1 1 1 1 seq_id | 0 0 0 0 1 1 | 0 0 0 0 1 1 + seq_pos | 0 1 2 3 4 5 | 6 7 0 1 2 3 seq_pos | 0 1 6 7 0 3 | 2 3 4 5 1 2 + ---------------------------|----------------- ---------------------------|------------------ + 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + P 0 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 0 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0, + U 0 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 0 7 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 0, 0, + 0 0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 -> 0 1 0 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 0, 0, + 0 5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 1 3 | 0, 0, 0, 0, 2, 2,| 0, 0, 0, 0, 2, 2, + ---------------------------|----------------- ---------------------------|------------------ + 0 6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 0 2 | 1, 1, 0, 0, 0, 0,| 1, 0, 0, 0, 0, 0, + G 0 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 0 3 | 1, 1, 0, 0, 0, 0,| 1, 1, 0, 0, 0, 0, + P 1 0 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 0, 0, 0 P 0 4 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 0, 0, 0, + U 1 1 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 0, 0 U 0 5 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 1, 0, 0, + 1 1 2 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 0 1 1 1 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 0, + 1 3 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 2 1 2 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 2, + + When all transformer layers in a model share the same CP configuration, i.e. cp_group, cp_global_ranks, + cp_comm_type and cp_stream, token re-ordering can take place in the dataloader, i.e. only once for + all the layers. An example of the re-ordering code is `get_batch_on_this_cp_rank + `_ + in Megatron-LM. + """ if cp_comm_type == "a2a+p2p": diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index ea601397a..1d788148d 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1243,12 +1243,18 @@ def gather_along_first_dim( final_quantizer = ( None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer ) + # Temporary fix for TP communication of Float8BlockwiseQTensorBase + if isinstance(rowwise, Float8BlockwiseQTensorBase): + rowwise = inp._original_tensor rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0] out_obj.rowwise_gemm_tensor = rowwise_total if rowwise is not columnwise: final_quantizer_columnwise = ( None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer ) + # Temporary fix for TP communication of Float8BlockwiseQTensorBase + if isinstance(columnwise, Float8BlockwiseQTensorBase): + columnwise = inp._original_tensor columnwise_total, _ = gather_along_first_dim( columnwise, process_group, False, final_quantizer_columnwise ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index a8b110690..e86ccd172 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -51,7 +51,7 @@ from ..utils import is_non_tn_fp8_gemm_supported from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase -from ...common.recipe import Recipe +from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor @@ -746,7 +746,7 @@ def reset(key): reset("scaling_fwd") reset("scaling_bwd") - def get_extra_state(self) -> Optional[torch.Tensor]: + def get_extra_state(self) -> torch.Tensor: """Save before checkpointing.""" # This implementation is working around a few issues: @@ -781,7 +781,7 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: state = None fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration if not fp8_checkpoint: - return None + return torch.empty(0, dtype=torch.uint8) # Copy tensors to CPU and store state = {} @@ -807,13 +807,13 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8) return state_serialized - def set_extra_state(self, state: Optional[torch.Tensor]) -> None: + def set_extra_state(self, state: torch.Tensor) -> None: """Load previous state.""" - if state is None: - return - # Load state if isinstance(state, torch.Tensor): + # No FP8 is indicated by an empty tensor we don't need to unpickle. + if state.numel() == 0: + return # Default format: byte tensor with pickled data state = pickle.loads(state.detach().cpu().numpy().tobytes()) elif isinstance(state, io.BytesIO): @@ -826,6 +826,14 @@ def set_extra_state(self, state: Optional[torch.Tensor]) -> None: if state is None: return + # TE 1.x checkpoint compatibility: add DelayedScaling recipe if missing + if "recipe" not in state: + # TE 1.x only supported delayed scaling, which was the default recipe + state["recipe"] = DelayedScaling() + # TE 1.x also saved scale_inv, which is not needed with Recipe object + state.pop("scale_inv_fwd", None) + state.pop("scale_inv_bwd", None) + # Load extra items self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta["recipe"] = state["recipe"] @@ -899,6 +907,8 @@ def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: # assume FP8 execution. def init_fp8_metadata(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" + _original_recipe = self.fp8_meta.get("recipe", None) + self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() self.fp8 = FP8GlobalStateManager.is_fp8_enabled() self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() @@ -937,6 +947,19 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + _current_recipe = self.fp8_meta["recipe"] + if _original_recipe is not None and not ( + issubclass(_current_recipe.__class__, _original_recipe.__class__) + or issubclass(_original_recipe.__class__, _current_recipe.__class__) + ): + warnings.warn( + f"Recipe type changed from {_original_recipe.__class__.__name__} " + f"to {_current_recipe.__class__.__name__}. " + "This may affect model behavior." + ) + # Clear cached workspaces as they were created with the old recipe/quantizer type + self._fp8_workspaces.clear() + @contextmanager def prepare_forward( self, @@ -961,6 +984,7 @@ def prepare_forward( self.set_activation_dtype(inp) self.init_fp8_metadata(num_gemms=num_gemms) + self._check_weight_tensor_recipe_correspondence() if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed(): assert self.fp8_meta["recipe"].reduce_amax, ( @@ -1072,7 +1096,12 @@ def grad_output_preprocess( if ( isinstance( grad_output_.get_tensor(True), - (QuantizedTensor, Float8TensorBase, MXFP8TensorBase), + ( + QuantizedTensor, + Float8TensorBase, + MXFP8TensorBase, + Float8BlockwiseQTensorBase, + ), ) and ctx.use_bias ): @@ -1380,6 +1409,43 @@ def _validate_name(self): ) self.name = f"Layer_{TEDebugState.get_layer_count()}" + def _check_weight_tensor_recipe_correspondence(self) -> None: + """ + Verify that the weight tensor types match their corresponding recipe type. + This is invoked in the forward(). + + This establishes a 1:1 correspondence between recipe types and tensor types: + - DelayedScaling → Float8Tensor + - Float8CurrentScaling → Float8Tensor + - MXFP8BlockScaling → MXFP8Tensor + - Float8BlockScaling → Float8BlockTensor + + Example case to check: recipe is DelayedScaling (DelayedScaling is set in fp8_autocast()), + but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in fp8_model_init()). + """ + if not self.fp8 and not self.fp8_calibration: + return + if not hasattr(self, "weight_names") or not self.weight_names: + return + + recipe = self.fp8_meta["recipe"] + weight_tensors = [getattr(self, name) for name in self.weight_names] + for i, tensor in enumerate(weight_tensors): + if isinstance(tensor, QuantizedTensorBase): + quantizer = tensor._get_quantizer() + if quantizer is None: + continue + compatible_recipe_class = quantizer._get_compatible_recipe() + if compatible_recipe_class is None: + continue + if not isinstance(recipe, compatible_recipe_class): + raise RuntimeError( + f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe" + f" {compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}." + " Please check the recipes assigned during fp8_model_init() and" + " fp8_autocast() calls." + ) + def _turn_off_unsupported_features_in_debug(self): if ( getattr(self, "ub_bulk_wgrad", False) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index a31823641..53f399d3d 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -205,6 +205,7 @@ def forward( # or if a gather of ln_out must be in high precision. with_quantized_norm = ( fp8 + and not debug and not return_layernorm_output and not return_layernorm_output_gathered and not force_hp_blockwise_ln_out_gather diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ce4137c66..4ab04da83 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -4,13 +4,14 @@ """Tensor class with FP8 data quantized with NxN tiles""" from __future__ import annotations -from typing import Optional, Tuple, Iterable +from typing import Optional, Tuple, Iterable, Union import math import torch import transformer_engine_torch as tex - from transformer_engine_torch import DType as TE_DType + +from transformer_engine.common.recipe import Float8BlockScaling, Recipe from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from ..utils import devices_match, round_up_to_nearest_multiple @@ -229,6 +230,9 @@ def calibrate(self, tensor: torch.Tensor) -> None: # where state from an estimator influences distribution parameters. pass + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return Float8BlockScaling + class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index b55ac577c..f43a6dd28 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -7,14 +7,15 @@ """Tensor class with FP8 data""" from __future__ import annotations import os -from typing import Optional, Tuple, Iterable +from typing import Optional, Tuple, Iterable, Union import warnings from torch.utils.cpp_extension import IS_HIP_EXTENSION import torch import transformer_engine_torch as tex - from transformer_engine_torch import DType as TE_DType + +from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe from ..utils import canonicalize_process_group, devices_match from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc @@ -177,6 +178,9 @@ def create_tensor_from_data( quantizer=self, ) + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return DelayedScaling + class Float8CurrentScalingQuantizer(Quantizer): """Builder class for FP8 tensors with per-tensor current scaling @@ -339,6 +343,9 @@ def _canonicalized_amax_reduction_group(self) -> dist_group_type: """Get process group for amax reduction""" return canonicalize_process_group(self.amax_reduction_group) + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return Float8CurrentScaling + class Float8Tensor(Float8TensorBase, QuantizedTensor): """Experimental tensor class with FP8 data diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 920b7d6b0..8f3c73eb9 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -7,16 +7,17 @@ from collections.abc import Iterable import math import os -from typing import Optional, Tuple from torch.utils.cpp_extension import IS_HIP_EXTENSION +from typing import Optional, Tuple, Union import torch if IS_HIP_EXTENSION: from ..triton_kernels.cast import te_quantize_triton import transformer_engine_torch as tex - from transformer_engine_torch import DType as TE_DType + +from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe from ..constants import MXFP8_BLOCK_SCALING_SIZE from ..utils import devices_match, round_up_to_nearest_multiple @@ -145,6 +146,9 @@ def calibrate(self, tensor: torch.Tensor) -> None: # TODO(ksivamani): No calibration needed for mxfp8? pass + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return MXFP8BlockScaling + class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): """Experimental tensor class with FP8 data diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index e521d4279..9b0adcc22 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -17,6 +17,7 @@ from torch.utils._pytree import tree_map import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe class QuantizedTensorBase: @@ -242,6 +243,10 @@ def copy(self) -> Quantizer: """Create shallow copy""" return copy.copy(self) + @abc.abstractmethod + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + """Returns recipe class that is compatible with this quantizer""" + class _QuantizeFunc(torch.autograd.Function): """Cast to FP8 from other dtype"""