diff --git a/benchmarks/attention/benchmark_attention.py b/benchmarks/attention/benchmark_attention.py index dafafdff4..1df16cc01 100644 --- a/benchmarks/attention/benchmark_attention.py +++ b/benchmarks/attention/benchmark_attention.py @@ -9,11 +9,11 @@ import torch import nvtx import transformer_engine -from tests.pytorch.fused_attn.test_fused_attn import ( +from tests.pytorch.utils import ( ModelConfig, - _get_attention_backends, - _run_dot_product_attention, + get_available_attention_backends, ) +from tests.pytorch.attention.test_attention import _run_dot_product_attention pd.set_option("display.precision", 4) @@ -197,7 +197,7 @@ def main(): ) for model in model_configs.keys(): config = model_configs[model] - available_backends, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, diff --git a/benchmarks/attention/benchmark_attention_rocm.py b/benchmarks/attention/benchmark_attention_rocm.py index b126fb022..6d4aa6404 100644 --- a/benchmarks/attention/benchmark_attention_rocm.py +++ b/benchmarks/attention/benchmark_attention_rocm.py @@ -13,15 +13,19 @@ import transformer_engine from transformer_engine_torch import NVTE_Fused_Attn_Backend -# Add test_fused_attn to the sys path +# Add paths tests/pytorch/ and tests/pytorch/attention to the sys path tests_path = os.path.abspath( - os.path.join(os.path.dirname(__file__), "../../tests/pytorch/fused_attn") + os.path.join(os.path.dirname(__file__), "../../tests") ) -sys.path.append(tests_path) +sys.path.append(tests_path + "/pytorch") +sys.path.append(tests_path + "/pytorch/attention") -from test_fused_attn import ( +# Add tests/pytorch/utils.py path into sys path +from utils import ( ModelConfig, - _get_attention_backends, + get_available_attention_backends, +) +from test_attention import ( _run_dot_product_attention, ) @@ -46,12 +50,12 @@ is_training = True model_configs = { - # test: b, h, hg, d, sq, skv, p, mask, bias - "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq - "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask - "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias - "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA - "test_4": ModelConfig(2, 128, 8, 128, 8192, 8192, 0.0, "causal_bottom_right", "no_bias") + # test: b, sq, h, d + "test_0": ModelConfig(2, 512, 16, 64), # short seq + "test_1": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), # longer seq, mask + "test_2": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"), # bias + "test_3": ModelConfig(2, 8192, 32, 128, num_gqa_groups=8, attn_mask_type="causal"), # GQA + "test_4": ModelConfig(2, 8192, 128, 128, num_gqa_groups=16, attn_mask_type="causal_bottom_right") } # DataFrame indices and columns for results @@ -303,7 +307,7 @@ def sanity_checks( } for model, cfg in model_configs.items(): - avail, _, fused_bes = _get_attention_backends( + avail, _, fused_bes = get_available_attention_backends( cfg, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -364,7 +368,7 @@ def main(args): # Benchmarking starts.. for model in model_configs.keys(): config = model_configs[model] - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index 2a45a8a5c..e70b4523a 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.6.0.dev0 +2.6.0 diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 4056a5fa7..bb084293f 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -27,20 +27,7 @@ def install_requirements() -> List[str]: """Install dependencies for TE/PyTorch extensions.""" - reqs = ["einops"] - if not rocm_build(): - reqs.append( - "nvdlfw-inspect @" - " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect" - ) - reqs.extend( - [ - "torch>=2.1", - "onnx", - "onnxscript@git+https://github.com/microsoft/onnxscript.git@51ecf47523ef079c53b0e620c62d56d70cfd3871", - ] - ) - return reqs + return ["torch>=2.1", "einops", "onnxscript==0.3.1", "onnx"] def test_requirements() -> List[str]: diff --git a/ci/pytorch.sh b/ci/pytorch.sh index 1b3aefd36..c86d6e2fc 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -65,7 +65,7 @@ run_test_config(){ run_default_fa 1 test_recipe.py run 1 test_sanity.py run_default_fa 1 test_sanity_import.py - run_default_fa 1 fused_attn/test_fused_attn.py # Backend selection is controlled by the test + run_default_fa 1 attention/test_attention.py # Backend selection is controlled by the test run_default_fa 1 triton_kernels/test_cast.py run_default_fa 1 triton_kernels/test_cast_mxfp8.py run_default_fa 1 triton_kernels/test_norm_common.py @@ -88,8 +88,8 @@ run_test_config_mgpu(){ run_default_fa 2 distributed/test_numerics.py run_default_fa 1 distributed/test_torch_fsdp2.py run_default_fa 2 distributed/test_torch_fsdp2_fp8.py - run_default_fa_lbl "flash" 3 fused_attn/test_fused_attn_with_cp.py -k "with_flash" - run_default_fa_lbl "fused" 2 fused_attn/test_fused_attn_with_cp.py -k "with_fused" + run_default_fa_lbl "flash" 3 attention/test_attention_with_cp.py -k "with_flash" + run_default_fa_lbl "fused" 2 attention/test_attention_with_cp.py -k "with_fused" } run_benchmark() { diff --git a/docs/debug/1_getting_started.rst b/docs/debug/1_getting_started.rst index bc2b95057..555b9b4b8 100644 --- a/docs/debug/1_getting_started.rst +++ b/docs/debug/1_getting_started.rst @@ -21,7 +21,7 @@ Transformer Engine provides a set of precision debug tools which allow you to ea There are 4 things one needs to do to use Transformer Engine debug features: 1. Create a configuration YAML file to configure the desired features. -2. Import, and initialize the `Nvidia-DL-Framework-Inspect `_ tool, which is installed as the dependency of the Transformer Engine. +2. Import, initialize, and install the `Nvidia-DL-Framework-Inspect `_ tool. 3. One can pass ``name="..."`` when creating TE layers to easier identify layer names. If this is not provided, names will be inferred automatically. 4. Invoke ``debug_api.step()`` at the end of one forward-backward pass. @@ -238,4 +238,4 @@ Let's run training and open TensorBoard by ``tensorboard --logdir=./tensorboard_ .. figure:: ./img/tensorboard.png :align: center - Fig 2: TensorBoard with plotted stats. \ No newline at end of file + Fig 2: TensorBoard with plotted stats. diff --git a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py index e9eec14d9..97f1bcd7e 100644 --- a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py +++ b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py @@ -5,7 +5,7 @@ import os import torch from typing import Tuple -from tests.pytorch.fused_attn.test_fused_attn import ModelConfig +from tests.pytorch.utils import ModelConfig from transformer_engine.pytorch.attention import DotProductAttention # Initialize RNG state diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 53a5eede7..6cd56d23d 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -375,7 +375,7 @@ "\n", "Our [unit tests](https://github.com/NVIDIA/TransformerEngine/tree/main/tests) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n", "\n", - "For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts." + "For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts." ] }, { @@ -394,10 +394,10 @@ "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n", "\n", "Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n", - "- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", - "- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", - "- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", - "- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py)" + "- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n", + "- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n", + "- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n", + "- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py)" ] }, { @@ -458,7 +458,7 @@ " \n", "\n", "\n", - "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", + "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", "\n", "
\n", "Note\n", @@ -548,7 +548,7 @@ "id": "dda4a589", "metadata": {}, "source": [ - "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py).\n", + "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py).\n", "\n", "### 3.3 Attention Bias\n", "\n", @@ -594,7 +594,7 @@ "\n", "The framework-native backends do not explicitly support `ALiBi`, but users can convert `ALiBi` to a regular `post_scale_bias` bias to achieve the same effect. In PyTorch, this utility function, `transformer_engine.pytorch.attention.get_alibi`, can be used to help with the conversion.\n", "\n", - "More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)." + "More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)." ] }, { @@ -612,7 +612,7 @@ "\n", "- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n", "\n", - "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`." + "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`." ] } ], diff --git a/docs/examples/attention/example_attention.py b/docs/examples/attention/example_attention.py index 2c32e8b5f..cf650265b 100644 --- a/docs/examples/attention/example_attention.py +++ b/docs/examples/attention/example_attention.py @@ -9,11 +9,11 @@ import torch import nvtx import transformer_engine -from tests.pytorch.fused_attn.test_fused_attn import ( +from tests.pytorch.utils import ( ModelConfig, - _get_attention_backends, - _run_dot_product_attention, + get_available_attention_backends, ) +from tests.pytorch.attention.test_attention import _run_dot_product_attention # data type dtype = torch.bfloat16 @@ -90,7 +90,7 @@ def main(): models = ["test_0"] for model in models: config = model_configs[model] - available_backends, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 3d00e0346..e4a3f4630 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -25,7 +25,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" @@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" # Test without custom calls export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -NVTE_JAX_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" +NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 7fe439b37..9a924282b 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -45,8 +45,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 09ef661c4..d7a4f054f 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -23,12 +23,13 @@ mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" diff --git a/qa/L2_jax_unittest/test.sh b/qa/L2_jax_unittest/test.sh index c5c193351..f933a0732 100644 --- a/qa/L2_jax_unittest/test.sh +++ b/qa/L2_jax_unittest/test.sh @@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" # Test without custom calls export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -NVTE_JAX_CUSTOM_CALLS_RE="" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" +NVTE_JAX_CUSTOM_CALLS="false" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 547849e95..7e9616cd0 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -41,6 +41,6 @@ do fi # Run tests - NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py done diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index b0a847d7d..aaa40f356 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -38,98 +38,38 @@ enum ActivationType { SReLU }; -template -void scale_block(const ProcessingMethod processing_method, +template +void compute_ref(const ProcessingMethod processing_method, + float (*OP)(const float), + const bool rowwise, + const bool colwise, const InputType* input, const InputType* grad, - OutputType* output_c, - float* dbias, - fp8e8m0* output_scales, - const size_t scale_idx, - const size_t i_min, - const size_t i_max, - const size_t j_min, - const size_t j_max, - const size_t cols) { + OutputType* output_rowwise, + OutputType* output_colwise, + fp8e8m0* output_scales_rowwise, + fp8e8m0* output_scales_colwise, + InputType* output_dbias, + const size_t rows, + const size_t cols, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) +{ #ifdef __HIP_PLATFORM_AMD__ using std::isnan, std::isinf; #endif - float amax = 0.0f; - - // Find the absolute maximum value in the block - for (size_t i = i_min; i < i_max; ++i) { - for (size_t j = j_min; j < j_max; ++j) { - const size_t idx = i * cols + j; - float elt = static_cast(input[idx]); - if (processing_method == ProcessingMethod::CAST_DBIAS) { - // grad is the input - elt = static_cast(grad[idx]); - } - if (processing_method != ProcessingMethod::CAST_ONLY - && processing_method != ProcessingMethod::CAST_DBIAS) { - elt = OP(elt); - } - if (processing_method == ProcessingMethod::CAST_DACT || - processing_method == ProcessingMethod::CAST_DBIAS_DACT) { - elt *= static_cast(grad[idx]); - } - dbias[j] += elt; - if (isinf(elt) || isnan(elt)) { - continue; - } - amax = std::max(amax, std::abs(elt)); - } - } - - const fp8e8m0 biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_reciprocal()); - const float scale_reciprocal = exp2f_rcp(biased_exponent); - output_scales[scale_idx] = biased_exponent; - - // Quantize elements in the block - for (size_t i = i_min; i < i_max; ++i) { - for (size_t j = j_min; j < j_max; ++j) { - const size_t idx = i * cols + j; - float elt = static_cast(input[idx]); - if (processing_method == ProcessingMethod::CAST_DBIAS) { - // grad is the input - elt = static_cast(grad[idx]); - } - if (processing_method != ProcessingMethod::CAST_ONLY - && processing_method != ProcessingMethod::CAST_DBIAS) { - elt = OP(elt); - } - if (processing_method == ProcessingMethod::CAST_DACT || - processing_method == ProcessingMethod::CAST_DBIAS_DACT) { - elt *= static_cast(grad[idx]); - } - output_c[idx] = static_cast(elt * scale_reciprocal); - } - } -} -template -void compute_ref_x1(const ProcessingMethod processing_method, - const InputType* input, - const InputType* grad, - OutputType* output_c, - fp8e8m0* output_scales, - InputType* output_dbias, - const size_t rows, - const size_t cols, - const size_t block_size_Y, - const size_t block_size_X, - const size_t scales_stride) -{ - const size_t tile_size_Y = std::max(32lu, block_size_Y); - const size_t tile_size_X = std::max(64lu, block_size_X); + const size_t tile_size_Y = 32; + const size_t tile_size_X = 32; const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; - const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y; - const size_t blocks_per_tile_X = tile_size_X / block_size_X; std::vector output_dbias_fp32(cols, 0); #pragma omp parallel proc_bind(spread) { + // Buffers to cache intermediate computations + std::vector cache_buffer(tile_size_Y * tile_size_X); + std::vector thread_dbias(cols, 0); #pragma omp for schedule(static) for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { @@ -138,24 +78,82 @@ void compute_ref_x1(const ProcessingMethod processing_method, const size_t tile_offset_Y = tile_Y * tile_size_Y; const size_t tile_offset_X = tile_X * tile_size_X; - for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) { - const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii; - const size_t block_offset_Y = ii * block_size_Y; - const size_t i_min = tile_offset_Y + block_offset_Y; - if (i_min >= rows) continue; - const size_t i_max = std::min(i_min + block_size_Y, rows); - - for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) { - const size_t block_idx_X = tile_X * blocks_per_tile_X + jj; - const size_t block_offset_X = jj * block_size_X; - const size_t j_min = tile_offset_X + block_offset_X; - if (j_min >= cols) continue; - const size_t j_max = std::min(j_min + block_size_X, cols); - - const size_t scale_idx = block_idx_Y * scales_stride + block_idx_X; - scale_block( - processing_method, input, grad, output_c, thread_dbias.data(), - output_scales, scale_idx, i_min, i_max, j_min, j_max, cols); + const size_t i_min = tile_offset_Y; + const size_t i_max = std::min(i_min + tile_size_Y, rows); + + const size_t j_min = tile_offset_X; + const size_t j_max = std::min(j_min + tile_size_X, cols); + + // Cache computations + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const int idx = i * cols + j; + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + + float elt = static_cast(input[idx]); + if (processing_method == ProcessingMethod::CAST_DBIAS) { + // grad is the input + elt = static_cast(grad[idx]); + } + if (processing_method != ProcessingMethod::CAST_ONLY + && processing_method != ProcessingMethod::CAST_DBIAS) { + elt = OP(elt); + } + if (processing_method == ProcessingMethod::CAST_DACT || + processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + elt *= static_cast(grad[idx]); + } + thread_dbias[j] += elt; + + // Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32 + elt = static_cast(static_cast(elt)); + cache_buffer[cache_idx] = elt; + if (isinf(elt) || isnan(elt)) { + continue; + } + } + } + + if (rowwise) { + for (size_t i = i_min; i < i_max; ++i) { + float block_amax = 0.0f; + + for (size_t j = j_min; j < j_max; ++j) { + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); + const int scale_idx = i * scales_stride_rowwise + tile_X; + output_scales_rowwise[scale_idx] = biased_exponent; + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + for (size_t j = j_min; j < j_max; ++j) { + const int idx = i * cols + j; + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + output_rowwise[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); + } + } + } + if (colwise) { + for (size_t j = j_min; j < j_max; ++j) { + float block_amax = 0.0f; + + for (size_t i = i_min; i < i_max; ++i) { + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); + const int scale_idx = tile_Y * scales_stride_colwise + j; + output_scales_colwise[scale_idx] = biased_exponent; + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + for (size_t i = i_min; i < i_max; ++i) { + const int idx = i * cols + j; + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + output_colwise[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); + } } } } @@ -171,29 +169,6 @@ void compute_ref_x1(const ProcessingMethod processing_method, } } -template -void compute_ref_x2(const ProcessingMethod processing_method, - const InputType* input, - const InputType* grad, - OutputType* output_rowwise, - OutputType* output_colwise, - fp8e8m0* scales_rowwise, - fp8e8m0* scales_colwise, - InputType* output_dbias, - const size_t rows, - const size_t cols, - const size_t block_size_Y, - const size_t block_size_X, - const size_t scales_stride_rowwise, - const size_t scales_stride_colwise) { - compute_ref_x1( - processing_method, input, grad, output_rowwise, scales_rowwise, output_dbias, - rows, cols, 1, block_size_X, scales_stride_rowwise); - compute_ref_x1( - processing_method, input, grad, output_colwise, scales_colwise, output_dbias, - rows, cols, block_size_Y, 1, scales_stride_colwise); -} - /** * Scaling along single dimension (either rows or columns) * Produces one set of output data and the corresponding data of the fused operation (dbias): @@ -202,8 +177,9 @@ void compute_ref_x2(const ProcessingMethod processing_method, * 2) Scaled columns + column-wise scaling factors */ -template +template void performTest_x1(const ProcessingMethod processing_method, + float (*OP)(const float), const std::vector& shape, const bool rowwise, const bool colwise, @@ -266,28 +242,46 @@ void performTest_x1(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DBIAS_DACT: { - nvte_quantize_dbias_dgelu(grad.data(), - input.data(), - output_c.data(), - output_dbias.data(), - workspace.data(), - 0); + auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu; + if (OP == &dsilu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; } + else if (OP == &drelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; } + else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; } + else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; } + + nvte_quantize_dbias_dact(grad.data(), + input.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_quantize_dbias_dgelu(grad.data(), - input.data(), - output_c.data(), - output_dbias.data(), - workspace.data(), - 0); + nvte_quantize_dbias_dact(grad.data(), + input.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); break; } case ProcessingMethod::CAST_DACT: { - nvte_dgelu(grad.data(), input.data(), output_c.data(), 0); + auto nvte_dact = &nvte_dgelu; + if (OP == &dsilu) { nvte_dact = &nvte_dsilu; } + else if (OP == &drelu) { nvte_dact = &nvte_drelu; } + else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; } + else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; } + + nvte_dact(grad.data(), input.data(), output_c.data(), 0); break; } case ProcessingMethod::CAST_ACT: { - nvte_gelu(input.data(), output_c.data(), 0); + auto nvte_act = &nvte_gelu; + if (OP == &silu) { nvte_act = &nvte_silu; } + else if (OP == &relu) { nvte_act = &nvte_relu; } + else if (OP == &qgelu) { nvte_act = &nvte_qgelu; } + else if (OP == &srelu) { nvte_act = &nvte_srelu; } + + nvte_act(input.data(), output_c.data(), 0); break; } } @@ -296,47 +290,70 @@ void performTest_x1(const ProcessingMethod processing_method, auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - compute_ref_x1(processing_method, - input.rowwise_cpu_dptr(), - grad.rowwise_cpu_dptr(), - ref_output_c.get(), - ref_output_scales.get(), - ref_output_dbias.get(), - rows, - cols, - block_size_rows, - block_size_cols, - scales_stride); - - -#ifdef __HIP_PLATFORM_AMD__ - if (processing_method != ProcessingMethod::CAST_ONLY) { - std::vector> mismatch_idx; - compare_e8m0_scaling_factors("scales", output_c, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, 0.01, rowwise, mismatch_idx); - - if (mismatch_idx.size()) { - adjust_ref(mismatch_idx, ref_output_c.get(), unpadded_blocks_Y, unpadded_blocks_X, rows, cols, otype); - } - - auto [atol, rtol] = getTolerances(otype); - compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol); - } - else -#endif // #ifdef __HIP_PLATFORM_AMD__ - { - auto [atol, rtol] = getTolerances(otype); - compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol); + compute_ref(processing_method, + OP, + rowwise, + colwise, + input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), + ref_output_c.get(), + ref_output_c.get(), + ref_output_scales.get(), + ref_output_scales.get(), + ref_output_dbias.get(), + rows, + cols, + scales_stride, + scales_stride); const uint8_t * const gpu_scales_ptr = rowwise ? output_c.rowwise_cpu_scale_inv_ptr() : output_c.columnwise_cpu_scale_inv_ptr(); + const size_t scale_diff_abs_tolerance = 0; +#ifdef __HIP_PLATFORM_AMD__ + double abs_tolerable_mismatches_limit = 0.0; + double rel_tolerable_mismatches_limit = 0.0; + if (processing_method != ProcessingMethod::CAST_ONLY) { + abs_tolerable_mismatches_limit = 1; + rel_tolerable_mismatches_limit = 1.0e-4; + } +#else + const double abs_tolerable_mismatches_limit = 0.0; + const double rel_tolerable_mismatches_limit = 0.0; +#endif // #ifdef __HIP_PLATFORM_AMD__ + + size_t mismatches_scales = 0; +#ifdef __HIP_PLATFORM_AMD__ + std::vector mismatches_scales_indices; +#endif // #ifdef __HIP_PLATFORM_AMD__ compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices, +#endif // #ifdef __HIP_PLATFORM_AMD__ + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); +#ifdef __HIP_PLATFORM_AMD__ + if (processing_method != ProcessingMethod::CAST_ONLY) { + adjust_ref_for_e8m0_scale_error("scales", mismatches_scales_indices, gpu_scales_ptr, + ref_output_scales.get(), unpadded_blocks_Y, unpadded_blocks_X, + scales_stride, rows, cols, ref_output_c.get(), otype); + mismatches_scales = 0; + }else{ + // should not have scale mismatch for cast only cases + ASSERT_EQ(mismatches_scales, 0) <<"expect no scale mismatches for cast only cases"< +template void performTest_x2(const ProcessingMethod processing_method, + float (*OP)(const float), const std::vector& shape, const size_t block_size_rows, const size_t block_size_cols, @@ -424,28 +442,46 @@ void performTest_x2(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DBIAS_DACT: { - nvte_quantize_dbias_dgelu(grad.data(), - input.data(), - output.data(), - output_dbias.data(), - workspace.data(), - 0); + auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu; + if (OP == &dsilu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; } + else if (OP == &drelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; } + else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; } + else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; } + + nvte_quantize_dbias_dact(grad.data(), + input.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_quantize_dbias_dgelu(grad.data(), - input.data(), - output.data(), - output_dbias.data(), - workspace.data(), - 0); + nvte_quantize_dbias_dact(grad.data(), + input.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); break; } case ProcessingMethod::CAST_DACT: { - nvte_dgelu(grad.data(), input.data(), output.data(), 0); + auto nvte_dact = &nvte_dgelu; + if (OP == &dsilu) { nvte_dact = &nvte_dsilu; } + else if (OP == &drelu) { nvte_dact = &nvte_drelu; } + else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; } + else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; } + + nvte_dact(grad.data(), input.data(), output.data(), 0); break; } case ProcessingMethod::CAST_ACT: { - nvte_gelu(input.data(), output.data(), 0); + auto nvte_act = &nvte_gelu; + if (OP == &silu) { nvte_act = &nvte_silu; } + else if (OP == &relu) { nvte_act = &nvte_relu; } + else if (OP == &qgelu) { nvte_act = &nvte_qgelu; } + else if (OP == &srelu) { nvte_act = &nvte_srelu; } + + nvte_act(input.data(), output.data(), 0); break; } } @@ -454,55 +490,89 @@ void performTest_x2(const ProcessingMethod processing_method, auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - compute_ref_x2(processing_method, - input.rowwise_cpu_dptr(), - grad.rowwise_cpu_dptr(), - ref_output_c_rowwise.get(), - ref_output_c_colwise.get(), - ref_scales_rowwise.get(), - ref_scales_colwise.get(), - ref_output_dbias.get(), - rows, - cols, - block_size_rows, - block_size_cols, - scales_stride_rowwise, - scales_stride_colwise); + compute_ref(processing_method, + OP, + true, + true, + input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), + ref_output_c_rowwise.get(), + ref_output_c_colwise.get(), + ref_scales_rowwise.get(), + ref_scales_colwise.get(), + ref_output_dbias.get(), + rows, + cols, + scales_stride_rowwise, + scales_stride_colwise); + + const size_t scale_diff_abs_tolerance = 0; #ifdef __HIP_PLATFORM_AMD__ - if (processing_method != ProcessingMethod::CAST_ONLY) { - std::vector> mismatch_idx_r; - compare_e8m0_scaling_factors("scales_rowwise", output, ref_scales_rowwise.get(), - unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise, 0.01, true, mismatch_idx_r); - - if (mismatch_idx_r.size()) { - adjust_ref(mismatch_idx_r, ref_output_c_rowwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, rows, cols, otype); - } - std::vector> mismatch_idx_c; - compare_e8m0_scaling_factors("scales_colwise", output, ref_scales_colwise.get(), - unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, scales_stride_colwise, 0.01, false, mismatch_idx_c); - - if (mismatch_idx_c.size()) { - adjust_ref(mismatch_idx_c, ref_output_c_colwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, rows, cols, otype); - } + const double abs_tolerable_mismatches_limit = 1.0; + const double rel_tolerable_mismatches_limit = 1.0e-4; +#else + const double abs_tolerable_mismatches_limit = 0.0; + const double rel_tolerable_mismatches_limit = 0.0; +#endif // #ifdef __HIP_PLATFORM_AMD__ - auto [atol, rtol] = getTolerances(otype); - compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol); - compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol); - } else + size_t mismatches_scales_rowwise = 0; +#ifdef __HIP_PLATFORM_AMD__ + std::vector mismatches_scales_indices_rowwise; #endif // #ifdef __HIP_PLATFORM_AMD__ - { - auto [atol, rtol] = getTolerances(otype); - compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol); - compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol); compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise); + unpadded_blocks_X_rowwise, scales_stride_rowwise, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices_rowwise, +#endif // #ifdef __HIP_PLATFORM_AMD__ + + mismatches_scales_rowwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + + size_t mismatches_scales_colwise = 0; +#ifdef __HIP_PLATFORM_AMD__ + std::vector mismatches_scales_indices_colwise; +#endif // #ifdef __HIP_PLATFORM_AMD__ compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise); + unpadded_blocks_X_colwise, scales_stride_colwise, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices_colwise, +#endif // #ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_colwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + +#ifdef __HIP_PLATFORM_AMD__ + if (processing_method != ProcessingMethod::CAST_ONLY) { + adjust_ref_for_e8m0_scale_error("scales_rowwise", mismatches_scales_indices_rowwise, output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, + scales_stride_rowwise, rows, cols, ref_output_c_rowwise.get(), otype); + adjust_ref_for_e8m0_scale_error("scales_colwise", mismatches_scales_indices_colwise, output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, + scales_stride_colwise, rows, cols, ref_output_c_colwise.get(), otype); + + mismatches_scales_rowwise = 0; + mismatches_scales_colwise = 0; + }else{ + // should not have scale mismatch for cast only cases + ASSERT_EQ(mismatches_scales_rowwise, 0) <<"expect no scale mismatches for cast only cases"<> matrix_sizes = { {128, 128}, {256, 256}, {993, 512}, - {256, 65536}, - {2048, 6144}, - {16384, 128}, - {32768, 160}, - {4096, 1632}, + {511, 6144}, + {8192, 128}, + {2048, 160}, + {577, 1632}, {1024}, {8, 32, 1024}, {16, 8, 4, 512}, @@ -574,26 +643,6 @@ class FusedCastMXFP8TestSuite : public ::testing::TestWithParam transformer_engine::DType, InputsFillCase>> {}; -#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ -switch (OP_FUNC_TYPE) { \ - case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \ - case ActivationType::GeLU: { constexpr auto OP = &dgelu; { __VA_ARGS__ } } break; \ - case ActivationType::SiLU: { constexpr auto OP = &dsilu; { __VA_ARGS__ } } break; \ - case ActivationType::ReLU: { constexpr auto OP = &drelu; { __VA_ARGS__ } } break; \ - case ActivationType::QGeLU: { constexpr auto OP = &dqgelu; { __VA_ARGS__ } } break; \ - case ActivationType::SReLU: { constexpr auto OP = &dsrelu; { __VA_ARGS__ } } break; \ -} - -#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ -switch (OP_FUNC_TYPE) { \ - case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \ - case ActivationType::GeLU: { constexpr auto OP = &gelu; { __VA_ARGS__ } } break; \ - case ActivationType::SiLU: { constexpr auto OP = &silu; { __VA_ARGS__ } } break; \ - case ActivationType::ReLU: { constexpr auto OP = &relu; { __VA_ARGS__ } } break; \ - case ActivationType::QGeLU: { constexpr auto OP = &qgelu; { __VA_ARGS__ } } break; \ - case ActivationType::SReLU: { constexpr auto OP = &srelu; { __VA_ARGS__ } } break; \ -} - TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { #ifndef __HIP_PLATFORM_AMD__ // Skip tests for pre-Blackwell architectures @@ -629,35 +678,48 @@ TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { const bool colwise = block_size.first != 1; if (processing_method == ProcessingMethod::CAST_ACT) { // Forward activations - ACT_FUNC_SWITCH(Act_type, OP, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, - if (block_size.first == 1 || block_size.second == 1) { - performTest_x1( - processing_method, matrix_size, - rowwise, colwise, fill_case); - } else { - performTest_x2( - processing_method, matrix_size, - block_size.first, block_size.second, fill_case); - } - ); + auto OP = &identity; + switch (Act_type) { + case ActivationType::GeLU: OP = &gelu; break; + case ActivationType::SiLU: OP = &silu; break; + case ActivationType::ReLU: OP = &relu; break; + case ActivationType::QGeLU: OP = &qgelu; break; + case ActivationType::SReLU: OP = &srelu; break; + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + if (block_size.first == 1 || block_size.second == 1) { + performTest_x1( + processing_method, OP, matrix_size, + rowwise, colwise, fill_case); + } else { + performTest_x2( + processing_method, OP, matrix_size, + block_size.first, block_size.second, fill_case); + } ); ); } else { - DACT_FUNC_SWITCH(Act_type, OP, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, - if (block_size.first == 1 || block_size.second == 1) { - performTest_x1( - processing_method, matrix_size, - rowwise, colwise, fill_case); - } else { - performTest_x2( - processing_method, matrix_size, - block_size.first, block_size.second, fill_case); - } - ); + auto OP = &identity; + switch (Act_type) { + case ActivationType::GeLU: OP = &dgelu; break; + case ActivationType::SiLU: OP = &dsilu; break; + case ActivationType::ReLU: OP = &drelu; break; + case ActivationType::QGeLU: OP = &dqgelu; break; + case ActivationType::SReLU: OP = &dsrelu; break; + } + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + if (block_size.first == 1 || block_size.second == 1) { + performTest_x1( + processing_method, OP, matrix_size, + rowwise, colwise, fill_case); + } else { + performTest_x2( + processing_method, OP, matrix_size, + block_size.first, block_size.second, fill_case); + } ); ); } diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index 96663e752..435fc6c02 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -24,108 +24,32 @@ using namespace test; namespace { -template -void scale_block(const IType* grad, +template +void compute_ref(const IType* grad, const IType* input, - OType* output, - fp8e8m0* output_scales, - const size_t scale_idx, - const size_t scale_idx_gate, - float& thread_amax, - const size_t i_min, - const size_t i_max, - const size_t j_min, - const size_t j_max, - const size_t cols) { - - float block_amax = 0.0f; - float block_amax_gate = 0.0f; - const size_t stride = cols * 2; - - - // Find the absolute maximum value in the block - for (size_t i = i_min; i < i_max; ++i) { - for (size_t j = j_min; j < j_max; ++j) { - float silu_elt = static_cast(input[i * stride + j]); - float gate_elt = static_cast(input[i * stride + cols + j]); - float gated_amax_act = 0; - float gated_amax_gate = 0; - - if constexpr (IS_DGATED) { - const float grad_elt = static_cast(grad[i * cols + j]); - const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; - const float after_dgate = silu(silu_elt) * grad_elt; - gated_amax_act = abs(after_dsilu); - gated_amax_gate = abs(after_dgate); - } else { - const float after_silu = silu(silu_elt) * gate_elt; - gated_amax_act = abs(after_silu); - } - - if (gated_amax_act > block_amax) { block_amax = gated_amax_act; } - if (gated_amax_gate > block_amax_gate) { block_amax_gate = gated_amax_gate; } - } - } - - const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * - Quantized_Limits::max_reciprocal()); - const float scale_reciprocal = exp2f_rcp(biased_exponent); - output_scales[scale_idx] = biased_exponent; - float scale_reciprocal_gate = 1; - if constexpr (IS_DGATED) { - const fp8e8m0 biased_exponent = float_to_e8m0(block_amax_gate * - Quantized_Limits::max_reciprocal()); - scale_reciprocal_gate = exp2f_rcp(biased_exponent); - output_scales[scale_idx_gate] = biased_exponent; - } - - - // Quantize elements in the block - for (size_t i = i_min; i < i_max; ++i) { - for (size_t j = j_min; j < j_max; ++j) { - float silu_elt = static_cast(input[i * stride + j]); - float gate_elt = static_cast(input[i * stride + cols + j]); - - if constexpr (IS_DGATED) { - const float grad_elt = static_cast(grad[i * cols + j]); - const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; - const float after_dgate = silu(silu_elt) * grad_elt; - output[i * stride + j] = static_cast(after_dsilu * scale_reciprocal); - output[i * stride + cols + j] = static_cast(after_dgate * - scale_reciprocal_gate); - } else { - const float after_silu = silu(silu_elt) * gate_elt; - output[i * cols + j] = static_cast(after_silu * scale_reciprocal); - } - - } - } - thread_amax = std::max(thread_amax, block_amax); - thread_amax = std::max(thread_amax, block_amax_gate); -} - -template -void compute_ref_x1(const IType* grad, - const IType* input, - OType* output, - fp8e8m0* output_scales, - float& ref_amax, - const size_t rows, - const size_t cols, - const size_t block_size_Y, - const size_t block_size_X, - const size_t scales_stride) { - const size_t tile_size_Y = std::max(32lu, block_size_Y); - const size_t tile_size_X = std::max(64lu, block_size_X); + OType* output_rowwise, + OType* output_colwise, + fp8e8m0* output_scales_rowwise, + fp8e8m0* output_scales_colwise, + float& ref_amax, + const bool IS_DGATED, + const size_t rows, + const size_t cols, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise, + const bool is_rowwise, + const bool is_colwise) { + constexpr size_t tile_size_Y = 32; + constexpr size_t tile_size_X = 32; const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; - const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y; - const size_t blocks_per_tile_X = tile_size_X / block_size_X; - float amax = 0; #pragma omp parallel reduction(max: amax) proc_bind(spread) { - float thread_amax = 0; + // Buffers to cache intermediate computations + std::vector cache_buffer_act(tile_size_Y * tile_size_X); + std::vector cache_buffer_gate(tile_size_Y * tile_size_X); + float thread_amax = 0.0f; #pragma omp for schedule(static) for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { const size_t tile_Y = t / tiles_num_X; @@ -133,26 +57,124 @@ void compute_ref_x1(const IType* grad, const size_t tile_offset_Y = tile_Y * tile_size_Y; const size_t tile_offset_X = tile_X * tile_size_X; - for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) { - const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii; - const size_t block_offset_Y = ii * block_size_Y; - const size_t i_min = tile_offset_Y + block_offset_Y; - if (i_min >= rows) continue; - const size_t i_max = std::min(i_min + block_size_Y, rows); - - for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) { - const size_t block_idx_X = tile_X * blocks_per_tile_X + jj; - const size_t block_offset_X = jj * block_size_X; - const size_t j_min = tile_offset_X + block_offset_X; - if (j_min >= cols) continue; - const size_t j_max = std::min(j_min + block_size_X, cols); - - const size_t mx_scale_idx = block_idx_Y * scales_stride + block_idx_X; - const size_t mx_scale_idx_gate = block_idx_Y * scales_stride + block_idx_X + - cols / block_size_X; - scale_block( - grad, input, output, output_scales, mx_scale_idx, mx_scale_idx_gate, - thread_amax, i_min, i_max, j_min, j_max, cols); + const size_t stride = cols * 2; + + const size_t i_min = tile_offset_Y; + const size_t i_max = std::min(rows, tile_offset_Y + tile_size_Y); + const size_t j_min = tile_offset_X; + const size_t j_max = std::min(cols, tile_offset_X + tile_size_X); + + // Compute and cache activations for the entire tile + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + float silu_elt = static_cast(input[i * stride + j]); + float gate_elt = static_cast(input[i * stride + cols + j]); + + const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + + if (IS_DGATED) { + const float x = silu_elt; + const float s = sigmoid(x); + const float act_x = x * s; + const float dact_x = x * s * (1 - s) + s; + + const float grad_elt = static_cast(grad[i * cols + j]); + float after_dsilu = dact_x * grad_elt * gate_elt; + float after_dgate = act_x * grad_elt; + + // Numerical truncation: after downcast to IType (BF16/FP16), upcast it back to FP32 + after_dsilu = static_cast(static_cast(after_dsilu)); + after_dgate = static_cast(static_cast(after_dgate)); + + cache_buffer_act[cached_idx] = after_dsilu; + cache_buffer_gate[cached_idx] = after_dgate; + thread_amax = std::max(thread_amax, std::abs(after_dsilu)); + thread_amax = std::max(thread_amax, std::abs(after_dgate)); + } else { + float after_silu = silu(silu_elt) * gate_elt; + + // Numerical truncation: after downcast to IType (BF16/FP16), upcast it back to FP32 + after_silu = static_cast(static_cast(after_silu)); + + cache_buffer_act[cached_idx] = after_silu; + thread_amax = std::max(thread_amax, std::abs(after_silu)); + } + } + } + + if (is_rowwise) { + for (size_t i = i_min; i < i_max; ++i) { + float block_amax_act = 0.0f; + float block_amax_gate = 0.0f; + for (size_t j = j_min; j < j_max; ++j) { + const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx])); + if (IS_DGATED) { + block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx])); + } + } + const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits::max_reciprocal()); + const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act); + const int scale_idx_act = i * scales_stride_rowwise + tile_X; + output_scales_rowwise[scale_idx_act] = biased_exponent_act; + + float scale_reciprocal_gate; + if (IS_DGATED) { + const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits::max_reciprocal()); + scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate); + const int scale_idx_gate = scale_idx_act + (cols + 32 - 1) / 32; + output_scales_rowwise[scale_idx_gate] = biased_exponent_gate; + } + for (size_t j = j_min; j < j_max; ++j) { + const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act; + + if (IS_DGATED) { + const float after_gate = cache_buffer_gate[cached_idx] * scale_reciprocal_gate; + output_rowwise[i * stride + j] = static_cast(after_act); + output_rowwise[i * stride + cols + j] = static_cast(after_gate); + } else { + output_rowwise[i * cols + j] = static_cast(after_act); + } + } + } + } + + if (is_colwise) { + for (size_t j = j_min; j < j_max; ++j) { + float block_amax_act = 0.0f; + float block_amax_gate = 0.0f; + for (size_t i = i_min; i < i_max; ++i) { + const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx])); + if (IS_DGATED) { + block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx])); + } + } + const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits::max_reciprocal()); + const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act); + const int scale_idx_act = tile_Y * scales_stride_colwise + j; + output_scales_colwise[scale_idx_act] = biased_exponent_act; + + float scale_reciprocal_gate; + if (IS_DGATED) { + const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits::max_reciprocal()); + const int scale_idx_gate = scale_idx_act + cols; + scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate); + output_scales_colwise[scale_idx_gate] = biased_exponent_gate; + } + for (size_t i = i_min; i < i_max; ++i) { + const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act; + + if (IS_DGATED) { + const float after_gate = cache_buffer_gate[cached_idx] * scale_reciprocal_gate; + output_colwise[i * stride + j] = static_cast(after_act); + output_colwise[i * stride + cols + j] = static_cast(after_gate); + } else { + output_colwise[i * cols + j] = static_cast(after_act); + } + } } } } @@ -163,26 +185,6 @@ void compute_ref_x1(const IType* grad, ref_amax = amax; } -template -void compute_ref_x2(const IType* grad, - const IType* input, - OType* output_rowwise, - OType* output_colwise, - fp8e8m0* scales_rowwise, - fp8e8m0* scales_colwise, - float& ref_amax, - const size_t rows, - const size_t cols, - const size_t block_size_Y, - const size_t block_size_X, - const size_t scales_stride_rowwise, - const size_t scales_stride_colwise) { - compute_ref_x1( - grad, input, output_rowwise, scales_rowwise, ref_amax, rows, cols, 1, block_size_X, scales_stride_rowwise); - compute_ref_x1( - grad, input, output_colwise, scales_colwise, ref_amax, rows, cols, block_size_Y, 1, scales_stride_colwise); -} - /** * Scaling along single dimension (either rows or columns) * Produces one set of output data and the corresponding data of the fused operation (dbias): @@ -190,12 +192,13 @@ void compute_ref_x2(const IType* grad, * OR * 2) Scaled columns + column-wise scaling factors */ -template +template void performTest_x1(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols, - InputsFillCase fill_case) { + InputsFillCase fill_case, + const bool IS_DGATED) { using namespace test; using EncodingType = fp32; DType itype = TypeInfo::dtype; @@ -205,12 +208,6 @@ void performTest_x1(const size_t rows, const bool colwise = (block_size_rows == 32) && (block_size_cols == 1); NVTE_CHECK(rowwise || colwise); - // std::cout << "unpadded_blocks_Y: " << unpadded_blocks_Y << std::endl; - // std::cout << "unpadded_blocks_X: " << unpadded_blocks_X << std::endl; - // std::cout << "blocks_Y: " << blocks_Y << std::endl; - // std::cout << "blocks_X: " << blocks_X << std::endl; - // std::cout << "scales_stride: " << scales_stride << std::endl; - Tensor grad("grad", std::vector{ rows, cols }, itype); Tensor input("input", std::vector{ rows, cols * 2 }, itype); @@ -236,12 +233,12 @@ void performTest_x1(const size_t rows, } // fillCase(&grad, fill_case); - if constexpr (IS_DGATED) { + if (IS_DGATED) { fillUniform(&grad); } fillUniform(&input); - if constexpr (IS_DGATED) { + if (IS_DGATED) { nvte_dswiglu(grad.data(), input.data(), output.data(), 0); } else { nvte_swiglu(input.data(), output.data(), 0); @@ -252,46 +249,69 @@ void performTest_x1(const size_t rows, ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); float ref_amax = 0; - compute_ref_x1(grad.rowwise_cpu_dptr(), - input.rowwise_cpu_dptr(), - ref_output.get(), - ref_output_scales.get(), - ref_amax, - rows, - cols, - block_size_rows, - block_size_cols, - scales_stride); + compute_ref(grad.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output.get(), + ref_output_scales.get(), + ref_output_scales.get(), + ref_amax, + IS_DGATED, + rows, + cols, + scales_stride, + scales_stride, + rowwise, + colwise); + + size_t mismatches_scales = 0; #ifdef __HIP_PLATFORM_AMD__ - std::vector> mismatch_idx; - if (rowwise) { - compare_e8m0_scaling_factors("rowwise scales", output, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, 0.01, true, mismatch_idx); - } else { - compare_e8m0_scaling_factors("colwise scales", output, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, 0.01, false, mismatch_idx); - } - if (mismatch_idx.size()) { - adjust_ref(mismatch_idx, ref_output.get(), unpadded_blocks_Y, unpadded_blocks_X, rows, cols, otype); - } + std::vector mismatches_scales_indices; +#endif // #ifdef __HIP_PLATFORM_AMD__ - auto [atol, rtol] = getTolerances(otype); - compareResults("output", output, ref_output.get(), rowwise, atol, rtol); -#else // #ifdef __HIP_PLATFORM_AMD__ - auto [atol, rtol] = getTolerances(otype); - compareResults("output", output, ref_output.get(), rowwise, atol, rtol); + const size_t scale_diff_abs_tolerance = 0; +#ifdef __HIP_PLATFORM_AMD__ + const double abs_tolerable_mismatches_limit = 1.0; + const double rel_tolerable_mismatches_limit = 1.0e-4; +#else + const double abs_tolerable_mismatches_limit = 0.0; + const double rel_tolerable_mismatches_limit = 0.0; +#endif // #ifdef __HIP_PLATFORM_AMD__ const uint8_t * const gpu_scales_ptr = rowwise ? output.rowwise_cpu_scale_inv_ptr() : output.columnwise_cpu_scale_inv_ptr(); if (rowwise) { compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices, +#endif // #ifdef __HIP_PLATFORM_AMD__ + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); } else { compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices, +#endif // #ifdef __HIP_PLATFORM_AMD__ + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); } +#ifdef __HIP_PLATFORM_AMD__ + adjust_ref_for_e8m0_scale_error("scales", mismatches_scales_indices, gpu_scales_ptr, + ref_output_scales.get(), unpadded_blocks_Y, unpadded_blocks_X, + scales_stride, rows, cols, ref_output.get(), otype); + mismatches_scales = 0; #endif // #ifdef __HIP_PLATFORM_AMD__ + + const size_t mismatches_elts = 32 * mismatches_scales; + auto [atol, rtol] = getTolerances(otype); + compareResults("output", output, ref_output.get(), rowwise, atol, rtol, true, mismatches_elts); } /** @@ -301,12 +321,13 @@ void performTest_x1(const size_t rows, * AND * 2) Scaled columns + column-wise scaling factors */ -template +template void performTest_x2(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols, - InputsFillCase fill_case) { + InputsFillCase fill_case, + const bool IS_DGATED) { using namespace test; using EncodingType = fp32; DType itype = TypeInfo::dtype; @@ -348,12 +369,12 @@ void performTest_x2(const size_t rows, } // fillCase(&grad, fill_case); - if constexpr (IS_DGATED) { + if (IS_DGATED) { fillUniform(&grad); } fillUniform(&input); - if constexpr (IS_DGATED) { + if (IS_DGATED) { nvte_dswiglu(grad.data(), input.data(), output.data(), 0); } else { nvte_swiglu(input.data(), output.data(), 0); @@ -364,54 +385,80 @@ void performTest_x2(const size_t rows, ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); float ref_amax = 0; - compute_ref_x2(grad.rowwise_cpu_dptr(), - input.rowwise_cpu_dptr(), - ref_output_rowwise.get(), - ref_output_colwise.get(), - ref_scales_rowwise.get(), - ref_scales_colwise.get(), - ref_amax, - rows, - cols, - block_size_rows, - block_size_cols, - scales_stride_rowwise, - scales_stride_colwise); + compute_ref(grad.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + ref_output_rowwise.get(), + ref_output_colwise.get(), + ref_scales_rowwise.get(), + ref_scales_colwise.get(), + ref_amax, + IS_DGATED, + rows, + cols, + scales_stride_rowwise, + scales_stride_colwise, + true, + true); + + const size_t scale_diff_abs_tolerance = 0; #ifdef __HIP_PLATFORM_AMD__ - std::vector> mismatch_idx_r; - compare_e8m0_scaling_factors("scales_rowwise", output, - ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise, 0.01, true, mismatch_idx_r); - - if (mismatch_idx_r.size()) { - adjust_ref(mismatch_idx_r, ref_output_colwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, rows, cols, otype); - } - - std::vector> mismatch_idx_c; - compare_e8m0_scaling_factors("scales_colwise", output, - ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise, 0.01, false, mismatch_idx_c); + const double abs_tolerable_mismatches_limit = 1.0; + const double rel_tolerable_mismatches_limit = 1.0e-4; +#else + const double abs_tolerable_mismatches_limit = 0.0; + const double rel_tolerable_mismatches_limit = 0.0; +#endif // #ifdef __HIP_PLATFORM_AMD__ - if (mismatch_idx_c.size()) { - adjust_ref(mismatch_idx_c, ref_output_rowwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, rows, cols, otype); - } + size_t mismatches_scales_rowwise = 0; +#ifdef __HIP_PLATFORM_AMD__ + std::vector mismatches_scales_indices_rowwise; +#endif // #ifdef __HIP_PLATFORM_AMD__ - auto [atol, rtol] = getTolerances(otype); - auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); - compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol); -#else // #ifdef __HIP_PLATFORM_AMD__ - auto [atol, rtol] = getTolerances(otype); - auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); - compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol); compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise); + unpadded_blocks_X_rowwise, scales_stride_rowwise, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices_rowwise, +#endif // #ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_rowwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + + size_t mismatches_scales_colwise = 0; +#ifdef __HIP_PLATFORM_AMD__ + std::vector mismatches_scales_indices_colwise; +#endif // #ifdef __HIP_PLATFORM_AMD__ compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise); + unpadded_blocks_X_colwise, scales_stride_colwise, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices_colwise, #endif // #ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_colwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + +#ifdef __HIP_PLATFORM_AMD__ + adjust_ref_for_e8m0_scale_error("scales_rowwise", mismatches_scales_indices_rowwise, output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, + scales_stride_rowwise, rows, cols, ref_output_rowwise.get(), otype); + adjust_ref_for_e8m0_scale_error("scales_colwise", mismatches_scales_indices_colwise, output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, + scales_stride_colwise, rows, cols, ref_output_colwise.get(), otype); + + mismatches_scales_rowwise = 0; + mismatches_scales_colwise = 0; +#endif // #ifdef __HIP_PLATFORM_AMD__ + + const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; + const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; + + auto [atol, rtol] = getTolerances(otype); + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol, true, mismatches_elts_rowwise); + compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol, true, mismatches_elts_colwise); } std::vector> matrix_sizes = { @@ -422,8 +469,8 @@ std::vector> matrix_sizes = { {256, 256}, {993, 512}, {768, 1024}, - {65504, 128}, - {16384, 1632}, + {8192, 128}, + {577, 1632}, }; std::vector> block_sizes = { @@ -440,9 +487,9 @@ std::vector input_scenarios = { // InputsFillCase::maxNorm_to_inf }; -std::vector is_dgated_op = { - true, - false +std::vector is_bwd_op = { + false, + true }; } // namespace @@ -456,10 +503,10 @@ class CastMXFP8_GatedActTestSuite : public ::testing::TestWithParam bool>> {}; TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) { - #ifdef __HIP_PLATFORM_AMD__ +#ifdef __HIP_PLATFORM_AMD__ omp_set_num_threads(std::min(128, omp_get_max_threads())); // Using threads = # of vcpus causes occasional errors. #else // #ifdef __HIP_PLATFORM_AMD__ - // Skip tests for pre-Blackwell architectures + // Skip tests for pre-Blackwell architectures if (getDeviceComputeCapability() < blackwellComputeCapability) { GTEST_SKIP(); } @@ -479,21 +526,11 @@ TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OType, if (block_size.first == 1 || block_size.second == 1) { - if (IS_DGATED) { - performTest_x1(matrix_size.first, matrix_size.second, - block_size.first, block_size.second, fill_case); - } else { - performTest_x1(matrix_size.first, matrix_size.second, - block_size.first, block_size.second, fill_case); - } + performTest_x1(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case, IS_DGATED); } else { - if (IS_DGATED) { - performTest_x2(matrix_size.first, matrix_size.second, - block_size.first, block_size.second, fill_case); - } else { - performTest_x2(matrix_size.first, matrix_size.second, - block_size.first, block_size.second, fill_case); - } + performTest_x2(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case, IS_DGATED); } ); ); @@ -508,7 +545,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::ValuesIn(input_scenarios), - ::testing::ValuesIn(is_dgated_op)), + ::testing::ValuesIn(is_bwd_op)), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + @@ -517,6 +554,6 @@ INSTANTIATE_TEST_SUITE_P( test::typeName(std::get<2>(info.param)) + "X" + test::typeName(std::get<3>(info.param)) + "X" + test::caseName(std::get<4>(info.param)) + "X" + - (std::get<5>(info.param) ? "DGATED" : "GATED"); + (std::get<5>(info.param) ? "BWD" : "FWD"); return name; }); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index a608f6ef2..fca792c5e 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -528,10 +528,13 @@ std::vector unravel(const size_t i, const NVTEShape &shape) { void compareResults_sequential(const std::string &name, const Tensor &test, const void *ref, const bool rowwise, - double atol, double rtol, bool if_on_gpus) { + double atol, double rtol, bool if_on_gpus, + const size_t tolerable_mismatches_limit) { if (if_on_gpus) test.to_cpu(); const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); const size_t N = product(shape); + size_t mismatches_num = 0; + int first_mismatch_idx = -1; TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); const T *ref_data = reinterpret_cast(ref); @@ -557,85 +560,106 @@ void compareResults_sequential(const std::string &name, const Tensor &test, assertion = !(cast_mean_m == std::min(t, r) && cast_mean_p == std::max(t, r)); } std::string direction = rowwise ? "rowwise" : "columnwise"; - ASSERT_FALSE(assertion) << "Error in tensor " << name << " in " - << direction << " direction." << std::endl - << "Mismatch at place " << to_string(unravel(i, shape)) - << " (" << std::to_string(i) << "): " << t << " vs " << r; + if (assertion) { + mismatches_num++; + if (first_mismatch_idx == -1) { + first_mismatch_idx = i; + } + } + if (mismatches_num > tolerable_mismatches_limit) { + const double first_mismatch_t = static_cast(test_data[first_mismatch_idx]); + const double first_mismatch_r = static_cast(ref_data[first_mismatch_idx]); + + GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "." << std::endl + << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "First mismatch at place " << to_string(unravel(first_mismatch_idx, shape)) + << " (" << std::to_string(first_mismatch_idx) << "): " + << first_mismatch_t << " vs " << first_mismatch_r; + } } ); } template static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data, - const size_t N, const double atol, const double rtol) { + const size_t N, const double atol, const double rtol, + size_t& mismatches) { int first_mismatch_idx = N; - bool is_mismatch_found = false; - #pragma omp parallel for schedule(static) firstprivate(is_mismatch_found) \ - reduction(min: first_mismatch_idx) proc_bind(spread) - for (size_t i = 0; i < N; ++i) { - if (is_mismatch_found) { // early escape of the omp thread - continue; - } - + #pragma omp parallel reduction(min: first_mismatch_idx) reduction(+: mismatches) proc_bind(spread) + { + size_t thread_mismatches = 0; + #pragma omp for schedule(static) + for (size_t i = 0; i < N; ++i) { #ifndef __HIP_PLATFORM_AMD__ - double t = static_cast(test_data[i]); - double r = static_cast(ref_data[i]); + double t = static_cast(test_data[i]); + double r = static_cast(ref_data[i]); #else - double t = static_cast(static_cast(test_data[i])); - double r = static_cast(static_cast(ref_data[i])); + double t = static_cast(static_cast(test_data[i])); + double r = static_cast(static_cast(ref_data[i])); #endif - - bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); - /* For Float32 the floating point comparison is enough to error out */ - bool assertion = mismatch && (data_type == DType::kFloat32); - if (mismatch && !assertion) { - /* Check if it is just a failure of round to nearest choosing different - side of the real value */ - const double mean = (t + r) / 2; - const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); - const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); - const double cast_mean_p = static_cast(static_cast(mean_p)); - const double cast_mean_m = static_cast(static_cast(mean_m)); - assertion = !(cast_mean_m == std::min(t, r) && cast_mean_p == std::max(t, r)); - } - if (assertion && i < first_mismatch_idx) { - first_mismatch_idx = i; - is_mismatch_found = true; + + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + /* For Float32 the floating point comparison is enough to error out */ + bool assertion = mismatch && (data_type == DType::kFloat32); + if (mismatch && !assertion) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + if (assertion) { + if (i < first_mismatch_idx) { + first_mismatch_idx = i; + } + thread_mismatches++; + } } + mismatches += thread_mismatches; } return first_mismatch_idx; } void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref, - const bool rowwise, double atol, double rtol, bool if_on_gpus) { + const bool rowwise, double atol, double rtol, bool if_on_gpus, + const size_t tolerable_mismatches_limit) { if (if_on_gpus) test.to_cpu(); const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); const size_t N = product(shape); + size_t mismatches = 0; TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); const T *ref_data = reinterpret_cast(ref); - - const size_t i = getFirstMismatchIdx(test.dtype(), test_data, ref_data, N, atol, rtol); - if (i != N) { + const size_t i = getFirstMismatchIdx(test.dtype(), test_data, ref_data, N, atol, rtol, mismatches); + if ((i != N) && (mismatches > tolerable_mismatches_limit)) { const double t = static_cast(test_data[i]); const double r = static_cast(ref_data[i]); std::string direction = rowwise ? "rowwise" : "columnwise"; - ASSERT_FALSE(true) << "Error in tensor " << name << " in " - << direction << " direction." << std::endl - << "Mismatch at place " << to_string(unravel(i, shape)) - << " (" << std::to_string(i) << "): " << t << " vs " << r; + + GTEST_FAIL() << mismatches << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "." << std::endl + << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "Mismatch at place " << to_string(unravel(i, shape)) + << " (" << std::to_string(i) << "): " << t << " vs " << r; } ); } void compareResults(const std::string &name, const Tensor &test, const void *ref, - const bool rowwise, double atol, double rtol, bool if_on_gpus) { + const bool rowwise, double atol, double rtol, bool if_on_gpus, + const size_t tolerable_mismatches_limit) { constexpr bool sequential = false; if constexpr (sequential) { - compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus); + compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit); } else { - compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus); + compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit); } } @@ -672,93 +696,89 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t } void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride) + const size_t row_blocks, const size_t col_blocks, const size_t stride, +#ifdef __HIP_PLATFORM_AMD__ + std::vector &mismatch_indices, +#endif + size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit) { + const size_t N = row_blocks * col_blocks; + const size_t tolerable_mismatches_limit = +#ifndef __HIP_PLATFORM_AMD__ + std::min( +#else + std::max( +#endif + abs_tolerable_mismatches_limit, + std::floor(N * rel_tolerable_mismatches_limit)); + mismatches_num = 0; +#ifndef __HIP_PLATFORM_AMD__ + std::vector mismatch_indices; +#endif + for (int i = 0; i < row_blocks; ++i) { for (int j = 0; j < col_blocks; ++j) { const int idx = i * stride + j; - ASSERT_FALSE(test[idx] != ref[idx]) << "Error in " << name << std::endl - << "Mismatch: " << static_cast(test[idx]) << " vs " - << static_cast(ref[idx]) << " at index " << idx; - } - } -} - -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t N) -{ - for (int i = 0; i < N; i++) { - ASSERT_FALSE(test[i] != ref[i]) << "Error in " << name << std::endl - << "Mismatch: " << static_cast(test[i]) << " vs " - << static_cast(ref[i]) << " at index " << i; - } -} + const int test_val = static_cast(test[idx]); + const int ref_val = static_cast(ref[idx]); + const int abs_delta = std::abs(test_val - ref_val); -#ifdef __HIP_PLATFORM_AMD__ -void compare_e8m0_scaling_factors(const std::string &name, Tensor &output, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, - double tol, bool rowwise, std::vector> &mismatch_idx) { - const uint8_t *const test = rowwise ? output.rowwise_cpu_scale_inv_ptr() - : output.columnwise_cpu_scale_inv_ptr(); - - const double scale_tol = std::max(1., row_blocks * col_blocks * tol); - - for (int i = 0; i < row_blocks; i++) { - for (int j = 0; j < col_blocks; j++) { - const int idx = i * stride + j; - if (test[idx] != ref[idx]) { - int t_scale = static_cast(test[idx]); - int r_scale = static_cast(ref[idx]); - if (std::abs(t_scale - r_scale) == 1) { - mismatch_idx.emplace_back(i, j, r_scale-t_scale); - } else { - GTEST_FAIL() << "Error in " << name << std::endl - << "Mismatch: " << t_scale << " vs " - << r_scale << " at index " << idx; + if (abs_delta > atol) { + mismatches_num++; + mismatch_indices.push_back(idx); + } + if (mismatches_num > tolerable_mismatches_limit) { + std::cout << "Error in " << name << std::endl; + for (const int index : mismatch_indices) { + std::cout << "Mismatch at (" << index << "):" + << static_cast(test[index]) << " vs " + << static_cast(ref[index]) << std::endl; } + GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "."; } } } - const size_t scale_mismatches = mismatch_idx.size(); - - ASSERT_FALSE(scale_mismatches > scale_tol) - << "Error in " << name << std::endl << std::setprecision(4) - << "Total scale mismatches: " << scale_mismatches << " (" << 100.*(double)scale_mismatches/(double)(row_blocks*col_blocks) - << "%) Exceeds tolerance of " << scale_tol << " (" << 100.*tol << "%) mismatches"; - - if (scale_mismatches) { - std::cout << "\x1b[33mWARNING:\x1b[0m " << scale_mismatches - << " scale mismatches were found. This does not imply an accuracy issue." << std::endl; - } } -void adjust_ref(std::vector> mismatch_idx, void *ref, const size_t row_blocks, - const size_t col_blocks, const size_t rows, const size_t cols, DType otype) { - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY( otype, T, - T *ref_data = reinterpret_cast(ref); - double scale_val; - const size_t col_blocks_size = cols / col_blocks; - const size_t row_blocks_size = rows / row_blocks; - for (const auto &[i, j, scale_diff] : mismatch_idx) { - if (scale_diff == 1) { - scale_val = 2.; - } else if (scale_diff == -1) { - scale_val = .5; - } else { // Shouldn't ever reach this - GTEST_FAIL() << "Error in adjust_ref, |scale_diff| > 1"; - } - size_t ii_min = i * row_blocks_size; - const size_t ii_max = std::min(ii_min + row_blocks_size, rows); - for (; ii_min < ii_max; ii_min++) { - size_t jj_min = j * col_blocks_size; - const size_t jj_max = std::min(jj_min + col_blocks_size, cols); - for (; jj_min < jj_max; jj_min++) { - const size_t data_idx = ii_min * cols + jj_min; +#ifdef __HIP_PLATFORM_AMD__ +void adjust_ref_for_e8m0_scale_error(const std::string &name, + const std::vector &mismatch_idx, + const uint8_t *test_scale, const uint8_t *ref_scale, + const size_t row_blocks, const size_t col_blocks, + const size_t stride, const size_t rows, const size_t cols, + void *ref_ptr, DType otype) { + double scale_val; + const size_t col_blocks_size = cols / col_blocks; + const size_t row_blocks_size = rows / row_blocks; + for (const auto scale_idx : mismatch_idx) { + const int scale_diff = ref_scale[scale_idx] - test_scale[scale_idx]; + if (scale_diff == 1) { + scale_val = 2.; + } else if (scale_diff == -1) { + scale_val = .5; + } else { + GTEST_FAIL() << "Error in " << name << ": mismatch " << test_scale[scale_idx] << " vs " + << ref_scale[scale_idx] << " at index " << scale_idx; + } + const int i = scale_idx / stride; + const int j = scale_idx % stride; + size_t ii_min = i * row_blocks_size; + const size_t ii_max = std::min(ii_min + row_blocks_size, rows); + for (; ii_min < ii_max; ii_min++) { + size_t jj_min = j * col_blocks_size; + const size_t jj_max = std::min(jj_min + col_blocks_size, cols); + for (; jj_min < jj_max; jj_min++) { + const size_t data_idx = ii_min * cols + jj_min; + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(otype, T, { + T *ref_data = reinterpret_cast(ref_ptr); ref_data[data_idx] = static_cast(static_cast(ref_data[data_idx]) * scale_val); - } + }); // NOLINT(*) } } - ); // NOLINT(*) + } } #endif // #ifdef __HIP_PLATFORM_AMD__ @@ -917,11 +937,10 @@ bool isFp8Type(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; } -int32_t getDeviceComputeCapability() -{ - cudaDeviceProp deviceProp; - (void)cudaGetDeviceProperties(&deviceProp, 0); - return 10 * deviceProp.major + deviceProp.minor; +int32_t getDeviceComputeCapability() { + cudaDeviceProp deviceProp; + (void)cudaGetDeviceProperties(&deviceProp, 0); + return 10 * deviceProp.major + deviceProp.minor; } size_t first_dimension(const std::vector &shape) { diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index a7290a535..980a5a70f 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -439,7 +439,12 @@ inline fp8e8m0 float_to_e8m0(float val) { } inline float exp2f_rcp(fp8e8m0 biased_exp) { - return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); + if (biased_exp == 0) { + return 1.0f; + } + int32_t int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127) + float fp32_val = *reinterpret_cast(&int_val); + return fp32_val; } inline float identity(const float x) { return x; } @@ -471,22 +476,29 @@ size_t last_dimension(const std::vector &shape); bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2); void compareResults(const std::string &name, const Tensor &test, const void *ref, - bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true); + bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, + const size_t tolerable_mismatches_limit = 0); void compareResults(const std::string &name, const float test, const float ref, double atol = 1e-5, double rtol = 1e-8); void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, size_t N, float mismatch_rate_tol = 0.); void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride); -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t N); -#ifdef USE_ROCM -void compare_e8m0_scaling_factors(const std::string &name, Tensor &output, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, - double tol, bool rowwise, std::vector> &mismatch_idx); - -void adjust_ref(std::vector> mismatch_idx, void *ref, const size_t row_blocks, - const size_t col_blocks, const size_t rows, const size_t cols, DType otype); + const size_t row_blocks, const size_t col_blocks, const size_t stride, +#ifdef __HIP_PLATFORM_AMD__ + std::vector& mismatches_scales_indices, +#endif + size_t& mismatches_num, + const size_t scale_diff_abs_tolerance = 0, + const double abs_tolerable_mismatches_limit = 0, + const double rel_tolerable_mismatches_limit = 0); + +#ifdef __HIP_PLATFORM_AMD__ +void adjust_ref_for_e8m0_scale_error(const std::string &name, + const std::vector &mismatch_idx, + const uint8_t *test_scale, const uint8_t *ref_scale, + const size_t row_blocks, const size_t col_blocks, + const size_t stride, const size_t rows, const size_t cols, + void *ref_ptr, DType otype); #endif std::array get_scale_tensor_dims(const size_t rows, const size_t cols, diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 20a8037eb..f5c45dd9a 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -86,8 +86,14 @@ def is_shape_supported_by_mxfp8(input_shape): return False -def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): +def assert_bitwise_scaled_tensors( + a: ScaledTensor, b: ScaledTensor, precise_comparison: bool = True +): if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x): + if not precise_comparison: + assert_allclose(a.dequantize(), b.dequantize(), dtype=a.data.dtype) + return + assert a.scaling_mode == b.scaling_mode assert a.scale_inv.dtype == b.scale_inv.dtype if a.scaling_mode.is_tensor_scaling(): @@ -102,8 +108,12 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): assert_allclose(a.data, b.data) elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x): - assert_bitwise_scaled_tensors(a.rowwise_tensor, b.rowwise_tensor) - assert_bitwise_scaled_tensors(a.colwise_tensor, b.colwise_tensor) + assert_bitwise_scaled_tensors( + a.rowwise_tensor, b.rowwise_tensor, precise_comparison=precise_comparison + ) + assert_bitwise_scaled_tensors( + a.colwise_tensor, b.colwise_tensor, precise_comparison=precise_comparison + ) else: pytest.fail("Unsupported input types") @@ -489,24 +499,7 @@ def _test_norm_forward( # if the input dtype is not float32 precise_comparison = False - if precise_comparison: - assert_bitwise_scaled_tensors(output, ref_out) - else: - if isinstance(ref_out, ScaledTensor1x): - assert_allclose(output.dequantize(), ref_out.dequantize(), dtype=out_dtype) - elif isinstance(ref_out, ScaledTensor2x): - assert_allclose( - output.rowwise_tensor.dequantize(), - ref_out.rowwise_tensor.dequantize(), - dtype=out_dtype, - ) - assert_allclose( - output.colwise_tensor.dequantize(), - ref_out.colwise_tensor.dequantize(), - dtype=out_dtype, - ) - else: - pytest.fail("Unsupported output type") + assert_bitwise_scaled_tensors(output, ref_out, precise_comparison=precise_comparison) assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype) if norm_type == "layernorm": @@ -776,12 +769,24 @@ def _test_quantize_dact_dbias( )(dz, x) if is_casted_output: - assert_bitwise_scaled_tensors(te_output, jax_output) + # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation + precise_comparison = not ( + in_dtype != jnp.float32 and scaling_mode.is_1d_block_scaling() + ) + assert_bitwise_scaled_tensors( + te_output, jax_output, precise_comparison=precise_comparison + ) else: assert_allclose(te_output, jax_output) if is_dbias: - assert_allclose(te_dbias, jax_dbias) + # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16. + precise_comparison = not ( + in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling() + ) + assert_allclose( + te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype + ) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @@ -866,15 +871,6 @@ def test_quantize_dact_dbias_mxfp8_scaling( ] -def _use_jax_fp8_gemm(enabled=False): - import os - - if enabled: - os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" - elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: - os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") - - class TestDense: def _ref_gemm_with_jnp_dot(self, a, b, data_layout): if data_layout[0] == "T": diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index e08c3a1b9..13adc8394 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -388,8 +388,12 @@ def _check_configs(self): self.head_dim_v, (-1, -1) if self.window_size is None else self.window_size, ).get_fused_attn_backend() - if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: - pytest.skip("Unsupported inputs combination or device compute capability.") + if is_hip_extension(): + if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: + pytest.skip("Unsupported inputs combination or device compute capability.") + else: + if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: + pytest.skip("Unsupported inputs combination or device compute capability.") if ( self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py index e237318a4..d0a3efd27 100644 --- a/tests/jax/test_helper.py +++ b/tests/jax/test_helper.py @@ -58,7 +58,6 @@ def _compare_delay_scaling(self, ref, test): self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo) def _compare_current_scaling(self, test): - self.assertEqual(QuantizeConfig.MARGIN, test.margin) self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format) self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING) @@ -91,7 +90,7 @@ def test_fp8_autocast_delayed_scaling(self): self._check_default_state() - @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) + @unittest.skipIf(not is_fp8_supported, reason=reason) def test_fp8_autocast_current_scaling(self): QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. self._check_default_state() @@ -101,14 +100,14 @@ def test_fp8_autocast_current_scaling(self): self._check_default_state() - cs = Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3) + cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3) with fp8_autocast(enabled=True, fp8_recipe=cs): self.assertTrue(QuantizeConfig.is_fp8_enabled()) self._compare_current_scaling(cs) self._check_default_state() - cs = Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID) + cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID) with fp8_autocast(enabled=True, fp8_recipe=cs): self.assertTrue(QuantizeConfig.is_fp8_enabled()) self._compare_current_scaling(cs) diff --git a/tests/jax/utils.py b/tests/jax/utils.py index f34fb5448..56d5df8e3 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -1607,16 +1607,18 @@ def print_debug_tensor_stats(prefix, tensor, hist=False): @contextmanager def use_jax_gemm(enabled=False): - orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None) + orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS", None) try: if enabled: - os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" + os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=false" + else: + os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=true" yield finally: if enabled: if orig_custom_calls_filter is None: - os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") + os.environ.pop("NVTE_JAX_CUSTOM_CALLS") else: - os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls_filter + os.environ["NVTE_JAX_CUSTOM_CALLS"] = orig_custom_calls_filter diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py similarity index 99% rename from tests/pytorch/fused_attn/run_fused_attn_with_cp.py rename to tests/pytorch/attention/run_attention_with_cp.py index 672950f50..10bb066a4 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -16,7 +16,7 @@ get_cu_seqlens_on_cp_rank, ) import transformer_engine_torch as tex -from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn +from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.common.recipe import DelayedScaling diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/attention/test_attention.py similarity index 77% rename from tests/pytorch/fused_attn/test_fused_attn.py rename to tests/pytorch/attention/test_attention.py index dde15c9b3..b3e098236 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/attention/test_attention.py @@ -7,8 +7,9 @@ import math import os from torch.utils.cpp_extension import IS_HIP_EXTENSION +import sys +import pathlib from typing import Any, Dict, List, Tuple, Union, Optional -from contextlib import contextmanager import pytest import torch @@ -24,7 +25,6 @@ FlashAttentionUtils, get_attention_backend, check_set_window_size, - AttentionParams, ) from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention import RotaryPositionEmbedding @@ -51,21 +51,22 @@ restore_from_saved, ) +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import ( + reset_rng_states, + ModelConfig, + dtype_tols, + logging_context, + get_available_attention_backends, +) + # Only run FP8 tests on H100 fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() -# Initialize RNG state seed = 1234 -torch.manual_seed(seed) -torch.cuda.manual_seed(seed) -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() - - -def reset_rng_states() -> None: - """Revert back to initial RNG state""" - torch.set_rng_state(_cpu_rng_state) - torch.cuda.set_rng_state(_cuda_rng_state) +# Reset RNG states +reset_rng_states() @pytest.fixture(autouse=True) @@ -73,209 +74,29 @@ def reset_global_fp8_state(): yield fp8.FP8GlobalStateManager.reset() - -class EnvVarCleaner: - def __init__(self, envs_): - self.envs = envs_ - self.flags = {} - for env in self.envs: - if env in os.environ: - self.flags[env] = os.environ[env] - def __del__(self): - for env in self.envs: - if env in self.flags: - os.environ[env] = self.flags[env] - else: - os.environ.pop(env, None) - - -@pytest.fixture(autouse=True) -def reset_attn_backend(): - env = EnvVarCleaner(["NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN", - "NVTE_FUSED_ATTN_CK", "NVTE_FUSED_ATTN_AOTRITON", - "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3"]) - yield - - -class ModelConfig: - def __init__( - self, - batch_size: int, - num_heads: int, - num_gqa_groups: int, - head_dim_qk: int, - max_seqlen_q: int, - max_seqlen_kv: int, - dropout_p: float, - attn_mask_type: str, - attn_bias_type: str, - head_dim_v: int = None, - alibi_type: str = "none", - num_layers: int = 1, - bias_shape: str = "1hss", - window_size: Tuple[int, int] = (-1, -1), - total_requests: int = None, - max_ctx_len: int = None, - ): - self.batch_size = batch_size - self.num_heads = num_heads - self.num_gqa_groups = num_gqa_groups - self.head_dim_qk = head_dim_qk - self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v - self.hidden_size = num_heads * head_dim_qk - self.hidden_size_kv = num_gqa_groups * self.head_dim_v - self.max_seqlen_q = max_seqlen_q - self.max_seqlen_kv = max_seqlen_kv - self.dropout_p = dropout_p - self.attn_mask_type = attn_mask_type - self.attn_bias_type = attn_bias_type - self.alibi_type = alibi_type - self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross" - self.num_layers = num_layers - self.bias_shape = bias_shape - self.window_size = window_size - self.total_requests = total_requests - self.max_ctx_len = max_ctx_len - - -@contextmanager -def logging_context(highest_level=logging.WARNING): - previous_level = logging.root.manager.disable - logging.disable(highest_level) - try: +if IS_HIP_EXTENSION: + from utils import EnvVarCleaner + @pytest.fixture(autouse=True) + def reset_attn_backend(): + env = EnvVarCleaner(["NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN", + "NVTE_FUSED_ATTN_CK", "NVTE_FUSED_ATTN_AOTRITON", + "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3"]) yield - finally: - logging.disable(previous_level) - - -def _get_attention_backends( - config: ModelConfig, - qkv_dtype: torch.dtype, - qkv_layout: str, - window_size: Tuple[int, int] = (-1, -1), - pad_between_seqs: bool = False, - context_parallel: bool = False, - deterministic: bool = False, - fp8: bool = False, - fp8_meta: Optional[Dict[str, Any]] = None, - is_training: bool = True, - inference_params: Optional[InferenceParams] = None, -) -> Tuple[List, List]: - """Check if what attention backends support a model configuration""" - - os.environ["NVTE_FLASH_ATTN"] = "1" - os.environ["NVTE_FUSED_ATTN"] = "1" - os.environ["NVTE_UNFUSED_ATTN"] = "1" - _attention_backends["backend_selection_requires_update"] = True - - alibi_slopes_shape = None - if config.attn_bias_type == "alibi" and config.alibi_type == "custom": - if config.bias_shape == "1hss": - alibi_slopes_shape = [config.num_heads] - if config.bias_shape == "bhss": - alibi_slopes_shape = [config.batch_size, config.num_heads] - - core_attention_bias_shape = ( - config.bias_shape if config.attn_bias_type == "post_scale_bias" else None - ) - core_attention_bias_requires_grad = False - # d=256 is supported by cuDNN 9.0+ for inference but not training - if ( - config.attn_bias_type == "post_scale_bias" - and config.head_dim_qk <= 128 - and config.head_dim_v <= 128 - ): - core_attention_bias_requires_grad = True - - fused_attn_backends = [] - available_backends = None - flash_attention_backend = None - fused_attention_backend = None - - def test(): - attention_params = AttentionParams( - qkv_dtype=qkv_dtype, - qkv_layout=qkv_layout, - batch_size=config.batch_size, - num_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - max_seqlen_q=config.max_seqlen_q, - max_seqlen_kv=config.max_seqlen_kv, - head_dim_qk=config.head_dim_qk, - head_dim_v=config.head_dim_v, - attn_mask_type=config.attn_mask_type, - window_size=window_size, - alibi_slopes_shape=alibi_slopes_shape, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias_shape=core_attention_bias_shape, - core_attention_bias_requires_grad=core_attention_bias_requires_grad, - pad_between_seqs=pad_between_seqs, - attention_dropout=config.dropout_p, - context_parallel=context_parallel, - deterministic=deterministic, - fp8=fp8, - fp8_meta=fp8_meta, - is_training=is_training, - inference_params=inference_params, - ) - ( - use_flash_attention, - flash_attention_backend, - use_fused_attention, - fused_attention_backend, - use_unfused_attention, - available_backends, - ) = get_attention_backend(attention_params) - # Set attention.py _attention_backends var using return value - # from get_attention_backend() - _attention_backends["use_flash_attention"] = use_flash_attention - _attention_backends["use_fused_attention"] = use_fused_attention - _attention_backends["flash_attention_backend"] = flash_attention_backend - _attention_backends["fused_attention_backend"] = fused_attention_backend - _attention_backends["use_unfused_attention"] = use_unfused_attention - _attention_backends["backend_selection_requires_update"] = False - return available_backends, flash_attention_backend, fused_attention_backend - - if IS_HIP_EXTENSION: - backends = {"AOTriton": "AOTRITON", "CK": "CK"} - with logging_context(): - for i in backends.keys(): - for k in backends.keys(): - os.environ["NVTE_FUSED_ATTN_"+backends[k]] = "0" - os.environ["NVTE_FUSED_ATTN_"+backends[i]] = "1" - _attention_backends["backend_selection_requires_update"] = True - available_backends, flash_attention_backend, fused_attention_backend = test() - if fused_attention_backend == FusedAttnBackend[i]: - fused_attn_backends.append(fused_attention_backend) - for i in backends.keys(): - del os.environ["NVTE_FUSED_ATTN_"+backends[i]] - available_backends[1] = len(fused_attn_backends) > 0 - else: - backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} - with logging_context(): - for i in range(len(backends)): - os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) - _attention_backends["backend_selection_requires_update"] = True - available_backends, flash_attention_backend, fused_attention_backend = test() - if fused_attention_backend == FusedAttnBackend[backends[i]]: - fused_attn_backends.append(fused_attention_backend) - return available_backends, flash_attention_backend, fused_attn_backends - model_configs_base = { # test: b, h, hg, d, sq, skv, p, mask, bias - "base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), - "base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), - "base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), - "base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), - "base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), - "base_4_0": ModelConfig(8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias"), - "base_4_1": ModelConfig(8, 16, 16, 192, 128, 2048, 0.0, "no_mask", "no_bias"), - "base_5_0": ModelConfig(8, 16, 16, 512, 1, 2048, 0.0, "no_mask", "no_bias"), - "base_5_1": ModelConfig(8, 16, 16, 512, 128, 2048, 0.0, "no_mask", "no_bias"), - "base_6_0": ModelConfig(8, 16, 16, 1024, 1, 2048, 0.0, "no_mask", "no_bias"), - "base_6_1": ModelConfig(8, 16, 16, 1024, 128, 2048, 0.0, "no_mask", "no_bias"), + "base_1_0": ModelConfig(8, 128, 16, 64), + "base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256), + "base_2_0": ModelConfig(2, 2048, 24, 128), + "base_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096), + "base_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048), + "base_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048), + "base_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048), + "base_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048), + "base_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048), + "base_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048), + "base_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048), + "base_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048), } @@ -294,10 +115,10 @@ def test_dot_product_mem_calc(): if not is_bf16_compatible(): pytest.skip("This test requires bf16 support.") dtype = torch.bfloat16 - config = ModelConfig(16, 128, 8, 128, 8192, 8192, 0.0, "causal", "no_bias") + config = ModelConfig(16, 8192, 128, 128, num_gqa_groups=16, attn_mask_type="causal") is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128 qkv_layout = "sbhd_sbhd_sbhd" - _, _, fused_attn_backends = _get_attention_backends( + _, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -359,7 +180,7 @@ def test_dot_product_attention( config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) is_training = config.head_dim_qk <= 192 and config.head_dim_v <= 128 - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -370,7 +191,7 @@ def test_dot_product_attention( flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported: is_training = False - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -525,40 +346,29 @@ def test_dpa_checkpoint(dtype, model_configs, model): model_configs_mla = { # test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend - "mla_1_0": ModelConfig( - 8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128 - ), # self , 0 - "mla_1_1": ModelConfig( - 4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128 - ), # cross, 0 - "mla_1_2": ModelConfig( - 4, 16, 16, 192, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128 - ), # cross, 0 - "mla_2_0": ModelConfig( - 2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64 - ), # self , 1 + "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0 + "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0 + "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0 + "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1 "mla_2_1": ModelConfig( - 1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64 + 1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64 ), # cross, 1 "mla_2_2": ModelConfig( - 1, 24, 24, 192, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=128 + 1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128 ), # cross, 1 - "mla_3_0": ModelConfig( - 8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64 - ), # inference - "mla_3_1": ModelConfig( - 8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128 - ), # inference - "mla_3_2": ModelConfig( - 8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128 - ), # inference + "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference + "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference + "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference +} +if IS_HIP_EXTENSION: + model_configs_mla.update({ "mla_4_0": ModelConfig( - 10, 16, 16, 192, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=128 + 10, 4096, 16, 192, attn_mask_type="causal", head_dim_v=128 ), "mla_4_1": ModelConfig( - 10, 16, 16, 192, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=128 + 10, 4096, 16, 192, head_dim_v=128 ), -} + }) @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @@ -572,40 +382,46 @@ def test_dpa_mla(dtype, model_configs, model): model_configs_mask = { # test: b, h, hg, d, sq, skv, p, mask, bias - "mask_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), - "mask_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "mask_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "mask_2_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_2_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_2_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), - "mask_3_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "mask_3_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "mask_3_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "mask_4_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "mask_4_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "mask_4_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), - "mask_5_0": ModelConfig( - 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + "mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"), + "mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"), + "mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"), + "mask_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"), + "mask_2_1": ModelConfig( + 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal_bottom_right" ), + "mask_2_2": ModelConfig( + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right" + ), + "mask_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"), + "mask_3_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"), + "mask_3_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"), + "mask_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"), + "mask_4_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"), + "mask_4_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"), + "mask_5_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"), "mask_5_1": ModelConfig( - 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right" ), "mask_5_2": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right" + ), + "mask_6_0": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal"), + "mask_6_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal"), + "mask_7_0": ModelConfig( + 2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right" ), - "mask_6_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"), - "mask_6_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"), - "mask_7_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_7_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_8_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding", "no_bias"), - "mask_8_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding", "no_bias"), - "mask_9_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding_causal", "no_bias"), - "mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding_causal", "no_bias"), + "mask_7_1": ModelConfig( + 2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right" + ), + "mask_8_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding"), + "mask_8_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding"), + "mask_9_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"), + "mask_9_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal"), "mask_10_0": ModelConfig( - 2, 24, 24, 128, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right" ), "mask_10_1": ModelConfig( - 2, 16, 16, 256, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right" ), } @@ -621,44 +437,102 @@ def test_dpa_mask(dtype, model_configs, model): model_configs_bias = { # test: b, h, hg, d, sq, skv, p, mask, bias - "bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), - "bias_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "post_scale_bias"), - "bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias"), - "bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "post_scale_bias"), - "bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "alibi"), # skipped - "bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "alibi"), # skipped - "bias_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), # skipped - "bias_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "post_scale_bias"), # skipped + "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"), + "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"), + "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"), + "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"), + "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"), # skipped + "bias_1_5": ModelConfig( + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi" + ), # skipped + "bias_2_0": ModelConfig( + 4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias" + ), # skipped + "bias_2_1": ModelConfig( + 2, + 128, + 16, + 64, + max_seqlen_kv=256, + attn_mask_type="padding", + attn_bias_type="post_scale_bias", + ), # skipped "bias_2_2": ModelConfig( - 4, 24, 24, 128, 2048, 2048, 0.0, "padding", "post_scale_bias" + 4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias" ), # skipped "bias_2_3": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding", "post_scale_bias" + 2, + 2048, + 24, + 128, + max_seqlen_kv=4096, + attn_mask_type="padding", + attn_bias_type="post_scale_bias", + ), # skipped + "bias_2_4": ModelConfig( + 4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi" ), # skipped - "bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "alibi"), # skipped - "bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "alibi"), # skipped - "bias_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), - "bias_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "post_scale_bias"), - "bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), + "bias_2_5": ModelConfig( + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi" + ), # skipped + "bias_3_0": ModelConfig( + 4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), + "bias_3_1": ModelConfig( + 2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), + "bias_3_2": ModelConfig( + 4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), "bias_3_3": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "causal", "post_scale_bias" + 2, + 2048, + 24, + 128, + max_seqlen_kv=4096, + attn_mask_type="causal", + attn_bias_type="post_scale_bias", + ), # skipped + "bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"), + "bias_3_5": ModelConfig( + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi" ), # skipped - "bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi"), - "bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "alibi"), # skipped "bias_4_0": ModelConfig( - 4, 16, 16, 64, 128, 128, 0.0, "padding_causal", "post_scale_bias" + 4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias" ), # skipped "bias_4_1": ModelConfig( - 2, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias" + 2, + 128, + 16, + 64, + max_seqlen_kv=256, + attn_mask_type="padding_causal", + attn_bias_type="post_scale_bias", ), # skipped "bias_4_2": ModelConfig( - 4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias" + 4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias" ), # skipped "bias_4_3": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias" + 2, + 2048, + 24, + 128, + max_seqlen_kv=4096, + attn_mask_type="padding_causal", + attn_bias_type="post_scale_bias", + ), # skipped + "bias_4_4": ModelConfig( + 4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi" + ), # skipped + "bias_4_5": ModelConfig( + 2, + 2048, + 24, + 128, + max_seqlen_kv=4096, + attn_mask_type="padding_causal", + attn_bias_type="alibi", ), # skipped - "bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "alibi"), # skipped - "bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped } @@ -673,33 +547,29 @@ def test_dpa_bias(dtype, model_configs, model): model_configs_bias_shapes = { # test: b, h, hg, d, sq, skv, p, - "bias_1_0": ModelConfig( + "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"), + "bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"), + "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"), + "bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"), + "bias_1_4": ModelConfig( 4, - 16, - 16, - 64, - 128, + 2048, + 24, 128, - 0.0, - # mask, bias, bias_shape, - "no_mask", - "post_scale_bias", - bias_shape="11ss", - ), - "bias_1_1": ModelConfig( - 2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias", bias_shape="1hss" - ), - "bias_1_2": ModelConfig( - 4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="b1ss" - ), - "bias_1_3": ModelConfig( - 2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="bhss" - ), - "bias_1_4": ModelConfig( - 4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="1hss", alibi_type="custom" + attn_mask_type="causal", + attn_bias_type="alibi", + bias_shape="1hss", + alibi_type="custom", ), "bias_1_5": ModelConfig( - 2, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="bhss", alibi_type="custom" + 2, + 2048, + 24, + 128, + attn_mask_type="causal", + attn_bias_type="alibi", + bias_shape="bhss", + alibi_type="custom", ), } @@ -715,29 +585,31 @@ def test_dpa_bias_shapes(dtype, model_configs, model): model_configs_swa = { # test: b, h, hg, d, sq, skv, p, mask, bias - "swa_1_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), - "swa_1_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), - "swa_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), - "swa_2_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "swa_3_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), - "swa_4_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "swa_4_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "swa_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "swa_5_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "swa_5_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "swa_5_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), - "swa_6_1": ModelConfig( - 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + "swa_1_1": ModelConfig(2, 2048, 16, 64), + "swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4), + "swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096), + "swa_2_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"), + "swa_2_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal"), + "swa_2_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"), + "swa_3_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"), + "swa_3_2": ModelConfig( + 2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal_bottom_right" + ), + "swa_3_3": ModelConfig( + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right" ), + "swa_4_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"), + "swa_4_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding"), + "swa_4_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"), + "swa_5_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"), + "swa_5_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), + "swa_5_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"), + "swa_6_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"), "swa_6_2": ModelConfig( - 2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal_bottom_right" ), "swa_6_3": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right" ), } @@ -753,13 +625,31 @@ def test_dpa_sliding_window(dtype, model_configs, model): model_configs_alibi_slopes = { # test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type - "alibi_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "alibi", alibi_type="vanilla"), - "alibi_1_1": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "causal", "alibi", alibi_type="vanilla"), + "alibi_1_0": ModelConfig( + 2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla" + ), + "alibi_1_1": ModelConfig( + 1, + 128, + 16, + 64, + max_seqlen_kv=256, + attn_mask_type="causal", + attn_bias_type="alibi", + alibi_type="vanilla", + ), "alibi_2_0": ModelConfig( - 2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type="custom" + 2, 1024, 24, 128, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="custom" ), "alibi_2_1": ModelConfig( - 1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type="custom" + 1, + 1024, + 24, + 128, + max_seqlen_kv=2048, + attn_mask_type="causal", + attn_bias_type="alibi", + alibi_type="custom", ), } @@ -789,16 +679,38 @@ def test_dpa_alibi_slopes(dtype, model_configs, model): model_configs_layout = { # test: b, h, hg, d, sq, skv, p, mask, bias - "layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), - "layout_0_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), - "layout_0_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), - "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"), - "layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), - "layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"), - "layout_2_0": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), - "layout_2_1": ModelConfig(2, 24, 24, 256, 2048, 2048, 0.0, "causal", "post_scale_bias"), + "layout_0_0": ModelConfig(2, 128, 16, 64), + "layout_0_1": ModelConfig( + 2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), + "layout_0_2": ModelConfig(1, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"), + "layout_0_3": ModelConfig( + 1, + 128, + 16, + 64, + max_seqlen_kv=256, + attn_mask_type="padding_causal", + attn_bias_type="post_scale_bias", + ), + "layout_1_0": ModelConfig(2, 2048, 24, 128), + "layout_1_1": ModelConfig( + 2, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), + "layout_1_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"), + "layout_1_3": ModelConfig( + 1, + 2048, + 24, + 128, + max_seqlen_kv=4096, + attn_mask_type="padding_causal", + attn_bias_type="post_scale_bias", + ), + "layout_2_0": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048), + "layout_2_1": ModelConfig( + 2, 2048, 24, 256, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), } @@ -815,55 +727,54 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"] model_configs_layout_thd = { # test: b, h, hg, d, sq, skv, p, mask, bias - "layout_0_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "layout_0_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "layout_0_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "layout_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), - "layout_2_0": ModelConfig( - 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + "layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"), + "layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"), + "layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"), + "layout_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"), + "layout_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"), + "layout_1_2": ModelConfig( + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal" ), + "layout_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"), "layout_2_1": ModelConfig( - 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right" ), "layout_2_2": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" - ), - "layout_3_0": ModelConfig( - 2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4) + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right" ), + "layout_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding", window_size=(4, 4)), "layout_3_1": ModelConfig( - 2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4) + 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding", window_size=(4, 4) ), "layout_3_2": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias", window_size=(4, 4) - ), - "layout_4_0": ModelConfig( - 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", window_size=(4, 4) ), + "layout_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal", window_size=(4, 0)), "layout_4_1": ModelConfig( - 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal", window_size=(4, 0) ), "layout_4_2": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal", window_size=(4, 0) ), "layout_5_0": ModelConfig( - 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0) + 2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right", window_size=(4, 0) ), "layout_5_1": ModelConfig( - 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0) + 2, + 2048, + 24, + 128, + num_gqa_groups=1, + attn_mask_type="padding_causal_bottom_right", + window_size=(4, 0), ), "layout_5_2": ModelConfig( 2, - 24, + 2048, 24, 128, - 2048, - 4096, - 0.0, - "padding_causal_bottom_right", - "no_bias", + max_seqlen_kv=4096, + attn_mask_type="padding_causal_bottom_right", window_size=(4, 0), ), } @@ -890,7 +801,7 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout, pad_between if (pad_between_seqs==False and get_cudnn_version() < (9, 3, 0)): pytest.skip("cuDNN 9.3.0+ is required to run pad_between_seqs = False"); - _, _, fused_attn_backends = _get_attention_backends( + _, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -912,7 +823,7 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout, pad_between def test_dpa_qkv_layout_thd_mqa_gqa(dtype, model_configs, model, qkv_layout, pad_between_seqs, share_cu_seqlens_ref): config = model_configs[model] - _, _, fused_attn_backends = _get_attention_backends( + _, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -1302,16 +1213,22 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: model_configs_te_layer = { # test: b, h, hg, d, sq, skv, p, mask, bias - "te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), - "te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), - "te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), - "te_1_3": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), - "te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), - "te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), - "te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "te_2_3": ModelConfig(1, 16, 16, 64, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"), - "te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"), - "te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"), + "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"), + "te_1_1": ModelConfig( + 4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), + "te_1_2": ModelConfig( + 2, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias" + ), + "te_1_3": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"), + "te_2_0": ModelConfig(1, 2048, 16, 64, attn_mask_type="causal"), + "te_2_1": ModelConfig(2, 2048, 16, 64), + "te_2_2": ModelConfig(1, 2048, 16, 64, attn_mask_type="padding"), + "te_2_3": ModelConfig( + 1, 2048, 16, 64, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right" + ), + "te_3_0": ModelConfig(4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"), + "te_3_1": ModelConfig(4, 2048, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"), } @@ -1335,7 +1252,7 @@ def test_transformer_layer( # Test backend availability is_training = True - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=( @@ -1346,7 +1263,7 @@ def test_transformer_layer( flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported: is_training = False - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=( @@ -1659,20 +1576,164 @@ def _run_transformer_layer( return out, inp.grad +model_configs_fp8_extra_state = { + "large": ModelConfig(2, 128, 4, 128, num_layers=1), +} + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.") +@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") +@pytest.mark.parametrize("model", ["large"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_sanity_attention_extra_state(model, dtype): + config = model_configs_fp8_extra_state[model] + # Test backend availability + is_training = True + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=torch.float8_e4m3fn, + qkv_layout="sb3hd", + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not fused_attn_supported and not flash_attn_supported: + pytest.skip("No attention backend available.") + + outputs = _run_attention_extra_state(dtype, config, checkpoint=False) + outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True) + outputs_checkpoint_v1_6 = _run_attention_extra_state( + dtype, config, mimic_v1_6=True, checkpoint=True + ) + + # Check that results match + tols = dtype_tols(dtype) + if dtype in (torch.float16, torch.bfloat16): + tols.update(dict(rtol=2e-2, atol=2e-3)) + for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)): + torch.testing.assert_close( + test, + ref, + **tols, + ) + for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)): + torch.testing.assert_close( + test, + ref, + **tols, + ) + + +def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): + steps = 10 + path = "checkpoint.pt" + fp8_enabled = True + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=fp8_enabled, + fp8_mha=False, + ) + + reset_rng_states() + hidden_states = torch.randn( + (config.max_seqlen_q, config.batch_size, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + + def get_model(dtype, config): + sigma = 0.023 + init_method = init_method_normal(sigma) + output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) + + with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe): + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.0, + attention_dropout=0.0, + fuse_qkv_params=True, + params_dtype=dtype, + device="cuda", + ) + return block + + block = get_model(dtype, config) + for i in range(steps // 2): + with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): + output = block(hidden_states, None) + loss = output.sum() + loss.backward() + + if checkpoint: + sd = block.state_dict() + if mimic_v1_6: + sd["self_attention.core_attention.fused_attention._extra_state"] = sd[ + "self_attention.core_attention._extra_state" + ] + del sd["self_attention.core_attention._extra_state"] + torch.save(sd, path) + + param_grads = [] + for p in block.parameters(): + if p.requires_grad: + param_grads.append(p.grad.clone()) + + _cpu_rng_state_new = torch.get_rng_state() + _cuda_rng_state_new = torch.cuda.get_rng_state() + + del block + block = get_model(dtype, config) + block.load_state_dict(torch.load(path, weights_only=False)) + torch.set_rng_state(_cpu_rng_state_new) + torch.cuda.set_rng_state(_cuda_rng_state_new) + + for p in block.parameters(): + if p.requires_grad: + p.grad = param_grads.pop(0) + + assert not param_grads, "Oops!" + + for i in range((steps + 1) // 2): + with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): + output = block(hidden_states, None) + loss = output.sum() + loss.backward() + + torch.cuda.synchronize() + + if os.path.exists(path): + os.remove(path) + + outputs = [output, hidden_states.grad] + for p in block.parameters(): + if p.requires_grad: + outputs.append(p.grad) + + return outputs + + model_configs_fp8_vs_f16 = { # test: b, h, hg, d, sq, skv, p, mask, bias - "fp8_9": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_10": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_11": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"), - "fp8_12": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "fp8_13": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), - "fp8_15": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "fp8_16": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "fp8_17": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding", "no_bias"), - "fp8_18": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "fp8_19": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "fp8_20": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding_causal", "no_bias"), + "fp8_9": ModelConfig(2, 2048, 16, 128), + "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), + "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), + "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), + "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), + "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), + "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), + "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), + "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), + "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), + "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), + "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), } param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] @@ -1722,18 +1783,30 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] - if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < ( - 9, - 7, - 0, - ): - pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7") - if ( - FlashAttentionUtils.v3_is_installed - and not is_training - and "padding" not in config.attn_mask_type - ): + # Test backend availability + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=torch.float8_e4m3fn, + qkv_layout=qkv_format.replace("hd", "h3d"), + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + # Skip if only unfused backend is supported + if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: + pytest.skip("Less than two backends to compare.") + if not fp8_dpa_bwd: + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_format.replace("hd", "h3d"), + is_training=is_training, + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("No attention backend available.") + + if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1759,11 +1832,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, rtol = 5e-1 rmse_tol = 0.15 logging.debug("========== {:^25s} ==========".format("forward output")) - if ( - FlashAttentionUtils.v3_is_installed - and not is_training - and "padding" not in config.attn_mask_type - ): + if flash_attn_supported: _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1937,23 +2006,34 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): # if get_device_compute_capability() >= (10, 0): # config.dropout_p = 0.1 - if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < ( - 9, - 7, - 0, - ): - pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7") - if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: - pytest.skip("qkv_layout not applicable for MQA/GQA") - os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" - if ( - FlashAttentionUtils.v3_is_installed - and not is_training - and "padding" not in config.attn_mask_type - ): + # Test backend availability + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=torch.float8_e4m3fn, + qkv_layout=qkv_layout, + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + # Skip if only unfused backend is supported + if flash_attn_supported + fused_attn_supported < 1: + pytest.skip("No FP8 attention backend available.") + if not fp8_dpa_bwd: + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + is_training=is_training, + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("No attention backend available.") + if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: + pytest.skip("qkv_layout not applicable for MQA/GQA") + + if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1982,11 +2062,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): rmse_tol = 0.11 bwd_names = ["dq", "dk", "dv"] logging.debug("========== {:^25s} ==========".format("forward output")) - if ( - FlashAttentionUtils.v3_is_installed - and not is_training - and "padding" not in config.attn_mask_type - ): + if flash_attn_supported: _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -2160,14 +2236,14 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: model_configs_fp8 = { # test: b, h, hg, d, sq, skv, p, mask, bias - "fp8_1": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "no_mask", "no_bias"), - "fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), - "fp8_3": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_4": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_5": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "causal", "no_bias"), - "fp8_6": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "causal", "no_bias"), - "fp8_7": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "fp8_8": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "fp8_1": ModelConfig(1, 512, 1, 64), + "fp8_2": ModelConfig(4, 512, 16, 64), + "fp8_3": ModelConfig(1, 2048, 1, 128), + "fp8_4": ModelConfig(2, 2048, 24, 128), + "fp8_5": ModelConfig(1, 512, 1, 64, attn_mask_type="causal"), + "fp8_6": ModelConfig(4, 512, 16, 64, attn_mask_type="causal"), + "fp8_7": ModelConfig(1, 2048, 1, 128, attn_mask_type="causal"), + "fp8_8": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"), } param_types_fp8 = [torch.float16, torch.bfloat16] cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1")) @@ -2197,6 +2273,18 @@ def test_custom_mha_fp8_vs_f16(dtype, model): config = model_configs_fp8[model] + # Test backend availability + is_training = True + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=torch.float8_e4m3fn, + qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd", + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not (fused_attn_backends and unfused_attn_supported): + pytest.skip("Not enough backends to run this test with.") + fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention") unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention") diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py similarity index 73% rename from tests/pytorch/fused_attn/test_fused_attn_with_cp.py rename to tests/pytorch/attention/test_attention_with_cp.py index edf518d6b..3f98aa318 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -6,6 +6,9 @@ import os import subprocess +import sys +import pathlib + import pytest import torch from torch.utils.cpp_extension import IS_HIP_EXTENSION @@ -14,26 +17,28 @@ get_cudnn_version, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils -from test_fused_attn import ModelConfig + +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import ModelConfig, get_available_attention_backends + +# Initialize RNG state +seed = 1234 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) model_configs_flash_attn = { # test: b, h, hg, d, sq, skv, p, mask, bias - "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA - "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA - "cp_1_2": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) - ), # MHA - "cp_1_3": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512) - ), # MHA - "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA - "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA + "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA + "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA + "cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA + "cp_1_3": ModelConfig(2, 4096, 12, 128, window_size=(512, 512)), # MHA + "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA + "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA "cp_2_2": ModelConfig( - 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) - ), # GQA - "cp_2_3": ModelConfig( - 2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512) + 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) ), # GQA + "cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA } @@ -45,7 +50,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): "--nproc-per-node=" + str(num_gpus_per_node), ] te_path = os.getenv("TE_PATH", "/opt/transformerengine") - script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py") + script_path = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp.py") args.append(script_path) for k, v in kwargs.items(): args.append(f"{k}={v}") @@ -94,32 +99,36 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): model_configs_fused_attn = { # test: b, h, hg, d, sq, skv, p, mask, bias - "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA - "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA - "cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA - "cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA - "cp_1_4": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA + "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA + "cp_1_2": ModelConfig( + 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), # MHA - "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA - "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA - "cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA - "cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA + "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA + "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA + "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA + "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA + "cp_2_2": ModelConfig( + 2, + 4096, + 12, + 128, + num_gqa_groups=2, + attn_mask_type="causal", + attn_bias_type="post_scale_bias", + ), # GQA + "cp_2_3": ModelConfig( + 2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias" + ), # GQA "cp_2_4": ModelConfig( - 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) ), # GQA - "cp_3_0": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64 - ), # MLA - "cp_3_1": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=64 - ), # MLA + "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA + "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA "cp_3_2": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias", head_dim_v=64 - ), # MLA - "cp_3_3": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias", head_dim_v=64 + 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 ), # MLA + "cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA } @@ -178,6 +187,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("MLA CP currently only support KV P2P!") if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently does not support FP8 attention!") + dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtypes[dtype], + qkv_layout="_".join([qkv_format] * 3), + window_size=config.window_size, + context_parallel=True, + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("No attention backend available.") subprocess.run( get_bash_arguments( diff --git a/tests/pytorch/fused_attn/test_kv_cache.py b/tests/pytorch/attention/test_kv_cache.py similarity index 97% rename from tests/pytorch/fused_attn/test_kv_cache.py rename to tests/pytorch/attention/test_kv_cache.py index 967309459..288c5382e 100644 --- a/tests/pytorch/fused_attn/test_kv_cache.py +++ b/tests/pytorch/attention/test_kv_cache.py @@ -5,18 +5,14 @@ from collections import OrderedDict from typing import List import os +import sys +import pathlib import logging import math import pytest import torch -from test_fused_attn import ( - ModelConfig, - reset_rng_states, - _get_attention_backends, -) - from torch.distributions import Exponential from transformer_engine.pytorch import make_graphed_callables from transformer_engine.common import recipe @@ -34,26 +30,25 @@ is_bf16_compatible, ) -# Initialize RNG state -seed = 1234 -torch.manual_seed(seed) -torch.cuda.manual_seed(seed) -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import ( + ModelConfig, + reset_rng_states, + get_available_attention_backends, +) +# Reset RNG states +reset_rng_states() param_types = [torch.float16] if is_bf16_compatible(): param_types.append(torch.bfloat16) model_configs_infer = { - # test: b, h, hg, d, sq, skv, p, mask, bias - "infer_0": ModelConfig( - 4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16 - ), - "infer_1": ModelConfig( - 2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16 - ), + # test: b, sq, hq, dqk, + "infer_0": ModelConfig(4, 64, 16, 128, total_requests=8, max_ctx_len=16), + "infer_1": ModelConfig(2, 66, 16, 256, num_gqa_groups=4, total_requests=6, max_ctx_len=16), } qkv_formats = ["bshd", "sbhd", "thd"] @@ -470,7 +465,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2) if is_paged: qkv_layout = "paged_kv_" + qkv_layout - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, diff --git a/tests/pytorch/distributed/test_sanity.py b/tests/pytorch/distributed/test_sanity.py new file mode 100644 index 000000000..39494a92b --- /dev/null +++ b/tests/pytorch/distributed/test_sanity.py @@ -0,0 +1,121 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pathlib +import sys +import pytest +import torch +import transformer_engine +from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention +from transformer_engine.pytorch import TransformerLayer, Linear + +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import ModelConfig + +model_configs = { + "small": ModelConfig(2, 10, 2, 16), +} + + +@pytest.mark.parametrize("model", ["small"]) +@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention", "Linear"]) +def test_current_device(model, module): + """Test cases where current device is different from tensor device""" + + num_devices = torch.cuda.device_count() + assert num_devices > 1, "This test requires more than one GPU!" + tensor_device = num_devices - 1 + dtype = torch.bfloat16 + config = model_configs[model] + + args = [] + kwargs = {} + bwd_args = [] + if module == "TransformerLayer": + model = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_heads, + params_dtype=dtype, + attn_input_format="thd", + self_attn_mask_type="padding", + device=f"cuda:{tensor_device}", + ) + num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item() + args = [ + torch.randn( + (num_tokens, config.hidden_size), + dtype=dtype, + device=f"cuda:{tensor_device}", + requires_grad=True, + ) + ] + cu_seqlens_q, cu_seqlens_kv = [ + torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2) + ] + kwargs["cu_seqlens_q"] = cu_seqlens_q + kwargs["cu_seqlens_kv"] = cu_seqlens_kv + kwargs["max_seqlen_q"] = config.max_seqlen_q + kwargs["max_seqlen_kv"] = config.max_seqlen_kv + if module == "DotProductAttention": + model = DotProductAttention( + config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding" + ) + num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item() + args = [ + torch.randn( + num_tokens, + config.num_heads, + config.head_dim_qk, + dtype=dtype, + device=tensor_device, + requires_grad=True, + ) + for _ in range(3) + ] + cu_seqlens_q, cu_seqlens_kv = [ + torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2) + ] + kwargs["cu_seqlens_q"] = cu_seqlens_q + kwargs["cu_seqlens_kv"] = cu_seqlens_kv + kwargs["max_seqlen_q"] = config.max_seqlen_q + kwargs["max_seqlen_kv"] = config.max_seqlen_kv + bwd_args = [torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=tensor_device)] + elif module == "Linear": + model = Linear( + config.hidden_size, + 4 * config.hidden_size, + params_dtype=dtype, + device=f"cuda:{tensor_device}", + ) + args = [ + torch.randn( + (config.max_seqlen_q, config.batch_size, config.hidden_size), + dtype=dtype, + device=f"cuda:{tensor_device}", + requires_grad=True, + ) + ] + + current_device_before = torch.cuda.current_device() + out = model(*args, **kwargs) + if module == "DotProductAttention": + out.backward(*bwd_args) + else: + loss = out.sum() + loss.backward() + current_device_after = torch.cuda.current_device() + tensor_device_out = out.get_device() + tensor_device_grad = args[0].grad.get_device() + + assert ( + current_device_after == current_device_before + ), "The current device should not have changed!" + assert ( + tensor_device_out == tensor_device + ), "The output tensor should be the same as the input tensors!" + assert ( + tensor_device_grad == tensor_device + ), "The gradient tensor should be the same as the input tensors!" diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 59383f21b..322c57c5d 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -10,6 +10,8 @@ import transformer_engine.pytorch as te from transformer_engine.common import recipe from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends +from utils import ModelConfig, get_available_attention_backends # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -22,10 +24,13 @@ recipe.DelayedScaling(), ] -SIZE = 512 -NUM_HEADS = 8 -NUM_LAYERS = 5 -EPSILON = 0.1 +model_config = { + "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1), +} +SIZE = model_config["small"].hidden_size +NUM_HEADS = model_config["small"].num_heads +NUM_LAYERS = model_config["small"].num_layers +EPSILON = model_config["small"].eps # Flash attention saves some internal tensor for the backward pass # that cannot be offloaded to CPU. @@ -130,6 +135,18 @@ def test_cpu_offload(fp8_recipe, model_key) -> None: if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if model_key in ["multihead_attention", "transformer_layer"]: + available_backends, *_ = get_available_attention_backends( + model_config["small"], + qkv_dtype=torch.bfloat16, + qkv_layout="sbhd_sbhd_sbhd", + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("Fused attention backend not available.") + os.environ["NVTE_FLASH_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + without_offloading = _measure_memory_between_forward_and_backward( models_list, fp8_recipe, False ) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 7bfe506f2..83837eafd 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -23,7 +23,7 @@ from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine.pytorch.ops as te_ops from transformer_engine.common import recipe - +from utils import ModelConfig, reset_rng_states # Check if FP8 is supported. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -32,27 +32,12 @@ ) mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +# Reset RNG states. +reset_rng_states() -# Record initial RNG state. -seed = 1234 -torch.manual_seed(seed) -torch.cuda.manual_seed(seed) -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() - - -@dataclass -class ModelConfig: - """Data tensor dimensions within Transformer model""" - - sequence_length: int - batch_size: int - hidden_size: int - num_heads: int - kv_channels: int - - -model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} +model_configs = { + "small": ModelConfig(32, 2, 2, 32), +} fp8_recipes = [ recipe.DelayedScaling(), @@ -67,12 +52,6 @@ class ModelConfig: dtypes.append(torch.bfloat16) -def reset_rng_states() -> None: - """Revert to initial RNG state.""" - torch.set_rng_state(_cpu_rng_state) - torch.cuda.set_rng_state(_cuda_rng_state) - - @pytest.fixture(autouse=True) def reset_global_fp8_state(): yield @@ -107,7 +86,7 @@ def generate_data( """Generate synthetic data.""" gen_func = torch.ones if warmup else torch.randn return gen_func( - model_config.sequence_length, + model_config.max_seqlen_q, model_config.batch_size, model_config.hidden_size, device="cuda", @@ -389,7 +368,7 @@ def generate_data_for_dot_product_attention( gen_func = torch.ones if warmup else torch.randn return [ gen_func( - model_config.sequence_length, + model_config.max_seqlen_q, model_config.batch_size, model_config.num_heads, model_config.kv_channels, @@ -483,8 +462,8 @@ def _test_cuda_graphs_with_kwargs( ( model_config.batch_size, 1, - model_config.sequence_length, - model_config.sequence_length, + model_config.max_seqlen_q, + model_config.max_seqlen_kv, ), dtype=torch.bool, device="cuda", @@ -510,8 +489,8 @@ def _test_cuda_graphs_with_kwargs( ( model_config.batch_size, 1, - model_config.sequence_length, - model_config.sequence_length, + model_config.max_seqlen_q, + model_config.max_seqlen_kv, ), dtype=torch.bool, device="cuda", diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 1787ab191..a519a711a 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -47,11 +47,13 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils as fa_utils from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm +from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.pytorch.utils import get_device_compute_capability, get_cudnn_version from transformer_engine.common import recipe import transformer_engine_torch as tex +from utils import ModelConfig, reset_rng_states, get_available_attention_backends # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -63,11 +65,8 @@ sm_80plus = get_device_compute_capability() >= (8, 0) seed = 1234 -torch.manual_seed(seed) -torch.cuda.manual_seed(seed) -# Record initial RNG state from script run. -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() +# Reset RNG states. +reset_rng_states() if torch.__version__ >= '2.7.0': torch._dynamo.config.recompile_limit = 16 @@ -83,24 +82,12 @@ def rocm_attn_backend() -> tuple[bool, bool, bool]: int(os.getenv("NVTE_FUSED_ATTN_CK", "1")) != 0) -class ModelConfig: - def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len): - self.hidden_size = hidden_size - self.eps = eps - self.num_attention_heads = num_attention_heads - self.embed = embed - self.num_layers = num_layers - self.seq_len = seq_len - - model_configs = { - "small": ModelConfig(128, 1e-5, 8, 36, 4, 128), - "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048), + "small": ModelConfig(1, 128, 8, 16, num_layers=4), + "126m": ModelConfig(1, 2048, 12, 64, num_layers=12), } - model_configs_inference = { - # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len - "126m": ModelConfig(768, 1e-5, 12, 64, 12, 256), + "126m": ModelConfig(1, 256, 12, 64, num_layers=12), } backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"] module_inference = ["TransformerLayer", "MultiheadAttention"] @@ -142,6 +129,26 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq ] +def is_fused_attn_available( + config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True +): + # backup the NVTE_FUSED_ATTN_* envs + if IS_HIP_EXTENSION: + from utils import EnvVarCleaner + env = EnvVarCleaner(["NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN", + "NVTE_FUSED_ATTN_CK", "NVTE_FUSED_ATTN_AOTRITON", + "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3"]) + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + is_training=is_training, + ) + if IS_HIP_EXTENSION: + return (FusedAttnBackend["AOTriton"] in fused_attn_backends) or (FusedAttnBackend["CK"] in fused_attn_backends) + return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends + + def get_causal_attn_mask(sq: int) -> torch.Tensor: return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() @@ -204,12 +211,6 @@ def assert_allclose( raise AssertionError(msg) -def reset_rng_states() -> None: - """revert back to initial RNG state.""" - torch.set_rng_state(_cpu_rng_state) - torch.cuda.set_rng_state(_cuda_rng_state) - - @pytest.fixture(autouse=True) def reset_global_fp8_state(): yield @@ -584,13 +585,13 @@ def _test_e2e_selective_recompute( block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, apply_residual_connection_post_layernorm=False, output_layernorm=False, params_dtype=dtype, @@ -599,13 +600,13 @@ def _test_e2e_selective_recompute( ) te_inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) te_inp_hidden_states.retain_grad() - te_inp_attn_mask = get_causal_attn_mask(config.seq_len) + te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) with fp8_autocast(enabled=fp8, fp8_recipe=recipe): te_out = block( @@ -680,13 +681,13 @@ def _test_e2e_full_recompute( block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, apply_residual_connection_post_layernorm=False, output_layernorm=False, params_dtype=dtype, @@ -695,14 +696,14 @@ def _test_e2e_full_recompute( ) te_inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=use_reentrant, ) if use_reentrant: te_inp_hidden_states.retain_grad() - te_inp_attn_mask = get_causal_attn_mask(config.seq_len) + te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if recompute: @@ -818,13 +819,13 @@ def _test_e2e_checkpointing_get_model(config, dtype): return TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, apply_residual_connection_post_layernorm=False, output_layernorm=False, params_dtype=dtype, @@ -836,7 +837,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= reset_rng_states() te_inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -866,14 +867,14 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= if p.requires_grad: param_grads.append(p.grad.clone()) - global _cpu_rng_state, _cuda_rng_state _cpu_rng_state = torch.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state() del block block = _test_e2e_checkpointing_get_model(config, dtype) block.load_state_dict(torch.load(path, weights_only=False)) - reset_rng_states() + torch.set_rng_state(_cpu_rng_state) + torch.cuda.set_rng_state(_cuda_rng_state) for p in block.parameters(): if p.requires_grad: @@ -906,6 +907,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= @pytest.mark.parametrize("model", ["126m"]) def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] + if not is_fused_attn_available(config, dtype): + pytest.skip("No attention backend available.") outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) @@ -932,13 +935,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): reset_rng_states() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) inp_hidden_states.retain_grad() - inp_attn_mask = get_causal_attn_mask(config.seq_len) + inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) out = block(inp_hidden_states, attention_mask=inp_attn_mask) loss = out.sum() @@ -958,11 +961,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): @pytest.mark.parametrize("parallel_attention_mlp", all_boolean) def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): config = model_configs[model] + if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False): + pytest.skip("No attention backend available.") te_gpt = TransformerLayer( hidden_size=config.hidden_size, ffn_hidden_size=4 * config.hidden_size, - num_attention_heads=config.num_attention_heads, + num_attention_heads=config.num_heads, layernorm_epsilon=config.eps, attention_dropout=0.1, hidden_dropout=0.1, @@ -977,7 +982,7 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): TorchGPT( config.hidden_size, config.eps, - config.num_attention_heads, + config.num_heads, parallel_attention_mlp=parallel_attention_mlp, ) .to(dtype=dtype) @@ -1038,13 +1043,13 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): reset_rng_states() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) inp_hidden_states.retain_grad() - inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None + inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) if mask_type == "causal" else None forward_kwargs = {} if te: @@ -1069,10 +1074,12 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): @pytest.mark.parametrize("mask_type", mask_types) def test_mha_accuracy(dtype, bs, model, mask_type): config = model_configs[model] + if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False): + pytest.skip("No attention backend available.") te_mha = MultiheadAttention( config.hidden_size, - config.num_attention_heads, + config.num_heads, fuse_qkv_params=True, params_dtype=dtype, qkv_weight_interleaved=False, @@ -1083,7 +1090,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type): torch_mha = ( TorchMHA( config.hidden_size, - config.num_attention_heads, + config.num_heads, ) .to(dtype=dtype) .cuda() @@ -1129,7 +1136,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False, FP8GlobalStateManager.reset() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -1161,7 +1168,7 @@ def _test_granular_accuracy_with_fp8(block, bs, dtype, config): reset_rng_states() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -1185,11 +1192,12 @@ def _test_dpa_accuracy(block, bs, dtype, config): reset_rng_states() mask = torch.triu( - torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1 + torch.ones(config.max_seqlen_q, config.max_seqlen_kv, dtype=torch.bool, device="cuda"), + diagonal=1, ) query, key, value = [ torch.randn( - (config.seq_len, bs, config.num_attention_heads, config.embed), + (config.max_seqlen_q, bs, config.num_heads, config.kv_channels), dtype=dtype, device="cuda", requires_grad=True, @@ -1218,8 +1226,8 @@ def test_dpa_accuracy(dtype, bs, model): te_dpa = ( DotProductAttention( - config.num_attention_heads, - config.embed, + config.num_heads, + config.kv_channels, attention_dropout=0.0, # disable dropout, FU uses rng differently ) .to(dtype=dtype) @@ -1228,7 +1236,7 @@ def test_dpa_accuracy(dtype, bs, model): torch_dpa = ( TorchDotProductAttention( - config.embed, + config.kv_channels, 0.0, # dropout ) .to(dtype=dtype) @@ -1434,7 +1442,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): pytest.skip("DelayedScaling recipe is not supported with save_original_input") config = model_configs[model] - if config.seq_len % 16 != 0 and fp8: + if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -1955,7 +1963,7 @@ def _test_grouped_linear_accuracy( FP8GlobalStateManager.reset() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -1968,14 +1976,14 @@ def _test_grouped_linear_accuracy( split_size = 16 if recipe.mxfp8(): split_size = 128 - m = config.seq_len // split_size + m = config.max_seqlen_q // split_size dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() dist.append(dist[-1]) # Manually add a zero m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) m_splits = m_splits * split_size - assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms + assert m_splits.sum() == config.max_seqlen_q and len(m_splits) == num_gemms else: - m_splits = torch.tensor([config.seq_len]) + m_splits = torch.tensor([config.max_seqlen_q]) with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if isinstance(block, GroupedLinear): @@ -2045,7 +2053,7 @@ def test_grouped_linear_accuracy( pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] - if config.seq_len % 16 != 0 and fp8: + if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -2149,7 +2157,7 @@ def test_grouped_linear_accuracy_save_original_input( pytest.skip("DelayedScaling recipe is not supported with save_original_input") config = model_configs[model] - if config.seq_len % 16 != 0 and fp8: + if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -2297,14 +2305,14 @@ def _generate_random_numbers(n, total_sum): FP8GlobalStateManager.reset() inp_hidden_states = torch.randn( - (config.seq_len * bs, config.hidden_size), + (config.max_seqlen_q * bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) inp_hidden_states.retain_grad() - m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs) + m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs) with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if isinstance(block, TorchGroupedLinearWithPadding): @@ -2357,7 +2365,7 @@ def test_padding_grouped_linear_accuracy( pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] - if config.seq_len % 16 != 0 and fp8: + if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -2434,7 +2442,7 @@ def test_padding_grouped_linear_accuracy_save_original_input( pytest.skip("DelayedScaling recipe is not supported with save_original_input") config = model_configs[model] - if config.seq_len % 16 != 0 and fp8: + if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -2491,9 +2499,11 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): # Placeholders used for graph capture. static_input = torch.randn( - config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True + config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True + ) + static_target = torch.randn( + config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype ) - static_target = torch.randn(config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype) real_input = torch.rand_like(static_input) real_target = torch.rand_like(static_target) @@ -2563,7 +2573,7 @@ def test_gpt_cuda_graph(dtype, bs, model): block_args = ( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, ) block_kwargs = dict( layernorm_epsilon=config.eps, @@ -2571,7 +2581,7 @@ def test_gpt_cuda_graph(dtype, bs, model): output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, @@ -2606,13 +2616,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, apply_residual_connection_post_layernorm=False, output_layernorm=False, params_dtype=dtype, @@ -2621,13 +2631,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): ) te_inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) te_inp_hidden_states.retain_grad() - te_inp_attn_mask = get_causal_attn_mask(config.seq_len) + te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) with fp8_autocast(enabled=True, fp8_recipe=recipe): te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) @@ -2689,13 +2699,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): block_sbhd = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0, attention_dropout=0, - kv_channels=config.embed, + kv_channels=config.kv_channels, params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, @@ -2710,13 +2720,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): block_bshd = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0, attention_dropout=0, - kv_channels=config.embed, + kv_channels=config.kv_channels, params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, @@ -2728,13 +2738,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): block_thd = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0, attention_dropout=0, - kv_channels=config.embed, + kv_channels=config.kv_channels, params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, @@ -2749,15 +2759,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical" x_sbhd = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) x_bshd = x_sbhd.transpose(0, 1).contiguous() - x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous() - x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len + x_thd = x_bshd.reshape(bs * config.max_seqlen_q, config.hidden_size).contiguous() + x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.max_seqlen_q # To make sure forward is also identical (just in case some module decides # to act fancy) @@ -2802,180 +2812,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): x_thd, cu_seqlens_q=x_thd_cumsum, cu_seqlens_kv=x_thd_cumsum, - max_seqlen_q=config.seq_len, - max_seqlen_kv=config.seq_len, + max_seqlen_q=config.max_seqlen_q, + max_seqlen_kv=config.max_seqlen_kv, ) torch.testing.assert_close( y_bshd, - y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), - ) - - -@pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model_key", model_configs_inference.keys()) -@pytest.mark.parametrize("use_RoPE", all_boolean) -@pytest.mark.parametrize("input_format", input_formats_inference) -@pytest.mark.parametrize("module", module_inference) -@pytest.mark.parametrize("backend", backends_inference) -@pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.usefixtures("reset_test_envs") -def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged): - if ((backend == "FlashAttention" and os.getenv("NVTE_FLASH_ATTN", "1") == "0") or - (backend == "FusedAttention" and os.getenv("NVTE_FUSED_ATTN", "1") == "0")): - pytest.skip(f"{backend} is disabled") - - if backend == "FlashAttention" and not fa_utils.is_installed: - pytest.skip("FlashAttention is not installed") - - if IS_HIP_EXTENSION and backend == "FusedAttention": - if is_paged: - pytest.skip("FusedAttention does not support KV cache with paging on ROCm") - if os.getenv("NVTE_FUSED_ATTN_CK", "1") == "0": - pytest.skip("CK FusedAttention backend is disabled") - - reset_rng_states() - - if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32: - pytest.skip("FusedAttention and FlashAttention do not support FP32") - if use_RoPE: - pytest.skip("KV cache does not support starting positions for RoPE") - if ( - not IS_HIP_EXTENSION and - backend == "FusedAttention" - and get_device_compute_capability() == (8, 9) - and get_cudnn_version() < (9, 12, 0) - ): - pytest.skip("Skip KV cache for sm89 and cuDNN < 9.12") - - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "0" - os.environ["NVTE_UNFUSED_ATTN"] = "0" - - if backend == "FlashAttention": - os.environ["NVTE_FLASH_ATTN"] = "1" - elif backend == "FusedAttention": - os.environ["NVTE_FUSED_ATTN"] = "1" - elif backend == "UnfusedAttention": - os.environ["NVTE_UNFUSED_ATTN"] = "1" - - config = model_configs_inference[model_key] - - S = config.seq_len - B = bs - H = config.num_attention_heads - D = config.hidden_size - head_size = config.embed - layer_number = 1 - - # Limits the max size of KV-cache - B_max = B - S_max = S - - if module == "TransformerLayer": - model = TransformerLayer( - hidden_size=D, - ffn_hidden_size=4 * D, - num_attention_heads=H, - attn_input_format=input_format, - self_attn_mask_type="causal", - enc_dec_attn_mask_type="causal", - layer_number=layer_number, - attention_dropout=0.0, - params_dtype=dtype, - device="cuda", - ).eval() - else: - model = ( - MultiheadAttention( - hidden_size=D, - num_attention_heads=H, - qkv_format=input_format, - layer_number=layer_number, - attention_dropout=0.0, - attn_mask_type="causal", - params_dtype=dtype, - ) - .cuda() - .eval() + y_thd.reshape(bs, config.max_seqlen_q, config.hidden_size).contiguous(), ) - inference_params = InferenceParams( - max_batch_size=B_max, - max_sequence_length=S_max, - num_heads_kv=H, - head_dim_k=head_size, - dtype=dtype, - is_paged=is_paged, - total_num_pages=int(B_max * S_max / 256), - page_size=256, - ) - - rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda") - - input = torch.randn((S, B, D), dtype=dtype, device="cuda") - if input_format == "bshd": - input = input.transpose(0, 1).contiguous() - - incremental_output = torch.zeros_like(input) - - # Generate output for the entire sequence - full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None) - - # Incrementaly generate outputs using KV-cache - step_dict = OrderedDict(zip(list(range(B)), [1] * B)) - for i in range(S): - inference_params.pre_step(step_dict) - - if input_format == "sbhd": - incremental_input = input[i].view(1, B, D) - else: - incremental_input = input[:, i, :].view(B, 1, D) - - seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda") - cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda") - cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) - cu_seqlens_kv = cu_seqlens_q.clone() - - mask_type = "padding" - kwargs = {} - if module == "TransformerLayer": - kwargs["self_attn_mask_type"] = mask_type - else: - kwargs["attn_mask_type"] = mask_type - line_output = model( - hidden_states=incremental_input, - inference_params=inference_params, - rotary_pos_emb=rotary_freqs if use_RoPE else None, - **kwargs, - max_seqlen_q=1, - max_seqlen_kv=S, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - ) - - if input_format == "sbhd": - incremental_output[i, :, :] = line_output.view(B, D) - else: - incremental_output[:, i, :] = line_output.view(B, D) - - if module == "TransformerLayer": - atol = { - torch.float32: 5e-3, - torch.half: 5e-3, - torch.bfloat16: 5e-2, - } - else: - atol = { - torch.float32: 1e-3, - torch.half: 1e-3, - torch.bfloat16: 1e-2, - } - - # Check if the fully generated output matches the one generated incrementally - assert_allclose(full_output, incremental_output, atol[dtype]) - @pytest.mark.parametrize( "shape", diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 9fbadd4b9..b26ca05d4 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -50,7 +50,7 @@ from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.distributed import checkpoint -from utils import dtype_tols +from utils import ModelConfig, dtype_tols # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -63,8 +63,6 @@ seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0")) @@ -109,37 +107,22 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor: return torch.min(amax_history, dim=0).values -def reset_rng_states() -> None: - """revert back to initial RNG state.""" - global _cpu_rng_state, _cuda_rng_state - torch.set_rng_state(_cpu_rng_state) - torch.cuda.set_rng_state(_cuda_rng_state) - - -@dataclass -class ModelConfig: - """Transformer model configuration""" - - num_layers: int - seq_len: int - batch_size: int - hidden_size: int - num_attention_heads: int - kv_channels: Optional[int] = None - - def is_fp8_supported(self): - if self.seq_len * self.batch_size % 16: - return False - if self.hidden_size % 16: - return False - return True +def is_fp8_supported(config: ModelConfig): + if ( + config.max_seqlen_q * config.batch_size % 16 + or config.max_seqlen_kv * config.batch_size % 16 + ): + return False + if config.hidden_size % 16 or config.hidden_size_kv % 16: + return False + return True model_configs = { - "126m": ModelConfig(12, 2048, 2, 768, 12), - "small": ModelConfig(2, 32, 2, 64, 2), - "weird": ModelConfig(2, 37, 3, 69, 3), - "large": ModelConfig(1, 128, 2, 512, 4, 128), + "126m": ModelConfig(2, 2048, 12, 64, num_layers=12), + "small": ModelConfig(2, 32, 2, 32, num_layers=2), + "weird": ModelConfig(3, 37, 3, 23, num_layers=2), + "large": ModelConfig(2, 128, 4, 128, num_layers=1), } fp8_recipes = [ @@ -188,7 +171,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): # Placeholders used for capture. static_input = torch.randn( - config.seq_len, + config.max_seqlen_q, config.batch_size, config.hidden_size, device="cuda", @@ -196,7 +179,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): requires_grad=True, ) static_target = torch.randn( - config.seq_len, config.batch_size, config.hidden_size, device="cuda", dtype=dtype + config.max_seqlen_q, config.batch_size, config.hidden_size, device="cuda", dtype=dtype ) real_input = torch.rand_like(static_input) @@ -241,7 +224,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=torch.float32, device="cuda", requires_grad=True, @@ -249,7 +232,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states.retain_grad() te_inp_attn_mask = torch.randint( 2, - (1, 1, config.seq_len, config.seq_len), + (1, 1, config.max_seqlen_q, config.max_seqlen_kv), dtype=torch.bool, device="cuda", ) @@ -276,14 +259,14 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) te_inp_attn_mask = torch.randint( 2, - (1, 1, config.seq_len, config.seq_len), + (1, 1, config.max_seqlen_q, config.max_seqlen_kv), dtype=torch.bool, device="cuda", ) @@ -316,7 +299,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -342,7 +325,7 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -350,7 +333,7 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_attn_mask = torch.randint( 2, - (config.batch_size, 1, 1, config.seq_len), + (config.batch_size, 1, 1, config.max_seqlen_q), dtype=torch.bool, device="cuda", ) @@ -368,21 +351,21 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) te_inp_attn_mask = torch.randint( 2, - (1, 1, config.seq_len, config.seq_len), + (1, 1, config.max_seqlen_q, config.max_seqlen_kv), dtype=torch.bool, device="cuda", ) enc_dec_attn_mask = torch.randint( 2, - (config.batch_size, 1, 1, config.seq_len), + (config.batch_size, 1, 1, config.max_seqlen_kv), dtype=torch.bool, device="cuda", ) @@ -410,7 +393,7 @@ def _test_sanity_common( pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") te_inp = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=not skip_dgrad, @@ -438,7 +421,7 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad) pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") te_inp = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), device="cuda", requires_grad=True, ) @@ -499,7 +482,7 @@ def test_sanity_layernorm_linear( pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -533,7 +516,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -560,7 +543,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ pytest.skip("Quantized model parameters are not supported in debug mode.") config = model_configs[model] ffn_hidden_size = 4 * config.hidden_size - num_tokens = bs * config.seq_len + num_tokens = bs * config.max_seqlen_q if fp8_recipe is not None: if not fp8_available: @@ -569,7 +552,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") use_fp8 = fp8_recipe is not None @@ -605,7 +588,7 @@ def test_sanity_grouped_linear( ffn_hidden_size = 4 * config.hidden_size # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527. bs = bs * 16 - num_tokens = bs * config.seq_len * (num_gemms - 1) + num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) if fp8_recipe is not None: if not fp8_available: @@ -614,7 +597,7 @@ def test_sanity_grouped_linear( pytest.skip(reason_for_no_mxfp8) if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") use_fp8 = fp8_recipe is not None @@ -626,7 +609,7 @@ def test_sanity_grouped_linear( inp_hidden_states = torch.randn( num_tokens, config.hidden_size, dtype=dtype, requires_grad=True ).cuda() - m_splits = [bs * config.seq_len] * num_gemms + m_splits = [bs * config.max_seqlen_q] * num_gemms if empty_split == "first": m_splits[0] = 0 elif empty_split == "last": @@ -670,7 +653,7 @@ def test_sanity_layernorm_mlp( pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -727,7 +710,7 @@ def test_sanity_gpt( pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -737,7 +720,7 @@ def test_sanity_gpt( block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -796,7 +779,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -806,7 +789,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -857,7 +840,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -867,7 +850,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -916,7 +899,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -926,7 +909,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -953,7 +936,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -963,7 +946,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -993,7 +976,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -1003,7 +986,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -1036,7 +1019,7 @@ def test_sanity_gradient_accumulation_fusion( pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -1046,7 +1029,7 @@ def test_sanity_gradient_accumulation_fusion( block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -1083,7 +1066,7 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm pytest.skip(reason_for_no_mxfp8) if fp8_recipe.float8_block_scaling(): pytest.skip("cuda graph not supported for float8_block_scaling recipe") - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -1093,7 +1076,7 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -1165,134 +1148,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): torch.cuda.synchronize() -#TODO: rocm fused_attn backends does not support fp8 yet -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.skipif(IS_HIP_EXTENSION or get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.") -@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") -@pytest.mark.parametrize("model", ["large"]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_sanity_attention_extra_state(model, dtype): - config = model_configs[model] - outputs = _run_attention_extra_state(dtype, config, checkpoint=False) - outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True) - outputs_checkpoint_v1_6 = _run_attention_extra_state( - dtype, config, mimic_v1_6=True, checkpoint=True - ) - - # Check that results match - tols = dtype_tols(dtype) - if dtype in (torch.float16, torch.bfloat16): - tols.update(dict(rtol=2e-2, atol=2e-3)) - for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)): - torch.testing.assert_close( - test, - ref, - **tols, - ) - for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)): - torch.testing.assert_close( - test, - ref, - **tols, - ) - - -def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): - steps = 10 - path = "checkpoint.pt" - fp8_enabled = True - fp8_recipe = recipe.DelayedScaling( - margin=0, - fp8_format=recipe.Format.HYBRID, - amax_history_len=1, - amax_compute_algo="most_recent", - fp8_dpa=fp8_enabled, - fp8_mha=False, - ) - - reset_rng_states() - hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), - dtype=dtype, - device="cuda", - requires_grad=True, - ) - - def get_model(dtype, config): - sigma = 0.023 - init_method = init_method_normal(sigma) - output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - - with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe): - block = TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.0, - attention_dropout=0.0, - fuse_qkv_params=True, - params_dtype=dtype, - device="cuda", - ) - return block - - block = get_model(dtype, config) - for i in range(steps // 2): - with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): - output = block(hidden_states, None) - loss = output.sum() - loss.backward() - - if checkpoint: - sd = block.state_dict() - if mimic_v1_6: - sd["self_attention.core_attention.fused_attention._extra_state"] = sd[ - "self_attention.core_attention._extra_state" - ] - del sd["self_attention.core_attention._extra_state"] - torch.save(sd, path) - - param_grads = [] - for p in block.parameters(): - if p.requires_grad: - param_grads.append(p.grad.clone()) - - _cpu_rng_state_new = torch.get_rng_state() - _cuda_rng_state_new = torch.cuda.get_rng_state() - - del block - block = get_model(dtype, config) - block.load_state_dict(torch.load(path, weights_only=False)) - torch.set_rng_state(_cpu_rng_state_new) - torch.cuda.set_rng_state(_cuda_rng_state_new) - - for p in block.parameters(): - if p.requires_grad: - p.grad = param_grads.pop(0) - - assert not param_grads, "Oops!" - - for i in range((steps + 1) // 2): - with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): - output = block(hidden_states, None) - loss = output.sum() - loss.backward() - - torch.cuda.synchronize() - - if os.path.exists(path): - os.remove(path) - - outputs = [output, hidden_states.grad] - for p in block.parameters(): - if p.requires_grad: - outputs.append(p.grad) - - return outputs - - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_replace_raw_data_for_float8tensor(): """Test the functionality of replace_raw_data""" diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 0c50592bd..48a49be58 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -6,6 +6,11 @@ from __future__ import annotations +import logging +import os +from contextlib import contextmanager + +import pytest import torch import transformer_engine @@ -13,6 +18,14 @@ import transformer_engine.pytorch as te from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type import transformer_engine_torch as tex +from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends +from transformer_engine.pytorch.attention.dot_product_attention.utils import ( + get_attention_backend, + AttentionParams, + AttentionLogging, +) +from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend +from torch.utils.cpp_extension import IS_HIP_EXTENSION torch_float8_e4m3_type = get_torch_float8_e4m3_type() torch_float8_e5m2_type = get_torch_float8_e5m2_type() @@ -111,3 +124,209 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]: if name == "fp8_block_scaling": return transformer_engine.common.recipe.Float8BlockScaling() raise ValueError(f"Unsupported quantization scheme ({name})") + + +# Cached RNG state +_rng_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + +def reset_rng_states() -> None: + """Revert to deterministic RNG state""" + global _rng_states + if _rng_states is None: + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + _rng_states = (torch.get_rng_state(), torch.cuda.get_rng_state()) + else: + cpu_rng_state, cuda_rng_state = _rng_states + torch.set_rng_state(cpu_rng_state) + torch.cuda.set_rng_state(cuda_rng_state) + + +class ModelConfig: + def __init__( + self, + batch_size: int, + max_seqlen_q: int, + num_heads: int, + head_dim_qk: int, + max_seqlen_kv: int = None, + num_gqa_groups: int = None, + head_dim_v: int = None, + dropout_p: float = 0.0, + attn_mask_type: str = "no_mask", + attn_bias_type: str = "no_bias", + alibi_type: str = "none", + bias_shape: str = "1hss", + window_size: Tuple[int, int] = (-1, -1), + total_requests: int = None, + max_ctx_len: int = None, + num_layers: int = 1, + eps: float = 1e-5, + ): + self.batch_size = batch_size + self.max_seqlen_q = max_seqlen_q + self.max_seqlen_kv = max_seqlen_q if max_seqlen_kv is None else max_seqlen_kv + self.num_heads = num_heads + self.num_gqa_groups = num_heads if num_gqa_groups is None else num_gqa_groups + self.head_dim_qk = head_dim_qk + self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v + if self.head_dim_qk == self.head_dim_v: + self.kv_channels = self.head_dim_qk + else: + self.kv_channels = (self.head_dim_qk, self.head_dim_v) + self.hidden_size = self.num_heads * self.head_dim_qk + self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v + self.dropout_p = dropout_p + self.attn_mask_type = attn_mask_type + self.attn_bias_type = attn_bias_type + self.alibi_type = alibi_type + self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross" + self.bias_shape = bias_shape + self.window_size = window_size + self.total_requests = total_requests + self.max_ctx_len = max_ctx_len + self.num_layers = num_layers + self.eps = eps + + +@contextmanager +def logging_context(highest_level=logging.WARNING): + previous_level = logging.root.manager.disable + logging.disable(highest_level) + try: + yield + finally: + logging.disable(previous_level) + +if IS_HIP_EXTENSION: + class EnvVarCleaner: + def __init__(self, envs_): + self.envs = envs_ + self.flags = {} + for env in self.envs: + if env in os.environ: + self.flags[env] = os.environ[env] + def __del__(self): + for env in self.envs: + if env in self.flags: + os.environ[env] = self.flags[env] + else: + os.environ.pop(env, None) + +def get_available_attention_backends( + config: ModelConfig, + qkv_dtype: torch.dtype, + qkv_layout: str, + window_size: Tuple[int, int] = (-1, -1), + pad_between_seqs: bool = False, + context_parallel: bool = False, + deterministic: bool = False, + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + is_training: bool = True, + inference_params: Optional[InferenceParams] = None, +) -> Tuple[List, List]: + """Check for all available attention backends that support a model configuration""" + + os.environ["NVTE_FLASH_ATTN"] = "1" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "1" + _attention_backends["backend_selection_requires_update"] = True + + alibi_slopes_shape = None + if config.attn_bias_type == "alibi" and config.alibi_type == "custom": + if config.bias_shape == "1hss": + alibi_slopes_shape = [config.num_heads] + if config.bias_shape == "bhss": + alibi_slopes_shape = [config.batch_size, config.num_heads] + + core_attention_bias_shape = ( + config.bias_shape if config.attn_bias_type == "post_scale_bias" else None + ) + core_attention_bias_requires_grad = False + # d=256 is supported by cuDNN 9.0+ for inference but not training + if ( + config.attn_bias_type == "post_scale_bias" + and config.head_dim_qk <= 128 + and config.head_dim_v <= 128 + ): + core_attention_bias_requires_grad = True + + fused_attn_backends = [] + available_backends = None + flash_attention_backend = None + fused_attention_backend = None + + def test(): + attention_params = AttentionParams( + qkv_dtype=qkv_dtype, + qkv_layout=qkv_layout, + batch_size=config.batch_size, + num_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + max_seqlen_q=config.max_seqlen_q, + max_seqlen_kv=config.max_seqlen_kv, + head_dim_qk=config.head_dim_qk, + head_dim_v=config.head_dim_v, + attn_mask_type=config.attn_mask_type, + window_size=window_size, + alibi_slopes_shape=alibi_slopes_shape, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias_shape=core_attention_bias_shape, + core_attention_bias_requires_grad=core_attention_bias_requires_grad, + pad_between_seqs=pad_between_seqs, + attention_dropout=config.dropout_p, + context_parallel=context_parallel, + deterministic=deterministic, + fp8=fp8, + fp8_meta=fp8_meta, + is_training=is_training, + inference_params=inference_params, + ) + ( + use_flash_attention, + use_fused_attention, + flash_attention_backend, + fused_attention_backend, + use_unfused_attention, + available_backends, + ) = get_attention_backend(attention_params) + # Set attention.py _attention_backends var using return value + # from get_attention_backend() + _attention_backends["use_flash_attention"] = use_flash_attention + _attention_backends["use_fused_attention"] = use_fused_attention + _attention_backends["flash_attention_backend"] = flash_attention_backend + _attention_backends["fused_attention_backend"] = fused_attention_backend + _attention_backends["use_unfused_attention"] = use_unfused_attention + _attention_backends["backend_selection_requires_update"] = False + return available_backends, flash_attention_backend, fused_attention_backend + + if IS_HIP_EXTENSION: + backends = {"AOTriton": "AOTRITON", "CK": "CK"} + if AttentionLogging._is_logging_setup is False: + AttentionLogging.setup_logging() + with logging_context(highest_level=AttentionLogging._log_level): + for i in backends.keys(): + for k in backends.keys(): + os.environ["NVTE_FUSED_ATTN_"+backends[k]] = "0" + os.environ["NVTE_FUSED_ATTN_"+backends[i]] = "1" + _attention_backends["backend_selection_requires_update"] = True + available_backends, flash_attention_backend, fused_attention_backend = test() + if fused_attention_backend == FusedAttnBackend[i]: + fused_attn_backends.append(fused_attention_backend) + for i in backends.keys(): + del os.environ["NVTE_FUSED_ATTN_"+backends[i]] + available_backends[1] = len(fused_attn_backends) > 0 + else: + backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} + if AttentionLogging._is_logging_setup is False: + AttentionLogging.setup_logging() + with logging_context(highest_level=AttentionLogging._log_level): + for i in range(3): + os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) + _attention_backends["backend_selection_requires_update"] = True + available_backends, flash_attention_backend, fused_attention_backend = test() + if fused_attention_backend == FusedAttnBackend[backends[i]]: + fused_attn_backends.append(fused_attention_backend) + return available_backends, flash_attention_backend, fused_attn_backends diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index bdbc97517..c424b3672 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -373,6 +373,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) set_source_files_properties(activation/gelu.cu activation/relu.cu activation/swiglu.cu + util/cast.cu PROPERTIES COMPILE_OPTIONS "--use_fast_math") endif() diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 8a73138e3..db39faa38 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -251,6 +251,18 @@ def _load_cudnn(): if found: return handle + # Attempt to locate libcudnn via ldconfig + libs = subprocess.check_output( + f"ldconfig -p | grep 'libcudnn{_get_sys_extension()}'", shell=True + ) + libs = libs.decode("utf-8").split("\n") + sos = [] + for lib in libs: + if "libcudnn" in lib and "=>" in lib: + sos.append(lib.split(">")[1].strip()) + if sos: + return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) + # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise return ctypes.CDLL(f"libcudnn{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) @@ -272,12 +284,12 @@ def _load_nvrtc(): return handle # Attempt to locate NVRTC via ldconfig - libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True) + libs = subprocess.check_output( + f"ldconfig -p | grep 'libnvrtc{_get_sys_extension()}'", shell=True + ) libs = libs.decode("utf-8").split("\n") sos = [] for lib in libs: - if "stub" in lib or "libnvrtc-builtins" in lib: - continue if "libnvrtc" in lib and "=>" in lib: sos.append(lib.split(">")[1].strip()) if sos: diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 483444751..b8b516f7c 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -165,10 +165,10 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, void *dataPtr = reinterpret_cast(reinterpret_cast(tensor.dptr) + (offset_elems * type_num_bits) / 8); - NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment), + NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_GMEM_ALIGNMENT), "Tensor data pointer must be 16B aligned"); - const int TMA_needed_size = (TMA_gmem_alignment * 8) / type_num_bits; + const int TMA_needed_size = (TMA_GMEM_ALIGNMENT * 8) / type_num_bits; NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits, "-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 39038724a..32e50b337 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -693,7 +693,8 @@ constexpr size_t scale_tensor_alignment_X_colwise = 128; constexpr size_t scale_tensor_alignment_Y_colwise = 4; // Alignment requirements for the Tensor Memory Accelerator (TMA) -constexpr int TMA_gmem_alignment = 16; // global memory address alignment +constexpr size_t TMA_GMEM_ALIGNMENT = 16; // global memory address alignment +constexpr size_t TMA_SHMEM_ALIGNMENT = 128; // shared memory address alignment inline bool is_aligned_ptr(const void *ptr, size_t alignment) { return reinterpret_cast(ptr) % alignment == 0; diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 9d4701730..bb30261b9 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -183,7 +183,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - !requires_64bit_ragged_offset) { + !requires_64bit_ragged_offset && + // 9.10.0: known bugs with SDPA FP8 + (cudnn_runtime_version != 91000)) { if (cudnn_runtime_version >= 8900) { backend = NVTE_Fused_Attn_Backend::NVTE_FP8; } else { @@ -239,20 +241,20 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1 (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 && layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) || - // 9.10: any head_dim + any arch + fprop + paged - // 9.10: any head_dim + any arch + fprop + non_paged + sq > 1 - // 9.10: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM} - (!is_training && cudnn_runtime_version >= 91000 && + // 9.10.2: any head_dim + any arch + fprop + paged + // 9.10.2: any head_dim + any arch + fprop + non_paged + sq > 1 + // 9.10.2: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM} + (!is_training && cudnn_runtime_version >= 91002 && (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 || (max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) || // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 91100)) && - // 9.11 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA - (!(cudnn_runtime_version == 91100 && is_training && sm_arch_ == 90 && head_dim_qk >= 128 && - head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) && - head_dim_qk != head_dim_v))) && + // 9.11/9.12 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA + (!((cudnn_runtime_version == 91100 || cudnn_runtime_version == 91200) && is_training && + sm_arch_ == 90 && head_dim_qk >= 128 && head_dim_v >= 128 && + !(head_dim_qk == 192 && head_dim_v == 128) && head_dim_qk != head_dim_v))) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || (cudnn_runtime_version >= 8906 && @@ -358,7 +360,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)))) && // check 64-bit ragged offset support - (supported_ragged_offset_size)) { + (supported_ragged_offset_size) && + // 9.10.0/9.10.1: known bugs with SDPA F16 + (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001)) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 31d2c0b74..68cc9ed60 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -29,6 +29,7 @@ #ifndef __HIP_PLATFORM_AMD__ #include "../cudnn_utils.h" #else +#include "../util/ptx.cuh" #include "../util/rocm_cast_kernels.cuh" #endif #include "../util/system.h" diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index a5a23c1c0..f051596a2 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -34,15 +34,9 @@ namespace transformer_engine { -template -__device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(T1 N, T2 M) { - return DIVUP(static_cast(N), static_cast(M)) * M; -} - namespace gated_kernels { #ifndef __HIP_PLATFORM_AMD__ -constexpr size_t ALIGNMENT_SIZE = 128; constexpr size_t CHUNK_DIM_Y = 128; constexpr size_t CHUNK_DIM_X = 128; constexpr size_t THREADS_PER_CHUNK = 512; @@ -84,18 +78,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float amax = 0; const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - extern __shared__ char dshmem_unaligned[]; - const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); - const uint64_t dshmem_aligned_as_uint = - DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; - char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; constexpr size_t buff_size_aligned_in = - DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); constexpr size_t buff_size_aligned_out = - DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; @@ -104,8 +99,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t in_mem = in_act_mem + in_gate_mem; constexpr size_t out_act_mem = buff_size_aligned_out; - - // const size_t in_transaction_size = grad_mem + in_mem; constexpr size_t in_transaction_size = buff_elems * sizeof(IType); // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned @@ -277,9 +270,34 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +namespace mxfp8_kernel { + +constexpr size_t CHUNK_DIM_Y = 64; +constexpr size_t CHUNK_DIM_X = 64; +constexpr size_t THREADS_PER_CHUNK_COLWISE = 128; +constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = CHUNK_DIM_X; + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 32; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = 32; +constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; +constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; +static_assert(BUFF_DIM_Y == 32); + +constexpr size_t PACK_SIZE = 4; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 + template + bool ROWWISE_SCALING, bool COLWISE_SCALING, size_t THREADS_PER_CHUNK> __global__ void __launch_bounds__(THREADS_PER_CHUNK) cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, const __grid_constant__ CUtensorMap tensor_map_input_act, @@ -292,43 +310,73 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; - constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; - constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 - constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + static_assert(STAGES >= 1); - constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 - constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 + constexpr bool IS_CACHED_ACT_OP = ROWWISE_SCALING && COLWISE_SCALING; + constexpr bool ONLY_COLWISE_SCALING = COLWISE_SCALING && (!ROWWISE_SCALING); - const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; - const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; - const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y; - const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X; + // # of rows covered by one wave. Equal to the # of columnwise threads in Y dimension. + constexpr int COLWISE_WAVEFRONT_SIZE = DIVUP(THREADS_PER_CHUNK, CHUNK_DIM_X); - const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * CHUNK_DIM_X; + const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; - const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; - const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + constexpr size_t THREADS_X_ROWWISE = CHUNK_DIM_X / SCALE_DIM_X; - const int thread_offset_Y = tid_Y; - const int thread_offset_X = tid_X; + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const int tid_Y_colwise = threadIdx.x / CHUNK_DIM_X; + const int tid_X_colwise = threadIdx.x % CHUNK_DIM_X; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const int thread_offset_Y_colwise = tid_Y_colwise; + const int thread_offset_X_colwise = tid_X_colwise; + + const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const int col_base_rowwise = block_offset_X + thread_offset_X_rowwise; + const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const int col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_rowwise = (col_base_rowwise >= cols); + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); - const bool col_out_of_bounds = (chunk_offset_X + thread_offset_X >= cols); + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - extern __shared__ char dshmem_unaligned[]; - const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); - const uint64_t dshmem_aligned_as_uint = - DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; - char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + const int gate_scale_idx_offset_rowwise = (cols + SCALE_DIM_X - 1) / SCALE_DIM_X; + const int gate_scale_idx_offset_colwise = cols; - const size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; - const size_t buff_elems_total = BUFFERS_NUM * buff_elems; - const size_t buff_size_aligned_in = - DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; - const size_t buff_size_aligned_out = - DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr int SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1; + __shared__ float subamax_colwise_buff[SUBAMAX_BUFF_DIM_Y][CHUNK_DIM_X]; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); @@ -337,12 +385,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t in_mem = in_act_mem + in_gate_mem; const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; + const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); const size_t out_mem = out_act_mem + out_gate_mem; - // const size_t in_transaction_size = grad_mem + in_mem; - const size_t in_transaction_size = (IS_DGATED ? 3 : 2) * buff_elems * sizeof(IType); - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned IType *in_grad_sh = reinterpret_cast(dshmem); IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); @@ -354,374 +399,493 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) OType *out_act_colwise_sh = out_act_rowwise_sh; OType *out_gate_colwise_sh = out_gate_rowwise_sh; - if constexpr (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { + if constexpr (ROWWISE_SCALING && COLWISE_SCALING) { out_act_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem); out_gate_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem + out_act_mem); } - const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); - const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); - const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); - const uint64_t *TMAP_output_act_rowwise = - reinterpret_cast(&tensor_map_output_act_rowwise); - const uint64_t *TMAP_output_gate_rowwise = - reinterpret_cast(&tensor_map_output_gate_rowwise); - const uint64_t *TMAP_output_act_colwise = - reinterpret_cast(&tensor_map_output_act_colwise); - const uint64_t *TMAP_output_gate_colwise = - reinterpret_cast(&tensor_map_output_gate_colwise); + IType *cached_act_sh = in_act_sh; // in_act_sh is used as a cache buffer for activations + IType *cached_gate_sh = in_gate_sh; // in_gate_sh is used as a cache buffer for gated values + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - __shared__ float stage_amax_sh[THREADS_PER_CHUNK_Y][CHUNK_DIM_X]; + const bool is_master_thread = (threadIdx.x == 0); // Initialize shared memory barrier with the number of threads participating in the barrier. #pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + __shared__ alignas(8) uint64_t mbar[STAGES]; - const bool is_master_thread = (threadIdx.x == 0); - - if (is_master_thread) { -// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - ptx::mbarrier_init(&mbar[it], THREADS_PER_CHUNK); - } - ptx::fence_proxy_async_shared_cta(); - } - // Syncthreads so initialized barrier is visible to all threads. - __syncthreads(); + initialize_barriers(mbar, is_master_thread); int parity = 0; - // Prefetch data of the first stage - if (is_master_thread) { - // Initiate bulk tensor copy - // Grad - if constexpr (IS_DGATED) { - ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_grad_sh[0]), - TMAP_grad_in, chunk_offset_X, chunk_offset_Y, - &mbar[0]); - } - - // Act - ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_act_sh[0]), - TMAP_in_act, chunk_offset_X, chunk_offset_Y, - &mbar[0]); - - // Gate - ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_gate_sh[0]), - TMAP_in_gate, chunk_offset_X, chunk_offset_Y, - &mbar[0]); - - // Arrive on the barrier and tell how many bytes are expected to come in. - ptx::mbarrier_arrive_expect_tx(&mbar[0], in_transaction_size); + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3(&in_grad_sh[0], &tensor_map_grad, block_offset_X, block_offset_Y, + &in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y, + &in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y, + shmem_buff_size, &mbar[0], is_master_thread); } else { - // Other threads just arrive - ptx::mbarrier_arrive(&mbar[0]); + copy_2d_to_sharedx2(&in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y, + &in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y, + shmem_buff_size, &mbar[0], is_master_thread); } #pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - const int buff = it % BUFFERS_NUM; - const int next_it = it + 1; - const size_t row_base = chunk_offset_Y + it * BUFFER_DIM_Y; - if (next_it < ITERATIONS) { - if (is_master_thread) { - const int next_buff = next_it % BUFFERS_NUM; - const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; - // Initiate bulk tensor copy - if constexpr (IS_DGATED) { - // Grad - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_grad_sh[next_buff * buff_elems]), TMAP_grad_in, - chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); - } - // Act - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_act_sh[next_buff * buff_elems]), TMAP_in_act, - chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); - // Gate - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_gate_sh[next_buff * buff_elems]), TMAP_in_gate, - chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); - - // Arrive on the barrier and tell how many bytes are expected to come in. - ptx::mbarrier_arrive_expect_tx(&mbar[next_it], in_transaction_size); + for (int stage = 0; stage < STAGES; ++stage) { + const int buff = stage % BUFFS_NUM; + const int next_stage = stage + 1; + const int stage_offset_Y = stage * BUFF_DIM_Y; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const int next_buff = next_stage % BUFFS_NUM; + const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const int global_offset_Y = block_offset_Y + next_stage_offset_Y; + const int global_offset_X = block_offset_X; + const int next_buff_offset = next_buff * BUFF_DIM; + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X, + global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act, + global_offset_X, global_offset_Y, &in_gate_sh[next_buff_offset], + &tensor_map_input_gate, global_offset_X, global_offset_Y, + shmem_buff_size, &mbar[next_stage], is_master_thread); } else { - // Other threads just arrive - ptx::mbarrier_arrive(&mbar[next_it]); + copy_2d_to_sharedx2(&in_act_sh[next_buff_offset], &tensor_map_input_act, global_offset_X, + global_offset_Y, &in_gate_sh[next_buff_offset], &tensor_map_input_gate, + global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); } } ptx::fence_proxy_async_shared_cta(); // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[it], parity); + ptx::mbarrier_wait_parity(&mbar[stage], parity); - IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; - IType *in_act_sh_curr = in_act_sh + buff * buff_elems; - IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; - OType *out_act_rowwise_sh_curr = out_act_rowwise_sh + buff * buff_elems; - OType *out_gate_rowwise_sh_curr = out_gate_rowwise_sh + buff * buff_elems; - OType *out_act_colwise_sh_curr = out_act_colwise_sh + buff * buff_elems; - OType *out_gate_colwise_sh_curr = out_gate_colwise_sh + buff * buff_elems; - - // Assuming one iteration covers exactly 32 rows - const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it; - const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y; - - float after_dact_reg[BUFFER_STAGES_NUM]; - float after_dgate_reg[BUFFER_STAGES_NUM]; - float thread_Y_mx_block_amax = 0.0f; - float thread_Y_mx_block_amax_gate = 0.0f; + if constexpr (COLWISE_SCALING) { + const int shmem_offset_base_colwise = + buff * BUFF_DIM + tid_Y_colwise * BUFF_DIM_X + tid_X_colwise; + float thread_amax_act = 0.0f; + float thread_amax_gate = 0.0f; + float after_act_colwise[BUFF_DIM_Y / COLWISE_WAVEFRONT_SIZE]; + float after_gate_colwise[BUFF_DIM_Y / COLWISE_WAVEFRONT_SIZE]; +// 1. Read/Compute elements. Find MXFP8-block AMAX #pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - const size_t row = row_base + shmem_offset_y; - const bool row_out_of_bounds = (row >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { + const int shmem_offset_colwise = + shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; - float act_elt = static_cast(in_act_sh_curr[shmem_idx]); - float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); + float act_elt = static_cast(in_act_sh[shmem_offset_colwise]); + float gate_elt = static_cast(in_gate_sh[shmem_offset_colwise]); + float after_act_elt; + float after_gate_elt; - if constexpr (IS_DGATED) { - float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); - const float x = act_elt; - float act_x; - float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); - act_x = x * s; - dact_x = x * s * (1 - s) + s; + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad_sh[shmem_offset_colwise]); + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + after_act_elt = dact_x * grad_elt * gate_elt; + after_gate_elt = act_x * grad_elt; } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + after_act_elt = ActOP(act_elt, {}) * gate_elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + after_act_elt = static_cast(static_cast(after_act_elt)); + if constexpr (IS_DGATED) { + after_gate_elt = static_cast(static_cast(after_gate_elt)); + } } - after_dact_reg[stage] = dact_x * grad_elt * gate_elt; - after_dgate_reg[stage] = act_x * grad_elt; - } else { - after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt; - } - if constexpr (USE_ROWWISE_SCALING) { + after_act_colwise[i] = after_act_elt; if constexpr (IS_DGATED) { - // dgate - float amax = fabsf(after_dgate_reg[stage]); - const float mx_block_X_amax = warp_reduce_max_broadcast(amax); - const e8m0_t biased_exponent_X = - float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); - - out_gate_rowwise_sh_curr[shmem_idx] = - static_cast(scale_reciprocal_X * after_dgate_reg[stage]); - - // Only single thread writes the computed scaling factor - if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { - const int global_scales_offset_Y = - iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; - const int global_scales_offset_X = - scales_rowwise_chunk_offset_X + (tid_X + cols) / SCALE_DIM_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent_X; - } + after_gate_colwise[i] = after_gate_elt; } - float amax = fabsf(after_dact_reg[stage]); - const float mx_block_X_amax = warp_reduce_max_broadcast(amax); - const e8m0_t biased_exponent_X = - float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); - - out_act_rowwise_sh_curr[shmem_idx] = - static_cast(scale_reciprocal_X * after_dact_reg[stage]); - - // Only single thread writes the computed scaling factor - if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { - const int global_scales_offset_Y = - iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; - const int global_scales_offset_X = scales_rowwise_chunk_offset_X + tid_X / SCALE_DIM_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent_X; + + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(after_act_elt); + if constexpr (IS_DGATED) { + cached_gate_sh[shmem_offset_colwise] = static_cast(after_gate_elt); + } } - } - if constexpr (USE_COLWISE_SCALING) { - __builtin_assume(thread_Y_mx_block_amax >= 0); - __builtin_assume(thread_Y_mx_block_amax_gate >= 0); - thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_dact_reg[stage])); - if constexpr (IS_DGATED) { - thread_Y_mx_block_amax_gate = - fmaxf(thread_Y_mx_block_amax_gate, fabsf(after_dgate_reg[stage])); + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + + if (!out_of_bounds) { + thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt)); + if constexpr (IS_DGATED) { + thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt)); + } } } - } - - if constexpr (USE_COLWISE_SCALING) { - const bool row_out_of_bounds = (row_base >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); - if constexpr (IS_DGATED) { - // Colwise max reduction of the amax element - if (tid_Y > 0) { - stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax_gate; + if constexpr (ONLY_COLWISE_SCALING) { + // Threads, whose id along Y-dim is 0, don't need to store to shared memory, + // as they manage the columwise reduction of the amax + if (tid_Y_colwise > 0) { + subamax_colwise_buff[tid_Y_colwise - 1][tid_X_colwise] = thread_amax_act; } __syncthreads(); - if (tid_Y == 0) { + if (tid_Y_colwise == 0) { #pragma unroll - for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { - thread_Y_mx_block_amax_gate = - fmaxf(thread_Y_mx_block_amax_gate, stage_amax_sh[y][tid_X]); + for (int t = 0; t < SUBAMAX_BUFF_DIM_Y; ++t) { + const float other_thread_amax = subamax_colwise_buff[t][tid_X_colwise]; + __builtin_assume(thread_amax_act >= 0); + __builtin_assume(other_thread_amax >= 0); + + thread_amax_act = fmaxf(thread_amax_act, other_thread_amax); } - stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax_gate; // write mx column-block amax + subamax_colwise_buff[0][tid_X_colwise] = thread_amax_act; } __syncthreads(); - const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax + // All threads read the reduced amax (ACT) + thread_amax_act = subamax_colwise_buff[0][tid_X_colwise]; - // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section - if constexpr (!USE_ROWWISE_SCALING) { - __builtin_assume(mx_block_Y_amax >= 0); + if constexpr (IS_DGATED) { + // Make sure the previous read of the ACT values has been completed, + // so the data are not rewritten + __syncthreads(); + if (tid_Y_colwise > 0) { + subamax_colwise_buff[tid_Y_colwise - 1][tid_X_colwise] = thread_amax_gate; + } + __syncthreads(); + if (tid_Y_colwise == 0) { +#pragma unroll + for (int t = 0; t < SUBAMAX_BUFF_DIM_Y; ++t) { + const float other_thread_amax = subamax_colwise_buff[t][tid_X_colwise]; + __builtin_assume(thread_amax_gate >= 0); + __builtin_assume(other_thread_amax >= 0); + + thread_amax_gate = fmaxf(thread_amax_gate, other_thread_amax); + } + subamax_colwise_buff[0][tid_X_colwise] = thread_amax_gate; + } + __syncthreads(); + + // All threads read the reduced amax (GATE) + thread_amax_gate = subamax_colwise_buff[0][tid_X_colwise]; } + } - const e8m0_t biased_exponent = - float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal = exp2f_rcp(biased_exponent); - - // Only single thread writes the computed scaling factor - // Also assuming one iteration covers exactly 32 rows - if ((tid_Y == 0) && !out_of_bounds) { - const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X + cols; - const int scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent_act = + ptx::float_to_e8m0(thread_amax_act * Quantized_Limits::max_norm_rcp); + + const int global_scales_offset_Y = scales_offset_Y_colwise + stage; + const int global_scales_offset_X = scales_offset_X_colwise; + const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows; + const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise; + + if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { + scales_colwise[scale_idx] = biased_exponent_act; + } + + float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act); + float block_scale_inverse_gate; + + if constexpr (IS_DGATED) { + const e8m0_t biased_exponent_gate = + ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); + // const int scale_idx_gate = scale_idx + scale_stride_colwise / 2; + const int scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; + if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { + scales_colwise[scale_idx_gate] = biased_exponent_gate; } + block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate); + } +// 3. Scale elements #pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - out_gate_colwise_sh_curr[shmem_idx] = - static_cast(scale_reciprocal * after_dgate_reg[stage]); + for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { + const int shmem_offset_elt = + shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; + if constexpr (IS_DGATED) { + OType2 out_pair; + ptx::floatx2 in_pair = {after_act_colwise[i], after_gate_colwise[i]}; + const ptx::floatx2 block_scale_inverse_2x_pair = {block_scale_inverse_act, + block_scale_inverse_gate}; + ptx::mul_cvt_2x(out_pair, in_pair, block_scale_inverse_2x_pair); + out_act_colwise_sh[shmem_offset_elt] = out_pair.x; + out_gate_colwise_sh[shmem_offset_elt] = out_pair.y; + } else { + const float scaled_out_act = block_scale_inverse_act * after_act_colwise[i]; + out_act_colwise_sh[shmem_offset_elt] = static_cast(scaled_out_act); } } - // Colwise max reduction of the amax element - if (tid_Y > 0) { - stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax; - } - __syncthreads(); - if (tid_Y == 0) { + } + + if constexpr (ROWWISE_SCALING) { + const int shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + + float thread_amax_act = 0.0f; + float thread_amax_gate = 0.0f; + + Vec in_cached_act[WAVES]; + Vec in_cached_gate[WAVES]; + + float after_act_rowwise[SCALE_DIM_X]; + float after_gate_rowwise[SCALE_DIM_X]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x_act = {static_cast(0.0f), static_cast(0.0f)}; + IType2 thread_amax_2x_gate = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached_act[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + if constexpr (IS_DGATED) { + in_cached_gate[w].load_from(&cached_gate_sh[shmem_offset_rowwise]); + } + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { #pragma unroll - for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { - thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, stage_amax_sh[y][tid_X]); + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax_act = fmaxf(thread_amax_act, fabsf(in_cached_act[w].data.elt[e])); + if constexpr (IS_DGATED) { + thread_amax_gate = fmaxf(thread_amax_gate, fabsf(in_cached_gate[w].data.elt[e])); + } + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x_act = {in_cached_act[w].data.elt[e], + in_cached_act[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x_act, thread_amax_2x_act, in_cached_2x_act); + if constexpr (IS_DGATED) { + const IType2 in_cached_2x_gate = {in_cached_gate[w].data.elt[e], + in_cached_gate[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x_gate, thread_amax_2x_gate, in_cached_2x_gate); + } + } + } + } } - stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax; // write mx column-block amax - } - __syncthreads(); + if constexpr (!std::is_same_v) { + thread_amax_act = static_cast( + __hmax(__habs(thread_amax_2x_act.x), __habs(thread_amax_2x_act.y))); + if constexpr (IS_DGATED) { + thread_amax_gate = static_cast( + __hmax(__habs(thread_amax_2x_gate.x), __habs(thread_amax_2x_gate.y))); + } + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax + Vec in_grad; + Vec in_act; + Vec in_gate; - // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section - if constexpr (!USE_ROWWISE_SCALING) { - __builtin_assume(mx_block_Y_amax >= 0); + in_act.load_from(&in_act_sh[shmem_offset_rowwise]); + in_gate.load_from(&in_gate_sh[shmem_offset_rowwise]); + if constexpr (IS_DGATED) { + in_grad.load_from(&in_grad_sh[shmem_offset_rowwise]); + } + +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + + float act_elt = static_cast(in_act.data.elt[e]); + float gate_elt = static_cast(in_gate.data.elt[e]); + float after_act_elt; + float after_gate_elt; + + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad.data.elt[e]); + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + after_act_elt = dact_x * grad_elt * gate_elt; + after_gate_elt = act_x * grad_elt; + after_act_rowwise[j] = after_act_elt; + after_gate_rowwise[j] = after_gate_elt; + } else { + after_act_elt = ActOP(act_elt, {}) * gate_elt; + after_act_rowwise[j] = after_act_elt; + } + + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + after_act_elt = static_cast(static_cast(after_act_elt)); + if constexpr (IS_DGATED) { + after_gate_elt = static_cast(static_cast(after_gate_elt)); + } + } + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt)); + if constexpr (IS_DGATED) { + thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt)); + } + } + } + } } - const e8m0_t biased_exponent = - float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal = exp2f_rcp(biased_exponent); - - // Only single thread writes the computed scaling factor - // Also assuming one iteration covers exactly 32 rows - if ((tid_Y == 0) && !out_of_bounds) { - const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent_act = + ptx::float_to_e8m0(thread_amax_act * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows; + const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise; + if (!out_of_bounds_rowwise) { + scales_rowwise[scale_idx] = biased_exponent_act; } + const float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act); + const ptx::floatx2 block_scale_inverse_2x_act = {block_scale_inverse_act, + block_scale_inverse_act}; + + float block_scale_inverse_gate; + ptx::floatx2 block_scale_inverse_2x_gate; + if constexpr (IS_DGATED) { + const e8m0_t biased_exponent_gate = + ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); + const int scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise; + if (!out_of_bounds_rowwise) { + scales_rowwise[scale_idx_gate] = biased_exponent_gate; + } + block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate); + block_scale_inverse_2x_gate = {block_scale_inverse_gate, block_scale_inverse_gate}; + } + +// 3. Scale elements #pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - out_act_colwise_sh_curr[shmem_idx] = - static_cast(scale_reciprocal * after_dact_reg[stage]); + for (int w = 0; w < WAVES; ++w) { + Vec out_act; + Vec out_gate; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in_act; + OType2 &out_act_pair = reinterpret_cast(out_act.data.elt[e]); + + if constexpr (IS_CACHED_ACT_OP) { + in_act.x = in_cached_act[w].data.elt[2 * e]; + in_act.y = in_cached_act[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in_act.x = after_act_rowwise[j]; + in_act.y = after_act_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_act_pair, in_act, block_scale_inverse_2x_act); + + if constexpr (IS_DGATED) { + IType2 in_gate; + OType2 &out_gate_pair = reinterpret_cast(out_gate.data.elt[e]); + + if constexpr (IS_CACHED_ACT_OP) { + in_gate.x = in_cached_gate[w].data.elt[2 * e]; + in_gate.y = in_cached_gate[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in_gate.x = after_gate_rowwise[j]; + in_gate.y = after_gate_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate); + } + } + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]); + if constexpr (IS_DGATED) { + out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]); + } } - } // endif USE_COLWISE_SCALING + } - // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) + // Wait for shared memory writes to be visible to TMA engine. ptx::fence_proxy_async_shared_cta(); __syncthreads(); // After syncthreads, writes by all threads are visible to TMA engine. // Initiate TMA transfer to copy shared memory to global memory if (is_master_thread) { - const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; - // dGeLU - if constexpr (USE_ROWWISE_SCALING) { + if constexpr (ROWWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_act_rowwise_sh_curr)); - + reinterpret_cast(&tensor_map_output_act_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_act_rowwise_sh[buff_offset])); if constexpr (IS_DGATED) { - // dGate ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_gate_rowwise, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_gate_rowwise_sh_curr)); + reinterpret_cast(&tensor_map_output_gate_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_gate_rowwise_sh[buff_offset])); } } - - // dGeLU - if constexpr (USE_COLWISE_SCALING) { + if constexpr (COLWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_act_colwise, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_act_colwise_sh_curr)); - + reinterpret_cast(&tensor_map_output_act_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_act_colwise_sh[buff_offset])); if constexpr (IS_DGATED) { - // dGate ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_gate_colwise, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_gate_colwise_sh_curr)); + reinterpret_cast(&tensor_map_output_gate_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_gate_colwise_sh[buff_offset])); } } // Create a "bulk async-group" out of the previous bulk copy operation. ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); } } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - // Destroy the barriers. This invalidates the memory region of the barrier. - // If further computations were to take place in the kernel, this allows the - // memory location of the shared memory barrier to be reused. - if (is_master_thread) { -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - ptx::mbarrier_invalid(&mbar[it]); - } - } + parity ^= 1; + destroy_barriers(mbar, is_master_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +} // namespace mxfp8_kernel template @@ -779,17 +943,16 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; const size_t buff_size_aligned_in = - DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); const size_t buff_size_aligned_out = - DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); const size_t in_act_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in; const size_t out_act_mem = buff_size_aligned_out; const size_t out_gate_mem = buff_size_aligned_out; - // const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); - const size_t shmem_size = ALIGNMENT_SIZE + grad_mem + (in_act_mem + in_gate_mem) + - (out_act_mem + out_gate_mem); // + mbar_mem; + const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + + (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; cudaFuncSetAttribute( cast_fp8_gated_kernel, @@ -802,7 +965,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu cols);); // NOLINT(*) ); // NOLINT(*) } -#endif //#ifdef __HIP_PLATFORM_AMD__ +#endif //#ifndef __HIP_PLATFORM_AMD__ template @@ -818,16 +981,47 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); } - // TODO: Make more general - const size_t scale_dim_X_rowwise = USE_ROWWISE_SCALING ? 32 : 1; - const size_t scale_dim_Y_colwise = USE_COLWISE_SCALING ? 32 : 1; +#ifndef __HIP_PLATFORM_AMD__ + ScalingType scaling_type; + if (USE_ROWWISE_SCALING && (!USE_COLWISE_SCALING)) { + scaling_type = ScalingType::ROWWISE; + } else if ((!USE_ROWWISE_SCALING) && USE_COLWISE_SCALING) { + scaling_type = ScalingType::COLWISE; + } else if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { + scaling_type = ScalingType::BIDIMENSIONAL; + } +#endif const size_t rows = gated_input.flat_first_dim(); const size_t cols = gated_input.flat_last_dim() / 2; const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; +#ifdef __HIP_PLATFORM_AMD__ + constexpr size_t TMA_SHMEM_ALIGNMENT = ALIGNMENT_SIZE; + + constexpr size_t BUFF_DIM_Y = BUFFER_DIM_Y; + constexpr size_t BUFF_DIM_X = BUFFER_DIM_X; + constexpr size_t BUFFS_NUM = BUFFERS_NUM; + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); +#else + constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; + constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; + constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; + + const size_t blocks_Y = DIVUP(rows, mxfp8_kernel::CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, mxfp8_kernel::CHUNK_DIM_X); + + constexpr size_t THREADS_PER_CHUNK_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_COLWISE; + constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_NON_COLWISE; + const size_t THREADS_PER_CHUNK = (scaling_type == ScalingType::COLWISE) + ? THREADS_PER_CHUNK_COLWISE + : THREADS_PER_CHUNK_NON_COLWISE; +#endif + + const dim3 grid(blocks_X, blocks_Y); + const dim3 block_size(THREADS_PER_CHUNK); size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1; size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1; @@ -837,116 +1031,162 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out e8m0_t *const scales_colwise_ptr = USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; - const dim3 block_dim(THREADS_PER_CHUNK); - const dim3 grid_dim(blocks_X, blocks_Y); + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - scale_dim_Y_colwise, SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - scale_dim_X_rowwise, SCALE_DIM_X, - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - gated_input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, #ifdef __HIP_PLATFORM_AMD__ - TRANSFORMER_ENGINE_SWITCH_CONDITION( - !(cols % (32 * sizeof(IType))), IS_ALIGNED, - const IType *tensor_map_grad = IS_DGATED ? reinterpret_cast(grad.data.dptr) : nullptr; - const IType *tensor_map_input_act = reinterpret_cast(gated_input.data.dptr); - const IType *tensor_map_input_gate = reinterpret_cast(gated_input.data.dptr) + cols; - OType *tensor_map_output_act_rowwise = USE_ROWWISE_SCALING ? reinterpret_cast(output->data.dptr) : nullptr; - OType *tensor_map_output_gate_rowwise = USE_ROWWISE_SCALING ? reinterpret_cast(output->data.dptr) + cols : nullptr; - OType *tensor_map_output_act_colwise = USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_data.dptr) : nullptr; - OType *tensor_map_output_gate_colwise = USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_data.dptr) + cols : nullptr; + const IType *tensor_map_grad = IS_DGATED ? reinterpret_cast(grad.data.dptr) : nullptr; + const IType *tensor_map_input_act = reinterpret_cast(gated_input.data.dptr); + const IType *tensor_map_input_gate = reinterpret_cast(gated_input.data.dptr) + cols; + OType *tensor_map_output_act_rowwise = USE_ROWWISE_SCALING ? reinterpret_cast(output->data.dptr) : nullptr; + OType *tensor_map_output_gate_rowwise = USE_ROWWISE_SCALING ? reinterpret_cast(output->data.dptr) + cols : nullptr; + OType *tensor_map_output_act_colwise = USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_data.dptr) : nullptr; + OType *tensor_map_output_gate_colwise = USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_data.dptr) + cols : nullptr; #else // #ifdef __HIP_PLATFORM_AMD__ - alignas(64) CUtensorMap tensor_map_grad{}; - alignas(64) CUtensorMap tensor_map_input_act{}; - alignas(64) CUtensorMap tensor_map_input_gate{}; - alignas(64) CUtensorMap tensor_map_output_act_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_gate_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_act_colwise{}; - alignas(64) CUtensorMap tensor_map_output_gate_colwise{}; - - if constexpr (IS_DGATED) { - create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols, 0, typeToNumBits(gated_input.dtype())); - } - - const uint32_t tensor_stride_elems = output_cols; - create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, - typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols, - typeToNumBits(gated_input.dtype())); - - if (USE_ROWWISE_SCALING) { - create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0, - typeToNumBits(output->dtype())); - create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols, - typeToNumBits(output->dtype())); - } - - if (USE_COLWISE_SCALING) { - create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, - rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, - 0, typeToNumBits(output->dtype())); - create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, - rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, - cols, typeToNumBits(output->dtype())); - } + alignas(64) CUtensorMap tensor_map_grad{}; + alignas(64) CUtensorMap tensor_map_input_act{}; + alignas(64) CUtensorMap tensor_map_input_gate{}; + alignas(64) CUtensorMap tensor_map_output_act_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_gate_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_act_colwise{}; + alignas(64) CUtensorMap tensor_map_output_gate_colwise{}; #endif // #ifdef __HIP_PLATFORM_AMD__ - const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; - const size_t buff_size_aligned_in = - DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; - const size_t buff_size_aligned_out = - DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + +#ifndef __HIP_PLATFORM_AMD__ + if constexpr (IS_DGATED) { + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, input_type_bit_size); + } - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; + const uint32_t tensor_stride_elems = output_cols; + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols * 2, 0, input_type_bit_size); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols * 2, cols, input_type_bit_size); + + if (USE_ROWWISE_SCALING) { + create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0, + output_type_bit_size); + create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols, + output_type_bit_size); + } - const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; - size_t out_mem = out_act_mem + out_gate_mem; - if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } + if (USE_COLWISE_SCALING) { + create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0, + output_type_bit_size); + create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, rows, + cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols, + output_type_bit_size); + } +#endif // #ifdef __HIP_PLATFORM_AMD__ - // const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); - // const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem + mbar_mem; + const size_t buff_elems_total = BUFFS_NUM * BUFF_DIM_Y * BUFF_DIM_X; + const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + const size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + const size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem; + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - (const void*)cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + const size_t out_gate_mem = buff_size_aligned_out; +#else + const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); +#endif // #ifdef __HIP_PLATFORM_AMD__ + size_t out_mem = out_act_mem + out_gate_mem; + if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } + + const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - cast_mxfp8_gated_kernel - <<>>( + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (USE_COLWISE_SCALING ? 32 : 1), SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (USE_ROWWISE_SCALING ? 32 : 1), SCALE_DIM_X, + TRANSFORMER_ENGINE_SWITCH_CONDITION(!(cols % (32 * sizeof(IType))), IS_ALIGNED, { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + (const void*)cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + }))); // NOLINT(*) +#else + switch (scaling_type) { + case ScalingType::ROWWISE: + cudaFuncSetAttribute( + mxfp8_kernel::cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + + mxfp8_kernel::cast_mxfp8_gated_kernel + <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise);); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) -#ifdef __HIP_PLATFORM_AMD__ - ); // NOLINT(*) + scale_stride_colwise); + break; + case ScalingType::COLWISE: + cudaFuncSetAttribute( + mxfp8_kernel::cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + + mxfp8_kernel::cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + case ScalingType::BIDIMENSIONAL: + cudaFuncSetAttribute( + mxfp8_kernel::cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + + mxfp8_kernel::cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + } #endif + ); // NOLINT(*) + ); // NOLINT(*) } template diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 468d31690..ae38e74cc 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -36,36 +36,25 @@ namespace transformer_engine { #ifndef __HIP_PLATFORM_AMD__ -constexpr size_t MXFP8_CHUNK_DIM_Y = 64; -constexpr size_t MXFP8_CHUNK_DIM_X = 64; -constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; -constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1; -constexpr size_t MXFP8_CHUNKS_PER_BLOCK = MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X; -constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; -constexpr size_t MXFP8_BUFFERS_NUM = 2; -constexpr size_t MXFP8_PREFETCH_BUFFERS_NUM = 1; -static_assert(MXFP8_PREFETCH_BUFFERS_NUM < MXFP8_BUFFERS_NUM); - -constexpr size_t ELEMS_PER_THREAD = 16; -constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported -constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; // 64 -constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; // 32 -constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; // 64 - -constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = - MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; // 4 = 64 / 16 -constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = - MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 16 = 64 / 4 -constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64 -constexpr size_t MXFP8_BUFF_STAGES_NUM = - MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16 -constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32 -static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM); +namespace mxfp8_kernel { + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 32; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t PACK_SIZE = 4; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 template -__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) + float (*OP)(float, const ParamOP &), typename IType, typename OType, bool ROWWISE_SCALING, + bool COLWISE_SCALING, size_t CHUNK_DIM_Y, size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK> +__global__ void __launch_bounds__(THREADS_PER_CHUNK) cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_act_input, const __grid_constant__ CUtensorMap tensor_map_output_rowwise, @@ -75,201 +64,341 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { - if (noop != nullptr && noop[0] == 1.0f) return; + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; + + if constexpr (NO_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } } + constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + + constexpr size_t BUFF_DIM_Y = THREADS_Y; + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; + static_assert(BUFF_DIM_Y == 32); + + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + static_assert(STAGES >= 1); + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * CHUNK_DIM_X; + const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X; + const int tid_X_rowwise = threadIdx.x % THREADS_X; + const int tid_Y_colwise = 0; + const int tid_X_colwise = threadIdx.x; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const int thread_offset_Y_colwise = tid_Y_colwise; + const int thread_offset_X_colwise = tid_X_colwise; + + const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const int col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); + OType *out_rowwise_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; - constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; - constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; - - constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y; // 2 = 64 / 32 - constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // 64 = 64 / 1 - constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y = - SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 - constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = - SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 - - constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32 - constexpr size_t SCALES_COLWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X; // 64 = 64 / 1 - constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = - SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 - constexpr size_t SCALES_COLWISE_PER_BLOCK_X = - SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 - - constexpr size_t THREADS_PER_SCALE_X_ROWWISE = - DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 - constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 - - const int block_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y; - const int block_offset_X = blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X; - const int scales_rowwise_block_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y; - const int scales_rowwise_block_offset_X = blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X; - const int scales_colwise_block_offset_Y = blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y; - const int scales_colwise_block_offset_X = blockIdx.x * SCALES_COLWISE_PER_BLOCK_X; - - const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; - const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; - // const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE; - const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; - - const int thread_offset_Y = tid_rowwise_Y; - const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; - // const int thread_offset_X_colwise = tid_colwise_X; - - const int dbias_rowwise_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y + tid_rowwise_Y; - const int dbias_rowwise_block_offset_X = - blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + thread_offset_X_rowwise; - const int dbias_colwise_offset_Y = blockIdx.y; - const int dbias_colwise_block_offset_X = - blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + tid_colwise_X; - const int dbias_stride = cols; + const bool is_master_thread = (threadIdx.x == 0); - Vec partial_dbias_rowwise[MXFP8_CHUNKS_PER_BLOCK_X]; - float partial_dbias_colwise[MXFP8_CHUNKS_PER_BLOCK_X]; + float partial_dbias_colwise = 0.0f; + float thread_dbias_rowwise[SCALE_DIM_X]; if constexpr (IS_DBIAS) { - if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { -#pragma unroll - for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { - partial_dbias_rowwise[i].clear(); - } - } else { #pragma unroll - for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { - partial_dbias_colwise[i] = 0; - } + for (int j = 0; j < SCALE_DIM_X; ++j) { + thread_dbias_rowwise[j] = 0.0f; } } - // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned - __shared__ alignas(128) IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - __shared__ alignas(128) IType act_in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - __shared__ alignas(128) - OType out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - __shared__ alignas(128) - OType out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - - constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - - float block_amax = 0; + float block_amax = 0.0f; // Initialize shared memory barrier with the number of threads participating in the barrier. #pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[MXFP8_ITERATIONS]; + __shared__ alignas(8) uint64_t mbar[STAGES]; - initialize_barriers(mbar, is_master_thread); + initialize_barriers(mbar, is_master_thread); int parity = 0; -#pragma unroll - for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; ++chunk) { - const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X; - const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X; - const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; - const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], + &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + } - const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; - const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int buff = stage % BUFFS_NUM; + const int next_stage = stage + 1; + const int stage_offset_Y = stage * BUFF_DIM_Y; - const int scales_rowwise_chunk_offset_Y = - scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y; - const int scales_rowwise_chunk_offset_X = - scales_rowwise_block_offset_X + chunk_X * SCALES_ROWWISE_PER_CHUNK_X; - const int scales_colwise_chunk_offset_Y = - scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y; - const int scales_colwise_chunk_offset_X = - scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X; + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); -#pragma unroll - for (int prefetch_buff = 0; prefetch_buff < MXFP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { - const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y; - const int chunk_stage_offset_X = chunk_offset_X; + const int next_buff = next_stage % BUFFS_NUM; + const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const int global_offset_Y = block_offset_Y + next_stage_offset_Y; + const int global_offset_X = block_offset_X; + const int next_buff_offset = next_buff * BUFF_DIM; if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, - chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, - &mbar[prefetch_buff], is_master_thread); + copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, + global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); } else { - copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], - is_master_thread); + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); } } + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], parity); + + float thread_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const int shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + thread_amax = 0.0f; + float in_compute_colwise[BUFF_DIM_Y]; + IType in_colwise_IType[BUFF_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType thread_amax_f16 = static_cast(0.0f); #pragma unroll - for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) { - const int buff = iter % MXFP8_BUFFERS_NUM; - const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM; - const size_t row_base = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; - - if (next_iter < MXFP8_ITERATIONS) { - const int next_buff = next_iter % MXFP8_BUFFERS_NUM; - const int chunk_it_offset_y = chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, - chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, - &mbar[next_iter], is_master_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); } - } + thread_amax = static_cast(thread_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - ptx::fence_proxy_async_shared_cta(); + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + partial_dbias_colwise += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[iter], parity); + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - if constexpr (USE_ROWWISE_SCALING) { - Vec in; - Vec act_in; - Vec out_c; + const int global_scales_offset_Y = scales_offset_Y_colwise + stage; + const int global_scales_offset_X = scales_offset_X_colwise; + const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; - const int iteration_scale_rowwise_offset_Y = - scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; +// 3. Scale elements #pragma unroll - for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X_rowwise; + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; - const size_t row = row_base + shmem_offset_y; - const bool row_out_of_bounds = (row >= rows); + const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } - in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); - if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[buff][shmem_offset_y][shmem_offset_x]); - } + if constexpr (ROWWISE_SCALING) { + const int shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + thread_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; - float thread_amax = 0; - float in_compute[ELEMS_PER_THREAD]; + // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY + Vec in_IType[WAVES]; + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { #pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; ++j) { - const bool col_out_of_bounds = (dbias_rowwise_offset_X + j >= cols); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + Vec in; + Vec act_in; - float elt = static_cast(in.data.elt[j]); + in.load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); + } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); if constexpr (IS_ACT) { elt = OP(elt, {}); } if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[j]); + float act_in_elt = static_cast(act_in.data.elt[e]); elt *= OP(act_in_elt, {}); } - if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { - if (!out_of_bounds) { - partial_dbias_rowwise[chunk_X].data.elt[j] += elt; - } - } - in_compute[j] = elt; - if constexpr (IS_ACT || IS_DACT) { + // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); if (!out_of_bounds) { thread_amax = fmaxf(thread_amax, fabsf(elt)); } @@ -277,196 +406,141 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) // If no activation, elt is 0 so we can safely do this thread_amax = fmaxf(thread_amax, fabsf(elt)); } + in_compute_rowwise[j] = elt; } - - __builtin_assume(block_amax >= 0); - __builtin_assume(thread_amax >= 0); - block_amax = fmaxf(block_amax, thread_amax); - - const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); - const e8m0_t biased_exponent = - float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); - - // Only single thread writes the computed scaling factor - if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { - const int global_scales_offset_Y = - iteration_scale_rowwise_offset_Y + stage_offset_Y + tid_rowwise_Y; - const int global_scales_offset_X = - scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; - const int scale_idx = - global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent; - } - - const float block_scale_inverse = exp2f_rcp(biased_exponent); - -#pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; ++j) { - out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); - } - out_c.store_to(&out_rowwise_sh[buff][shmem_offset_y][shmem_offset_x]); } } - if constexpr (USE_COLWISE_SCALING) { - const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); - float in_compute[SCALE_DIM_Y]; + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent; - float amax = 0; -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - const size_t row = row_base + i; - const bool row_out_of_bounds = (row >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - float elt = static_cast(in_sh[buff][i][tid_colwise_X]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[buff][i][tid_colwise_X]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - if (!out_of_bounds) { - partial_dbias_colwise[chunk_X] += elt; - } - } - in_compute[i] = elt; - if constexpr (IS_ACT || IS_DACT) { - if (!out_of_bounds) { - amax = fmaxf(amax, fabsf(elt)); - } + // 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; } else { - // If no activation, elt is 0 so we can safely do this - amax = fmaxf(amax, fabsf(elt)); + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); } + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out.store_to(&out_rowwise_sh[shmem_offset_rowwise]); + } + } - __builtin_assume(block_amax >= 0); - __builtin_assume(amax >= 0); - block_amax = fmaxf(block_amax, amax); - - const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); - const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. - const float block_scale_inverse = exp2f_rcp(biased_exponent); -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - out_colwise_sh[buff][i][tid_colwise_X] = - static_cast(in_compute[i] * block_scale_inverse); - } + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_sh[buff_offset])); } - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const int chunk_it_offset_y = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; - if constexpr (USE_ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), chunk_it_offset_x, - chunk_it_offset_y, reinterpret_cast(&out_rowwise_sh[buff])); - } - if constexpr (USE_COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), chunk_it_offset_x, - chunk_it_offset_y, reinterpret_cast(&out_colwise_sh[buff])); - } - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_sh[buff_offset])); } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - parity ^= 1; + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } } - if constexpr (IS_DBIAS) { - if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { - constexpr size_t CZ = MXFP8_CHUNKS_PER_BLOCK_X; - constexpr size_t Y = THREADS_PER_CHUNK_Y_ROWWISE - 1; - constexpr size_t X = THREADS_PER_CHUNK_X_ROWWISE; - __shared__ float shmem_partial_dbias_rowwise[CZ][Y][X][ELEMS_PER_THREAD]; - - if (tid_rowwise_Y > 0) { -#pragma unroll - for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { - partial_dbias_rowwise[c].store_to( - &shmem_partial_dbias_rowwise[c][tid_rowwise_Y - 1][tid_rowwise_X]); - } - } - __syncthreads(); + parity ^= 1; - if (tid_rowwise_Y == 0) { -#pragma unroll - for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { - Vec other_row_dbias; - const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + c * MXFP8_CHUNK_DIM_X; - const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_offset_X; + if constexpr (IS_DBIAS) { + float thread_partial_dbias = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_partial_dbias = partial_dbias_colwise; + } else { + // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] + // HEIGHT = THREADS_Y + // WIDTH = THREADS_X * (SCALE_DIM_X + 1) + // Added extra 1-element padding per thread_X to reduce bank conflicts + float *partial_dbias_rowwise = reinterpret_cast(dshmem); - const int left_bound = dbias_rowwise_offset_X; - const int right_bound = dbias_rowwise_offset_X + ELEMS_PER_THREAD - 1; + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + const int shmem_thread_offset = + tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); #pragma unroll - for (int i = 0; i < Y; ++i) { - other_row_dbias.load_from(&shmem_partial_dbias_rowwise[c][i][tid_rowwise_X]); + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; #pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; ++j) { - partial_dbias_rowwise[c].data.elt[j] += other_row_dbias.data.elt[j]; - } - } - - // Vectorized store when all elements are inside the boundaries - if (right_bound < cols) { - partial_dbias_rowwise[c].store_to(&dbias_workspace[dbias_offset]); - } else if (left_bound < cols && right_bound >= cols) { - // Element-by-element store when some elements cross the boundaries - const int in_bound_elts_count = cols - left_bound; - partial_dbias_rowwise[c].store_to_elts(&dbias_workspace[dbias_offset], 0, - in_bound_elts_count); - } + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + const int shmem_elt_idx = swizzled_group_offset + e; + partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; } } - } else { + __syncthreads(); #pragma unroll - for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { - const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + i * MXFP8_CHUNK_DIM_X; - const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_offset_X; - const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); - if (!col_out_of_bounds) { - dbias_workspace[dbias_offset] = partial_dbias_colwise[i]; - } + for (int i = 0; i < THREADS_Y; ++i) { + // Add extra element offset per MXFP8 scaling block [1x32] + const int scaling_block = threadIdx.x / SCALE_DIM_X; + thread_partial_dbias += + partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; } } + const int dbias_stride = cols; + const int dbias_offset_Y = blockIdx.y; + const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); + if (!col_out_of_bounds_dbias) { + dbias_workspace[dbias_idx] = thread_partial_dbias; + } } if (amax_ptr != nullptr) { const int warp_id = threadIdx.x / THREADS_PER_WARP; // Reduce the amax over the block - block_amax = reduce_max(block_amax, warp_id); + block_amax = reduce_max(block_amax, warp_id); } if (is_master_thread && amax_ptr != nullptr) { atomicMaxFloat(amax_ptr, block_amax); } - destroy_barriers(mbar, is_master_thread); + destroy_barriers(mbar, is_master_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +} // namespace mxfp8_kernel constexpr size_t FP8_CHUNK_DIM_Y = 128; constexpr size_t FP8_CHUNK_DIM_X = 128; @@ -515,9 +589,12 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(128) IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(128) IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(128) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; @@ -686,8 +763,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(128) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; - __shared__ alignas(128) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; constexpr int transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; constexpr int transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; @@ -932,6 +1009,9 @@ template has_data(); bool use_colwise_scaling = output->has_columnwise_data(); #ifndef __HIP_PLATFORM_AMD__ @@ -950,16 +1030,36 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, } CheckNoopTensor(*noop, "cast_noop"); - // TODO: Make more general - const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; - const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; - const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); + + constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); + +#ifndef __HIP_PLATFORM_AMD__ + constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; + constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; + constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64; + + constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + constexpr size_t BUFF_DIM_Y = THREADS_Y; + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); +#else const size_t chunks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); const size_t chunks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y); const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X); +#endif // #ifndef __HIP_PLATFORM_AMD__ + + const dim3 grid(blocks_X, blocks_Y); +#ifndef __HIP_PLATFORM_AMD__ + const size_t block_size = THREADS_PER_CHUNK; +#else + const size_t block_size = MXFP8_THREADS_PER_CHUNK; +#endif // #ifndef __HIP_PLATFORM_AMD__ const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; const size_t scale_stride_colwise = @@ -972,6 +1072,17 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const size_t dbias_rows = blocks_Y; const size_t dbias_cols = cols; +#ifndef __HIP_PLATFORM_AMD__ + ScalingType scaling_type; + if (use_rowwise_scaling && (!use_colwise_scaling)) { + scaling_type = ScalingType::ROWWISE; + } else if ((!use_rowwise_scaling) && use_colwise_scaling) { + scaling_type = ScalingType::COLWISE; + } else if (use_rowwise_scaling && use_colwise_scaling) { + scaling_type = ScalingType::BIDIMENSIONAL; + } +#endif // #ifndef __HIP_PLATFORM_AMD__ + if constexpr (IS_DBIAS) { NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); @@ -986,76 +1097,125 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); - const dim3 block(MXFP8_THREADS_PER_CHUNK); - const dim3 grid(blocks_X, blocks_Y); - - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - scale_dim_Y_colwise, SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - scale_dim_X_rowwise, SCALE_DIM_X, - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, #ifdef __HIP_PLATFORM_AMD__ - TRANSFORMER_ENGINE_SWITCH_CONDITION( - !(cols % (32 * sizeof(IType))), IS_ALIGNED, - cast_mxfp8_2D_kernel<<>>( - reinterpret_cast(input.data.dptr), - (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->columnwise_data.dptr), - scales_rowwise_ptr, scales_colwise_ptr, - reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, - rows, cols, scale_stride_rowwise, scale_stride_colwise); + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_colwise_scaling ? 32 : 1), SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_rowwise_scaling ? 32 : 1), SCALE_DIM_X, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + !(cols % (32 * sizeof(IType))), IS_ALIGNED, + cast_mxfp8_2D_kernel<<>>( + reinterpret_cast(input.data.dptr), + (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + scales_rowwise_ptr, scales_colwise_ptr, + reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, + rows, cols, scale_stride_rowwise, scale_stride_colwise); + ))); // NOLINT(*) #else // #ifdef __HIP_PLATFORM_AMD__ + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y, - MXFP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype())); + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, input_type_bit_size); - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, - MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, - typeToNumBits(input.dtype())); - } + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, input_type_bit_size); + } - if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, - MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, - typeToNumBits(output->dtype())); - } + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, output_type_bit_size); + } - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, - cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, - typeToNumBits(output->dtype())); - } + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size); + } - cast_mxfp8_2D_kernel<<>>( + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = mxfp8_kernel::BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + switch (scaling_type) { + case ScalingType::ROWWISE: + cudaFuncSetAttribute( + cast_mxfp8_2D_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + + cast_mxfp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + case ScalingType::COLWISE: + cudaFuncSetAttribute( + cast_mxfp8_2D_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + + cast_mxfp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + case ScalingType::BIDIMENSIONAL: + cudaFuncSetAttribute( + cast_mxfp8_2D_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + + cast_mxfp8_2D_kernel + <<>>( tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, - reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, - rows, cols, scale_stride_rowwise, scale_stride_colwise); + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + } #endif // #ifdef __HIP_PLATFORM_AMD__ - if constexpr (IS_DBIAS) { - reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - - }); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) -#ifdef __HIP_PLATFORM_AMD__ - ); // NOLINT(*) -#endif + if constexpr (IS_DBIAS) { + reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) } namespace detail { @@ -1171,8 +1331,8 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons case NVTE_DELAYED_TENSOR_SCALING: { if (!IS_DBIAS && !IS_DACT) { if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_gmem_alignment) && - is_aligned_tensor_data(*output, TMA_gmem_alignment)) { + is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) { // Aligned AND FP8 cast_fp8_1D(input, output, stream); } else { @@ -1181,9 +1341,9 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons } } else if (!IS_DBIAS && IS_DACT) { if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_gmem_alignment) && - is_aligned_tensor_data(*output, TMA_gmem_alignment) && - is_aligned_tensor_data(*act_input, TMA_gmem_alignment)) { + is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) { // Aligned AND FP8 (+dAct) cast_fp8_2D(input, act_input, output, dbias, workspace, stream); diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index 8f0a9730b..4177cc094 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -76,8 +76,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // const int thread_offset_X_colwise = tid_colwise_X; // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned - __shared__ alignas(128) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; - __shared__ alignas(128) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM; constexpr int transaction_size = shmem_buff_size; @@ -158,7 +158,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X; const e8m0_t biased_exponent = scales_ptr[scale_idx]; - const float block_scale = exp2f(static_cast(biased_exponent) - FP32_EXPONENT_BIAS); + const float block_scale = ptx::exp2f(biased_exponent); if constexpr (USE_ROWWISE_SCALING) { Vec in; diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 55bc247f7..c920bd7e4 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -104,6 +106,56 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3 #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +constexpr uint32_t FP32_MANTISSA_BITS = 23; +constexpr uint32_t FP32_EXPONENT_BIAS = 127; + +__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { + return (biased_exp == 0) ? 1 + : __int_as_float((254 - biased_exp) + << FP32_MANTISSA_BITS); // 127 - (biased_exp - 127) +} + +__device__ __forceinline__ float exp2f(e8m0_t biased_exp) { + return __int_as_float(biased_exp << FP32_MANTISSA_BITS); +} + +__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { +#ifdef __HIP_PLATFORM_AMD__ +#define __CUDA_ARCH_HAS_FEATURE__(x) 0 +#endif //__HIP_PLATFORM_AMD__ +#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ + (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) + uint16_t out; + asm volatile( + "{\n" + "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" + "}" + : "=h"(out) + : "f"(val)); + return *reinterpret_cast(&out); +#else + // TODO: nan/inf needs to be set for any value + // of nan/inf in input not just amax. + if (isnan(val)) { + return 0xFF; + } + if (isinf(val)) { + return 0xFE; + } + if (val == 0.0f) { + return 0x00; + } + uint32_t val_u32 = *reinterpret_cast(&val); + e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); + uint32_t mantissa = val_u32 & 0x7FFFFF; + // Round up exponent and deal with satfinite. + if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { + ++exponent; + } + return exponent; +#endif +} + #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor @@ -169,6 +221,159 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() { asm volatile("fence.proxy.async.shared::cta;"); } +template +struct alignas(2 * sizeof(T)) FPx2 { + T x; + T y; +}; + +using floatx2 = FPx2; +using bf16x2 = FPx2; +using fp16x2 = FPx2; +using fp8e4m3x2 = FPx2; +using fp8e5m2x2 = FPx2; + +static_assert(sizeof(floatx2) == 8); +static_assert(sizeof(bf16x2) == 4); +static_assert(sizeof(fp16x2) == 4); +static_assert(sizeof(fp8e4m3x2) == 2); +static_assert(sizeof(fp8e5m2x2) == 2); + +// SIMD like "Fused" cast + multiplication (x2) +__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, + const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + "mul.f32x2 val_pair, %1, %2; \n\t" + "mov.b64 {val2,val1}, val_pair; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in, + const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + "mul.f32x2 val_pair, %1, %2; \n\t" + "mov.b64 {val2,val1}, val_pair; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair_before; \n\t" + ".reg.b64 val_pair_after; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + ".reg.b16 val1_bf16; \n\t" + ".reg.b16 val2_bf16; \n\t" + "mov.b32 {val1_bf16, val2_bf16} , %1; \n\t" + "cvt.f32.bf16 val1, val1_bf16; \n\t" + "cvt.f32.bf16 val2, val2_bf16; \n\t" + "mov.b64 val_pair_before, {val1,val2}; \n\t" + "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t" + "mov.b64 {val2,val1}, val_pair_after; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "r"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair_before; \n\t" + ".reg.b64 val_pair_after; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + ".reg.b16 val1_bf16; \n\t" + ".reg.b16 val2_bf16; \n\t" + "mov.b32 {val1_bf16, val2_bf16} , %1; \n\t" + "cvt.f32.bf16 val1, val1_bf16; \n\t" + "cvt.f32.bf16 val2, val2_bf16; \n\t" + "mov.b64 val_pair_before, {val1,val2}; \n\t" + "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t" + "mov.b64 {val2,val1}, val_pair_after; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "r"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair_before; \n\t" + ".reg.b64 val_pair_after; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + ".reg.b16 val1_fp16; \n\t" + ".reg.b16 val2_fp16; \n\t" + "mov.b32 {val1_fp16, val2_fp16} , %1; \n\t" + "cvt.f32.f16 val1, val1_fp16; \n\t" + "cvt.f32.f16 val2, val2_fp16; \n\t" + "mov.b64 val_pair_before, {val1,val2}; \n\t" + "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t" + "mov.b64 {val2,val1}, val_pair_after; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "r"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair_before; \n\t" + ".reg.b64 val_pair_after; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + ".reg.b16 val1_fp16; \n\t" + ".reg.b16 val2_fp16; \n\t" + "mov.b32 {val1_fp16, val2_fp16} , %1; \n\t" + "cvt.f32.f16 val1, val1_fp16; \n\t" + "cvt.f32.f16 val2, val2_fp16; \n\t" + "mov.b64 val_pair_before, {val1,val2}; \n\t" + "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t" + "mov.b64 {val2,val1}, val_pair_after; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "r"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void abs_max_2x(bf16x2 &dst, const bf16x2 &p1, const bf16x2 &p2) { + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;" + : "=r"(reinterpret_cast(dst)) + : "r"(reinterpret_cast(p1)), + "r"(reinterpret_cast(p2))); +} + +__device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const fp16x2 &p2) { + asm volatile("max.xorsign.abs.f16x2 %0, %1, %2;" + : "=r"(reinterpret_cast(dst)) + : "r"(reinterpret_cast(p1)), + "r"(reinterpret_cast(p2))); +} + #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // namespace ptx diff --git a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh index b8fee6862..e7b0b0e21 100644 --- a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh @@ -191,14 +191,22 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt; } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + after_dact_reg[stage] = static_cast(static_cast(after_dact_reg[stage])); + if constexpr (IS_DGATED) { + after_dgate_reg[stage] = static_cast(static_cast(after_dgate_reg[stage])); + } + } + if constexpr (USE_ROWWISE_SCALING) { if constexpr (IS_DGATED) { // dgate float amax = fabsf(after_dgate_reg[stage]); const float mx_block_X_amax = warp_reduce_max_broadcast(amax); const e8m0_t biased_exponent_X = - float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); + ptx::float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal_X = ptx::exp2f_rcp(biased_exponent_X); out_gate_rowwise_sh[shmem_idx] = static_cast(scale_reciprocal_X * after_dgate_reg[stage]); @@ -217,8 +225,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float amax = fabsf(after_dact_reg[stage]); const float mx_block_X_amax = warp_reduce_max_broadcast(amax); const e8m0_t biased_exponent_X = - float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); + ptx::float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal_X = ptx::exp2f_rcp(biased_exponent_X); out_act_rowwise_sh[shmem_idx] = static_cast(scale_reciprocal_X * after_dact_reg[stage]); @@ -273,8 +281,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } const e8m0_t biased_exponent = - float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal = exp2f_rcp(biased_exponent); + ptx::float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal = ptx::exp2f_rcp(biased_exponent); // Only single thread writes the computed scaling factor // Also assuming one iteration covers exactly 32 rows @@ -319,8 +327,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } const e8m0_t biased_exponent = - float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal = exp2f_rcp(biased_exponent); + ptx::float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal = ptx::exp2f_rcp(biased_exponent); // Only single thread writes the computed scaling factor // Also assuming one iteration covers exactly 32 rows diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index d62350e0a..383182eef 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -209,6 +209,10 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) partial_dbias_rowwise[chunk_X].data.elt[j] += elt; } } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } in_compute[j] = elt; if (!out_of_bounds) { thread_amax = fmaxf(thread_amax, fabsf(elt)); @@ -221,7 +225,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); const e8m0_t biased_exponent = - float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp) + (IS_NORM ? 1 : 0); // Normalization requires a +1 scale to avoid saturation + ptx::float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp) + (IS_NORM ? 1 : 0); // Normalization requires a +1 scale to avoid saturation // Only single thread writes the computed scaling factor if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { @@ -234,7 +238,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) scales_rowwise[scale_idx] = biased_exponent; } - const float block_scale_inverse = exp2f_rcp(biased_exponent); + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); #pragma unroll for (int j = 0; j < ELEMS_PER_THREAD; j++) { @@ -268,6 +272,10 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) partial_dbias_colwise[chunk_X] += elt; } } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } in_compute[i] = elt; if (!out_of_bounds) { amax = fmaxf(amax, fabsf(elt)); @@ -278,7 +286,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) __builtin_assume(amax >= 0); block_amax = fmaxf(block_amax, amax); - const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_norm_rcp) + (IS_NORM ? 1 : 0); // Normalization requires a +1 scale to avoid saturation + const e8m0_t biased_exponent = ptx::float_to_e8m0(amax * Quantized_Limits::max_norm_rcp) + (IS_NORM ? 1 : 0); // Normalization requires a +1 scale to avoid saturation const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; @@ -286,7 +294,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; scales_colwise[scale_idx] = biased_exponent; - const float block_scale_inverse = exp2f_rcp(biased_exponent); + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); #pragma unroll for (int i = 0; i < SCALE_DIM_Y; i++) { out_colwise_sh[i][tid_colwise_X] = diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/util/rocm_dequantize_kernels.cuh index ae5cb4bbd..30df50d74 100644 --- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh +++ b/transformer_engine/common/util/rocm_dequantize_kernels.cuh @@ -102,7 +102,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X; const e8m0_t biased_exponent = scales_ptr[scale_idx]; - const float block_scale = exp2f(static_cast(biased_exponent) - FP32_EXPONENT_BIAS); + const float block_scale = exp2f(static_cast(biased_exponent) - ptx::FP32_EXPONENT_BIAS); if constexpr (USE_ROWWISE_SCALING) { Vec in; diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 0ed5bca6b..4208f3511 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -984,10 +984,7 @@ using fp8e5m2 = te_hip_fp8_e5m2; #endif //__HIP_PLATFORM_AMD__ using e8m0_t = uint8_t; -constexpr uint32_t FP32_MANTISSA_BITS = 23; -constexpr uint32_t FP32_EXPONENT_BIAS = 127; - -enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENTIONAL = 2 }; +enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENSIONAL = 2 }; template struct Numeric_Traits; @@ -1045,47 +1042,6 @@ struct Quantized_Limits { #endif // TE_DYNAMIC_HIP_FP8_TYPE }; -__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { - // TODO: nan/inf needs to be set for any value - // of nan/inf in input not just amax. - if (isnan(val)) { - return 0xFF; - } - if (isinf(val)) { - return 0xFE; - } -#ifdef __HIP_PLATFORM_AMD__ -#define __CUDA_ARCH_HAS_FEATURE__(x) 0 -#endif //__HIP_PLATFORM_AMD__ -#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ - (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) - uint16_t out; - asm volatile( - "{\n" - "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" - "}" - : "=h"(out) - : "f"(val)); - return *reinterpret_cast(&out); -#else - if (val == 0.0f) { - return 0x00; - } - uint32_t val_u32 = *reinterpret_cast(&val); - e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); - uint32_t mantissa = val_u32 & 0x7FFFFF; - // Round up exponent and deal with satfinite. - if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { - ++exponent; - } - return exponent; -#endif -} - -__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { - return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); -} - } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 0cd8f5a36..c2c4d5db3 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -922,11 +922,11 @@ def shardy_sharding_rule( class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): - """Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" + """Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): - """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" + """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]: diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index bf3b3b7fd..ade3aea09 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -6,6 +6,7 @@ """JAX/TE base custom ops""" import os import re +import warnings from abc import ABCMeta, abstractmethod from functools import partial from packaging import version @@ -33,19 +34,77 @@ class BasePrimitive(metaclass=ABCMeta): name = None + _is_enabled = True + + # Default list of primitives to disable for all recipes + _default_disable_names = ["GemmPrimitive"] + @classmethod def enabled(cls): """ - A custom call is marked as disabled if the `cls.__name__` does not fully match the - `NVTE_JAX_CUSTOM_CALLS_RE` pattern. - This uses the Python class name of the primitive definitions that inherit from BasePrimitive. - By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names. - For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!DBiasQuantizePrimitive$).+$'` to disable `DBiasQuantizePrimitive`. + Determines if a custom call is enabled based on a state variable and environment variables. + Checks `NVTE_JAX_CUSTOM_CALLS` (key/value format) first, then falls back to the deprecated `NVTE_JAX_CUSTOM_CALLS_RE` (regex pattern), + and finally to the internal state `_is_enabled` if neither is set. + + Environment Variables: + 1. `NVTE_JAX_CUSTOM_CALLS`: Preferred key/value format to enable/disable specific primitives or a single value 'true' or 'false' to enable/disable all primitives. + - Example 1 (global enable): 'true' enables all primitives. + - Example 2 (global disable): 'false' disables all primitives. + - Example 3 (specific settings): 'DBiasQuantizePrimitive=false,GemmPrimitive=true' disables DBiasQuantizePrimitive and enables GemmPrimitive, leaving others at their default state. + Note that the default state is set at class level based on _default_disable_names. + 2. `NVTE_JAX_CUSTOM_CALLS_RE`: Deprecated regex pattern to match primitive names. + - Example: 'DBiasQuantizePrimitive' or '^(?!DBiasQuantizePrimitive$).+$' to enable/disable DBiasQuantizePrimitive. + - A deprecation warning is raised if used; it will be removed in future releases. + + Behavior: + 1. Checks if `NVTE_JAX_CUSTOM_CALLS` is set and parses key/value pairs or single true/false value. + 2. If not set, checks `NVTE_JAX_CUSTOM_CALLS_RE` (with deprecation warning) for regex matching. + 3. If neither is set, falls back to the internal state `_is_enabled`. """ - pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*") - pattern = re.compile(pattern) - is_enabled = pattern.fullmatch(cls.__name__) is not None - return is_enabled + + # Check new key/value environment variable first + custom_calls_str = os.getenv("NVTE_JAX_CUSTOM_CALLS") + if custom_calls_str is not None: + custom_calls_str = custom_calls_str.strip() + if custom_calls_str.lower() == "true": + return True + if custom_calls_str.lower() == "false": + return False + + # Parse key=value pairs + settings = {} + for pair in custom_calls_str.split(","): + pair = pair.strip() + if "=" in pair: + key, value = pair.split("=", 1) + key = key.strip() + value = value.strip().lower() + settings[key] = value == "true" + if cls.__name__ in settings: + return settings[cls.__name__] + + # Check old regex environment variable (deprecated) + pattern_str = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE") + if pattern_str is not None: + warnings.warn( + "NVTE_JAX_CUSTOM_CALLS_RE is deprecated and will be removed in future releases. Use" + " NVTE_JAX_CUSTOM_CALLS with key=value format instead (e.g.," + " 'DBiasQuantizePrimitive=false').", + DeprecationWarning, + ) + pattern = re.compile(pattern_str) + env_enabled = pattern.fullmatch(cls.__name__) is not None + return env_enabled + + # If no environment variable is set, fall back to the internal state + return cls._is_enabled + + @classmethod + def set_enabled(cls, enabled: bool): + """ + Sets the enabled state for this primitive. + """ + cls._is_enabled = enabled @staticmethod @abstractmethod @@ -112,10 +171,19 @@ def shardy_sharding_rule(*args): return "... -> ..." +# Registry to store all registered primitive classes +_primitive_registry = {} + + def register_primitive(cls): """ - register jax primitive + Register a JAX primitive and add it to the internal registry. """ + _primitive_registry[cls.__name__] = cls + + # Set default disabled state at class level based on _default_disable_names + if cls.__name__ in BasePrimitive._default_disable_names: + cls.set_enabled(False) def name_of_wrapper_p(): return cls.name + "_wrapper" @@ -153,3 +221,48 @@ def name_of_wrapper_p(): for _name, _value in transformer_engine_jax.registrations().items(): ffi.register_ffi_target(_name, _value, platform="ROCM" if is_hip_extension else "CUDA") + + +def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False): + """ + Helper function to manage primitive states by name without modifying environment variables. + Allows enabling specific primitives, disabling specific primitives, or disabling all primitives. + This helper is used in the QuantizeConfig.initialize() methods. + + Args: + enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None. + disable_names: List of strings, each representing the name of a primitive class to disable. Defaults to None. + disable_all_first: Boolean, if True, disables all primitives before applying enable/disable lists. Defaults to False. + + Note: + 1. If `disable_all_first` is True, all primitives are disabled first, then `enable_names` is applied. + 2. Conflicts (a primitive in both enable and disable lists) are resolved by applying disable last. + """ + + enable_set = set(enable_names or []) + disable_set = set(disable_names or []) + + if disable_all_first: + for name, cls in _primitive_registry.items(): + if ( + isinstance(cls, type) + and issubclass(cls, BasePrimitive) + and cls is not BasePrimitive + ): + cls.set_enabled(False) + + # Apply enables + for name in enable_set: + cls = _primitive_registry.get(name) + if cls and isinstance(cls, type) and issubclass(cls, BasePrimitive): + cls.set_enabled(True) + else: + raise ValueError(f"Primitive not found in registry: {name}") + + # Apply disables (overrides enables if there's a conflict) + for name in disable_set: + cls = _primitive_registry.get(name) + if cls and isinstance(cls, type) and issubclass(cls, BasePrimitive): + cls.set_enabled(False) + else: + raise ValueError(f"Primitive not found in registry: {name}") diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 4a5468b4a..da154c8b3 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -166,7 +166,7 @@ class GemmPrimitive(BasePrimitive): name = "te_gemm_ffi" multiple_results = True - impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15) + impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17) inner_primitive = None outer_primitive = None @@ -188,8 +188,14 @@ def abstract( fuse_gelu, grad, use_split_accumulator, + sequence_parallel_output, + sequence_dim, ): del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator + del ( + sequence_parallel_output, + sequence_dim, + ) def _dims_are_consecutive(dims): if len(dims) <= 1: @@ -354,8 +360,12 @@ def lowering( fuse_gelu, grad, use_split_accumulator, + sequence_parallel_output, + sequence_dim, ): del batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, out_dtype + del sequence_parallel_output, sequence_dim + lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) lhs_transposed, rhs_transposed = _get_gemm_layout( @@ -404,6 +414,8 @@ def impl( fuse_gelu, grad, use_split_accumulator, + sequence_parallel_output, + sequence_dim, ): lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) lhs_transposed, rhs_transposed = _get_gemm_layout( @@ -441,6 +453,8 @@ def impl( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + sequence_parallel_output=sequence_parallel_output, + sequence_dim=sequence_dim, ) return outputs[:-3] # discard workspace arrays @@ -458,6 +472,8 @@ def batcher( fuse_gelu, grad, use_split_accumulator, + sequence_parallel_output, + sequence_dim, ): assert GemmPrimitive.outer_primitive is not None lhs, _, rhs, *_ = batched_args @@ -500,6 +516,8 @@ def batcher( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + sequence_parallel_output=sequence_parallel_output, + sequence_dim=sequence_dim, ), (out_bdims, bias_bdims, pre_gelu_bdims), ) @@ -521,7 +539,13 @@ def _decompose_operand_specs(specs, contracting_dims, batch_dims): return bspecs, lspecs, cspecs @staticmethod - def _parse_operand_output_specs(arg_infos, contracting_dims, batched_dims): + def _parse_operand_output_specs( + arg_infos, + contracting_dims, + batched_dims, + sequence_parallel_output, + sequence_dim, + ): lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map( @@ -567,96 +591,66 @@ def _parse_operand_output_specs(arg_infos, contracting_dims, batched_dims): ) # Extract single leading and contracting dimension specs - (lhs_lspec, rhs_lspec, lhs_cspec, rhs_cspec) = map( + (lhs_cspec, rhs_cspec) = map( lambda specs: None if len(specs) == 0 else specs[0], - (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none), + (lhs_cspec_not_none, rhs_cspec_not_none), ) - # Reproducing jax.nn.scaled_matmul() custom partitioning for arbitrary GEMM layouts - # with row-wise LHS:(B, M, K1) and row-wise RHS:(B, N, K2) operands. - # 1. K1 == K2 != None and N == None - # LHS: (B, M, K) - # RHS: (B, None, K) - # OUT: (B, M, None) --(AR)-> (B, M, None) - # 2. K1 == K2 != None and M == N != None - # LHS: (B, M, K) - # RHS: (B, N, K)--(AG)->(B, None, K) - # OUT: (B, M, None) --(RS)--> (B, M, N) - # 3. M == N - # LHS: (B, M, K)--(AG)->(B, M, None) - # RHS: (B, M, K)--(AG)->(B, None, None) - # OUT: (B, M, None) - # 4. M != N - # LHS: (B, M, K)--(AG)->(B, M, None) - # RHS: (B, N, K)--(AG)->(B, N, None) - # OUT: (B, M, N) - reduce_flag = lhs_cspec is not None and lhs_cspec == rhs_cspec - all_reduce_output = reduce_flag and rhs_lspec is None - reduce_scatter_output = reduce_flag and lhs_lspec is not None and lhs_lspec == rhs_lspec - all_reduce_spec = reduce_scatter_spec = scatter_dim = None + # Partitioning rules: + # ([B], M, K1) x ([B], N, K2)^T = ([B], M, N) + # 1. K1 == K2 != None + # - Require non-batched non-contracting dims of both LHS and RHS to be unsharded. + # - If `sequence_parallel_output=True`, then reduce-scatter the output. + # - Otherwise, all-reduce the output. + # 2. Otherwise + # - Require contracting dimensions of both LHS and RHS to be unsharded. + # - Require non-batched non-contracting dims of LHS to be unsharded. + reduce_output = rhs_cspec is not None and lhs_cspec == rhs_cspec + reduce_spec = scatter_dim = None + if reduce_output: + reduce_spec = rhs_cspec + if sequence_parallel_output: + # If the sequence dimension is not specified, assume it to be the first + # non-batched non-contracting dimension of the LHS operand. + scatter_dim = sequence_dim if sequence_dim is not None else lhs_ldims[0] + + # Always require the non-batched non-contracting dims of LHS to be unsharded + # NOTE: This will all-gather sequence-parallel inputs and preserve tensor-parallel params. + lhs_specs = tuple( + lhs_specs[i] if i in set(lhs_bdims + lhs_cdims) else None for i in range(lhs_ndim) + ) + if reduce_output: + # When reducing GEMM output, require non-batched non-contracting dims of the RHS + # operand to be unsharded (i.e. FSDP) + rhs_specs = tuple( + None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i] + for i in range(rhs_ndim) + ) + else: + # Otherwise, require contracting dims of both operands to be unsharded + lhs_specs = tuple(None if i in lhs_cdims else lhs_specs[i] for i in range(lhs_ndim)) + rhs_specs = tuple(None if i in rhs_cdims else rhs_specs[i] for i in range(rhs_ndim)) + # Combine modified LHS and RHS specs into the output lhs_non_contracting_specs, rhs_non_contracting_specs = map( lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims), (lhs_specs, rhs_specs), (lhs_cdims, rhs_cdims), ) - out_specs = (*lhs_non_contracting_specs, *rhs_non_contracting_specs) - if reduce_scatter_output: - # All-gather (if necessary) the non-batch non-contracting dimension of RHS - # (B, N, K) --(AG)-> (B, None, K) - # (B, M, K) x (B, None, K)^T = (B, M, None) --(RS)-> (B, M, N) - rhs_spec = tuple( - rhs_spec[i] if i in set(rhs_bdims + rhs_cdims) else None for i in range(rhs_ndim) - ) - reduce_scatter_spec = lhs_cspec - scatter_dim = out_specs.index(rhs_lspec) - - elif all_reduce_output: - # Set all output trailing dimensions to zero - out_specs = ( - *lhs_non_contracting_specs, - *[None for _ in range(len(rhs_non_contracting_specs))], - ) - all_reduce_spec = lhs_cspec - else: - # All-gather (if necessary) the non-batch contracting dimensions - # (B, M, K) --(AG)-> (B, M, None) - # (B, N, K) --(AG)-> (B, N, None) - # (B, M, None) x (B, N, None)^T = (B, M, N) - lhs_specs = tuple( - None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i] - for i in range(lhs_ndim) - ) - rhs_specs = tuple( - None if i in rhs_cdims and i not in rhs_bdims else rhs_specs[i] - for i in range(rhs_ndim) - ) - # Check if RHS non-contracting spec also appears in the LHS non-contracting specs - if rhs_lspec is not None and rhs_lspec in tuple( - lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_cdims - ): - # All-gather (if necessary) the non-batch non-contracting dimensions of RHS - # (B, N, None) --(AG)-> (B, None, None) - # (B, M, None) x (B, None, None)^T = (B, M, None) - rhs_specs = tuple( - None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i] - for i in range(rhs_ndim) - ) - # Set all output trailing dimensions to zero - out_specs = ( - *lhs_non_contracting_specs, - *[None for _ in range(len(rhs_non_contracting_specs))], - ) + out_specs = [*lhs_non_contracting_specs, *rhs_non_contracting_specs] - # Bias and Pre-GeLU sharding is based on GEMM output - bias_specs = out_specs[len(lhs_non_contracting_specs) :] - gelu_specs = out_specs + # Bias and Pre-GeLU sharding is based on GEMM output before any scatter + bias_specs = tuple(list(out_specs[len(lhs_non_contracting_specs) :]).copy()) + gelu_specs = tuple(list(out_specs).copy()) + + # Set output scatter dim to the tensor-parallel spec + if sequence_parallel_output: + out_specs[scatter_dim] = reduce_spec return ( (lhs_specs, rhs_specs, bias_specs, gelu_specs), (out_specs, bias_specs, gelu_specs), - all_reduce_spec, - reduce_scatter_spec, + reduce_spec, scatter_dim, ) @@ -672,6 +666,8 @@ def infer_sharding_from_operands( fuse_gelu, grad, use_split_accumulator, + sequence_parallel_output, + sequence_dim, mesh, arg_infos, result_infos, @@ -686,7 +682,13 @@ def infer_sharding_from_operands( del use_split_accumulator, result_infos (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( - GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims) + GemmPrimitive._parse_operand_output_specs( + arg_infos, + contracting_dims, + batched_dims, + sequence_parallel_output, + sequence_dim, + ) ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) @@ -714,6 +716,8 @@ def partition( fuse_gelu, grad, use_split_accumulator, + sequence_parallel_output, + sequence_dim, mesh, arg_infos, result_infos, @@ -723,10 +727,15 @@ def partition( ( (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), (out_specs, dbias_specs, pre_gelu_specs), - all_reduce_spec, - reduce_scatter_spec, + reduce_spec, scatter_dim, - ) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims) + ) = GemmPrimitive._parse_operand_output_specs( + arg_infos, + contracting_dims, + batched_dims, + sequence_parallel_output, + sequence_dim, + ) # Assemble argument shardings # NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded. @@ -781,20 +790,17 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + sequence_parallel_output=sequence_parallel_output, + sequence_dim=sequence_dim, ) # All-Reduce/Reduce-Scatter GEMM output - if all_reduce_spec is not None: - outputs[0] = jax.lax.psum(outputs[0], all_reduce_spec) - if fuse_gelu and not grad: - outputs[2] = jax.lax.psum(outputs[2], all_reduce_spec) - elif reduce_scatter_spec is not None: - outputs[0] = jax.lax.psum_scatter( - outputs[0], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True - ) - if fuse_gelu and not grad: - outputs[2] = jax.lax.psum_scatter( - outputs[2], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True + if reduce_spec is not None: + if scatter_dim is None: + outputs[0] = jax.lax.psum(outputs[0], reduce_spec) + else: + outputs[0] = jax.lax.psum_scatter( + outputs[0], reduce_spec, scatter_dimension=scatter_dim, tiled=True ) return outputs @@ -813,12 +819,14 @@ def shardy_sharding_rule( fuse_gelu, grad, use_split_accumulator, + sequence_parallel_output, + sequence_dim, mesh, operand_types, result_types, ): del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype, grad, use_split_accumulator - del mesh, result_types + del sequence_parallel_output, sequence_dim, mesh, result_types prefix = "GemmPrimitive_" @@ -907,6 +915,8 @@ def _te_gemm( fuse_gelu: bool = False, grad: bool = False, use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP, + sequence_parallel_output: bool = False, + sequence_dim: int = None, ) -> Tuple[jax.Array, ...]: # Prepare non-quantized GEMM operands @@ -980,6 +990,8 @@ def _te_gemm( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + sequence_parallel_output=sequence_parallel_output, + sequence_dim=sequence_dim, ) @@ -1318,9 +1330,9 @@ def gemm( Tuple of sequences representing the contracting dimensions of the operands. batched_dims: Tuple[Sequence[int], Sequence[int]], default = ((), ()), Tuple of sequences representing the batched dimensions of the operands. This is *not* used - to perform a batched matrix multiplication, but it is required to avoid a potentially - undesirable reduction in any batched contracting dimensions when invoked with sharded - operands (e.g. when computing weight gradients in a Flax module). + to perform a batched matrix multiplication, but it is required for TE's custom cuBLAS GEMM + call to avoid a potentially undesirable reduction in any batched contracting dimensions + when invoked with sharded operands (e.g. when computing weight gradients in a Flax module). bias: jax.Array, default = None Optional additive bias term, required for forward GEMM with bias fusion. Only supported with TE's custom call to cuBLAS GEMM. @@ -1338,7 +1350,17 @@ def gemm( TE's custom call to cuBLAS GEMM. use_split_accumulator: bool, default = True Enable promoting some intermediate sums to higher precision when accumulating the result in - the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. + the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Only + supported with TE's custom call to cuBLAS GEMM. + sequence_parallel_output: bool, default = False + Produces an output with the first non-batched non-contracting dimension sharded with the + same spec as operand contracting dimensions. This effectively converts the `jax.lax.psum` + for the GEMM output into a `jax.lax.psum_scatter`. Only supported with TE's custom call to + cuBLAS GEMM. + sequence_dim: int, default = None + Index of the sequence dimension for the LHS operand. This controls which dimension of the + GEMM output is scattered when `sequence_parallel_output=True`. When `None`, the first + non-batched non-contracting dimension is assumed to be the sequence dimension. Returns ------- @@ -1369,12 +1391,20 @@ def gemm( if not GemmPrimitive.enabled(): assert kwargs.get("bias", None) is None and not fuse_gelu, ( "TE GEMM was invoked with bias fusion options that are not supported by the " - "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " + "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS " "GEMM primitive is disabled." ) assert kwargs.get("gelu_input", None) is None and not fuse_bias, ( "TE GEMM was invoked with GeLU fusion options that are not supported by the " - "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " + "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS " + "GEMM primitive is disabled." + ) + assert ( + not kwargs.get("sequence_parallel_output", False) + and kwargs.get("sequence_dim", None) is None + ), ( + "TE GEMM was invoked with sequence-parallelism options that are not supported by the " + "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backedns used when the custom cuBLAS " "GEMM primitive is disabled." ) return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 7a5b31ad7..be2fb4425 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -524,11 +524,11 @@ def shardy_sharding_rule( class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive): - """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" + """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" class QuantizePrimitive(BaseDBiasQuantizePrimitive): - """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" + """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" def _jax_quantize( diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index a0fc7b7af..5be551dbd 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -22,6 +22,7 @@ TensorUsage, ) +from .sharding import get_sequence_parallel_dim DENSE_BATCH_FIRST_WARNING_ISSUED = False @@ -41,6 +42,7 @@ def dense( input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, batch_first: bool = True, + sequence_parallel_output: bool = False, quantizer_set: QuantizerSet = noop_quantizer_set, ): """Perform dense layer transformation with optional quantization. @@ -55,6 +57,8 @@ def dense( bias: Optional bias tensor to add after the transformation contracting_dims: Tuple of sequences specifying which dimensions to contract batch_first: Assume that X is batched in the first dimension. + sequence_parallel_output: Produce an output that sharded in the first non-batched dim. Only + supported for TE custom GEMM with row-parallel kernel axes. quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: @@ -69,13 +73,31 @@ def dense( output += jnp.reshape(bias, bias_new_shape) else: output = _dense( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + batch_first, + sequence_parallel_output, + quantizer_set, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) -def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set): +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7)) +def _dense( + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + batch_first, + sequence_parallel_output, + quantizer_set, +): """Internal implementation of dense layer transformation with custom VJP. This function implements the core dense layer transformation logic with support @@ -88,20 +110,38 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_fir contracting_dims: Contracting dimensions specification input_axes: Logical axes for sharding the activation input kernel_axes: Logical axes for sharding the weight matrix - quantizer_set: QuantizerSet which contains quantizers for different tensor types batch_first: Assume that X is batched in the first dimension if it has more than 2 dims. + sequence_parallel_output: Produce an output that sharded in the first non-batched dim. Only + supported for TE custom GEMM with row-parallel kernel axes. + quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: Transformed output tensor """ output, _ = _dense_fwd_rule( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + batch_first, + sequence_parallel_output, + quantizer_set, ) return output def _dense_fwd_rule( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + batch_first, + sequence_parallel_output, + quantizer_set, ): """Forward pass rule for dense layer transformation. @@ -161,6 +201,7 @@ def _dense_fwd_rule( batched_dims=((x_bdim,), ()), bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, + sequence_parallel_output=sequence_parallel_output and not tex.gemm_uses_jax_dot(), ) if use_bias and tex.gemm_uses_jax_dot(): @@ -181,7 +222,7 @@ def _dense_fwd_rule( def _dense_bwd_rule( - contracting_dims, input_axes, kernel_axes, batch_first, ctx, grad + contracting_dims, input_axes, kernel_axes, batch_first, sequence_parallel_output, ctx, grad ): # pylint: disable=unused-argument """Backward pass rule for dense layer transformation. @@ -220,11 +261,22 @@ def _dense_bwd_rule( k_contracting_dim = tuple( dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims ) + + # Get sequence-parallel dimension of the FWD input (if it exists) + sequence_dim = get_sequence_parallel_dim(input_axes, fwd_x_contracting_dims, (x_bdim,)) dgrad = tex.gemm( casted_grad.get_tensor(usage=TensorUsage.LHS), casted_kernel_rhs, contracting_dims=(g_contracting_dim, k_contracting_dim), batched_dims=((x_bdim,), ()), + sequence_parallel_output=( + sequence_dim is not None + and not sequence_parallel_output + and not tex.gemm_uses_jax_dot() + ), + sequence_dim=( + None if sequence_parallel_output or tex.gemm_uses_jax_dot() else sequence_dim + ), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 5992d3607..6670377f7 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -415,6 +415,8 @@ class DenseGeneral(TransformerEngineBase): Indicate the logical axes of sharding constraint to the input, like (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert sharding constraint. + sequence_parallel_output: bool, default = False + Produce a sequence-parallel output with the first non-batch dimension sharded over Optimization parameters ----------------------- @@ -439,6 +441,7 @@ class DenseGeneral(TransformerEngineBase): dtype: DType = jnp.float32 transpose_batch_sequence: bool = False input_axes: Tuple[str, ...] = () + sequence_parallel_output: bool = False def __post_init__(self): if self.transpose_batch_sequence: @@ -511,6 +514,7 @@ def __call__(self, inputs: Array) -> Array: input_axes=self.input_axes, kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, + sequence_parallel_output=self.sequence_parallel_output, ) if self.enable_low_rank_adaptation: diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index f2c0bc2a1..5f309820c 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -1425,6 +1425,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, name="out", + sequence_parallel_output=self.enable_sequence_parallel, )(x) out = checkpoint_name(out, "out_proj") diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 5ccfc71c2..c616aa699 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -24,6 +24,7 @@ with_sharding_constraint_by_logical_axes, TensorUsage, ) +from .sharding import get_sequence_parallel_dim LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False @@ -324,11 +325,16 @@ def _layernorm_dense_bwd_rule( ) # NT GEMM + sequence_dim = get_sequence_parallel_dim( + layernorm_input_axes, x_contracting_dims_in_fwd, (x_bdim,) + ) dgrad = tex.gemm( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel, contracting_dims=(g_constracting_dim, k_constracting_dim), batched_dims=((x_bdim,), ()), + sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(), + sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None, ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 507c49c7e..8dd045100 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -29,7 +29,10 @@ noop_quantizer_set, TensorUsage, ) -from .sharding import get_non_contracting_logical_axes +from .sharding import ( + get_non_contracting_logical_axes, + get_sequence_parallel_dim, +) LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False @@ -342,6 +345,7 @@ def _layernorm_mlp_fwd_rule( # NN GEMM # (batch..., hidden_in) x (hidden_out, hidden_in) + sequence_dim = get_sequence_parallel_dim(norm_input_axes, x_contracting_dims, (x_bdim,)) dot_2_output = tex.gemm( casted_act_out.get_tensor(TensorUsage.LHS), casted_kernel_2.get_tensor(TensorUsage.RHS), @@ -349,6 +353,8 @@ def _layernorm_mlp_fwd_rule( batched_dims=((x_bdim,), ()), bias=bias_2 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, + sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(), + sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None, ) if use_bias_2 and tex.gemm_uses_jax_dot(): @@ -377,6 +383,7 @@ def _layernorm_mlp_fwd_rule( use_bias_2, quantizer_sets, x_bdim, + sequence_dim, ) return dot_2_output, ctx @@ -431,6 +438,7 @@ def _layernorm_mlp_bwd_rule( use_bias_2, quantizer_sets, x_bdim, + sequence_dim, ) = ctx ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets @@ -501,6 +509,8 @@ def _layernorm_mlp_bwd_rule( casted_kernel_1, contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), batched_dims=((x_bdim,), ()), + sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(), + sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None, ) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 0b9659a46..fc7146ca5 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -368,6 +368,9 @@ def initialize(fp8_recipe: recipe.Recipe) -> None: cls.initialize(fp8_recipe) cls.AMAX_HISTORY_LEN = 0 + # Use TE GEMM instead of JAX GEMM for better performance + tex.base.manage_primitives(enable_names=["GemmPrimitive"]) + @staticmethod def finalize() -> None: """Reset the block scaling configuration.""" diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index e59c9de12..a7bbef997 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -86,17 +86,61 @@ def get_sharding_map_logic_axis_to_mesh_axis(): return te_logical_axis_to_mesh_axis -def generate_pspec(logical_axis_names): +def get_sequence_parallel_dim(logical_axes, contracting_dims, batch_dims): + """ + Get the index for the sequence-parallel dimension based on the given logical axes. + + The sequence-parallel dimension is assumed to be the only sharded non-batched non-contracting + dimension. + """ + if not logical_axes: + return None + + pspec = generate_pspec(logical_axes, with_flax_rules=True, padded=True) + ldims = [i for i in range(len(logical_axes)) if i not in set(contracting_dims + batch_dims)] + lspecs = [pspec[i] for i in ldims if pspec[i] is not None] + if len(lspecs) == 0: + return None + + assert len(lspecs) == 1, ( + "Expected only 1 non-batched non-contracting dimension to be sharded for " + f"sequence-parallelism, but found {len(lspecs)}: {pspec} @ idx {ldims}" + ) + + return pspec.index(lspecs[0]) + + +def generate_pspec(logical_axis_names, with_flax_rules=False, padded=False): """ Convert logical axes to PartitionSpec """ - rules = get_sharding_map_logic_axis_to_mesh_axis() + rules = None + if with_flax_rules: + try: + import flax + + rules = dict(flax.linen.get_logical_axis_rules()) + except ImportError: + pass + + if rules is None: + warnings.warn( + "Transformer Engine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated" + " and removed in a future version. Please use Flax logical axes with the" + " `flax.linen.logical_axis_rules()` context and optionally use" + " `transformer_engine.jax.flax.extend_logical_axis_rules()` to extend Flax axis rules" + " with Transformer Engine logical axes.", + DeprecationWarning, + ) + rules = get_sharding_map_logic_axis_to_mesh_axis() # mesh_axis_names = [rules[name] for name in logical_axis_names] mesh_axis_names = [] for name in logical_axis_names: axis_name = rules[name] if name in rules else None mesh_axis_names.append(axis_name) pspec = jax.sharding.PartitionSpec(*mesh_axis_names) + if padded: + pspec = get_padded_spec(pspec, len(mesh_axis_names)) return pspec diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 893e2d228..b35b87a83 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -630,7 +630,7 @@ def forward( If true, there are padding tokens between individual sequences in a packed batch. """ - with self.prepare_forward( + with torch.cuda.device(query_layer.device), self.prepare_forward( query_layer, num_gemms=3, allow_non_contiguous=True, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index e03543d40..f3bda7d20 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -439,8 +439,8 @@ def get_attention_backend( # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 if inference_params is not None: - if device_compute_capability == (8, 9) and cudnn_version < (9, 12, 0): - logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN < 9.12") + if device_compute_capability == (8, 9) and cudnn_version <= (9, 12, 0): + logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.12") use_fused_attention = False if context_parallel: logger.debug("Disabling all backends for KV caching with context parallelism") @@ -615,7 +615,7 @@ def get_attention_backend( " bias for THD format" ) use_fused_attention = False - elif fp8 and head_dim_qk != head_dim_v: + elif fp8 and fp8_meta["recipe"].fp8_dpa and head_dim_qk != head_dim_v: logger.debug( "Disabling FusedAttention as it does not support context parallelism with FP8" " MLA attention" diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 0c7e3fe19..e8025f46d 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -412,7 +412,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument per_callable_static_grad_inputs = [None] * len(flatten_sample_args) fwd_idx = [0] * num_model_chunks bwd_idx = [0] * num_model_chunks - static_grad_outputs = None + static_grad_outputs_dict = {} previous_per_callable_bwd_idx = None for c_id in _order: if c_id > 0: @@ -444,9 +444,21 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument static_outputs = per_callable_static_outputs[per_callable_bwd_idx] bwd_graph = bwd_graphs[per_callable_bwd_idx] # For now, assumes all static_outputs require grad - if not _reuse_graph_input_output_buffers or static_grad_outputs is None: + if _reuse_graph_input_output_buffers: # Note for _reuse_graph_input_output_buffers: grad output is only used # within backward, so we can reuse the same static buffers every time. + static_grad_outputs_keys = tuple( + (o.shape, o.dtype, o.layout) for o in static_outputs if o.requires_grad + ) + if static_grad_outputs_keys in static_grad_outputs_dict: + static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys] + else: + static_grad_outputs = tuple( + torch.empty_like(o) if o.requires_grad else None + for o in static_outputs + ) + static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs + else: static_grad_outputs = tuple( torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index a6ab1b22a..a2ace89c8 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1326,21 +1326,29 @@ def get_weight_workspace( # Try getting workspace from cache out = None - if cache_name is not None: out = self._fp8_workspaces.get(cache_name, None) - if quantizer is not None and isinstance(out, MXFP8TensorBase): + + # Reset cache if workspace is invalid + if out is not None and quantizer is not None: + reset_cache = False + if isinstance(out, Float8TensorBase): + if ( + not is_non_tn_fp8_gemm_supported() + and quantizer.columnwise_usage + and out._transpose is None + ): + reset_cache = True + elif isinstance(out, MXFP8TensorBase): if quantizer.rowwise_usage and out._rowwise_data is None: - out = None - del self._fp8_workspaces[cache_name] + reset_cache = True elif quantizer.columnwise_usage and out._columnwise_data is None: - out = None - del self._fp8_workspaces[cache_name] - - is_debug = isinstance(quantizer, DebugQuantizer) - is_out_debug_tensor = out is not None and isinstance(out, DebugQuantizedTensor) - if is_debug != is_out_debug_tensor: + reset_cache = True + if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer): + reset_cache = True + if reset_cache: out = None + del self._fp8_workspaces[cache_name] # Gather cached Fp8 workspace if it's distributed # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index da66e68b4..cc472390f 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -742,7 +742,9 @@ def forward( if skip_fp8_weight_update is not None: is_first_microbatch = False - with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: + with torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ), self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: weight_tensors = self._get_weight_tensors() bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 5af56b2e0..0f2cbbfe0 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1535,7 +1535,9 @@ def forward( if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( + with torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ), self.prepare_forward( inp, allow_non_contiguous=False # removed .contiguous from inside the layer ) as inp: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 8772418c9..018af22c5 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1797,7 +1797,9 @@ def forward( if get_ub("fc2_fprop").is_fp8_ubuf(): fp8_output = True - with self.prepare_forward(inp, num_gemms=2) as inp: + with torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ), self.prepare_forward(inp, num_gemms=2) as inp: quantizers = ( self._get_quantizers(fp8_output) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index fb5592540..49b27818c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -67,8 +67,6 @@ ) from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..export import is_in_onnx_export_mode, assert_warmed_up from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...debug.pytorch.debug_state import TEDebugState @@ -175,16 +173,19 @@ def forward( if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") if not isinstance(inputmat, QuantizedTensorBase): - input_quantizer.set_usage( - rowwise=True, columnwise=backward_needs_input and not save_original_input - ) + own_quantized_input = True + input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) if isinstance( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): # All-gather is not supported with FP8 column-wise data input_quantizer.set_usage(columnwise=False) + if save_original_input: + # No need for column-wise data since this + # tensor will not be cached for backward pass + input_quantizer.set_usage(columnwise=False) + own_quantized_input = False inputmat = input_quantizer(inputmat) - own_quantized_input = True else: inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP @@ -352,23 +353,29 @@ def forward( inputmat = inp ctx.weight_quantizer = weight_quantizer - saved_inputmat = None ctx.backward_input_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) + # Discard unneeded data in input tensor + if ( + backward_needs_input + and own_quantized_input + and isinstance(inputmat, QuantizedTensorBase) + ): + if ctx.backward_input_needs_gather and isinstance( + quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ): + # All-gather is not supported with FP8 column-wise data + inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) + else: + # Discard row-wise data since it is not needed in backward pass + inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + + # Cached input tensor + saved_inputmat = None if backward_needs_input: - if not save_original_input: - if own_quantized_input and isinstance(inputmat, QuantizedTensorBase): - # For sequence parallel in vanilla FP8, rowwise data is - # to gather the input. For MXFP8, columnwise only data - # can be allgathered. - if ( - isinstance(inputmat, (MXFP8TensorBase, Float8BlockwiseQTensorBase)) - or not ctx.backward_input_needs_gather - ): - inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) saved_inputmat = inputmat # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. @@ -584,20 +591,26 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - input_is_quantized = isinstance(inputmat, QuantizedTensorBase) if ctx.fp8 or ctx.debug: - if not input_is_quantized: + if isinstance(inputmat, QuantizedTensorBase): + # Input tensor is already quantized + pass + elif ctx.debug: + # Debug quantizer will be applied immediately before wgrad GEMM + pass + else: + # Quantize input tensor quantizer = ctx.input_quantizer - if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): - quantizer.set_usage( - rowwise=True, - columnwise=not ctx.backward_input_needs_gather, - ) + if ctx.backward_input_needs_gather and isinstance( + quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ): + # All-gather is not supported with FP8 column-wise data + quantizer.set_usage(rowwise=True, columnwise=False) else: - quantizer.set_usage(rowwise=False, columnwise=True) + quantizer.set_usage(rowwise=True, columnwise=True) inputmat = quantizer(inputmat) else: - if input_is_quantized: + if isinstance(inputmat, QuantizedTensorBase): inputmat = inputmat.dequantize(dtype=ctx.activation_dtype) else: inputmat = cast_if_needed(inputmat, ctx.activation_dtype) @@ -1377,7 +1390,9 @@ def forward( if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( + with torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ), self.prepare_forward( inp, allow_non_contiguous=isinstance(inp, QuantizedTensor), ) as inp: