diff --git a/.dep-versions b/.dep-versions index bf7a56edbb..eb2ff7df75 100644 --- a/.dep-versions +++ b/.dep-versions @@ -2,7 +2,7 @@ # To update JAX version alongside compatible dependency tags, run the following script: # python3 .github/workflows/set_dep_versions.py {JAX_version} jax=0.6.2 -mhlo=1dd2e71331014ae0373f6bf900ce6be393357190 +stablehlo=69d6dae46e1c7de36e6e6973654754f05353cba5 llvm=f8cb7987c64dcffb72414a40560055cb717dbf74 enzyme=v0.0.186 diff --git a/.github/workflows/build-wheel-linux-arm64.yaml b/.github/workflows/build-wheel-linux-arm64.yaml index 4b83d7ae38..b79f6c164c 100644 --- a/.github/workflows/build-wheel-linux-arm64.yaml +++ b/.github/workflows/build-wheel-linux-arm64.yaml @@ -84,12 +84,12 @@ jobs: key: llvm-${{ needs.constants.outputs.llvm_version }}-container-source enableCrossOsArchive: True - - name: Cache MHLO Source - id: cache-mhlo-source + - name: Cache Stablehlo Source + id: cache-stablehlo-source uses: actions/cache@v4 with: - path: ${{ github.workspace }}/mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-container-source + path: ${{ github.workspace }}/mlir/stablehlo + key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-container-source enableCrossOsArchive: True - name: Cache Enzyme Source @@ -109,25 +109,18 @@ jobs: path: ${{ github.workspace }}/mlir/llvm-project - name: Patch LLVM Source - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' + if: steps.cache-llvm-source.outputs.cache-hit != 'true' run: | cd $GITHUB_WORKSPACE/mlir/llvm-project git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch - - name: Clone MHLO Submodule - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' + - name: Clone Stablehlo Submodule + if: steps.cache-stablehlo-source.outputs.cache-hit != 'true' uses: actions/checkout@v4 with: - repository: tensorflow/mlir-hlo - ref: ${{ needs.constants.outputs.mhlo_version }} - path: ${{ github.workspace }}/mlir/mlir-hlo - - - name: Patch MHLO Source - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' - run: | - cd $GITHUB_WORKSPACE/mlir/mlir-hlo - git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch - git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-rename-sort.patch + repository: openxla/stablehlo + ref: ${{ needs.constants.outputs.stablehlo_version }} + path: ${{ github.workspace }}/mlir/stablehlo - name: Clone Enzyme Submodule if: steps.cache-enzyme-source.outputs.cache-hit != 'true' @@ -151,12 +144,12 @@ jobs: path: ${{ github.workspace }}/llvm-build key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{ matrix.python_version }}-wheel-build - - name: Check MHLO Build Cache - id: cache-mhlo-build + - name: Check Stablehlo Build Cache + id: cache-stablehlo-build uses: actions/cache/restore@v4 with: - path: ${{ github.workspace }}/mhlo-build - key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + path: ${{ github.workspace }}/stablehlo-build + key: ${{ matrix.container_img }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build lookup-only: True - name: Check Enzyme Build Cache @@ -170,7 +163,7 @@ jobs: - name: Install dependencies if: | steps.cache-llvm-build.outputs.cache-hit != 'true' || - steps.cache-mhlo-build.outputs.cache-hit != 'true' || + steps.cache-stablehlo-build.outputs.cache-hit != 'true' || steps.cache-enzyme-build.outputs.cache-hit != 'true' run: | cat /etc/dnf.conf | sed "s/\[main\]/\[main\]\ntimeout=5/g" > /etc/dnf.conf @@ -207,31 +200,24 @@ jobs: path: ${{ github.workspace }}/llvm-build key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{matrix.python_version}}-wheel-build - - name: Build MHLO Dialect - if: steps.cache-mhlo-build.outputs.cache-hit != 'true' - # building with LLD is a strong requirement for mhlo + - name: Build Stablehlo Dialect + if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' run: | - export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH - - cmake -S mlir/mlir-hlo -B $GITHUB_WORKSPACE/mhlo-build -G Ninja \ - -DCMAKE_BUILD_TYPE=Release \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ - -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ - -DLLVM_ENABLE_ZLIB=FORCE_ON \ - -DLLVM_ENABLE_ZSTD=OFF \ - -DCMAKE_CXX_VISIBILITY_PRESET=default \ - -DLLVM_ENABLE_LLD=ON - - LIT_FILTER_OUT="chlo_legalize_to_mhlo" cmake --build $GITHUB_WORKSPACE/mhlo-build --target check-mlir-hlo + C_COMPILER=$(which gcc) \ + CXX_COMPILER=$(which g++) \ + LLVM_BUILD_DIR="$(pwd)/llvm-build" \ + STABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ + COMPILER_LAUNCHER="" \ + ENABLE_LLD=OFF \ + make stablehlo - - name: Save MHLO Build - id: save-mhlo-build - if: steps.cache-mhlo-build.outputs.cache-hit != 'true' + - name: Save Stablehlo Build + id: save-stablehlo-build + if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' uses: actions/cache/save@v4 with: - path: ${{ github.workspace }}/mhlo-build - key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + path: ${{ github.workspace }}/stablehlo-build + key: ${{ matrix.container_img }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build - name: Build Enzyme if: steps.cache-enzyme-build.outputs.cache-hit != 'true' @@ -305,21 +291,21 @@ jobs: key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-3.11-wheel-build fail-on-cache-miss: True - - name: Get Cached MHLO Source - id: cache-mhlo-source + - name: Get Cached Stablehlo Source + id: cache-stablehlo-source uses: actions/cache/restore@v4 with: - path: ${{ github.workspace }}/mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-container-source + path: ${{ github.workspace }}/mlir/stablehlo + key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-container-source enableCrossOsArchive: True fail-on-cache-miss: True - - name: Get Cached MHLO Build - id: cache-mhlo-build + - name: Get Cached Stablehlo Build + id: cache-stablehlo-build uses: actions/cache/restore@v4 with: - path: ${{ github.workspace }}/mhlo-build - key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + path: ${{ github.workspace }}/stablehlo-build + key: ${{ matrix.container_img }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build fail-on-cache-miss: True - name: Get Cached Enzyme Source @@ -372,8 +358,8 @@ jobs: -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ -DPython3_NumPy_INCLUDE_DIRS=$(python${{ matrix.python_version }} -c "import numpy as np; print(np.get_include())") \ -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ - -DMHLO_DIR="$GITHUB_WORKSPACE/mhlo-build/lib/cmake/mlir-hlo" \ - -DMHLO_BINARY_DIR="$GITHUB_WORKSPACE/mhlo-build/bin" \ + -DSTABLEHLO_DIR="$GITHUB_WORKSPACE/mlir/stablehlo" \ + -DSTABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ -DEnzyme_DIR="$GITHUB_WORKSPACE/enzyme-build" \ -DENZYME_SRC_DIR="$GITHUB_WORKSPACE/mlir/Enzyme" \ -DLLVM_ENABLE_ZLIB=FORCE_ON \ @@ -396,7 +382,7 @@ jobs: run: | PYTHON=python${{ matrix.python_version }} \ LLVM_BUILD_DIR="$GITHUB_WORKSPACE/llvm-build" \ - MHLO_BUILD_DIR="$GITHUB_WORKSPACE/mhlo-build" \ + STABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ DIALECTS_BUILD_DIR="$GITHUB_WORKSPACE/quantum-build" \ RT_BUILD_DIR="$GITHUB_WORKSPACE/runtime-build" \ OQC_BUILD_DIR="$GITHUB_WORKSPACE/oqc-build" \ diff --git a/.github/workflows/build-wheel-linux-x86_64.yaml b/.github/workflows/build-wheel-linux-x86_64.yaml index abd4f069ab..d6faa14248 100644 --- a/.github/workflows/build-wheel-linux-x86_64.yaml +++ b/.github/workflows/build-wheel-linux-x86_64.yaml @@ -103,12 +103,12 @@ jobs: key: llvm-${{ needs.constants.outputs.llvm_version }}-container-source enableCrossOsArchive: True - - name: Cache MHLO Source - id: cache-mhlo-source + - name: Cache Stablehlo Source + id: cache-stablehlo-source uses: actions/cache@v4 with: - path: ${{ github.workspace }}/mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-container-source + path: ${{ github.workspace }}/mlir/stablehlo + key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-container-source enableCrossOsArchive: True - name: Cache Enzyme Source @@ -128,25 +128,18 @@ jobs: path: ${{ github.workspace }}/mlir/llvm-project - name: Patch LLVM Source - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' + if: steps.cache-llvm-source.outputs.cache-hit != 'true' run: | cd $GITHUB_WORKSPACE/mlir/llvm-project git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch - - name: Clone MHLO Submodule - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' + - name: Clone Stablehlo Submodule + if: steps.cache-stablehlo-source.outputs.cache-hit != 'true' uses: actions/checkout@v4 with: - repository: tensorflow/mlir-hlo - ref: ${{ needs.constants.outputs.mhlo_version }} - path: ${{ github.workspace }}/mlir/mlir-hlo - - - name: Patch MHLO Source - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' - run: | - cd $GITHUB_WORKSPACE/mlir/mlir-hlo - git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch - git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-rename-sort.patch + repository: openxla/stablehlo + ref: ${{ needs.constants.outputs.stablehlo_version }} + path: ${{ github.workspace }}/mlir/stablehlo - name: Clone Enzyme Submodule if: steps.cache-enzyme-source.outputs.cache-hit != 'true' @@ -170,12 +163,12 @@ jobs: path: ${{ github.workspace }}/llvm-build key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{ matrix.python_version }}-wheel-build - - name: Check MHLO Build Cache - id: cache-mhlo-build + - name: Check Stablehlo Build Cache + id: cache-stablehlo-build uses: actions/cache/restore@v4 with: - path: ${{ github.workspace }}/mhlo-build - key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + path: ${{ github.workspace }}/stablehlo-build + key: ${{ matrix.container_img }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build lookup-only: True - name: Check Enzyme Build Cache @@ -189,7 +182,7 @@ jobs: - name: Install dependencies (AlmaLinux) if: | steps.cache-llvm-build.outputs.cache-hit != 'true' || - steps.cache-mhlo-build.outputs.cache-hit != 'true' || + steps.cache-stablehlo-build.outputs.cache-hit != 'true' || steps.cache-enzyme-build.outputs.cache-hit != 'true' run: | # Reduce wait time for repos not responding @@ -205,7 +198,6 @@ jobs: PYTHON_BINS=$(find /opt/_internal/cpython-${{ matrix.python_version }}.*/bin -maxdepth 1 -type d | tr '\n' ':' | sed 's/:$//') echo $PYTHON_BINS >> $GITHUB_PATH - # LLD is required for MHLO builds. # (Don't forget to add the build directory to PATH in subsequent steps, so # other tools can find it, in particular collect2 invoked by gcc.) - name: Build LLVM / MLIR @@ -230,31 +222,24 @@ jobs: path: ${{ github.workspace }}/llvm-build key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{matrix.python_version}}-wheel-build - - name: Build MHLO Dialect - if: steps.cache-mhlo-build.outputs.cache-hit != 'true' - # building with LLD is a strong requirement for mhlo + - name: Build Stablehlo Dialect + if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' run: | - export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH - - cmake -S mlir/mlir-hlo -B $GITHUB_WORKSPACE/mhlo-build -G Ninja \ - -DCMAKE_BUILD_TYPE=Release \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ - -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ - -DLLVM_ENABLE_ZLIB=FORCE_ON \ - -DLLVM_ENABLE_ZSTD=OFF \ - -DCMAKE_CXX_VISIBILITY_PRESET=default \ - -DLLVM_ENABLE_LLD=ON - - LIT_FILTER_OUT="chlo_legalize_to_mhlo" cmake --build $GITHUB_WORKSPACE/mhlo-build --target check-mlir-hlo + C_COMPILER=$(which gcc) \ + CXX_COMPILER=$(which g++) \ + LLVM_BUILD_DIR="$(pwd)/llvm-build" \ + STABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ + COMPILER_LAUNCHER="" \ + ENABLE_LLD=OFF \ + make stablehlo - - name: Save MHLO Build - id: save-mhlo-build - if: steps.cache-mhlo-build.outputs.cache-hit != 'true' + - name: Save Stablehlo Build + id: save-stablehlo-build + if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' uses: actions/cache/save@v4 with: - path: ${{ github.workspace }}/mhlo-build - key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + path: ${{ github.workspace }}/stablehlo-build + key: ${{ matrix.container_img }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build - name: Build Enzyme if: steps.cache-enzyme-build.outputs.cache-hit != 'true' @@ -328,21 +313,21 @@ jobs: key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-3.11-wheel-build fail-on-cache-miss: True - - name: Get Cached MHLO Source - id: cache-mhlo-source + - name: Get Cached Stablehlo Source + id: cache-stablehlo-source uses: actions/cache/restore@v4 with: - path: ${{ github.workspace }}/mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-container-source + path: ${{ github.workspace }}/mlir/stablehlo + key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-container-source enableCrossOsArchive: True fail-on-cache-miss: True - - name: Get Cached MHLO Build - id: cache-mhlo-build + - name: Get Cached Stablehlo Build + id: cache-stablehlo-build uses: actions/cache/restore@v4 with: - path: ${{ github.workspace }}/mhlo-build - key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + path: ${{ github.workspace }}/stablehlo-build + key: ${{ matrix.container_img }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build fail-on-cache-miss: True - name: Get Cached Enzyme Source @@ -397,8 +382,8 @@ jobs: -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ -DPython3_NumPy_INCLUDE_DIRS=$(python${{ matrix.python_version }} -c "import numpy as np; print(np.get_include())") \ -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ - -DMHLO_DIR="$GITHUB_WORKSPACE/mhlo-build/lib/cmake/mlir-hlo" \ - -DMHLO_BINARY_DIR="$GITHUB_WORKSPACE/mhlo-build/bin" \ + -DSTABLEHLO_DIR="$GITHUB_WORKSPACE/mlir/stablehlo" \ + -DSTABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ -DEnzyme_DIR="$GITHUB_WORKSPACE/enzyme-build" \ -DENZYME_SRC_DIR="$GITHUB_WORKSPACE/mlir/Enzyme" \ -DLLVM_ENABLE_ZLIB=FORCE_ON \ @@ -421,7 +406,7 @@ jobs: run: | PYTHON=python${{ matrix.python_version }} \ LLVM_BUILD_DIR="$GITHUB_WORKSPACE/llvm-build" \ - MHLO_BUILD_DIR="$GITHUB_WORKSPACE/mhlo-build" \ + STABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ DIALECTS_BUILD_DIR="$GITHUB_WORKSPACE/quantum-build" \ RT_BUILD_DIR="$GITHUB_WORKSPACE/runtime-build" \ OQC_BUILD_DIR="$GITHUB_WORKSPACE/oqc-build" \ diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml index 784c47b4ff..b189279b32 100644 --- a/.github/workflows/build-wheel-macos-arm64.yaml +++ b/.github/workflows/build-wheel-macos-arm64.yaml @@ -89,12 +89,12 @@ jobs: key: llvm-${{ needs.constants.outputs.llvm_version }}-default-source enableCrossOsArchive: True - - name: Cache MHLO Source - id: cache-mhlo-source + - name: Cache Stablehlo Source + id: cache-stablehlo-source uses: actions/cache@v4 with: - path: ${{ github.workspace }}/mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-default-source + path: ${{ github.workspace }}/mlir/stablehlo + key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-container-source enableCrossOsArchive: True - name: Cache Enzyme Source @@ -114,25 +114,18 @@ jobs: path: ${{ github.workspace }}/mlir/llvm-project - name: Patch LLVM Source - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' + if: steps.cache-llvm-source.outputs.cache-hit != 'true' run: | cd $GITHUB_WORKSPACE/mlir/llvm-project git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch - - name: Clone MHLO Submodule - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' + - name: Clone Stablehlo Submodule + if: steps.cache-stablehlo-source.outputs.cache-hit != 'true' uses: actions/checkout@v4 with: - repository: tensorflow/mlir-hlo - ref: ${{ needs.constants.outputs.mhlo_version }} - path: ${{ github.workspace }}/mlir/mlir-hlo - - - name: Patch MHLO Source - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' - run: | - cd $GITHUB_WORKSPACE/mlir/mlir-hlo - git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch - git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-rename-sort.patch + repository: openxla/stablehlo + ref: ${{ needs.constants.outputs.stablehlo_version }} + path: ${{ github.workspace }}/mlir/stablehlo - name: Clone Enzyme Submodule if: steps.cache-enzyme-source.outputs.cache-hit != 'true' @@ -156,12 +149,12 @@ jobs: path: ${{ github.workspace }}/llvm-build key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{ matrix.python_version }}-wheel-build - - name: Check MHLO Build Cache - id: cache-mhlo-build + - name: Check Stablehlo Build Cache + id: cache-stablehlo-build uses: actions/cache/restore@v4 with: - path: ${{ github.workspace }}/mhlo-build - key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + path: ${{ github.workspace }}/stablehlo-build + key: ${{ runner.os }}-${{ runner.arch }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build lookup-only: True - name: Check Enzyme Build Cache @@ -200,30 +193,21 @@ jobs: path: ${{ github.workspace }}/llvm-build key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{ matrix.python_version }}-wheel-build - - name: Build MHLO Dialect - if: steps.cache-mhlo-build.outputs.cache-hit != 'true' + - name: Build Stablehlo Dialect + if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' run: | - export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH - - cmake -S mlir/mlir-hlo -B $GITHUB_WORKSPACE/mhlo-build -G Ninja \ - -DCMAKE_BUILD_TYPE=Release \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ - -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ - -DLLVM_ENABLE_LLD=OFF \ - -DLLVM_ENABLE_ZLIB=FORCE_ON \ - -DLLVM_ENABLE_ZSTD=OFF \ - -DCMAKE_CXX_VISIBILITY_PRESET=default - - cmake --build $GITHUB_WORKSPACE/mhlo-build --target check-mlir-hlo - - - name: Save MHLO Build - id: save-mhlo-build - if: steps.cache-mhlo-build.outputs.cache-hit != 'true' + LLVM_BUILD_DIR="$(pwd)/llvm-build" \ + STABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ + COMPILER_LAUNCHER="" \ + make stablehlo + + - name: Save Stablehlo Build + id: save-stablehlo-build + if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' uses: actions/cache/save@v4 with: - path: ${{ github.workspace }}/mhlo-build - key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + path: ${{ github.workspace }}/stablehlo-build + key: ${{ runner.os }}-${{ runner.arch }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build - name: Build Enzyme if: steps.cache-enzyme-build.outputs.cache-hit != 'true' @@ -291,21 +275,21 @@ jobs: key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{ needs.constants.outputs.primary_python_version }}-wheel-build fail-on-cache-miss: True - - name: Get Cached MHLO Source - id: cache-mhlo-source + - name: Get Cached Stablehlo Source + id: cache-stablehlo-source uses: actions/cache/restore@v4 with: - path: ${{ github.workspace }}/mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-default-source + path: ${{ github.workspace }}/mlir/stablehlo + key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-container-source enableCrossOsArchive: True fail-on-cache-miss: True - - name: Get Cached MHLO Build - id: cache-mhlo-build + - name: Get Cached Stablehlo Build + id: cache-stablehlo-build uses: actions/cache/restore@v4 with: - path: ${{ github.workspace }}/mhlo-build - key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + path: ${{ github.workspace }}/stablehlo-build + key: ${{ runner.os }}-${{ runner.arch }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build fail-on-cache-miss: True - name: Get Cached Enzyme Source @@ -377,8 +361,8 @@ jobs: -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ -DPython3_NumPy_INCLUDE_DIRS=$(python${{ matrix.python_version }} -c "import numpy as np; print(np.get_include())") \ -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ - -DMHLO_DIR="$GITHUB_WORKSPACE/mhlo-build/lib/cmake/mlir-hlo" \ - -DMHLO_BINARY_DIR="$GITHUB_WORKSPACE/mhlo-build/bin" \ + -DSTABLEHLO_DIR="$GITHUB_WORKSPACE/mlir/stablehlo" \ + -DSTABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ -DEnzyme_DIR="$GITHUB_WORKSPACE/enzyme-build" \ -DENZYME_SRC_DIR="$GITHUB_WORKSPACE/mlir/Enzyme" \ -DLLVM_ENABLE_ZLIB=FORCE_ON \ @@ -401,7 +385,7 @@ jobs: run: | PYTHON=python${{ matrix.python_version }} \ LLVM_BUILD_DIR="$GITHUB_WORKSPACE/llvm-build" \ - MHLO_BUILD_DIR="$GITHUB_WORKSPACE/mhlo-build" \ + STABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ DIALECTS_BUILD_DIR="$GITHUB_WORKSPACE/quantum-build" \ RT_BUILD_DIR="$GITHUB_WORKSPACE/runtime-build" \ OQC_BUILD_DIR="$GITHUB_WORKSPACE/oqc-build" \ diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 89efafb509..d78e130215 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -172,8 +172,8 @@ jobs: COMPILER_LAUNCHER="" \ make llvm - mhlo: - name: MHLO Dialect Build + stablehlo: + name: Stablehlo Dialect Build needs: [constants, llvm, determine_runner] runs-on: ${{ needs.determine_runner.outputs.runner_group }} strategy: @@ -189,32 +189,32 @@ jobs: with: python-version: ${{ needs.constants.outputs.primary_python_version }} - - name: Cache MHLO Source - id: cache-mhlo-source + - name: Cache Stablehlo Source + id: cache-stablehlo-source uses: actions/cache@v4 with: - path: mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-default-source + path: mlir/stablehlo + key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-default-source enableCrossOsArchive: true - - name: Clone MHLO Submodule - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' + - name: Clone Stablehlo Submodule + if: steps.cache-stablehlo-source.outputs.cache-hit != 'true' uses: actions/checkout@v4 with: - repository: tensorflow/mlir-hlo - ref: ${{ needs.constants.outputs.mhlo_version }} - path: mlir/mlir-hlo + repository: openxla/stablehlo + ref: ${{ needs.constants.outputs.stablehlo_version }} + path: mlir/stablehlo - - name: Cache MHLO Build - id: cache-mhlo + - name: Cache Stablehlo Build + id: cache-stablehlo uses: actions/cache@v4 with: - path: mhlo-build - key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-ci-build-${{ matrix.compiler }}-0 + path: stablehlo-build + key: ${{ runner.os }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-ci-build-${{ matrix.compiler }}-0 - name: Get Cached LLVM Source id: cache-llvm-source - if: steps.cache-mhlo.outputs.cache-hit != 'true' + if: steps.cache-stablehlo.outputs.cache-hit != 'true' uses: actions/cache@v4 with: path: mlir/llvm-project @@ -224,7 +224,7 @@ jobs: - name: Get Cached LLVM Build id: cache-llvm-build - if: steps.cache-mhlo.outputs.cache-hit != 'true' + if: steps.cache-stablehlo.outputs.cache-hit != 'true' uses: actions/cache@v4 with: path: llvm-build @@ -232,20 +232,19 @@ jobs: fail-on-cache-miss: true - name: Install Deps - if: steps.cache-mhlo.outputs.cache-hit != 'true' + if: steps.cache-stablehlo.outputs.cache-hit != 'true' run: | sudo apt-get update sudo apt-get install -y cmake ninja-build clang lld - - - name: Build MHLO Dialect - if: steps.cache-mhlo.outputs.cache-hit != 'true' + - name: Build Stablehlo Dialect + if: steps.cache-stablehlo.outputs.cache-hit != 'true' run: | C_COMPILER=$(which ${{ needs.constants.outputs[format('c_compiler.{0}', matrix.compiler)] }}) \ CXX_COMPILER=$(which ${{ needs.constants.outputs[format('cxx_compiler.{0}', matrix.compiler)] }}) \ LLVM_BUILD_DIR="$(pwd)/llvm-build" \ - MHLO_BUILD_DIR="$(pwd)/mhlo-build" \ + STABLEHLO_BUILD_DIR="$(pwd)/stablehlo-build" \ COMPILER_LAUNCHER="" \ - make mhlo + make stablehlo enzyme: name: Enzyme Build @@ -324,7 +323,7 @@ jobs: quantum: name: Quantum Dialects Build - needs: [constants, llvm, mhlo, enzyme, determine_runner] + needs: [constants, llvm, stablehlo, enzyme, determine_runner] runs-on: ${{ needs.determine_runner.outputs.runner_group }} strategy: matrix: @@ -363,21 +362,21 @@ jobs: key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-ci-build-${{ matrix.compiler }} fail-on-cache-miss: true - - name: Get Cached MHLO Source - id: cache-mhlo-source + - name: Get Cached Stablehlo Source + id: cache-stablehlo-source uses: actions/cache/restore@v4 with: - path: mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-default-source + path: mlir/stablehlo + key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-default-source enableCrossOsArchive: true fail-on-cache-miss: true - - name: Get Cached MHLO Build - id: cache-mhlo + - name: Get Cached Stablehlo Build + id: cache-stablehlo uses: actions/cache/restore@v4 with: - path: mhlo-build - key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-ci-build-${{ matrix.compiler }}-0 + path: stablehlo-build + key: ${{ runner.os }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-ci-build-${{ matrix.compiler }}-0 fail-on-cache-miss: true - name: Get Cached Enzyme Source @@ -412,7 +411,7 @@ jobs: C_COMPILER=$(which ${{ needs.constants.outputs[format('c_compiler.{0}', matrix.compiler)] }}) \ CXX_COMPILER=$(which ${{ needs.constants.outputs[format('cxx_compiler.{0}', matrix.compiler)] }}) \ LLVM_BUILD_DIR="$(pwd)/llvm-build" \ - MHLO_BUILD_DIR="$(pwd)/mhlo-build" \ + STABLEHLO_BUILD_DIR="$(pwd)/stablehlo-build" \ ENZYME_BUILD_DIR="$(pwd)/enzyme-build" \ DIALECTS_BUILD_DIR="$(pwd)/quantum-build" \ make dialects diff --git a/.github/workflows/check-jax-release.yaml b/.github/workflows/check-jax-release.yaml index 7e181478c7..37221f4bb4 100644 --- a/.github/workflows/check-jax-release.yaml +++ b/.github/workflows/check-jax-release.yaml @@ -38,7 +38,7 @@ jobs: - name: Re-read versions run: | echo "LLVM_REVISION=$(grep llvm .dep-versions | awk -F '=' '{ print $2 }')" >> $GITHUB_ENV - echo "MHLO_REVISION=$(grep mhlo .dep-versions | awk -F '=' '{ print $2 }')" >> $GITHUB_ENV + echo "STABLEHLO_REVISION=$(grep stablehlo .dep-versions | awk -F '=' '{ print $2 }')" >> $GITHUB_ENV echo "ENZYME_REVISION=$(grep enzyme .dep-versions | awk -F '=' '{ print $2 }')" >> $GITHUB_ENV - name: Clone LLVM repo @@ -48,12 +48,12 @@ jobs: ref: ${{ env.LLVM_REVISION }} path: mlir/llvm-project - - name: Clone MHLO repo + - name: Clone Stablehlo repo uses: actions/checkout@v4 with: - repository: tensorflow/mlir-hlo - ref: ${{ env.MHLO_REVISION }} - path: mlir/mlir-hlo + repository: openxla/stablehlo + ref: ${{ env.STABLEHLO_REVISION }} + path: mlir/stablehlo - name: Clone Enzyme repo uses: actions/checkout@v4 @@ -72,7 +72,7 @@ jobs: - name: Build MHLO run: | - make mhlo + make stablehlo - name: Build Enzyme run: | diff --git a/.github/workflows/check-pl-compat.yaml b/.github/workflows/check-pl-compat.yaml index e26fabbb05..29962195fd 100644 --- a/.github/workflows/check-pl-compat.yaml +++ b/.github/workflows/check-pl-compat.yaml @@ -76,15 +76,15 @@ jobs: - uses: actions/cache/restore@v4 if: ${{ inputs.catalyst != 'stable' }} with: - path: mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-default-source + path: mlir/stablehlo + key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-default-source enableCrossOsArchive: True fail-on-cache-miss: True - uses: actions/cache/restore@v4 if: ${{ inputs.catalyst != 'stable' }} with: - path: mhlo-build - key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-ci-build-gcc + path: stablehlo-build + key: ${{ runner.os }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-ci-build-gcc fail-on-cache-miss: True - uses: actions/cache/restore@v4 if: ${{ inputs.catalyst != 'stable' }} @@ -120,7 +120,7 @@ jobs: ENABLE_LLD=ON \ RT_BUILD_DIR="$(pwd)/runtime-build" \ LLVM_BUILD_DIR="$(pwd)/llvm-build" \ - MHLO_BUILD_DIR="$(pwd)/mhlo-build" \ + STABLEHLO_BUILD_DIR="$(pwd)/stablehlo-build" \ ENZYME_BUILD_DIR="$(pwd)/enzyme-build" \ DIALECTS_BUILD_DIR="$(pwd)/quantum-build" \ ENABLE_OPENQASM=ON \ diff --git a/.github/workflows/constants.yaml b/.github/workflows/constants.yaml index 7e6510517f..52a2ad7dcc 100644 --- a/.github/workflows/constants.yaml +++ b/.github/workflows/constants.yaml @@ -19,9 +19,9 @@ on: llvm_version: description: "LLVM version" value: ${{ jobs.set-constants.outputs.llvm_version }} - mhlo_version: - description: "MHLO version" - value: ${{ jobs.set-constants.outputs.mhlo_version }} + stablehlo_version: + description: "Stablehlo version" + value: ${{ jobs.set-constants.outputs.stablehlo_version }} enzyme_version: description: "Enzyme version" value: ${{ jobs.set-constants.outputs.enzyme_version }} @@ -69,9 +69,9 @@ jobs: id: llvm_version run: echo "llvm_version=$(grep llvm .dep-versions | awk -F '=' '{ print $2 }')" >> $GITHUB_OUTPUT - - name: MHLO version - id: mhlo_version - run: echo "mhlo_version=$(grep mhlo .dep-versions | awk -F '=' '{ print $2 }')" >> $GITHUB_OUTPUT + - name: Stablehlo version + id: stablehlo_version + run: echo "stablehlo_version=$(grep stablehlo .dep-versions | awk -F '=' '{ print $2 }')" >> $GITHUB_OUTPUT - name: Enzyme version id: enzyme_version @@ -113,7 +113,7 @@ jobs: outputs: llvm_version: ${{ steps.llvm_version.outputs.llvm_version }} - mhlo_version: ${{ steps.mhlo_version.outputs.mhlo_version }} + stablehlo_version: ${{ steps.stablehlo_version.outputs.stablehlo_version }} enzyme_version: ${{ steps.enzyme_version.outputs.enzyme_version }} python_versions: ${{ steps.python_versions.outputs.python_versions }} python_test_versions: ${{ steps.python_test_versions.outputs.python_test_versions }} diff --git a/.github/workflows/set_dep_versions.py b/.github/workflows/set_dep_versions.py deleted file mode 100644 index ce9a998c18..0000000000 --- a/.github/workflows/set_dep_versions.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2022-2023 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This module computes commit hashes for LLVM and MLIR-HLO based on a given JAX version. -""" - -# pylint: disable=line-too-long -# pylint: disable=anomalous-backslash-in-string -# pylint: disable=consider-using-with - -import os -import re -import sys - -import requests - -jax_version = sys.argv[1] -dep_versions_path = os.path.join(os.path.dirname(__file__), "../../.dep-versions") -catalyst_init_path = os.path.join(os.path.dirname(__file__), "../../frontend/catalyst/__init__.py") - -assert os.path.isfile(dep_versions_path) -assert os.path.isfile(catalyst_init_path) - -url = f"https://raw.githubusercontent.com/jax-ml/jax/jax-v{jax_version}/WORKSPACE" -response = requests.get(url) -match = re.search(r'strip_prefix = "xla-([a-zA-Z0-9]*)"', response.text) -if not match: - url = f"https://raw.githubusercontent.com/jax-ml/jax/jax-v{jax_version}/third_party/xla/workspace.bzl" - response = requests.get(url) - match = re.search(r'XLA_COMMIT = "([a-zA-Z0-9]*)"', response.text) -xla_commit = match.group(1) - -url = f"https://raw.githubusercontent.com/openxla/xla/{xla_commit}/third_party/llvm/workspace.bzl" -response = requests.get(url) -match = re.search(r'LLVM_COMMIT = "([a-zA-Z0-9]*)"', response.text) -llvm_commit = match.group(1) - -# If the XLA commit is an "Integrate LLVM" commit we need to get the piper_id directly from there -# to look up the corresponding mlir-hlo commit. -url = f"https://api.github.com/repos/openxla/xla/commits?sha={xla_commit}&per_page=1" -response = requests.get(url).json() -match = re.search(r"Integrate LLVM", response[0]["commit"]["message"]) -if match: - match = re.search(r"PiperOrigin-RevId: ([0-9]*)", response[0]["commit"]["message"]) - piper_id = match.group(1) -else: - # Otherwise, we get the last commit in the XLA repository that touched the mlir-hlo files, and - # get its piper_id to get the same commit in the standalone mlir-hlo repo. - url = f"https://api.github.com/repos/openxla/xla/commits?sha={xla_commit}&path=xla/mlir_hlo&per_page=1" - response = requests.get(url).json() - xla_hlo_commit = response[0]["sha"] - match = re.search(r"PiperOrigin-RevId: ([0-9]*)", response[0]["commit"]["message"]) - piper_id = match.group(1) - -url = f"https://api.github.com/search/commits?q=repo:tensorflow/mlir-hlo+{piper_id}" -response = requests.get(url).json() -hlo_commit = response["items"][0]["sha"] - -quote = '"' -# Update each version using sed -cmds = [ - f"sed -i '' 's/^jax=.*/jax={jax_version}/' {dep_versions_path}", - f"sed -i '' 's/^mhlo=.*/mhlo={hlo_commit}/' {dep_versions_path}", - f"sed -i '' 's/^llvm=.*/llvm={llvm_commit}/' {dep_versions_path}", - # Update jaxlib version in __init__.py - rf"sed -i '' 's/_jaxlib_version = {quote}\([0-9.]\+\){quote}/_jaxlib_version = {quote}{jax_version}{quote}/g' {catalyst_init_path}", -] - -for cmd in cmds: - res = os.system(cmd) - assert res == 0 diff --git a/.gitmodules b/.gitmodules index 0148d21bdd..9fa2499eb3 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ -[submodule "mlir-hlo"] - path = mlir/mlir-hlo - url = https://github.com/tensorflow/mlir-hlo.git +[submodule "stablehlo"] + path = mlir/stablehlo + url = https://github.com/openxla/stablehlo.git shallow = true ignore = dirty [submodule "llvm-project"] diff --git a/Makefile b/Makefile index 541d8b6418..1b1f7a407c 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ BLACKVERSIONMINOR := $(if $(BLACKVERSIONMINOR),$(BLACKVERSIONMINOR),0) MK_ABSPATH := $(abspath $(lastword $(MAKEFILE_LIST))) MK_DIR := $(dir $(MK_ABSPATH)) LLVM_BUILD_DIR ?= $(MK_DIR)/mlir/llvm-project/build -MHLO_BUILD_DIR ?= $(MK_DIR)/mlir/mlir-hlo/bazel-build +STABLEHLO_BUILD_DIR ?= $(MK_DIR)/mlir/stablehlo/build DIALECTS_SRC_DIR ?= $(MK_DIR)/mlir DIALECTS_BUILD_DIR ?= $(MK_DIR)/mlir/build RT_BUILD_DIR ?= $(MK_DIR)/runtime/build @@ -119,15 +119,15 @@ frontend: $(PYTHON) -m pip install -e . --extra-index-url https://test.pypi.org/simple $(PIP_VERBOSE_FLAG) rm -r frontend/pennylane_catalyst.egg-info -.PHONY: mlir llvm mhlo enzyme dialects runtime oqc +.PHONY: mlir llvm stablehlo enzyme dialects runtime oqc mlir: $(MAKE) -C mlir all llvm: $(MAKE) -C mlir llvm -mhlo: - $(MAKE) -C mlir mhlo +stablehlo: + $(MAKE) -C mlir stablehlo enzyme: $(MAKE) -C mlir enzyme @@ -274,7 +274,7 @@ clean: clean-all: clean clean-mlir clean-runtime clean-oqc clean-catalyst: clean clean-dialects clean-runtime clean-oqc -.PHONY: clean-mlir clean-dialects clean-plugin clean-llvm clean-mhlo clean-enzyme +.PHONY: clean-mlir clean-dialects clean-plugin clean-llvm clean-stablehlo clean-enzyme clean-mlir: $(MAKE) -C mlir clean @@ -290,8 +290,8 @@ clean-llvm: reset-llvm: $(MAKE) -C mlir reset-llvm -clean-mhlo: - $(MAKE) -C mlir clean-mhlo +clean-stablehlo: + $(MAKE) -C mlir clean-stablehlo clean-enzyme: $(MAKE) -C mlir clean-enzyme diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 5b6ef008c6..054599f132 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -42,13 +42,15 @@ * The JAX version used by Catalyst is updated to 0.6.2. [(#1897)](https://github.com/PennyLaneAI/catalyst/pull/1897) -* The version of LLVM, mlir-hlo, and Enzyme used by Catalyst has been updated. +* The version of LLVM and Enzyme used by Catalyst has been updated. + The mlir-hlo dependency has been replaced with stablehlo. [(#1916)](https://github.com/PennyLaneAI/catalyst/pull/1916) + [(#1921)](https://github.com/PennyLaneAI/catalyst/pull/1921) The LLVM version has been updated to [commit f8cb798](https://github.com/llvm/llvm-project/tree/f8cb7987c64dcffb72414a40560055cb717dbf74). - The mlir-hlo version has been updated to - [commit 1dd2e71](https://github.com/tensorflow/mlir-hlo/tree/1dd2e71331014ae0373f6bf900ce6be393357190). + The stablehlo version has been updated to + [commit 69d6dae](https://github.com/openxla/stablehlo/commit/69d6dae46e1c7de36e6e6973654754f05353cba5). The Enzyme version has been updated to [v0.0.186](https://github.com/EnzymeAD/Enzyme/releases/tag/v0.0.186). diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index e7be5aaf0e..328a4d6623 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -226,13 +226,13 @@ def get_hlo_lowering_stage(_options: CompileOptions) -> List[str]: """Returns the list of passes to lower StableHLO to upstream MLIR dialects.""" hlo_lowering = [ "canonicalize", - "func.func(chlo-legalize-to-hlo)", - "stablehlo-legalize-to-hlo", - "func.func(mhlo-legalize-control-flow)", - "func.func(hlo-legalize-to-linalg)", - "func.func(mhlo-legalize-to-std)", - "func.func(mhlo-legalize-sort)", - "convert-to-signless", + "func.func(chlo-legalize-to-stablehlo)", + "func.func(stablehlo-legalize-control-flow)", + "func.func(stablehlo-aggressive-simplification)", + "stablehlo-legalize-to-linalg", + "func.func(stablehlo-legalize-to-std)", + "func.func(stablehlo-legalize-sort)", + "stablehlo-convert-to-signless", "canonicalize", "scatter-lowering", "hlo-custom-call-lowering", diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index c0a083010b..894ddb8a12 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -30,37 +30,61 @@ endif() ######################### find_package(MLIR REQUIRED CONFIG) -if(NOT CATALYST_DOCS_ONLY) - find_package(MHLO REQUIRED CONFIG) -endif() message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") -message(STATUS "Using MHLOConfig.cmake in: ${MHLO_DIR}") set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) -# Taken from mlir-hlo/mhlo/transforms/CMakeLists.txt. -# Unfortunately, AllMhloPasses doesn't appear to be exported. -set(ALL_MHLO_PASSES - ChloPasses - MhloPasses - StablehloPasses - MhloToArithmeticConversion - MhloToMemrefConversion - HloToLinalgUtils - MhloToLinalg - MhloToStablehlo - StablehloToMhlo -) - list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") -list(APPEND CMAKE_MODULE_PATH "${MHLO_CMAKE_DIR}") list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake/modules") + +# Discover stablehlo libraries and bundle them into a cmake target for catalyst to link to +# This is because stablehlo does not have a Config.cmake +# so we cannot use find_package(stablehlo) and have to do this manually +set(STABLEHLO_LIBS + ChloCAPI + ChloOps + StablehloAssemblyFormat + StablehloBase + StablehloBroadcastUtils + StablehloCAPI + StablehloLinalgTransforms + StablehloOps + StablehloOptimizationPasses + StablehloPasses + StablehloPassUtils + StablehloRegister + StablehloTypeConversion + StablehloTypeInference + Version + VhloCAPI + VhloOps + VhloTypes +) + +foreach(STABLEHLO_LIB IN LISTS STABLEHLO_LIBS) + add_library(${STABLEHLO_LIB} STATIC IMPORTED GLOBAL) + set_property(TARGET ${STABLEHLO_LIB} PROPERTY + IMPORTED_LOCATION "${STABLEHLO_BUILD_DIR}/lib/lib${STABLEHLO_LIB}.a" + ) +endforeach() + +add_library(ExternalStablehloLib INTERFACE) + +foreach(STABLEHLO_LIB IN LISTS STABLEHLO_LIBS) + target_link_libraries(ExternalStablehloLib INTERFACE ${STABLEHLO_LIB}) +endforeach() + +target_include_directories(ExternalStablehloLib SYSTEM INTERFACE + ${STABLEHLO_DIR} + ${STABLEHLO_BUILD_DIR} # for the generated .inc files +) + # Policy CMP0175 was introduced in CMake 4.31 and raises warnings in the upstream CMake modules. # Policy CMP0177 was introduced in CMake 4.31 and raises warnings in the upstream CMake modules. # TODO: Remove once they (and us) have updated their code to deal with it. @@ -88,9 +112,6 @@ if(QUANTUM_ENABLE_BINDINGS_PYTHON) mlir_configure_python_dev_packages() endif() -list(GET MHLO_INCLUDE_DIRS 1 MLIRHLO_DIR) -list(GET MHLO_INCLUDE_DIRS 2 MLIRHLO_BUILD_DIR) - set(CATALYST_MAIN_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/include) set(CATALYST_GEN_INCLUDE_DIR ${PROJECT_BINARY_DIR}/include) set(CATALYST_LIB_DIR ${PROJECT_BINARY_DIR}) @@ -98,9 +119,6 @@ set(CATALYST_LIB_DIR ${PROJECT_BINARY_DIR}) include_directories(SYSTEM ${LLVM_INCLUDE_DIRS} ${MLIR_INCLUDE_DIRS} - ${MHLO_INCLUDE_DIRS} - ${MLIRHLO_DIR}/stablehlo - ${MLIRHLO_BUILD_DIR}/stablehlo ) link_directories(${LLVM_BUILD_LIBRARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) diff --git a/mlir/Makefile b/mlir/Makefile index 5d2dab0c1b..8fc76e11e7 100644 --- a/mlir/Makefile +++ b/mlir/Makefile @@ -8,7 +8,7 @@ MK_DIR := $(dir $(MK_ABSPATH)) DIALECTS_BUILD_DIR ?= $(MK_DIR)/build DIALECTS_DOCS_BUILD_DIR ?= $(MK_DIR)/build-docs LLVM_BUILD_DIR ?= $(MK_DIR)/llvm-project/build -MHLO_BUILD_DIR ?= $(MK_DIR)/mlir-hlo/bazel-build +STABLEHLO_BUILD_DIR ?= $(MK_DIR)/stablehlo/build ENZYME_BUILD_DIR ?= $(MK_DIR)/Enzyme/build RT_BUILD_DIR ?= $(MK_DIR)/../runtime/build ENABLE_ASAN ?= OFF @@ -42,9 +42,9 @@ LLVM_TARGETS ?= check-mlir llvm-symbolizer .PHONY: help help: @echo "Please use \`make ' where is one of" - @echo " all to build MLIR, MLIR-HLO and custom Catalyst dialects" + @echo " all to build MLIR, Stablehlo and custom Catalyst dialects" @echo " llvm to build MLIR enabling Python bindings" - @echo " mhlo to build MLIR-HLO" + @echo " stablehlo to build stablehlo" @echo " enzyme to build Enzyme" @echo " dialects to build custom Catalyst MLIR dialects" @echo " dialect-docs to build custom Catalyst MLIR dialect documentation" @@ -54,7 +54,7 @@ help: @echo " format [version=?] to apply C++ formatter; use with 'version={version}' to run clang-format-{version} instead of clang-format" .PHONY: all -all: llvm mhlo enzyme dialects dialect-docs plugin +all: llvm stablehlo enzyme dialects dialect-docs plugin .PHONY: llvm llvm: @@ -94,36 +94,28 @@ llvm: # test to reduce unnecessary dependencies. LIT_FILTER_OUT="Bytecode|tosa-to-tensor|execution_engine" cmake --build $(LLVM_BUILD_DIR) --target $(LLVM_TARGETS) -.PHONY: mhlo -mhlo: - @echo "build MLIR-HLO" +.PHONY: stablehlo +stablehlo: + @echo "build stablehlo" - # Patch MHLO shardy dependency - @if cd mlir-hlo; git apply --check $(MK_DIR)/patches/mhlo-remove-shardy.patch; then \ - git apply $(MK_DIR)/patches/mhlo-remove-shardy.patch; \ - fi - - # Patch a MHLO bug with std::sort - @if cd mlir-hlo; git apply --check $(MK_DIR)/patches/mhlo-rename-sort.patch; then \ - git apply $(MK_DIR)/patches/mhlo-rename-sort.patch; \ - fi - cmake -G Ninja -S mlir-hlo -B $(MHLO_BUILD_DIR) \ + cmake -G Ninja -S stablehlo -B $(STABLEHLO_BUILD_DIR) \ + -DSTABLEHLO_ENABLE_LLD=$(ENABLE_LLD) \ -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ - -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_DIR=$(LLVM_BUILD_DIR)/lib/cmake/mlir \ - -DPython3_EXECUTABLE=$(PYTHON) \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_LLD=$(ENABLE_LLD) \ + -DLLVM_ENABLE_ZLIB=$(ENABLE_ZLIB) \ + -DSTABLEHLO_ENABLE_BINDINGS_PYTHON=OFF \ + -DSTABLEHLO_ENABLE_SPLIT_DWARF=ON \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ -DCMAKE_C_COMPILER=$(C_COMPILER) \ -DCMAKE_CXX_COMPILER=$(CXX_COMPILER) \ -DCMAKE_C_COMPILER_LAUNCHER=$(COMPILER_LAUNCHER) \ -DCMAKE_CXX_COMPILER_LAUNCHER=$(COMPILER_LAUNCHER) \ -DCMAKE_EXE_LINKER_FLAGS=$(USE_SANITIZER_FLAGS) \ - -DLLVM_ENABLE_LLD=$(ENABLE_LLD) \ - -DLLVM_ENABLE_ZLIB=$(ENABLE_ZLIB) \ - -DLLVM_ENABLE_ZSTD=$(ENABLE_ZSTD) \ -DCMAKE_CXX_VISIBILITY_PRESET=$(SYMBOL_VISIBILITY) - # TODO: figure out why this test is failing - LIT_FILTER_OUT="chlo_legalize_to_mhlo" cmake --build $(MHLO_BUILD_DIR) --target check-mlir-hlo + cmake --build $(STABLEHLO_BUILD_DIR) .PHONY: enzyme enzyme: TARGET_FILE := $(MK_DIR)/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -183,8 +175,8 @@ dialects: -DEnzyme_DIR=$(ENZYME_BUILD_DIR) \ -DENZYME_SRC_DIR=$(MK_DIR)/Enzyme \ -DMLIR_DIR=$(LLVM_BUILD_DIR)/lib/cmake/mlir \ - -DMHLO_DIR=$(MHLO_BUILD_DIR)/lib/cmake/mlir-hlo \ - -DMHLO_BINARY_DIR=$(MHLO_BUILD_DIR)/bin \ + -DSTABLEHLO_DIR=$(MK_DIR)/stablehlo \ + -DSTABLEHLO_BUILD_DIR=$(STABLEHLO_BUILD_DIR) \ -DRUNTIME_LIB_DIR=$(RT_BUILD_DIR)/lib \ -DMLIR_LIB_DIR=$(LLVM_BUILD_DIR)/lib \ -DCMAKE_C_COMPILER=$(C_COMPILER) \ @@ -222,8 +214,8 @@ test: @echo "test the Catalyst MLIR dialects test suite" cmake --build $(DIALECTS_BUILD_DIR) --target check-dialects -.PHONY: clean clean-dialects clean-enzyme clean-mhlo clean-plugin clean-dialect-docs -clean: clean-dialects clean-llvm clean-mhlo clean-enzyme clean-plugin +.PHONY: clean clean-dialects clean-enzyme clean-stablehlo clean-plugin clean-dialect-docs +clean: clean-dialects clean-llvm clean-stablehlo clean-enzyme clean-plugin clean-dialects: @echo "clean catalyst dialect build files" @@ -234,15 +226,15 @@ clean-llvm: rm -rf $(LLVM_BUILD_DIR) cd llvm-project; git clean -fd; git checkout . +clean-stablehlo: + @echo "clean Stablehlo dialect build files" + rm -rf $(STABLEHLO_BUILD_DIR) + cd stablehlo; git clean -fd; git checkout . + reset-llvm: @echo "reset llvm git state to the commit tracked in .dep-versions without deleting llvm builds" cd llvm-project; git clean -fd; git checkout . -clean-mhlo: - @echo "clean HLO dialect build files" - rm -rf $(MHLO_BUILD_DIR) - cd mlir-hlo; git clean -fd; git checkout . - clean-enzyme: @echo "clean enzyme build files" rm -rf $(ENZYME_BUILD_DIR) diff --git a/mlir/cmake/modules/CMakeLists.txt b/mlir/cmake/modules/CMakeLists.txt index 835aea8ced..b262fb0b01 100644 --- a/mlir/cmake/modules/CMakeLists.txt +++ b/mlir/cmake/modules/CMakeLists.txt @@ -28,7 +28,7 @@ set(llvm_cmake_builddir "${LLVM_BINARY_DIR}/${LLVM_INSTALL_PACKAGE_DIR}") get_property(MLIR_EXPORTS GLOBAL PROPERTY MLIR_EXPORTS) set(TARGETS_TO_REMOVE nlohmann_json tomlplusplus_tomlplusplus ion-transforms CatalystCompilerDriver QECUtils QuantumCAPI qec-transforms) list(REMOVE_ITEM MLIR_EXPORTS ${TARGETS_TO_REMOVE}) -export(TARGETS ${MLIR_EXPORTS} FILE ${catalyst_cmake_builddir}/CatalystTargets.cmake) +export(TARGETS ${MLIR_EXPORTS} ExternalStablehloLib FILE ${catalyst_cmake_builddir}/CatalystTargets.cmake) # Generate MlirConfig.cmake for the build tree. set(CATALYST_CONFIG_CMAKE_DIR "${catalyst_cmake_builddir}") diff --git a/mlir/include/CMakeLists.txt b/mlir/include/CMakeLists.txt index 6f6a692c66..85180214a9 100644 --- a/mlir/include/CMakeLists.txt +++ b/mlir/include/CMakeLists.txt @@ -1,9 +1,9 @@ add_subdirectory(Catalyst) -add_subdirectory(Quantum) -add_subdirectory(QEC) add_subdirectory(Gradient) +add_subdirectory(hlo-extensions) add_subdirectory(Ion) add_subdirectory(MBQC) add_subdirectory(Mitigation) +add_subdirectory(QEC) +add_subdirectory(Quantum) add_subdirectory(Test) -add_subdirectory(mlir-hlo) diff --git a/mlir/include/Catalyst/Transforms/CMakeLists.txt b/mlir/include/Catalyst/Transforms/CMakeLists.txt index 52802b77fc..68deb5bc72 100644 --- a/mlir/include/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/include/Catalyst/Transforms/CMakeLists.txt @@ -1,4 +1,6 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name catalyst) -add_public_tablegen_target(MLIRCatalystPassIncGen) +add_public_tablegen_target(MLIRCatalystPassIncGen + DEPENDS ExternalStablehloLib +) add_mlir_doc(Passes CatalystPasses ./ -gen-pass-doc) diff --git a/mlir/include/Catalyst/Transforms/Passes.h b/mlir/include/Catalyst/Transforms/Passes.h index f39918b697..87e33e67e5 100644 --- a/mlir/include/Catalyst/Transforms/Passes.h +++ b/mlir/include/Catalyst/Transforms/Passes.h @@ -30,13 +30,11 @@ std::unique_ptr createDetensorizeFunctionBoundaryPass(); std::unique_ptr createDetensorizeSCFPass(); std::unique_ptr createDisableAssertionPass(); std::unique_ptr createGEPInboundsPass(); -std::unique_ptr createHloCustomCallLoweringPass(); std::unique_ptr createInlineNestedModulePass(); std::unique_ptr createMemrefCopyToLinalgCopyPass(); std::unique_ptr createMemrefToLLVMWithTBAAPass(); std::unique_ptr createQnodeToAsyncLoweringPass(); std::unique_ptr createRegisterInactiveCallbackPass(); -std::unique_ptr createScatterLoweringPass(); std::unique_ptr createSplitMultipleTapesPass(); void registerAllCatalystPasses(); diff --git a/mlir/include/Catalyst/Transforms/Passes.td b/mlir/include/Catalyst/Transforms/Passes.td index 269cdeb6e0..c76860cb54 100644 --- a/mlir/include/Catalyst/Transforms/Passes.td +++ b/mlir/include/Catalyst/Transforms/Passes.td @@ -65,32 +65,6 @@ def CatalystConversionPass : Pass<"convert-catalyst-to-llvm"> { let constructor = "catalyst::createCatalystConversionPass()"; } -def ScatterLoweringPass : Pass<"scatter-lowering"> { - let summary = "Lower scatter op from Stable HLO to loops."; - - let dependentDialects = [ - "mlir::func::FuncDialect", - "index::IndexDialect", - "mhlo::MhloDialect", - "tensor::TensorDialect", - "scf::SCFDialect" - ]; - - let constructor = "catalyst::createScatterLoweringPass()"; -} - -def HloCustomCallLoweringPass : Pass<"hlo-custom-call-lowering"> { - let summary = "Lower custom calls op from Stable HLO to CallOp."; - - let dependentDialects = [ - "index::IndexDialect", - "mlir::func::FuncDialect", - "catalyst::CatalystDialect", - ]; - - let constructor = "catalyst::createHloCustomCallLoweringPass()"; -} - def QnodeToAsyncLoweringPass : Pass<"qnode-to-async-lowering"> { let summary = "Lower Qnode func and call operations to async func and call operations."; diff --git a/mlir/include/Catalyst/Transforms/Patterns.h b/mlir/include/Catalyst/Transforms/Patterns.h index 6bbf3150ff..53ebc04245 100644 --- a/mlir/include/Catalyst/Transforms/Patterns.h +++ b/mlir/include/Catalyst/Transforms/Patterns.h @@ -21,10 +21,6 @@ namespace catalyst { -void populateScatterPatterns(mlir::RewritePatternSet &); - -void populateHloCustomCallPatterns(mlir::RewritePatternSet &); - void populateQnodeToAsyncPatterns(mlir::RewritePatternSet &); void populateDisableAssertionPatterns(mlir::RewritePatternSet &); diff --git a/mlir/include/hlo-extensions/CMakeLists.txt b/mlir/include/hlo-extensions/CMakeLists.txt new file mode 100644 index 0000000000..d7c71deda7 --- /dev/null +++ b/mlir/include/hlo-extensions/CMakeLists.txt @@ -0,0 +1,16 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name stablehlo) +add_public_tablegen_target(STABLEHLOCatalystPassIncGen + DEPENDS ExternalStablehloLib +) + +# The following is modified from the +# tensorflow/mlir-hlo +# repository at +# https://github.com/tensorflow/mlir-hlo/blob/a5529d99fc4d1132b0c282a053d26c11e6636b3a/mhlo/transforms/CMakeLists.txt +# to build the rewrite patterns for the --stablehlo-legalize-to-std pass +set(LLVM_TARGET_DEFINITIONS stablehlo_legalize_to_standard_patterns.td) +include_directories( + ${CATALYST_MAIN_INCLUDE_DIR}/../stablehlo) +mlir_tablegen(generated_stablehlo_legalize_to_standard.cpp.inc -gen-rewriters) +add_public_tablegen_target(MLIRStablehloLegalizeToStandardIncGen) diff --git a/mlir/include/hlo-extensions/Passes.h b/mlir/include/hlo-extensions/Passes.h new file mode 100644 index 0000000000..c15ed04114 --- /dev/null +++ b/mlir/include/hlo-extensions/Passes.h @@ -0,0 +1,27 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "mlir/Pass/Pass.h" + +namespace catalyst { +std::unique_ptr createHloCustomCallLoweringPass(); +std::unique_ptr createScatterLoweringPass(); +std::unique_ptr createStablehloLegalizeSortPass(); +std::unique_ptr createStablehloLegalizeToStdPass(); +std::unique_ptr createStablehloLegalizeControlFlowPass(); +} // namespace catalyst diff --git a/mlir/include/mlir-hlo/Passes.td b/mlir/include/hlo-extensions/Passes.td similarity index 52% rename from mlir/include/mlir-hlo/Passes.td rename to mlir/include/hlo-extensions/Passes.td index b4b7791018..bfa517c406 100644 --- a/mlir/include/mlir-hlo/Passes.td +++ b/mlir/include/hlo-extensions/Passes.td @@ -31,31 +31,60 @@ // limitations under the License. // ==============================================================================*/ -#ifndef CATALYST_MLIRHLO_PASSES -#define CATALYST_MLIRHLO_PASSES +#ifndef CATALYST_STABLEHLO_PASSES +#define CATALYST_STABLEHLO_PASSES include "mlir/Pass/PassBase.td" -// mhlo legalize sort pass. -def MhloLegalizeSortPass : Pass<"mhlo-legalize-sort", "func::FuncOp"> { - let summary = "Legalize from Mhlo sort to SCF control flow."; - let constructor = "createMhloLegalizeSortPass()"; +// -------------------- Catalyst's own hlo-related passes ------------------------ // + +def ScatterLoweringPass : Pass<"scatter-lowering"> { + let summary = "Lower scatter op from Stable HLO to loops."; + + let dependentDialects = [ + "mlir::func::FuncDialect", + "index::IndexDialect", + "stablehlo::StablehloDialect", + "tensor::TensorDialect", + "scf::SCFDialect" + ]; + + let constructor = "catalyst::createScatterLoweringPass()"; +} + +def HloCustomCallLoweringPass : Pass<"hlo-custom-call-lowering"> { + let summary = "Lower custom calls op from Stable HLO to CallOp."; + + let dependentDialects = [ + "index::IndexDialect", + "mlir::func::FuncDialect", + "catalyst::CatalystDialect", + ]; + + let constructor = "catalyst::createHloCustomCallLoweringPass()"; +} + +// -------------------- upstream mhlo passes removed in stablehlo ------------------------ // +// stablehlo legalize sort pass. +def StablehloLegalizeSortPass : Pass<"stablehlo-legalize-sort", "func::FuncOp"> { + let summary = "Legalize from Stablehlo sort to SCF control flow."; + let constructor = "createStablehloLegalizeSortPass()"; let dependentDialects = ["arith::ArithDialect", "bufferization::BufferizationDialect", "scf::SCFDialect", "tensor::TensorDialect"]; } -// mhlo legalize to std pass. -def MhloLegalizeToStandardPass : Pass<"mhlo-legalize-to-std", "func::FuncOp"> { - let summary = "Legalize from MHLO dialect to standard dialect."; - let constructor = "createMhloLegalizeToStdPass()"; +// stablehlo legalize to std pass. +def StablehloLegalizeToStandardPass : Pass<"stablehlo-legalize-to-std", "func::FuncOp"> { + let summary = "Legalize from Stablehlo dialect to standard dialect."; + let constructor = "createStablehloLegalizeToStdPass()"; } -// mhlo legalize to control flow pass. -def MhloLegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "func::FuncOp"> { - let summary = "Legalize from MHLO control flow to SCF control flow."; - let constructor = "createMhloLegalizeControlFlowPass()"; +// stablehlo legalize to control flow pass. +def StablehloLegalizeControlFlowPass : Pass<"stablehlo-legalize-control-flow", "func::FuncOp"> { + let summary = "Legalize from Stablehlo control flow to SCF control flow."; + let constructor = "createStablehloLegalizeControlFlowPass()"; let dependentDialects = ["scf::SCFDialect", "tensor::TensorDialect"]; } -#endif // CATALYST_MLIRHLO_PASSES +#endif // CATALYST_STABLEHLO_PASSES diff --git a/mlir/include/mlir-hlo/Passes.h b/mlir/include/hlo-extensions/Patterns.h similarity index 69% rename from mlir/include/mlir-hlo/Passes.h rename to mlir/include/hlo-extensions/Patterns.h index 6ce80c494e..8caa0ecc1c 100644 --- a/mlir/include/mlir-hlo/Passes.h +++ b/mlir/include/hlo-extensions/Patterns.h @@ -14,12 +14,14 @@ #pragma once -#include - -#include "mlir/Pass/Pass.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" namespace catalyst { - std::unique_ptr createMhloLegalizeSortPass(); - std::unique_ptr createMhloLegalizeToStdPass(); - std::unique_ptr createMhloLegalizeControlFlowPass(); -} + +void populateScatterPatterns(mlir::RewritePatternSet &); + +void populateHloCustomCallPatterns(mlir::RewritePatternSet &); + +} // namespace catalyst diff --git a/mlir/include/mlir-hlo/mhlo_legalize_to_standard_patterns.td b/mlir/include/hlo-extensions/stablehlo_legalize_to_standard_patterns.td similarity index 77% rename from mlir/include/mlir-hlo/mhlo_legalize_to_standard_patterns.td rename to mlir/include/hlo-extensions/stablehlo_legalize_to_standard_patterns.td index a8b365bc18..ffeddb9e93 100644 --- a/mlir/include/mlir-hlo/mhlo_legalize_to_standard_patterns.td +++ b/mlir/include/hlo-extensions/stablehlo_legalize_to_standard_patterns.td @@ -36,22 +36,19 @@ -// This is the legalization pattern definition file for MHLO to StandardOps. +// This is the legalization pattern definition file for Stablehlo to StandardOps. include "mlir/IR/OpBase.td" include "mlir/Dialect/Arith/IR/ArithOps.td" include "mlir/Dialect/Math/IR/MathOps.td" include "mlir/Dialect/Func/IR/FuncOps.td" -include "mhlo/IR/hlo_ops.td" -// TODO: change the above mhlo include line to the following when migrating to stablehlo -//include "stablehlo/dialect/StablehloOps.td" +include "stablehlo/dialect/StablehloOps.td" //===----------------------------------------------------------------------===// // Nullary op patterns. //===----------------------------------------------------------------------===// -// TODO: update `MHLO_BlahOp` to `StableHLO_BlahOp` when migrating to stablehlo. -def : Pat<(MHLO_ConstantOp ElementsAttr:$value), +def : Pat<(StableHLO_ConstantOp ElementsAttr:$value), (Arith_ConstantOp $value)>; //===----------------------------------------------------------------------===// @@ -77,46 +74,46 @@ def createDenormalIEEE : NativeCodeCall< // Unary Lowering Patterns. -def : Pat<(MHLO_CeilOp HLO_FpTensor:$i), (Math_CeilOp $i, (createFastMathNone ))>; +def : Pat<(StableHLO_CeilOp HLO_FpTensor:$i), (Math_CeilOp $i, (createFastMathNone ))>; // Binary Lowering Patterns. -def : Pat<(MHLO_AndOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(StableHLO_AndOp HLO_IntTensor:$l, HLO_IntTensor:$r), (Arith_AndIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_OrOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(StableHLO_OrOp HLO_IntTensor:$l, HLO_IntTensor:$r), (Arith_OrIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r), +def : Pat<(StableHLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r), (Arith_AddFOp $l, $r, (createFastMathNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_SubtractOp HLO_FpTensor:$l, HLO_FpTensor:$r), +def : Pat<(StableHLO_SubtractOp HLO_FpTensor:$l, HLO_FpTensor:$r), (Arith_SubFOp $l, $r, (createFastMathNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r), +def : Pat<(StableHLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r), (Arith_MulFOp $l, $r, (createFastMathNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r), +def : Pat<(StableHLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r), (Arith_DivFOp $l, $r, (createFastMathNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r), +def : Pat<(StableHLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r), (Arith_RemFOp $l, $r, (createFastMathNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(StableHLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r), (Arith_AddIOp $l, $r, (createOverflowNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_SubtractOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(StableHLO_SubtractOp HLO_IntTensor:$l, HLO_IntTensor:$r), (Arith_SubIOp $l, $r, (createOverflowNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(StableHLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r), (Arith_MulIOp $l, $r, (createOverflowNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(StableHLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r), (Arith_DivSIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(StableHLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r), (Arith_RemSIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_SelectOp $pred, $tv, $fv), +def : Pat<(StableHLO_SelectOp $pred, $tv, $fv), (SelectOp $pred, $tv, $fv), [(IsSameSizeConstraint $pred, $tv), (IsSameSizeConstraint $tv, $fv)]>; diff --git a/mlir/include/mlir-hlo/CMakeLists.txt b/mlir/include/mlir-hlo/CMakeLists.txt deleted file mode 100644 index 584f1e2617..0000000000 --- a/mlir/include/mlir-hlo/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls -name mlir-hlo) -add_public_tablegen_target(MLIRHLOCatalystPassIncGen) - -# The following is modified from the -# tensorflow/mlir-hlo -# repository at -# https://github.com/tensorflow/mlir-hlo/blob/a5529d99fc4d1132b0c282a053d26c11e6636b3a/mhlo/transforms/CMakeLists.txt -# to build the rewrite patterns for the --mhlo-legalize-to-std pass -set(LLVM_TARGET_DEFINITIONS mhlo_legalize_to_standard_patterns.td) -include_directories( - ${CATALYST_MAIN_INCLUDE_DIR}/../mlir-hlo) -mlir_tablegen(generated_mhlo_legalize_to_standard.cpp.inc -gen-rewriters) -add_public_tablegen_target(MLIRMhloLegalizeToStandardIncGen) diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt index 949f171567..d6afa814aa 100644 --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -4,7 +4,7 @@ add_mlir_public_c_api_library(QuantumCAPI LINK_LIBS PRIVATE MLIRCatalyst catalyst-transforms - catalyst-mhlo-transforms + catalyst-stablehlo-transforms MLIRQuantum quantum-transforms MLIRQEC diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt index e9043f25ed..687046ef64 100644 --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -1,11 +1,11 @@ -add_subdirectory(Driver) add_subdirectory(CAPI) add_subdirectory(Catalyst) -add_subdirectory(Quantum) -add_subdirectory(QEC) +add_subdirectory(Driver) add_subdirectory(Gradient) +add_subdirectory(hlo-extensions) add_subdirectory(Ion) add_subdirectory(MBQC) add_subdirectory(Mitigation) +add_subdirectory(QEC) +add_subdirectory(Quantum) add_subdirectory(Test) -add_subdirectory(mlir-hlo) diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index d8d3a81ed0..41776762a7 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -14,8 +14,6 @@ file(GLOB SRC DisableAssertionPatterns.cpp GEPInboundsPass.cpp GEPInboundsPatterns.cpp - hlo_custom_call_lowering.cpp - HloCustomCallPatterns.cpp InlineNestedModules.cpp MemrefCopyToLinalgCopyPass.cpp MemrefCopyToLinalgCopyPatterns.cpp @@ -23,8 +21,6 @@ file(GLOB SRC QnodeToAsyncPatterns.cpp RegisterAllPasses.cpp RegisterInactiveCallbackPass.cpp - scatter_lowering.cpp - ScatterPatterns.cpp SplitMultipleTapes.cpp TBAAPatterns.cpp TBAATagsPass.cpp @@ -40,7 +36,6 @@ set(LIBS set(DEPENDS MLIRCatalystPassIncGen ) - add_mlir_library(${LIBRARY_NAME} STATIC ${SRC} LINK_LIBS PRIVATE ${LIBS} DEPENDS ${DEPENDS}) target_compile_features(${LIBRARY_NAME} PUBLIC cxx_std_20) target_include_directories(${LIBRARY_NAME} PUBLIC diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index c6a2d983d5..1ef631247b 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mlir/Pass/PassRegistry.h" + #include "Catalyst/Transforms/Passes.h" #include "Gradient/Transforms/Passes.h" #include "Ion/Transforms/Passes.h" @@ -20,8 +22,7 @@ #include "QEC/Transforms/Passes.h" #include "Quantum/Transforms/Passes.h" #include "Test/Transforms/Passes.h" -#include "mlir-hlo/Passes.h" -#include "mlir/Pass/PassRegistry.h" +#include "hlo-extensions/Passes.h" void catalyst::registerAllCatalystPasses() { @@ -60,9 +61,6 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createMBQCConversionPass); mlir::registerPass(catalyst::createMemrefCopyToLinalgCopyPass); mlir::registerPass(catalyst::createMemrefToLLVMWithTBAAPass); - mlir::registerPass(catalyst::createMhloLegalizeSortPass); - mlir::registerPass(catalyst::createMhloLegalizeToStdPass); - mlir::registerPass(catalyst::createMhloLegalizeControlFlowPass); mlir::registerPass(catalyst::createMitigationLoweringPass); mlir::registerPass(catalyst::createQnodeToAsyncLoweringPass); mlir::registerPass(catalyst::createQuantumConversionPass); @@ -70,6 +68,9 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createRemoveChainedSelfInversePass); mlir::registerPass(catalyst::createMergeRotationsPass); mlir::registerPass(catalyst::createScatterLoweringPass); + mlir::registerPass(catalyst::createStablehloLegalizeControlFlowPass); + mlir::registerPass(catalyst::createStablehloLegalizeSortPass); + mlir::registerPass(catalyst::createStablehloLegalizeToStdPass); mlir::registerPass(catalyst::createSplitMultipleTapesPass); mlir::registerPass(catalyst::createTestPass); } diff --git a/mlir/lib/Driver/CMakeLists.txt b/mlir/lib/Driver/CMakeLists.txt index 1bb5720366..c3a82be9ec 100644 --- a/mlir/lib/Driver/CMakeLists.txt +++ b/mlir/lib/Driver/CMakeLists.txt @@ -25,6 +25,7 @@ set(LIBS ${conversion_libs} ${extension_libs} ${translation_libs} + ExternalStablehloLib MLIROptLib MLIRCatalyst catalyst-transforms @@ -40,10 +41,7 @@ set(LIBS mitigation-transforms MLIRIon ion-transforms - MhloRegisterDialects - StablehloRegister MLIRCatalystTest - ${ALL_MHLO_PASSES} ${ENZYME_LIB} ) diff --git a/mlir/lib/Driver/CompilerDriver.cpp b/mlir/lib/Driver/CompilerDriver.cpp index d72ef39ebb..701dd68a93 100644 --- a/mlir/lib/Driver/CompilerDriver.cpp +++ b/mlir/lib/Driver/CompilerDriver.cpp @@ -24,8 +24,11 @@ #include #include -#include "mhlo/IR/register.h" -#include "mhlo/transforms/passes.h" +#include "stablehlo/dialect/Register.h" +#include "stablehlo/integrations/c/StablehloPasses.h" +#include "stablehlo/transforms/Passes.h" +#include "stablehlo/transforms/optimization/Passes.h" + #include "mlir/IR/DialectRegistry.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllExtensions.h" @@ -34,7 +37,6 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Target/LLVMIR/Export.h" -#include "stablehlo/dialect/Register.h" #include "llvm/Analysis/CGSCCPassManager.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/IR/LegacyPassManager.h" @@ -294,7 +296,6 @@ void registerAllCatalystDialects(DialectRegistry ®istry) registerAllExtensions(registry); // HLO - mhlo::registerAllMhloDialects(registry); stablehlo::registerAllDialects(registry); // Catalyst @@ -962,7 +963,8 @@ int QuantumDriverMainFromCL(int argc, char **argv) registerAllPasses(); registerAllCatalystPasses(); registerAllCatalystPipelines(); - mhlo::registerAllMhloPasses(); + mlirRegisterAllStablehloPasses(); + mlir::stablehlo::registerOptimizationPasses(); registerAllCatalystDialects(registry); registerLLVMTranslations(registry); diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index aff124dc3d..ccf4e4ab4f 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -12,21 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "Driver/Pipelines.h" +#include + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "stablehlo/conversions/linalg/transforms/Passes.h" +#include "stablehlo/transforms/Passes.h" +#include "stablehlo/transforms/optimization/Passes.h" + #include "Catalyst/IR/CatalystDialect.h" #include "Catalyst/Transforms/Passes.h" +#include "Driver/Pipelines.h" #include "Gradient/IR/GradientDialect.h" #include "Gradient/Transforms/Passes.h" #include "Mitigation/Transforms/Passes.h" #include "Quantum/IR/QuantumDialect.h" #include "Quantum/Transforms/Passes.h" -#include "mhlo/transforms/passes.h" -#include "mlir-hlo/Passes.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/InitAllDialects.h" -#include "mlir/InitAllPasses.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/Passes.h" +#include "hlo-extensions/Passes.h" using namespace mlir; namespace catalyst { @@ -41,20 +46,23 @@ void createEnforceRuntimeInvariantsPipeline(OpPassManager &pm) void createHloLoweringPipeline(OpPassManager &pm) { pm.addPass(mlir::createCanonicalizerPass()); - pm.addNestedPass(mhlo::createChloLegalizeToHloPass()); - pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); - pm.addNestedPass(catalyst::createMhloLegalizeControlFlowPass()); - pm.addNestedPass(mhlo::createLegalizeHloToLinalgPass()); - pm.addNestedPass(catalyst::createMhloLegalizeToStdPass()); - pm.addNestedPass(catalyst::createMhloLegalizeSortPass()); - pm.addPass(mlir::mhlo::createConvertToSignlessPass()); + + pm.addNestedPass(stablehlo::createChloLegalizeToStablehloPass()); + pm.addNestedPass(catalyst::createStablehloLegalizeControlFlowPass()); + stablehlo::StablehloAggressiveSimplificationPassOptions ASoptions; + pm.addNestedPass( + stablehlo::createStablehloAggressiveSimplificationPass(ASoptions)); + pm.addNestedPass(stablehlo::createStablehloLegalizeToLinalgPass()); + pm.addNestedPass(catalyst::createStablehloLegalizeToStdPass()); + pm.addNestedPass(catalyst::createStablehloLegalizeSortPass()); + pm.addPass(stablehlo::createStablehloConvertToSignlessPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(catalyst::createScatterLoweringPass()); pm.addPass(catalyst::createHloCustomCallLoweringPass()); pm.addPass(mlir::createCSEPass()); - mlir::LinalgDetensorizePassOptions options; - options.aggressiveMode = true; - pm.addNestedPass(mlir::createLinalgDetensorizePass(options)); + mlir::LinalgDetensorizePassOptions LDoptions; + LDoptions.aggressiveMode = true; + pm.addNestedPass(mlir::createLinalgDetensorizePass(LDoptions)); pm.addPass(catalyst::createDetensorizeSCFPass()); pm.addPass(mlir::createCanonicalizerPass()); } diff --git a/mlir/lib/mlir-hlo/CMakeLists.txt b/mlir/lib/hlo-extensions/CMakeLists.txt similarity index 61% rename from mlir/lib/mlir-hlo/CMakeLists.txt rename to mlir/lib/hlo-extensions/CMakeLists.txt index 8ad4f319ba..de0833b29d 100644 --- a/mlir/lib/mlir-hlo/CMakeLists.txt +++ b/mlir/lib/hlo-extensions/CMakeLists.txt @@ -1,9 +1,13 @@ -set(LIBRARY_NAME catalyst-mhlo-transforms) +set(LIBRARY_NAME catalyst-stablehlo-transforms) file(GLOB SRC - mhlo_legalize_control_flow.cpp - mhlo_legalize_sort.cpp - mhlo_legalize_to_std.cpp + hlo_custom_call_lowering.cpp + HloCustomCallPatterns.cpp + scatter_lowering.cpp + ScatterPatterns.cpp + stablehlo_legalize_control_flow.cpp + stablehlo_legalize_sort.cpp + stablehlo_legalize_to_std.cpp ) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) @@ -11,12 +15,13 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) set(LIBS ${dialect_libs} ${conversion_libs} + ExternalStablehloLib ) set(DEPENDS + STABLEHLOCatalystPassIncGen MLIRCatalystPassIncGen - MLIRHLOCatalystPassIncGen - MLIRMhloLegalizeToStandardIncGen + MLIRStablehloLegalizeToStandardIncGen ) add_mlir_library(${LIBRARY_NAME} STATIC ${SRC} LINK_LIBS PRIVATE ${LIBS} DEPENDS ${DEPENDS}) diff --git a/mlir/lib/Catalyst/Transforms/HloCustomCallPatterns.cpp b/mlir/lib/hlo-extensions/HloCustomCallPatterns.cpp similarity index 96% rename from mlir/lib/Catalyst/Transforms/HloCustomCallPatterns.cpp rename to mlir/lib/hlo-extensions/HloCustomCallPatterns.cpp index 019abc33ce..2168cc46e9 100644 --- a/mlir/lib/Catalyst/Transforms/HloCustomCallPatterns.cpp +++ b/mlir/lib/hlo-extensions/HloCustomCallPatterns.cpp @@ -14,22 +14,22 @@ #define DEBUG_TYPE "scatter" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "llvm/Support/Debug.h" #include "Catalyst/IR/CatalystOps.h" -#include "mhlo/IR/hlo_ops.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "llvm/Support/Debug.h" using namespace mlir; namespace catalyst { -struct HloCustomCallOpRewritePattern : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; +struct HloCustomCallOpRewritePattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - mlir::LogicalResult matchAndRewrite(mhlo::CustomCallOp op, + mlir::LogicalResult matchAndRewrite(stablehlo::CustomCallOp op, mlir::PatternRewriter &rewriter) const override { StringRef calleeName = op.getCallTargetName(); diff --git a/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp b/mlir/lib/hlo-extensions/ScatterPatterns.cpp similarity index 97% rename from mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp rename to mlir/lib/hlo-extensions/ScatterPatterns.cpp index 3a95533737..0887be843b 100644 --- a/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp +++ b/mlir/lib/hlo-extensions/ScatterPatterns.cpp @@ -22,17 +22,16 @@ #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" - -#include "mhlo/IR/hlo_ops.h" +#include "stablehlo/dialect/StablehloOps.h" using namespace mlir; namespace catalyst { -struct ScatterOpRewritePattern : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; +struct ScatterOpRewritePattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - void emitIndicesError(mhlo::ScatterOp op) const + void emitIndicesError(stablehlo::ScatterOp op) const { op.emitError() << "Indices are not unique and/or not sorted. Note that when using multiple indices " @@ -44,7 +43,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern << ", sorted: " << op.getIndicesAreSorted(); } - mlir::LogicalResult onlyOneInputUpdateAndResult(mhlo::ScatterOp op) const + mlir::LogicalResult onlyOneInputUpdateAndResult(stablehlo::ScatterOp op) const { // Semantics of scatter: // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter @@ -64,7 +63,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern return op.getResults().size() == 1 ? success() : failure(); } - mlir::LogicalResult isAssignment(mhlo::ScatterOp op) const + mlir::LogicalResult isAssignment(stablehlo::ScatterOp op) const { // From: // C23: update_computation has type @@ -90,7 +89,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern return failure(); } - mhlo::ReturnOp returnOp = dyn_cast(block.getTerminator()); + stablehlo::ReturnOp returnOp = dyn_cast(block.getTerminator()); if (!returnOp) { return failure(); } @@ -98,7 +97,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern return returnOp.getResults().front() == block.getArgument(1) ? success() : failure(); } - mlir::LogicalResult noBatching(mhlo::ScatterOp op) const + mlir::LogicalResult noBatching(stablehlo::ScatterOp op) const { // Ok, now that we know it is an assignment, we need to worry about // where exactly are we assigning and what are we assigning. @@ -123,7 +122,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern // return scatterDimNumbers.getInputBatchingDims().empty() ? success() : failure(); } - mlir::LogicalResult singleFullSlices(mhlo::ScatterOp op) const + mlir::LogicalResult singleFullSlices(stablehlo::ScatterOp op) const { // From: // More formally, for all update_index in index_space(updates[0]): @@ -144,13 +143,13 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern return rank == scatterDimNumbers.getUpdateWindowDims().size() ? success() : failure(); } - mlir::LogicalResult canBeDoneWithSingleTensorInsertSlice(mhlo::ScatterOp op) const + mlir::LogicalResult canBeDoneWithSingleTensorInsertSlice(stablehlo::ScatterOp op) const { return cast(op.getScatterIndices().getType()).getRank() == 1 ? success() : failure(); } - mlir::LogicalResult lowerToTensorInsertSlice(mhlo::ScatterOp op, + mlir::LogicalResult lowerToTensorInsertSlice(stablehlo::ScatterOp op, mlir::PatternRewriter &rewriter) const { // mhlo::ScatterOp is exactly the same as stablehlo::ScatterOp @@ -284,7 +283,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern return success(); } - mlir::LogicalResult matchAndRewrite(mhlo::ScatterOp op, + mlir::LogicalResult matchAndRewrite(stablehlo::ScatterOp op, mlir::PatternRewriter &rewriter) const override { // FastPath @@ -457,7 +456,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern }; // Store all the necessary variables for the SCF for op in above defined struct - UpdateData getUpdateData(mhlo::ScatterOp &op, mlir::PatternRewriter &rewriter, + UpdateData getUpdateData(stablehlo::ScatterOp &op, mlir::PatternRewriter &rewriter, mlir::Location loc) const { UpdateData data; diff --git a/mlir/lib/Catalyst/Transforms/hlo_custom_call_lowering.cpp b/mlir/lib/hlo-extensions/hlo_custom_call_lowering.cpp similarity index 89% rename from mlir/lib/Catalyst/Transforms/hlo_custom_call_lowering.cpp rename to mlir/lib/hlo-extensions/hlo_custom_call_lowering.cpp index 290d3827a6..cb0a97dc8e 100644 --- a/mlir/lib/Catalyst/Transforms/hlo_custom_call_lowering.cpp +++ b/mlir/lib/hlo-extensions/hlo_custom_call_lowering.cpp @@ -16,18 +16,17 @@ #include -#include "llvm/Support/Debug.h" - -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" +#include "llvm/Support/Debug.h" #include "Catalyst/IR/CatalystDialect.h" -#include "Catalyst/Transforms/Patterns.h" +#include "hlo-extensions/Passes.h" +#include "hlo-extensions/Patterns.h" using namespace llvm; using namespace mlir; @@ -35,7 +34,7 @@ using namespace catalyst; namespace catalyst { #define GEN_PASS_DEF_HLOCUSTOMCALLLOWERINGPASS -#include "Catalyst/Transforms/Passes.h.inc" +#include "hlo-extensions/Passes.h.inc" struct HloCustomCallLoweringPass : impl::HloCustomCallLoweringPassBase { using HloCustomCallLoweringPassBase::HloCustomCallLoweringPassBase; diff --git a/mlir/lib/Catalyst/Transforms/scatter_lowering.cpp b/mlir/lib/hlo-extensions/scatter_lowering.cpp similarity index 89% rename from mlir/lib/Catalyst/Transforms/scatter_lowering.cpp rename to mlir/lib/hlo-extensions/scatter_lowering.cpp index e56b35390a..1726ce8a62 100644 --- a/mlir/lib/Catalyst/Transforms/scatter_lowering.cpp +++ b/mlir/lib/hlo-extensions/scatter_lowering.cpp @@ -16,19 +16,18 @@ #include -#include "llvm/Support/Debug.h" - -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" +#include "llvm/Support/Debug.h" -#include "Catalyst/Transforms/Patterns.h" +#include "hlo-extensions/Passes.h" +#include "hlo-extensions/Patterns.h" using namespace llvm; using namespace mlir; @@ -36,7 +35,7 @@ using namespace catalyst; namespace catalyst { #define GEN_PASS_DEF_SCATTERLOWERINGPASS -#include "Catalyst/Transforms/Passes.h.inc" +#include "hlo-extensions/Passes.h.inc" struct ScatterLoweringPass : impl::ScatterLoweringPassBase { using ScatterLoweringPassBase::ScatterLoweringPassBase; diff --git a/mlir/lib/mlir-hlo/mhlo_legalize_control_flow.cpp b/mlir/lib/hlo-extensions/stablehlo_legalize_control_flow.cpp similarity index 75% rename from mlir/lib/mlir-hlo/mhlo_legalize_control_flow.cpp rename to mlir/lib/hlo-extensions/stablehlo_legalize_control_flow.cpp index fd9641cb81..03ad07f78d 100644 --- a/mlir/lib/mlir-hlo/mhlo_legalize_control_flow.cpp +++ b/mlir/lib/hlo-extensions/stablehlo_legalize_control_flow.cpp @@ -33,16 +33,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// The modifications are porting the pass from the upstream MHLO namespace to +// The modifications are porting the pass from the upstream stablehlo namespace to // catalyst namespace. -// This file implements logic for lowering MHLO dialect to SCF dialect. +// This file implements logic for lowering Stablehlo dialect to SCF dialect. #include #include #include -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project @@ -56,38 +54,35 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" -// #include "stablehlo/dialect/StablehloOps.h" -// #include "stablehlo/transforms/Passes.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" #include "llvm/Support/Casting.h" -#include "mlir-hlo/Passes.h" +#include "hlo-extensions/Passes.h" using namespace mlir; -using namespace mhlo; -// using namespace stablehlo; +using namespace stablehlo; using namespace catalyst; namespace catalyst { -#define GEN_PASS_DEF_MHLOLEGALIZECONTROLFLOWPASS -#define GEN_PASS_DECL_MHLOLEGALIZECONTROLFLOWPASS -// #define GEN_PASS_DEF_STABLEHLOLEGALIZECONTROLFLOWPASS -// #define GEN_PASS_DECL_STABLEHLOLEGALIZECONTROLFLOWPASS -#include "mlir-hlo/Passes.h.inc" +#define GEN_PASS_DEF_STABLEHLOLEGALIZECONTROLFLOWPASS +#define GEN_PASS_DECL_STABLEHLOLEGALIZECONTROLFLOWPASS +#include "hlo-extensions/Passes.h.inc" } // namespace catalyst namespace { -// All transformations in this file take mhlo blocks which end with +// All transformations in this file take stablehlo blocks which end with // stablehlo::ReturnOp and lower to SCF ops which end with scf::YieldOp. Inline an // entire block with the only change being return -> yield. -void inlineMhloRegionIntoSCFRegion(PatternRewriter &rewriter, Region &mhlo, Region &scf) +void inlineStablehloRegionIntoSCFRegion(PatternRewriter &rewriter, Region &r, Region &scf) { // Remove an existing block, then move the region over. if (!scf.empty()) rewriter.eraseBlock(&scf.back()); - rewriter.inlineRegionBefore(mhlo, scf, scf.end()); + rewriter.inlineRegionBefore(r, scf, scf.end()); // Fix up the terminator. PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToEnd(&scf.back()); @@ -95,7 +90,7 @@ void inlineMhloRegionIntoSCFRegion(PatternRewriter &rewriter, Region &mhlo, Regi rewriter.replaceOpWithNewOp(terminator, terminator->getOperands()); } -// mhlo ops need inputs to be tensors, but scalar values can be a scalar tensor +// stablehlo ops need inputs to be tensors, but scalar values can be a scalar tensor // or a 1 element tensor. To handle this, collapse shape before extracting the // scalar value when necessary. Value extractTensorValue(OpBuilder &b, Value tensor) @@ -116,7 +111,7 @@ struct ScfForBounds { unsigned indexArgIndex; }; -std::optional extractForBounds(mhlo::WhileOp op) +std::optional extractForBounds(stablehlo::WhileOp op) { auto &cond = op.getCond().front(); auto &body = op.getBody().front(); @@ -129,10 +124,10 @@ std::optional extractForBounds(mhlo::WhileOp op) return mlir::cast(v).getArgNumber(); }; - auto compare = llvm::dyn_cast(cond.front()); + auto compare = llvm::dyn_cast(cond.front()); // If the rhs of the comapare is defined outside the block, it's a constant // within the loop. - if (!compare || compare.getComparisonDirection() != mhlo::ComparisonDirection::LT || + if (!compare || compare.getComparisonDirection() != stablehlo::ComparisonDirection::LT || compare.getRhs().getParentBlock() == &cond || !getElementTypeOrSelf(compare.getLhs().getType()).isSignlessIntOrIndex()) { return std::nullopt; @@ -142,7 +137,7 @@ std::optional extractForBounds(mhlo::WhileOp op) if (!iterArg) return std::nullopt; - auto add = llvm::dyn_cast_or_null( + auto add = llvm::dyn_cast_or_null( body.getTerminator()->getOperand(*iterArg).getDefiningOp()); if (!add || matchBbArg(add.getLhs(), body) != iterArg || add.getRhs().getParentBlock() == &body) { @@ -158,10 +153,10 @@ std::optional extractForBounds(mhlo::WhileOp op) } // Rewrites `stablehlo.while` to `scf.while` or `scf.for`. -struct WhileOpPattern : public OpConversionPattern { +struct WhileOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(mhlo::WhileOp op, OpAdaptor adaptor, + LogicalResult matchAndRewrite(stablehlo::WhileOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); @@ -173,8 +168,8 @@ struct WhileOpPattern : public OpConversionPattern { extractTensorValue(rewriter, bounds->step), adaptor.getOperands()); rewriter.setInsertionPointToEnd(newForOp.getBody()); - // Inline while body, and only replace the mhlo.return with an scf.yield. - inlineMhloRegionIntoSCFRegion(rewriter, op.getBody(), newForOp.getRegion()); + // Inline while body, and only replace the stablehlo.return with an scf.yield. + inlineStablehloRegionIntoSCFRegion(rewriter, op.getBody(), newForOp.getRegion()); auto indexArg = newForOp.getRegion().insertArgument( unsigned{0}, newForOp.getLowerBound().getType(), loc); auto oldIndexArg = newForOp.getRegion().getArgument(1 + bounds->indexArgIndex); @@ -194,39 +189,40 @@ struct WhileOpPattern : public OpConversionPattern { // needs to be extracted and used with an scf.condition. rewriter.inlineRegionBefore(op.getCond(), newWhileOp.getBefore(), newWhileOp.getBefore().end()); - auto conditionReturn = cast(newWhileOp.getBefore().front().getTerminator()); + auto conditionReturn = + cast(newWhileOp.getBefore().front().getTerminator()); rewriter.setInsertionPointToEnd(&newWhileOp.getBefore().front()); Value i1 = extractTensorValue(rewriter, conditionReturn->getOperand(0)); rewriter.replaceOpWithNewOp(conditionReturn, i1, newWhileOp.getBeforeArguments()); - // Inline while body, and only replace the mhlo.return with an scf.yield. - inlineMhloRegionIntoSCFRegion(rewriter, op.getBody(), newWhileOp.getAfter()); + // Inline while body, and only replace the stablehlo.return with an scf.yield. + inlineStablehloRegionIntoSCFRegion(rewriter, op.getBody(), newWhileOp.getAfter()); rewriter.replaceOp(op, newWhileOp.getResults()); return success(); } }; -// Rewrites `mhlo.if` to `scf.if`. -struct IfOpPattern : public OpConversionPattern { +// Rewrites `stablehlo.if` to `scf.if`. +struct IfOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(mhlo::IfOp op, OpAdaptor adaptor, + LogicalResult matchAndRewrite(stablehlo::IfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto scfIf = rewriter.create(op.getLoc(), op.getResultTypes(), extractTensorValue(rewriter, adaptor.getPred()), /*withElseRegion=*/true); - inlineMhloRegionIntoSCFRegion(rewriter, op.getTrueBranch(), scfIf.getThenRegion()); - inlineMhloRegionIntoSCFRegion(rewriter, op.getFalseBranch(), scfIf.getElseRegion()); + inlineStablehloRegionIntoSCFRegion(rewriter, op.getTrueBranch(), scfIf.getThenRegion()); + inlineStablehloRegionIntoSCFRegion(rewriter, op.getFalseBranch(), scfIf.getElseRegion()); rewriter.replaceOp(op, scfIf.getResults()); return success(); } }; -// Rewrites `mhlo.case` to a nested `scf.if`. -struct CaseOpPattern : public OpConversionPattern { +// Rewrites `stablehlo.case` to a nested `scf.if`. +struct CaseOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; // Recursively create if/else ops to handle each possible value in a case op. @@ -243,21 +239,21 @@ struct CaseOpPattern : public OpConversionPattern { auto constAttr = DenseElementsAttr::get( shapedType, {mlir::cast(outerBuilder.getI32IntegerAttr(currentIdx))}); Value currentIdxVal = - outerBuilder.create(loc, idxValue.getType(), constAttr); + outerBuilder.create(loc, idxValue.getType(), constAttr); auto scfIf = outerBuilder.create( loc, op.getResultTypes(), extractTensorValue(outerBuilder, - outerBuilder.create(loc, idxValue, currentIdxVal, - ComparisonDirection::EQ)), + outerBuilder.create( + loc, idxValue, currentIdxVal, ComparisonDirection::EQ)), /*withElseRegion=*/true); - inlineMhloRegionIntoSCFRegion(outerBuilder, op.getBranches()[currentIdx], - scfIf.getThenRegion()); + inlineStablehloRegionIntoSCFRegion(outerBuilder, op.getBranches()[currentIdx], + scfIf.getThenRegion()); int nextIdx = currentIdx + 1; // Don't recurse for the final default block. if (currentIdx == static_cast(finalIdx)) { - inlineMhloRegionIntoSCFRegion(outerBuilder, op.getBranches()[nextIdx], - scfIf.getElseRegion()); + inlineStablehloRegionIntoSCFRegion(outerBuilder, op.getBranches()[nextIdx], + scfIf.getElseRegion()); } else { PatternRewriter::InsertionGuard guard(outerBuilder); @@ -268,14 +264,14 @@ struct CaseOpPattern : public OpConversionPattern { return scfIf; } - LogicalResult matchAndRewrite(mhlo::CaseOp op, OpAdaptor adaptor, + LogicalResult matchAndRewrite(stablehlo::CaseOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Inline the op if there is only a default block. if (op.getBranches().size() == 1) { Block &block = op.getBranches().front().front(); auto results = block.getTerminator()->getOperands(); - // Remove the mhlo.return terminator, then inline the block. + // Remove the stablehlo.return terminator, then inline the block. rewriter.eraseOp(block.getTerminator()); rewriter.inlineBlockBefore(/*source=*/&block, /*dest=*/op.getOperation(), /*argValues=*/{}); @@ -289,8 +285,9 @@ struct CaseOpPattern : public OpConversionPattern { } }; -struct MhloLegalizeControlFlowPass - : public catalyst::impl::MhloLegalizeControlFlowPassBase { +struct StablehloLegalizeControlFlowPass + : public catalyst::impl::StablehloLegalizeControlFlowPassBase< + StablehloLegalizeControlFlowPass> { // Perform the lowering to MLIR control flow. void runOnOperation() override { @@ -302,7 +299,7 @@ struct MhloLegalizeControlFlowPass mlir::ConversionTarget target(*ctx); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - target.addIllegalOp(); + target.addIllegalOp(); if (failed(applyPartialConversion(f, target, std::move(patterns)))) { signalPassFailure(); @@ -312,7 +309,7 @@ struct MhloLegalizeControlFlowPass } // namespace -std::unique_ptr catalyst::createMhloLegalizeControlFlowPass() +std::unique_ptr catalyst::createStablehloLegalizeControlFlowPass() { - return std::make_unique(); + return std::make_unique(); } diff --git a/mlir/lib/mlir-hlo/mhlo_legalize_sort.cpp b/mlir/lib/hlo-extensions/stablehlo_legalize_sort.cpp similarity index 96% rename from mlir/lib/mlir-hlo/mhlo_legalize_sort.cpp rename to mlir/lib/hlo-extensions/stablehlo_legalize_sort.cpp index 98a87b1726..29162dccbd 100644 --- a/mlir/lib/mlir-hlo/mhlo_legalize_sort.cpp +++ b/mlir/lib/hlo-extensions/stablehlo_legalize_sort.cpp @@ -33,7 +33,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// The modifications are porting the pass from the upstream MHLO namespace to +// The modifications are porting the pass from the upstream stablehlo namespace to // catalyst namespace. // This file implements logic for lowering stablehlo.sort to the SCF dialect. @@ -41,8 +41,6 @@ limitations under the License. #include #include -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -63,24 +61,21 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" #include "llvm/ADT/STLExtras.h" -// #include "stablehlo/dialect/StablehloOps.h" -// #include "stablehlo/transforms/Passes.h" -#include "mlir-hlo/Passes.h" +#include "hlo-extensions/Passes.h" using namespace mlir; -using namespace mhlo; -// using namespace stablehlo; +using namespace stablehlo; using namespace catalyst; namespace catalyst { -#define GEN_PASS_DEF_MHLOLEGALIZESORTPASS -#define GEN_PASS_DECL_MHLOLEGALIZESORTPASS -// #define GEN_PASS_DEF_STABLEHLOLEGALIZESORTPASS -// #define GEN_PASS_DECL_STABLEHLOLEGALIZESORTPASS -#include "mlir-hlo/Passes.h.inc" +#define GEN_PASS_DEF_STABLEHLOLEGALIZESORTPASS +#define GEN_PASS_DECL_STABLEHLOLEGALIZESORTPASS +#include "hlo-extensions/Passes.h.inc" } // namespace catalyst @@ -574,8 +569,8 @@ struct SortOpPattern : public OpRewritePattern { } }; -struct MhloLegalizeSortPass - : public catalyst::impl::MhloLegalizeSortPassBase { +struct StablehloLegalizeSortPass + : public catalyst::impl::StablehloLegalizeSortPassBase { // Perform the lowering to MLIR control flow. void runOnOperation() override { @@ -587,7 +582,7 @@ struct MhloLegalizeSortPass mlir::ConversionTarget target(*ctx); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - target.addIllegalOp(); + target.addIllegalOp(); if (failed(applyPartialConversion(f, target, std::move(patterns)))) { signalPassFailure(); @@ -597,7 +592,7 @@ struct MhloLegalizeSortPass } // namespace -std::unique_ptr catalyst::createMhloLegalizeSortPass() +std::unique_ptr catalyst::createStablehloLegalizeSortPass() { - return std::make_unique(); + return std::make_unique(); } diff --git a/mlir/lib/mlir-hlo/mhlo_legalize_to_std.cpp b/mlir/lib/hlo-extensions/stablehlo_legalize_to_std.cpp similarity index 82% rename from mlir/lib/mlir-hlo/mhlo_legalize_to_std.cpp rename to mlir/lib/hlo-extensions/stablehlo_legalize_to_std.cpp index 020f39b297..1cec1aa55a 100644 --- a/mlir/lib/mlir-hlo/mhlo_legalize_to_std.cpp +++ b/mlir/lib/hlo-extensions/stablehlo_legalize_to_std.cpp @@ -33,18 +33,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// The modifications are porting the pass from the upstream MHLO namespace to +// The modifications are porting the pass from the upstream stablehlo namespace to // catalyst namespace. -// This file implements logic for lowering MHLO dialect to Standard dialect. +// This file implements logic for lowering Stablehlo dialect to Standard dialect. #include #include #include -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" -#include "mhlo/transforms/rewriters.h" // (??) #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -52,34 +49,31 @@ limitations under the License. #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -// #include "stablehlo/dialect/StablehloOps.h" -// #include "stablehlo/transforms/Passes.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" -#include "mlir-hlo/Passes.h" +#include "hlo-extensions/Passes.h" using namespace mlir; -using namespace mhlo; -// using namespace stablehlo; +using namespace stablehlo; using namespace catalyst; namespace catalyst { -#define GEN_PASS_DEF_MHLOLEGALIZETOSTANDARDPASS -#define GEN_PASS_DECL_MHLOLEGALIZETOSTANDARDPASS -// #define GEN_PASS_DEF_STABLEHLOLEGALIZETOSTANDARDPASS -// #define GEN_PASS_DECL_STABLEHLOLEGALIZETOSTANDARDPASS -#include "mlir-hlo/Passes.h.inc" -#include "mlir-hlo/generated_mhlo_legalize_to_standard.cpp.inc" +#define GEN_PASS_DEF_STABLEHLOLEGALIZETOSTANDARDPASS +#define GEN_PASS_DECL_STABLEHLOLEGALIZETOSTANDARDPASS +#include "hlo-extensions/Passes.h.inc" +#include "hlo-extensions/generated_stablehlo_legalize_to_standard.cpp.inc" } // namespace catalyst namespace { -class CompareIConvert : public OpRewritePattern { +class CompareIConvert : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mhlo::CompareOp op, PatternRewriter &rewriter) const override + LogicalResult matchAndRewrite(stablehlo::CompareOp op, PatternRewriter &rewriter) const override { auto lhs = op.getLhs(); auto rhs = op.getRhs(); @@ -124,11 +118,11 @@ class CompareIConvert : public OpRewritePattern { } }; -class CompareFConvert : public OpRewritePattern { +class CompareFConvert : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mhlo::CompareOp op, PatternRewriter &rewriter) const override + LogicalResult matchAndRewrite(stablehlo::CompareOp op, PatternRewriter &rewriter) const override { auto lhs = op.getLhs(); auto rhs = op.getRhs(); @@ -177,11 +171,11 @@ class CompareFConvert : public OpRewritePattern { // convert the integer constant to iota result type. For complex types, the real // part is replaced with the generated constant and the imaginary part is // replaced with zero tensor. -class ConvertIotaOp : public OpRewritePattern { +class ConvertIotaOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mhlo::IotaOp op, PatternRewriter &rewriter) const override + LogicalResult matchAndRewrite(stablehlo::IotaOp op, PatternRewriter &rewriter) const override { auto outputType = mlir::cast(op.getType()); auto outputSize = outputType.getNumElements(); @@ -233,19 +227,19 @@ class ConvertIotaOp : public OpRewritePattern { auto zeroes = rewriter.create( loc, DenseIntElementsAttr::get(intShapeType, APInt(bitwidth, 0))); auto imagZeroes = rewriter.create(loc, intOrFloatShapeTy, zeroes); - rewriter.replaceOpWithNewOp(op, iotaConst, imagZeroes); + rewriter.replaceOpWithNewOp(op, iotaConst, imagZeroes); return success(); } }; -void populateMhloToStdPatterns(RewritePatternSet *patterns, mlir::MLIRContext *ctx) +void populateStablehloToStdPatterns(RewritePatternSet *patterns, mlir::MLIRContext *ctx) { populateWithGenerated(*patterns); patterns->add(ctx); } -struct MhloLegalizeToStandardPass - : public catalyst::impl::MhloLegalizeToStandardPassBase { +struct StablehloLegalizeToStandardPass + : public catalyst::impl::StablehloLegalizeToStandardPassBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -255,14 +249,14 @@ struct MhloLegalizeToStandardPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); - populateMhloToStdPatterns(&patterns, &getContext()); + populateStablehloToStdPatterns(&patterns, &getContext()); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; } // end anonymous namespace -std::unique_ptr catalyst::createMhloLegalizeToStdPass() +std::unique_ptr catalyst::createStablehloLegalizeToStdPass() { - return std::make_unique(); + return std::make_unique(); } diff --git a/mlir/mlir-hlo b/mlir/mlir-hlo deleted file mode 160000 index 1dd2e71331..0000000000 --- a/mlir/mlir-hlo +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 1dd2e71331014ae0373f6bf900ce6be393357190 diff --git a/mlir/patches/mhlo-remove-shardy.patch b/mlir/patches/mhlo-remove-shardy.patch deleted file mode 100644 index 32ce71061f..0000000000 --- a/mlir/patches/mhlo-remove-shardy.patch +++ /dev/null @@ -1,132 +0,0 @@ -From 70172e8399383d6c1964d73a2d20cba3c55a3279 Mon Sep 17 00:00:00 2001 -From: paul0403 -Date: Thu, 29 May 2025 10:06:35 -0400 -Subject: [PATCH] remove shardy dependency - ---- - bindings/c/CMakeLists.txt | 1 - - stablehlo_ext/CMakeLists.txt | 1 + - stablehlo_ext/analysis/CMakeLists.txt | 3 ++- - stablehlo_ext/transforms/CMakeLists.txt | 7 ++++++- - stablehlo_ext/transforms/stablehlo_refine_shapes.cpp | 3 --- - tests/lit.cfg.py | 1 + - tools/mlir-hlo-opt/mlir-hlo-opt.cc | 2 -- - 7 files changed, 10 insertions(+), 8 deletions(-) - -diff --git a/bindings/c/CMakeLists.txt b/bindings/c/CMakeLists.txt -index fd2a5c2c..53d916d5 100644 ---- a/bindings/c/CMakeLists.txt -+++ b/bindings/c/CMakeLists.txt -@@ -10,7 +10,6 @@ add_mlir_public_c_api_library(MLIRHLOCAPIDialects - MhloPasses - MhloToArithmeticConversion - MhloToMemrefConversion -- MhloToStandard - MhloToLinalg - MhloToStablehlo - StablehloToMhlo -diff --git a/stablehlo_ext/CMakeLists.txt b/stablehlo_ext/CMakeLists.txt -index 3e55a89d..e8d318f1 100644 ---- a/stablehlo_ext/CMakeLists.txt -+++ b/stablehlo_ext/CMakeLists.txt -@@ -12,5 +12,6 @@ - # See the License for the specific language governing permissions and - # limitations under the License. - -+add_subdirectory(analysis) - add_subdirectory(IR) - add_subdirectory(transforms) -diff --git a/stablehlo_ext/analysis/CMakeLists.txt b/stablehlo_ext/analysis/CMakeLists.txt -index 726d340d..0c0259b8 100644 ---- a/stablehlo_ext/analysis/CMakeLists.txt -+++ b/stablehlo_ext/analysis/CMakeLists.txt -@@ -1,5 +1,6 @@ - add_mlir_library(MhloAnalysis -- shape_component_analysis.cc -+ shape_component_analysis.cpp -+ PARTIAL_SOURCES_INTENDED - - DEPENDS - mlir-headers -diff --git a/stablehlo_ext/transforms/CMakeLists.txt b/stablehlo_ext/transforms/CMakeLists.txt -index ee58f490..2d7cc22c 100644 ---- a/stablehlo_ext/transforms/CMakeLists.txt -+++ b/stablehlo_ext/transforms/CMakeLists.txt -@@ -20,9 +20,14 @@ add_mlir_dialect_library(StablehloExtensionPasses - PARTIAL_SOURCES_INTENDED - chlo_recompose_ops.cpp - chlo_preserve_high_level_ops.cpp -+ sink_constants_to_control_flow.cpp -+ stablehlo_add_quant_dequant_conv.cpp - stablehlo_canonicalize_dynamism.cpp -+ stablehlo_canonicalize_from_hlo_import.cpp -+ stablehlo_legalize_quant_composite.cpp -+ stablehlo_prepare_for_hlo_export.cpp - stablehlo_refine_shapes.cpp -- sdy_refine_shapes.cpp -+ symbolic_shape_optimization.cpp - - DEPENDS - StablehloExtensionPassesIncGen -diff --git a/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp b/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp -index cabd6a9f..2e64b4ed 100644 ---- a/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp -+++ b/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp -@@ -34,7 +34,6 @@ limitations under the License. - #include "stablehlo_ext/IR/base.h" - #include "stablehlo_ext/IR/stablehlo_ops.h" - #include "stablehlo_ext/transforms/passes.h" // NOLINT: Used in passes.h.inc --#include "stablehlo_ext/transforms/sdy_refine_shapes.h" - - namespace mlir { - namespace stablehlo_ext { -@@ -154,7 +153,6 @@ struct StablehloRefineShapesPass - patterns->add(context); - patterns->add(context); - patterns->add(context); -- populateSdyShapeRefinementPatterns(context, patterns); - }; - - if (failed(stablehlo::refineEntryFunction(*context, func, -@@ -172,7 +170,6 @@ void populateStablehloExtRefineShapesPatterns(RewritePatternSet *patterns, - patterns->add(context); - patterns->add(context); - patterns->add(context); -- populateSdyShapeRefinementPatterns(context, patterns); - } - - } // namespace stablehlo_ext -diff --git a/tests/lit.cfg.py b/tests/lit.cfg.py -index ab20fbb5..6c61aec5 100644 ---- a/tests/lit.cfg.py -+++ b/tests/lit.cfg.py -@@ -32,6 +32,7 @@ config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) - - # suffixes: A list of file extensions to treat as test files. - config.suffixes = ['.mlir'] -+config.excludes = ['sdy_refine_shapes.mlir'] - - # test_source_root: The root path where tests are located. - config.test_source_root = os.path.dirname(__file__) -diff --git a/tools/mlir-hlo-opt/mlir-hlo-opt.cc b/tools/mlir-hlo-opt/mlir-hlo-opt.cc -index f018cbdc..b4474850 100644 ---- a/tools/mlir-hlo-opt/mlir-hlo-opt.cc -+++ b/tools/mlir-hlo-opt/mlir-hlo-opt.cc -@@ -20,7 +20,6 @@ limitations under the License. - #include "mlir/InitAllExtensions.h" - #include "mlir/InitAllPasses.h" - #include "mlir/Tools/mlir-opt/MlirOptMain.h" --#include "shardy/dialect/sdy/ir/dialect.h" - #include "stablehlo/dialect/Register.h" - #include "stablehlo_ext/transforms/passes.h" - #include "transforms/gpu_passes.h" -@@ -41,6 +40,5 @@ int main(int argc, char** argv) { - registerAllExtensions(registry); - mhlo::registerAllMhloDialects(registry); - stablehlo::registerAllDialects(registry); -- registry.insert(); - return failed(MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry)); - } --- -2.34.1 - diff --git a/mlir/patches/mhlo-rename-sort.patch b/mlir/patches/mhlo-rename-sort.patch deleted file mode 100644 index c356cc35e3..0000000000 --- a/mlir/patches/mhlo-rename-sort.patch +++ /dev/null @@ -1,15 +0,0 @@ -diff --git a/utils/cycle_detector.cc b/utils/cycle_detector.cc -index e3901ae88..890f39654 100644 ---- a/utils/cycle_detector.cc -+++ b/utils/cycle_detector.cc -@@ -199,8 +199,8 @@ static void backwardDfs(GraphCycles::Rep* r, int32_t n, int32_t lowerBound) { - // Recomputes rank assignments to make them compatible with the edges (producer - // has smaller rank than its consumer) - static void reorder(GraphCycles::Rep* r) { -- sort(r->nodes, &r->deltab); -- sort(r->nodes, &r->deltaf); -+ mlir::sort(r->nodes, &r->deltab); -+ mlir::sort(r->nodes, &r->deltaf); - - // Adds contents of delta lists to list (backwards deltas first). - r->list.clear(); diff --git a/mlir/stablehlo b/mlir/stablehlo new file mode 160000 index 0000000000..69d6dae46e --- /dev/null +++ b/mlir/stablehlo @@ -0,0 +1 @@ +Subproject commit 69d6dae46e1c7de36e6e6973654754f05353cba5 diff --git a/mlir/test/Catalyst/HloCustomCallsTest.mlir b/mlir/test/Catalyst/HloCustomCallsTest.mlir index 019dc7a279..4804478574 100644 --- a/mlir/test/Catalyst/HloCustomCallsTest.mlir +++ b/mlir/test/Catalyst/HloCustomCallsTest.mlir @@ -19,6 +19,6 @@ func.func @custom_call(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { // CHECK: %cst_0 = arith.constant dense<3> : tensor // CHECK: %0 = catalyst.custom_call fn("lapack_dgesdd_ffi") (%cst, %cst_0, %cst_0, %arg0) : (tensor, tensor, tensor, tensor<3x3xf64>) -> tensor<3x3xf64> // CHECK: return %0 : tensor<3x3xf64> - %0 = mhlo.custom_call @lapack_dgesdd_ffi(%arg0) {api_version = 2 : i32, backend_config = "", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#mhlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xf64>) -> tensor<3x3xf64> + %0 = stablehlo.custom_call @lapack_dgesdd_ffi(%arg0) {api_version = 2 : i32, backend_config = "", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xf64>) -> tensor<3x3xf64> return %0 : tensor<3x3xf64> } diff --git a/mlir/test/Catalyst/ScatterTest.mlir b/mlir/test/Catalyst/ScatterTest.mlir index 9ce56aa949..8266678880 100644 --- a/mlir/test/Catalyst/ScatterTest.mlir +++ b/mlir/test/Catalyst/ScatterTest.mlir @@ -26,14 +26,14 @@ func.func public @scatter_multiply(%arg0: tensor<3xf64>, %arg1: tensor) -> %2 = arith.select %0, %1, %extracted_1 : i64 %3 = arith.trunci %2 : i64 to i32 %from_elements = tensor.from_elements %3 : tensor<1xi32> - %4 = "mhlo.scatter"(%arg0, %from_elements, %cst) ({ + %4 = "stablehlo.scatter"(%arg0, %from_elements, %cst) ({ ^bb0(%arg2: tensor, %arg3: tensor): %extracted_2 = tensor.extract %arg2[] : tensor %extracted_3 = tensor.extract %arg3[] : tensor %5 = arith.mulf %extracted_2, %extracted_3 : f64 %from_elements_4 = tensor.from_elements %5 : tensor - mhlo.return %from_elements_4 : tensor - }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true} : (tensor<3xf64>, tensor<1xi32>, tensor) -> tensor<3xf64> + stablehlo.return %from_elements_4 : tensor + }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xf64>, tensor<1xi32>, tensor) -> tensor<3xf64> return %4 : tensor<3xf64> } @@ -74,14 +74,14 @@ func.func public @two_scatter(%arg0: tensor<3xf64>, %arg1: tensor) -> tenso %2 = arith.select %0, %1, %extracted_2 : i64 %3 = arith.trunci %2 : i64 to i32 %from_elements = tensor.from_elements %3 : tensor<1xi32> - %4 = "mhlo.scatter"(%arg0, %from_elements, %cst_0) ({ + %4 = "stablehlo.scatter"(%arg0, %from_elements, %cst_0) ({ ^bb0(%arg2: tensor, %arg3: tensor): %extracted_7 = tensor.extract %arg2[] : tensor %extracted_8 = tensor.extract %arg3[] : tensor %12 = arith.mulf %extracted_7, %extracted_8 : f64 %from_elements_9 = tensor.from_elements %12 : tensor - mhlo.return %from_elements_9 : tensor - }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true} : (tensor<3xf64>, tensor<1xi32>, tensor) -> tensor<3xf64> + stablehlo.return %from_elements_9 : tensor + }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xf64>, tensor<1xi32>, tensor) -> tensor<3xf64> %extracted_3 = tensor.extract %arg1[] : tensor %5 = arith.cmpi slt, %extracted_3, %c0_i64 : i64 %extracted_4 = tensor.extract %arg1[] : tensor @@ -90,14 +90,14 @@ func.func public @two_scatter(%arg0: tensor<3xf64>, %arg1: tensor) -> tenso %7 = arith.select %5, %6, %extracted_5 : i64 %8 = arith.trunci %7 : i64 to i32 %from_elements_6 = tensor.from_elements %8 : tensor<1xi32> - %9 = "mhlo.scatter"(%arg0, %from_elements_6, %cst) ({ + %9 = "stablehlo.scatter"(%arg0, %from_elements_6, %cst) ({ ^bb0(%arg2: tensor, %arg3: tensor): %extracted_7 = tensor.extract %arg2[] : tensor %extracted_8 = tensor.extract %arg3[] : tensor %12 = arith.addf %extracted_7, %extracted_8 : f64 %from_elements_9 = tensor.from_elements %12 : tensor - mhlo.return %from_elements_9 : tensor - }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true} : (tensor<3xf64>, tensor<1xi32>, tensor) -> tensor<3xf64> + stablehlo.return %from_elements_9 : tensor + }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xf64>, tensor<1xi32>, tensor) -> tensor<3xf64> %10 = tensor.empty() : tensor<3xf64> %11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%4, %9 : tensor<3xf64>, tensor<3xf64>) outs(%10 : tensor<3xf64>) { ^bb0(%in: f64, %in_7: f64, %out: f64): @@ -149,15 +149,15 @@ func.func public @two_scatter(%arg0: tensor<3xf64>, %arg1: tensor) -> tenso func.func public @full_example_scatter(%input: tensor<3x4x2xi64>, %update: tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64> attributes {llvm.emit_c_interface} { %scatter_indices = arith.constant dense<2> : tensor<2x3x2xi32> - %result = "mhlo.scatter"(%input, %scatter_indices, %update) ({ + %result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({ ^bb0(%arg2: tensor, %arg3: tensor): %extracted_1 = tensor.extract %arg2[] : tensor %extracted_2 = tensor.extract %arg3[] : tensor %1 = arith.addi %extracted_1, %extracted_2 : i64 %from_elements = tensor.from_elements %1 : tensor - mhlo.return %from_elements : tensor + stablehlo.return %from_elements : tensor }) { - scatter_dimension_numbers = #mhlo.scatter< + scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [2, 3], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0, 1], @@ -205,7 +205,7 @@ func.func public @example_no_update_dim(%arg0: tensor<4xf64>) -> tensor<4xf64> { ^bb0(%in: i32, %out: i32): linalg.yield %in : i32 } -> tensor<2x1xi32> - %2 = "mhlo.scatter"(%cst, %1, %cst_0) ({ + %2 = "stablehlo.scatter"(%cst, %1, %cst_0) ({ ^bb0(%arg1: tensor, %arg2: tensor): %3 = tensor.empty() : tensor %4 = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = []} ins(%arg1, %arg2 : tensor, tensor) outs(%3 : tensor) { @@ -213,8 +213,8 @@ func.func public @example_no_update_dim(%arg0: tensor<4xf64>) -> tensor<4xf64> { %5 = arith.addf %in, %in_2 : f64 linalg.yield %5 : f64 } -> tensor - mhlo.return %4 : tensor - }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true} : (tensor<4xf64>, tensor<2x1xi32>, tensor<2xf64>) -> tensor<4xf64> + stablehlo.return %4 : tensor + }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<4xf64>, tensor<2x1xi32>, tensor<2xf64>) -> tensor<4xf64> return %2 : tensor<4xf64> } @@ -252,7 +252,7 @@ func.func public @example_no_update_dim(%arg0: tensor<4xf64>) -> tensor<4xf64> { // CHECK-LABEL: @test_happy_path module @test_happy_path { - // CHECK-NOT: mhlo.scatter + // CHECK-NOT: stablehlo.scatter // CHECK-DAG: [[cst0:%.+]] = index.constant 0 // CHECK-DAG: [[inputs:%.+]] = "test.op"() : () -> tensor<[[dim1:.*]]x[[dim0:.*]]xf64> // CHECK-DAG: [[scatter_indices:%.+]] = "test.op"() : () -> tensor<1xi32> @@ -263,17 +263,17 @@ module @test_happy_path { %inputs = "test.op"() : () -> (tensor<7x5xf64>) %scatter_indices = "test.op"() : () -> (tensor<1xi32>) %updates = "test.op"() : () -> (tensor<5xf64>) - %results = "mhlo.scatter"(%inputs, %scatter_indices, %updates) <{ + %results = "stablehlo.scatter"(%inputs, %scatter_indices, %updates) <{ indices_are_sorted = true, unique_indices = true, - scatter_dimension_numbers = #mhlo.scatter< + scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [0], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0] > }> ({ ^bb0(%arg3: tensor, %arg4: tensor): - mhlo.return %arg4 : tensor + stablehlo.return %arg4 : tensor }) : (tensor<7x5xf64>, tensor<1xi32>, tensor<5xf64>) -> tensor<7x5xf64> "test.op"(%results) : (tensor<7x5xf64>) -> () @@ -286,17 +286,17 @@ module @test_multiple_inputs { %scatter_indices = "test.op"() : () -> (tensor<1xi32>) %updates = "test.op"() : () -> (tensor<5xf64>) // expected-error@+1 {{Only one input, update, and result}} - %results:2 = "mhlo.scatter"(%inputs, %inputs, %scatter_indices, %updates, %updates) <{ + %results:2 = "stablehlo.scatter"(%inputs, %inputs, %scatter_indices, %updates, %updates) <{ indices_are_sorted = true, unique_indices = true, - scatter_dimension_numbers = #mhlo.scatter< + scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [0], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0] > }> ({ ^bb0(%arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor): - mhlo.return %arg4, %arg6 : tensor, tensor + stablehlo.return %arg4, %arg6 : tensor, tensor }) : (tensor<7x5xf64>, tensor<7x5xf64>, tensor<1xi32>, tensor<5xf64>, tensor<5xf64>) -> (tensor<7x5xf64>, tensor<7x5xf64>) "test.op"(%results#0, %results#1) : (tensor<7x5xf64>, tensor<7x5xf64>) -> () } @@ -312,10 +312,10 @@ module @test_is_not_assignment { // CHECK-NOT: tensor.insert_slice - %results = "mhlo.scatter"(%inputs, %scatter_indices, %updates) <{ + %results = "stablehlo.scatter"(%inputs, %scatter_indices, %updates) <{ indices_are_sorted = true, unique_indices = true, - scatter_dimension_numbers = #mhlo.scatter< + scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [0], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0] @@ -323,7 +323,7 @@ module @test_is_not_assignment { }> ({ ^bb0(%arg3: tensor, %arg4: tensor): %add = stablehlo.add %arg3, %arg4 : tensor - mhlo.return %add : tensor + stablehlo.return %add : tensor }) : (tensor<7x5xf64>, tensor<1xi32>, tensor<5xf64>) -> tensor<7x5xf64> "test.op"(%results) : (tensor<7x5xf64>) -> () } @@ -345,17 +345,17 @@ module @insert_tensor_rank_2 { // CHECK: [[idx:%.+]] = arith.index_cast [[scatter_idx]] : i32 to index // CHECK: tensor.insert_slice [[updates]] into [[inputs]][[[idx]], 0, 0] [1, [[dim1]], [[dim0]]] [1, 1, 1] : tensor<7x5xf64> into tensor<9x7x5xf64> - %results = "mhlo.scatter"(%inputs, %scatter_indices, %updates) <{ + %results = "stablehlo.scatter"(%inputs, %scatter_indices, %updates) <{ indices_are_sorted = true, unique_indices = true, - scatter_dimension_numbers = #mhlo.scatter< + scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [0, 1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0] > }> ({ ^bb0(%arg3: tensor, %arg4: tensor): - mhlo.return %arg4 : tensor + stablehlo.return %arg4 : tensor }) : (tensor<9x7x5xf64>, tensor<1xi32>, tensor<7x5xf64>) -> tensor<9x7x5xf64> "test.op"(%results) : (tensor<9x7x5xf64>) -> () } @@ -380,17 +380,17 @@ module @two_dyn_indices { // CHECK-DAG: [[idx__1:%.+]] = arith.index_cast [[scatter_idx_1]] : i32 to index // CHECK: tensor.insert_slice [[updates]] into [[inputs]][[[idx__0]], [[idx__1]], 0] [1, 1, [[dim0]]] [1, 1, 1] : tensor<[[dim0]]xf64> into tensor<9x7x5xf64> - %results = "mhlo.scatter"(%inputs, %scatter_indices, %updates) <{ + %results = "stablehlo.scatter"(%inputs, %scatter_indices, %updates) <{ indices_are_sorted = true, unique_indices = true, - scatter_dimension_numbers = #mhlo.scatter< + scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [0], inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1] > }> ({ ^bb0(%arg3: tensor, %arg4: tensor): - mhlo.return %arg4 : tensor + stablehlo.return %arg4 : tensor }) : (tensor<9x7x5xf64>, tensor<2xi32>, tensor<5xf64>) -> tensor<9x7x5xf64> "test.op"(%results) : (tensor<9x7x5xf64>) -> () } @@ -414,17 +414,17 @@ module @two_dyn_indices_reverted { // CHECK-DAG: [[idx__0:%.+]] = arith.index_cast [[scatter_idx_0]] : i32 to index // CHECK-DAG: [[idx__1:%.+]] = arith.index_cast [[scatter_idx_1]] : i32 to index - %results = "mhlo.scatter"(%inputs, %scatter_indices, %updates) <{ + %results = "stablehlo.scatter"(%inputs, %scatter_indices, %updates) <{ indices_are_sorted = true, unique_indices = true, - scatter_dimension_numbers = #mhlo.scatter< + scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [0], inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [1, 0] // This line is changed > }> ({ ^bb0(%arg3: tensor, %arg4: tensor): - mhlo.return %arg4 : tensor + stablehlo.return %arg4 : tensor }) : (tensor<9x7x5xf64>, tensor<2xi32>, tensor<5xf64>) -> tensor<9x7x5xf64> "test.op"(%results) : (tensor<9x7x5xf64>) -> () diff --git a/mlir/test/frontend/lit.site.cfg.py.in b/mlir/test/frontend/lit.site.cfg.py.in index 2a6816dbfd..e888797f16 100644 --- a/mlir/test/frontend/lit.site.cfg.py.in +++ b/mlir/test/frontend/lit.site.cfg.py.in @@ -5,7 +5,6 @@ config.python_executable = "@Python3_EXECUTABLE@" config.frontend_test_dir = "@CMAKE_BINARY_DIR@" + "/test/frontend" config.quantum_bin_dir = "@CMAKE_BINARY_DIR@" + "/bin" config.mlir_bindings_dir = "@CMAKE_BINARY_DIR@" + "/python_packages/quantum" -config.mhlo_bin_dir = "@MHLO_BINARY_DIR@" config.lrt_lib_dir = "@RUNTIME_LIB_DIR@" config.mlir_lib_dir = "@MLIR_LIB_DIR@" diff --git a/mlir/tools/catalyst-cli/CMakeLists.txt b/mlir/tools/catalyst-cli/CMakeLists.txt index a5af9f411a..1dd3d9693d 100644 --- a/mlir/tools/catalyst-cli/CMakeLists.txt +++ b/mlir/tools/catalyst-cli/CMakeLists.txt @@ -23,10 +23,11 @@ set(LIBS ${dialect_libs} ${conversion_libs} ${extension_libs} + ExternalStablehloLib MLIROptLib MLIRCatalyst catalyst-transforms - catalyst-mhlo-transforms + catalyst-stablehlo-transforms MLIRQuantum quantum-transforms MLIRQEC @@ -39,10 +40,7 @@ set(LIBS mitigation-transforms MLIRIon ion-transforms - MhloRegisterDialects - StablehloRegister MLIRCatalystTest - ${ALL_MHLO_PASSES} ${ENZYME_LIB} CatalystCompilerDriver ) diff --git a/mlir/tools/quantum-lsp-server/CMakeLists.txt b/mlir/tools/quantum-lsp-server/CMakeLists.txt index 2bfcd7e134..f4a7c2e727 100644 --- a/mlir/tools/quantum-lsp-server/CMakeLists.txt +++ b/mlir/tools/quantum-lsp-server/CMakeLists.txt @@ -3,6 +3,7 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) set(LIBS ${dialect_libs} ${conversion_libs} + ExternalStablehloLib MLIRLspServerLib MLIRCatalyst MLIRQuantum @@ -11,8 +12,6 @@ set(LIBS MLIRMBQC MLIRMitigation MLIRIon - MhloRegisterDialects - StablehloRegister ) add_llvm_executable(quantum-lsp-server quantum-lsp-server.cpp) diff --git a/mlir/tools/quantum-lsp-server/quantum-lsp-server.cpp b/mlir/tools/quantum-lsp-server/quantum-lsp-server.cpp index 3160ea1919..a98de69942 100644 --- a/mlir/tools/quantum-lsp-server/quantum-lsp-server.cpp +++ b/mlir/tools/quantum-lsp-server/quantum-lsp-server.cpp @@ -24,7 +24,6 @@ #include "QEC/IR/QECDialect.h" #include "Quantum/IR/QuantumDialect.h" -#include "mhlo/IR/register.h" #include "stablehlo/dialect/Register.h" int main(int argc, char **argv) @@ -39,7 +38,6 @@ int main(int argc, char **argv) registry.insert(); registry.insert(); - mlir::mhlo::registerAllMhloDialects(registry); mlir::stablehlo::registerAllDialects(registry); return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry)); diff --git a/mlir/tools/quantum-opt/CMakeLists.txt b/mlir/tools/quantum-opt/CMakeLists.txt index 81b7aa7815..10c6ed5a0f 100644 --- a/mlir/tools/quantum-opt/CMakeLists.txt +++ b/mlir/tools/quantum-opt/CMakeLists.txt @@ -5,10 +5,11 @@ set(LIBS ${dialect_libs} ${conversion_libs} ${extension_libs} + ExternalStablehloLib MLIROptLib MLIRCatalyst catalyst-transforms - catalyst-mhlo-transforms + catalyst-stablehlo-transforms MLIRQuantum quantum-transforms MLIRQEC @@ -21,12 +22,9 @@ set(LIBS mitigation-transforms MLIRIon ion-transforms - MhloRegisterDialects - StablehloRegister MLIRCatalystTest MLIRCatalystUtils MLIRTestDialect - ${ALL_MHLO_PASSES} ) add_mlir_tool(quantum-opt quantum-opt.cpp DEPENDS ${LIBS} SUPPORT_PLUGINS) diff --git a/mlir/tools/quantum-opt/quantum-opt.cpp b/mlir/tools/quantum-opt/quantum-opt.cpp index 36a5b72c2a..1d252733ea 100644 --- a/mlir/tools/quantum-opt/quantum-opt.cpp +++ b/mlir/tools/quantum-opt/quantum-opt.cpp @@ -16,8 +16,6 @@ #include // ifstream #include //regex -#include "mhlo/IR/register.h" -#include "mhlo/transforms/passes.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/DialectRegistry.h" @@ -25,11 +23,13 @@ #include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "stablehlo/dialect/Register.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/integrations/c/StablehloPasses.h" +#include "stablehlo/transforms/Passes.h" +#include "stablehlo/transforms/optimization/Passes.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/raw_ostream.h" -#include "mhlo/IR/hlo_ops.h" - #include "Catalyst/IR/CatalystDialect.h" #include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h" #include "Catalyst/Transforms/Passes.h" @@ -55,12 +55,12 @@ int main(int argc, char **argv) llvm::cl::AddExtraVersionPrinter(catalyst::printVersion); mlir::registerAllPasses(); catalyst::registerAllCatalystPasses(); - mlir::mhlo::registerAllMhloPasses(); + mlirRegisterAllStablehloPasses(); + mlir::stablehlo::registerOptimizationPasses(); mlir::DialectRegistry registry; mlir::registerAllDialects(registry); test::registerTestDialect(registry); - mlir::mhlo::registerAllMhloDialects(registry); mlir::stablehlo::registerAllDialects(registry); mlir::func::registerAllExtensions(registry); registry.insert(); @@ -70,7 +70,7 @@ int main(int argc, char **argv) registry.insert(); registry.insert(); registry.insert(); - registry.insert(); + registry.insert(); catalyst::registerBufferizableOpInterfaceExternalModels(registry); catalyst::gradient::registerBufferizableOpInterfaceExternalModels(registry);