diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index cfba6d7225..2922da501e 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -21,10 +21,12 @@ concurrency: cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} permissions: read-all env: + TRITON_BUILD_WITH_CCACHE: "true" TRITON_BUILD_WITH_CLANG_LLD: "TRUE" TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE" TRITON_DISABLE_LINE_INFO: 1 PROTON_SKIP_PC_SAMPLING_TEST: 1 + CCACHE_COMPRESS: "true" jobs: Runner-Preparation: runs-on: ubuntu-latest @@ -39,6 +41,11 @@ jobs: if: github.event_name == 'pull_request' run: | echo "enable_integration=true" >> $GITHUB_ENV + - name: Decide manual trigger integration test enablement + # Always enable integration tests when manually triggered + if: github.event_name == 'workflow_dispatch' + run: | + echo "enable_integration=true" >> $GITHUB_ENV - name: Checkout post-submit commits if: github.event_name == 'push' uses: actions/checkout@v4 @@ -154,6 +161,8 @@ jobs: strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-CUDA)}} + env: + RUNNER_TYPE: ${{ matrix.runner[0] }} steps: - name: Checkout uses: actions/checkout@v4 @@ -199,22 +208,28 @@ jobs: # "restore" step. This is to prevent the caches from accumulating stale # files over time. name: Restore cache of ccache and Triton compilation artifacts - if: github.event_name != 'push' + id: restore-build-cache + if: github.ref != 'refs/heads/main' uses: actions/cache/restore@v4 with: path: | ~/.triton/cache - ~/.cache/ccache + ~/.ccache # Restore the most recent cache entry. - restore-keys: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}- + restore-keys: | + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}- + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}- # We expect this cache key never to hit and for us to fall back # unconditionally to the restore-key, so it doesn't actually matter # what we put here (so long as it doesn't hit an existing key). - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directory + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} + - name: Inspect cache directories run: | mkdir -p ~/.triton - ls -alh ~/.triton + du -h -d 1 ~/.triton + + mkdir -p ~/.ccache + du -h -d 1 ~/.ccache - name: Update PATH run: | echo "$HOME/.local/bin" >> $GITHUB_PATH @@ -224,12 +239,14 @@ jobs: python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-forked pytest-xdist lit - name: Install Triton env: - TRITON_BUILD_WITH_CCACHE: "true" CUDA_HOME: "/usr/local/cuda" run: | echo "PATH is '$PATH'" cd python - python3 -m pip install '.[tests]' + ccache --zero-stats + python3 -m pip install -v '.[tests]' + - name: CCache Stats + run: ccache --print-stats - name: Run lit tests run: | cd python @@ -278,6 +295,13 @@ jobs: cd third_party/proton/test python3 -m pytest -s . cd .. + - name: Inspect cache directories + run: | + mkdir -p ~/.triton + du -h -d 1 ~/.triton + + mkdir -p ~/.ccache + du -h -d 1 ~/.ccache - # If we're on branch `main`, save the ccache Triton compilation artifacts # to the cache so they can be used by other (non-main) CI runs. # @@ -287,22 +311,17 @@ jobs: if: github.ref == 'refs/heads/main' uses: actions/cache/save@v4 with: - path: ~/.triton/cache ~/.cache/ccache - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directories - run: | - mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** - - mkdir -p ~/.cache/ccache - ls -alh ~/.cache/ccache - du -sh ~/.cache/ccache + path: | + ~/.triton/cache + ~/.ccache + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} Integration-Tests-AMD: needs: Runner-Preparation if: needs.Runner-Preparation.outputs.matrix-HIP != '' runs-on: ${{ matrix.runner }} timeout-minutes: 30 + env: + RUNNER_TYPE: ${{ matrix.runner[1] }} strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-HIP)}} @@ -355,22 +374,28 @@ jobs: # "restore" step. This is to prevent the caches from accumulating stale # files over time. name: Restore cache of ccache and Triton compilation artifacts - if: github.event_name != 'push' + id: restore-build-cache + if: github.ref != 'refs/heads/main' uses: actions/cache/restore@v4 with: path: | ~/.triton/cache - ~/.cache/ccache + ~/.ccache # Restore the most recent cache entry. - restore-keys: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}- + restore-keys: | + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}- + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}- # We expect this cache key never to hit and for us to fall back # unconditionally to the restore-key, so it doesn't actually matter # what we put here (so long as it doesn't hit an existing key). - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directory + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} + - name: Inspect cache directories run: | mkdir -p ~/.triton - ls -alh ~/.triton + du -h -d 1 ~/.triton + + mkdir -p ~/.ccache + du -h -d 1 ~/.ccache - name: Update PATH run: | echo "/opt/rocm/llvm/bin" >> $GITHUB_PATH @@ -378,17 +403,24 @@ jobs: run: | python3 -m pip install --upgrade pip python3 -m pip install lit + - name: Install apt dependencies + run: | + apt update + apt install ccache - name: Install Triton id: amd-install-triton run: | echo "PATH is '$PATH'" pip uninstall -y triton cd python + ccache --zero-stats pip install -v -e '.[tests]' - name: Clean up after an unsuccessful build if: ${{ !success() && steps.amd-install-triton.outcome != 'success' }} run: | rm -rf ~/.triton + - name: CCache Stats + run: ccache --print-stats - name: Run lit tests run: | cd python @@ -431,6 +463,13 @@ jobs: cd python cd "build/$(ls build | grep -i cmake)" ctest -j32 + - name: Inspect cache directories + run: | + mkdir -p ~/.triton + du -h -d 1 ~/.triton + + mkdir -p ~/.ccache + du -h -d 1 ~/.ccache - # If we're on branch `main`, save the ccache Triton compilation artifacts # to the cache so they can be used by other (non-main) CI runs. # @@ -440,17 +479,10 @@ jobs: if: github.ref == 'refs/heads/main' uses: actions/cache/save@v4 with: - path: ~/.triton/cache ~/.cache/ccache - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directories - run: | - mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** - - mkdir -p ~/.cache/ccache - ls -alh ~/.cache/ccache - du -sh ~/.cache/ccache + path: | + ~/.triton/cache + ~/.ccache + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - name: Clean up caches run: | rm -rf ~/.triton/cache @@ -458,10 +490,12 @@ jobs: needs: Runner-Preparation if: needs.Runner-Preparation.outputs.matrix-MACOS != '' runs-on: ${{ matrix.runner }} - timeout-minutes: 30 + timeout-minutes: 40 strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-MACOS)}} + env: + RUNNER_TYPE: ${{ matrix.runner[0] }} steps: - name: Checkout uses: actions/checkout@v4 @@ -470,7 +504,7 @@ jobs: - name: Install brew dependencies run: | brew update - brew install ccache llvm@19 lld + brew install ccache llvm@19 lld coreutils - name: Compute cache keys id: cache-key run: | @@ -511,22 +545,28 @@ jobs: # "restore" step. This is to prevent the caches from accumulating stale # files over time. name: Restore cache of ccache and Triton compilation artifacts - if: github.event_name != 'push' + id: restore-build-cache + if: github.ref != 'refs/heads/main' uses: actions/cache/restore@v4 with: path: | ~/.triton/cache - ~/.cache/ccache + ~/.ccache # Restore the most recent cache entry. - restore-keys: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}- + restore-keys: | + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}- + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}- # We expect this cache key never to hit and for us to fall back # unconditionally to the restore-key, so it doesn't actually matter # what we put here (so long as it doesn't hit an existing key). - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directory + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} + - name: Inspect cache directories run: | mkdir -p ~/.triton - ls -alh ~/.triton + du -h -d 1 ~/.triton + + mkdir -p ~/.ccache + du -h -d 1 ~/.ccache - name: Update PATH run: | echo "$HOME/.local/bin" >> $GITHUB_PATH @@ -539,7 +579,6 @@ jobs: python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-xdist lit pybind11 - name: Install Triton env: - TRITON_BUILD_WITH_CCACHE: "true" TRITON_BUILD_WITH_O1: "true" # macos-latest has 3 vcpus and 7GB DRAM, to save memory we limit the number of jobs to 3 # https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories @@ -548,7 +587,17 @@ jobs: source ~/.venv/bin/activate echo "PATH is '$PATH'" cd python - python3 -m pip install --no-build-isolation . + ccache --zero-stats + python3 -m pip install -v --no-build-isolation . + - name: CCache Stats + run: ccache --print-stats + - name: Inspect cache directories + run: | + mkdir -p ~/.triton + du -h -d 1 ~/.triton + + mkdir -p ~/.ccache + du -h -d 1 ~/.ccache - # If we're on branch `main`, save the ccache Triton compilation artifacts # to the cache so they can be used by other (non-main) CI runs. # @@ -558,14 +607,7 @@ jobs: if: github.ref == 'refs/heads/main' uses: actions/cache/save@v4 with: - path: ~/.triton/cache ~/.cache/ccache - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directories - run: | - mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** - - mkdir -p ~/.cache/ccache - ls -alh ~/.cache/ccache - du -sh ~/.cache/ccache + path: | + ~/.triton/cache + ~/.ccache + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index 7da4aa0793..7de7264272 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -23,10 +23,12 @@ concurrency: permissions: read-all env: + TRITON_BUILD_WITH_CCACHE: "true" TRITON_BUILD_WITH_CLANG_LLD: "TRUE" TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE" TRITON_DISABLE_LINE_INFO: 1 PROTON_SKIP_PC_SAMPLING_TEST: 1 + CCACHE_COMPRESS: "true" jobs: Runner-Preparation: @@ -43,6 +45,12 @@ jobs: run: | echo "enable_integration=true" >> $GITHUB_ENV + - name: Decide manual trigger integration test enablement + # Always enable integration tests when manually triggered + if: github.event_name == 'workflow_dispatch' + run: | + echo "enable_integration=true" >> $GITHUB_ENV + - name: Checkout post-submit commits if: github.event_name == 'push' uses: actions/checkout@v4 @@ -174,6 +182,9 @@ jobs: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-CUDA)}} + env: + RUNNER_TYPE: ${{ matrix.runner[0] }} + steps: - name: Checkout uses: actions/checkout@v4 @@ -225,24 +236,30 @@ jobs: # files over time. - &restore-build-artifacts-step name: Restore cache of ccache and Triton compilation artifacts - if: github.event_name != 'push' + id: restore-build-cache + if: github.ref != 'refs/heads/main' uses: actions/cache/restore@v4 with: path: | ~/.triton/cache - ~/.cache/ccache + ~/.ccache # Restore the most recent cache entry. - restore-keys: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}- + restore-keys: | + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}- + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}- # We expect this cache key never to hit and for us to fall back # unconditionally to the restore-key, so it doesn't actually matter # what we put here (so long as it doesn't hit an existing key). - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - &inspect-cache-directory-step - name: Inspect cache directory + - &inspect-cache-directories-step + name: Inspect cache directories run: | mkdir -p ~/.triton - ls -alh ~/.triton + du -h -d 1 ~/.triton + + mkdir -p ~/.ccache + du -h -d 1 ~/.ccache - name: Update PATH run: | @@ -255,12 +272,16 @@ jobs: - name: Install Triton env: - TRITON_BUILD_WITH_CCACHE: "true" CUDA_HOME: "/usr/local/cuda" run: | echo "PATH is '$PATH'" cd python - python3 -m pip install '.[tests]' + ccache --zero-stats + python3 -m pip install -v '.[tests]' + + - &print-ccache-stats + name: CCache Stats + run: ccache --print-stats - &run-lit-tests-step name: Run lit tests @@ -319,6 +340,8 @@ jobs: python3 -m pytest -s . cd .. + - *inspect-cache-directories-step + # If we're on branch `main`, save the ccache Triton compilation artifacts # to the cache so they can be used by other (non-main) CI runs. # @@ -329,19 +352,10 @@ jobs: if: github.ref == 'refs/heads/main' uses: actions/cache/save@v4 with: - path: ~/.triton/cache ~/.cache/ccache - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - - &inspect-cache-directories-step - name: Inspect cache directories - run: | - mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** - - mkdir -p ~/.cache/ccache - ls -alh ~/.cache/ccache - du -sh ~/.cache/ccache + path: | + ~/.triton/cache + ~/.ccache + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} Integration-Tests-AMD: needs: Runner-Preparation @@ -350,6 +364,9 @@ jobs: runs-on: ${{ matrix.runner }} timeout-minutes: 30 + env: + RUNNER_TYPE: ${{ matrix.runner[1] }} + strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-HIP)}} @@ -369,7 +386,7 @@ jobs: - *compute-cache-keys-step - *cache-build-dependencies-step - *restore-build-artifacts-step - - *inspect-cache-directory-step + - *inspect-cache-directories-step - name: Update PATH run: | @@ -380,12 +397,18 @@ jobs: python3 -m pip install --upgrade pip python3 -m pip install lit + - name: Install apt dependencies + run: | + apt update + apt install ccache + - name: Install Triton id: amd-install-triton run: | echo "PATH is '$PATH'" pip uninstall -y triton cd python + ccache --zero-stats pip install -v -e '.[tests]' - name: Clean up after an unsuccessful build @@ -393,6 +416,7 @@ jobs: run: | rm -rf ~/.triton + - *print-ccache-stats - *run-lit-tests-step - name: Run python tests on HIP @@ -423,8 +447,8 @@ jobs: - *run-proton-tests-step - *run-cpp-unittests-step - - *save-build-artifacts-step - *inspect-cache-directories-step + - *save-build-artifacts-step - name: Clean up caches run: | @@ -434,10 +458,14 @@ jobs: needs: Runner-Preparation if: needs.Runner-Preparation.outputs.matrix-MACOS != '' runs-on: ${{ matrix.runner }} - timeout-minutes: 30 + timeout-minutes: 40 strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-MACOS)}} + + env: + RUNNER_TYPE: ${{ matrix.runner[0] }} + steps: - name: Checkout uses: actions/checkout@v4 @@ -446,12 +474,12 @@ jobs: - name: Install brew dependencies run: | brew update - brew install ccache llvm@19 lld + brew install ccache llvm@19 lld coreutils - *compute-cache-keys-step - *cache-build-dependencies-step - *restore-build-artifacts-step - - *inspect-cache-directory-step + - *inspect-cache-directories-step - name: Update PATH run: | @@ -465,7 +493,6 @@ jobs: python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-xdist lit pybind11 - name: Install Triton env: - TRITON_BUILD_WITH_CCACHE: "true" TRITON_BUILD_WITH_O1: "true" # macos-latest has 3 vcpus and 7GB DRAM, to save memory we limit the number of jobs to 3 # https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories @@ -474,7 +501,9 @@ jobs: source ~/.venv/bin/activate echo "PATH is '$PATH'" cd python - python3 -m pip install --no-build-isolation . + ccache --zero-stats + python3 -m pip install -v --no-build-isolation . - - *save-build-artifacts-step + - *print-ccache-stats - *inspect-cache-directories-step + - *save-build-artifacts-step diff --git a/CMakeLists.txt b/CMakeLists.txt index aa9bd605c9..e4d16d4f9d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,6 +45,13 @@ if(TRITON_BUILD_WITH_CCACHE) endif() endif() +set(TRITON_PARALLEL_LINK_JOBS "" CACHE STRING + "Define the maximum number of concurrent link jobs (Ninja only).") +if (TRITON_PARALLEL_LINK_JOBS) + set_property(GLOBAL APPEND PROPERTY JOB_POOLS link_job_pool=${TRITON_PARALLEL_LINK_JOBS}) + set(CMAKE_JOB_POOL_LINK link_job_pool) +endif() + # Ensure Python3 vars are set correctly # used conditionally in this file and by lit tests @@ -226,6 +233,9 @@ if(TRITON_BUILD_PYTHON_MODULE) if (TRITON_BUILD_PROTON) add_subdirectory(third_party/proton) endif() + # We always build proton dialect + list(APPEND TRITON_PLUGIN_NAMES "proton") + add_subdirectory(third_party/proton/dialect) get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS) @@ -334,6 +344,7 @@ if(NOT TRITON_BUILD_PYTHON_MODULE) foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS}) add_subdirectory(third_party/${CODEGEN_BACKEND}) endforeach() + add_subdirectory(third_party/proton/dialect) endif() if(WIN32) option(CMAKE_USE_WIN32_THREADS_INIT "using WIN32 threads" ON) diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index b103adeaba..a59956af5c 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -12,6 +12,7 @@ #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "amd/include/TritonAMDGPUTransforms/Passes.h" #include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" @@ -93,14 +94,15 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::registerTritonAMDGPULowerInstructionSchedHints(); // TODO: register Triton & TritonGPU passes - registry.insert(); + registry + .insert(); } diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 8e8b089549..2d06980809 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -374,24 +374,24 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // TODO (Keren): Currently, we handle general mma/blocked/slice/dot(ampere) // -> mma/blocked/slice/dot(ampere) conversions. The following tasks must be // completed before we can remove the layoutIsOK check: - // 1. Support for AMD's MFMA and WMMA + // 1. Support for AMD's WMMA std::function layoutIsOK = [&](Attribute layout) { - if (auto nvidiaMma = dyn_cast(layout)) { - if (useLegacyMMAConversion) { - return false; - } - return true; + if (isa(layout)) { + return !useLegacyMMAConversion; } if (auto dotOperand = dyn_cast(layout)) { - if (auto nvidiaMma = - dyn_cast(dotOperand.getParent())) { - if (useLegacyMMAConversion) { - return false; - } + auto parent = dotOperand.getParent(); + if (isa(parent) && useLegacyMMAConversion) { + return false; + } + if (auto nvidiaMma = dyn_cast(parent)) { if (nvidiaMma.isAmpere()) { return true; } } + if (isa(parent)) { + return true; + } return false; } if (isa(layout)) { diff --git a/python/setup.py b/python/setup.py index 65388d8664..1e6dee4cf6 100644 --- a/python/setup.py +++ b/python/setup.py @@ -523,6 +523,7 @@ def build_extension(self, ext): "TRITON_BUILD_PROTON", "TRITON_BUILD_TUTORIALS", "TRITON_BUILD_WITH_CCACHE", + "TRITON_PARALLEL_LINK_JOBS", ] cmake_args += [f"-D{option}={os.getenv(option)}" for option in passthrough_args if option in os.environ] diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index 83c9e535d8..a2c8f48718 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -7,6 +7,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.func public @shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { // CHECK-NOT: store // CHECK-NOT: load + // CHECK: llvm.return %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> tt.return } @@ -21,6 +22,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.func public @no_shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { // CHECK: store // CHECK: load + // CHECK: llvm.return %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> tt.return } diff --git a/test/Proton/ops.mlir b/test/Proton/ops.mlir new file mode 100644 index 0000000000..22a17e3f0f --- /dev/null +++ b/test/Proton/ops.mlir @@ -0,0 +1,15 @@ +// RUN: triton-opt --split-input-file %s -cse -canonicalize | FileCheck %s + +module { + // CHECK-LABEL: proton_record + tt.func @proton_record() { + // CHECK: proton.record() {isStart = true, regionId = 1 : i32} + // CHECK-NEXT: proton.record() {isStart = false, regionId = 1 : i32} + // CHECK-NEXT: tt.return + proton.record() {isStart = true, regionId = 1 : i32} + proton.record() {isStart = false, regionId = 1 : i32} + tt.return + } +} // end module + +// ----- diff --git a/third_party/proton/dialect/CMakeLists.txt b/third_party/proton/dialect/CMakeLists.txt new file mode 100644 index 0000000000..c7b5413a0e --- /dev/null +++ b/third_party/proton/dialect/CMakeLists.txt @@ -0,0 +1,7 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +add_subdirectory(include) +add_subdirectory(lib) +if(TRITON_BUILD_PYTHON_MODULE) + add_triton_plugin(TritonProton ${CMAKE_CURRENT_SOURCE_DIR}/triton_proton.cc LINK_LIBS ProtonIR) +endif() diff --git a/third_party/proton/dialect/include/CMakeLists.txt b/third_party/proton/dialect/include/CMakeLists.txt new file mode 100644 index 0000000000..0ca0f41c5a --- /dev/null +++ b/third_party/proton/dialect/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect) diff --git a/third_party/proton/dialect/include/Dialect/CMakeLists.txt b/third_party/proton/dialect/include/Dialect/CMakeLists.txt new file mode 100644 index 0000000000..f18c30ba1a --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Proton) diff --git a/third_party/proton/dialect/include/Dialect/Proton/CMakeLists.txt b/third_party/proton/dialect/include/Dialect/Proton/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/CMakeLists.txt b/third_party/proton/dialect/include/Dialect/Proton/IR/CMakeLists.txt new file mode 100644 index 0000000000..4645b0ebcd --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/CMakeLists.txt @@ -0,0 +1,18 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS ProtonOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=proton) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=proton) +mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(ProtonDialect ProtonDialect dialects/ -gen-dialect-doc) +add_mlir_doc(ProtonOps ProtonOps dialects/ -gen-op-doc) +add_public_tablegen_target(ProtonTableGen) + +set(LLVM_TARGET_DEFINITIONS ProtonAttrDefs.td) +mlir_tablegen(ProtonAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(ProtonAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(ProtonAttrDefsIncGen) diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h b/third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h new file mode 100644 index 0000000000..680a205f08 --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h @@ -0,0 +1,23 @@ +#ifndef TRITON_DIALECT_PROTON_IR_DIALECT_H_ +#define TRITON_DIALECT_PROTON_IR_DIALECT_H_ + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/PatternMatch.h" +#include "proton/dialect/include/Dialect/Proton/IR/Dialect.h.inc" +#include "proton/dialect/include/Dialect/Proton/IR/OpsEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.h.inc" + +#define GET_OP_CLASSES +#include "proton/dialect/include/Dialect/Proton/IR/Ops.h.inc" + +namespace mlir { +namespace triton { +namespace proton {} // namespace proton +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_PROTON_IR_DIALECT_H_ diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td new file mode 100644 index 0000000000..d469fbb35f --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td @@ -0,0 +1,12 @@ +#ifndef PROTON_ATTRDEFS +#define PROTON_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "ProtonDialect.td" + +class Proton_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { +} + +#endif // PROTON_ATTRDEFS diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonDialect.td b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonDialect.td new file mode 100644 index 0000000000..245f2e09a2 --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonDialect.td @@ -0,0 +1,18 @@ +#ifndef PROTON_DIALECT +#define PROTON_DIALECT + +include "mlir/IR/OpBase.td" + +def Proton_Dialect : Dialect { + let name = "proton"; + let cppNamespace = "::mlir::triton::proton"; + + let description = [{ + Proton Dialect provides core ops for building third-party compiler-based + performance profiling and analysis tools. + }]; + + let dependentDialects = []; +} + +#endif diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td new file mode 100644 index 0000000000..d18a48d5d1 --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td @@ -0,0 +1,65 @@ +#ifndef PROTON_OPS +#define PROTON_OPS + +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "ProtonDialect.td" +include "ProtonAttrDefs.td" + +class TT_Proton_Op traits = []> : + Op { +} + +// Proton profiling metric. +def MetricAttr : I32EnumAttr< + "Metric", "", + [ + I32EnumAttrCase<"CYCLE", 0, "cycle">, + ]> { + let cppNamespace = "::mlir::triton::proton"; +} + +// Proton profiling granularity. +def GranularityAttr : I32EnumAttr< + "Granularity", "", + [ + I32EnumAttrCase<"WARPGROUP", 0, "warpgroup">, + I32EnumAttrCase<"WARP", 1, "warp">, + ]> { + let cppNamespace = "::mlir::triton::proton"; +} + +def TT_RecordOp : TT_Proton_Op<"record", [DeclareOpInterfaceMethods]> { + let summary = "Record a GPU hardware event"; + + let description = [{ + The operator records GPU events from performance counters. + Currently only cycle counter is supported. + + Example: + + ```mlir + proton.record() {isStart = true, regionId = 4 : i32} + ... + proton.record() {isStart = false, regionId = 4 : i32} + ... + proton.record() {isStart = true, regionId = 1 : i32, granularity = 1 : i32} + ... + proton.record() {isStart = false, regionId = 1 : i32, granularity = 1 : i32} + ``` + }]; + let arguments = ( + ins BoolAttr: $isStart, + ConfinedAttr:$regionId, + DefaultValuedAttr:$metric, + DefaultValuedAttr:$granularity + ); + let assemblyFormat = " `(` operands `)` attr-dict"; +} + +#endif // PROTON_OPS diff --git a/third_party/proton/dialect/lib/CMakeLists.txt b/third_party/proton/dialect/lib/CMakeLists.txt new file mode 100644 index 0000000000..0ca0f41c5a --- /dev/null +++ b/third_party/proton/dialect/lib/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect) diff --git a/third_party/proton/dialect/lib/Dialect/CMakeLists.txt b/third_party/proton/dialect/lib/Dialect/CMakeLists.txt new file mode 100644 index 0000000000..f18c30ba1a --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Proton) diff --git a/third_party/proton/dialect/lib/Dialect/Proton/CMakeLists.txt b/third_party/proton/dialect/lib/Dialect/Proton/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/Proton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt b/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt new file mode 100644 index 0000000000..5eea5cb3cf --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt @@ -0,0 +1,13 @@ +add_triton_library(ProtonIR + Dialect.cpp + Ops.cpp + + DEPENDS + ProtonTableGen + ProtonAttrDefsIncGen + + LINK_LIBS PUBLIC + MLIRLLVMDialect + TritonIR + TritonGPUIR +) diff --git a/third_party/proton/dialect/lib/Dialect/Proton/IR/Dialect.cpp b/third_party/proton/dialect/lib/Dialect/Proton/IR/Dialect.cpp new file mode 100644 index 0000000000..60c2852654 --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/Proton/IR/Dialect.cpp @@ -0,0 +1,25 @@ +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" + +// clang-format off +#include "Dialect/Proton/IR/Dialect.h" +#include "Dialect/Proton/IR/Dialect.cpp.inc" +// clang-format on + +using namespace mlir; +using namespace mlir::triton::proton; + +void mlir::triton::proton::ProtonDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "Dialect/Proton/IR/ProtonAttrDefs.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "Dialect/Proton/IR/Ops.cpp.inc" + >(); +} + +#define GET_ATTRDEF_CLASSES +#include "Dialect/Proton/IR/ProtonAttrDefs.cpp.inc" diff --git a/third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp b/third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp new file mode 100644 index 0000000000..1a0799aea1 --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp @@ -0,0 +1,33 @@ +#include "Dialect/Proton/IR/Dialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +#define GET_OP_CLASSES +#include "Dialect/Proton/IR/Ops.cpp.inc" +#include "Dialect/Proton/IR/OpsEnums.cpp.inc" + +namespace mlir { +namespace triton { +namespace proton { + +// -- RecordOp -- +void RecordOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +} // namespace proton +} // namespace triton +} // namespace mlir diff --git a/third_party/proton/dialect/triton_proton.cc b/third_party/proton/dialect/triton_proton.cc new file mode 100644 index 0000000000..8046539794 --- /dev/null +++ b/third_party/proton/dialect/triton_proton.cc @@ -0,0 +1,20 @@ +#include "Dialect/Proton/IR/Dialect.h" +#include "mlir/Pass/PassManager.h" +#include "passes.h" +#include +#include +#include + +namespace py = pybind11; + +void init_triton_proton(py::module &&m) { + auto passes = m.def_submodule("passes"); + + // load dialects + m.def("load_dialects", [](mlir::MLIRContext &context) { + mlir::DialectRegistry registry; + registry.insert(); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); +}