diff --git a/.github/workflows/build-wheel-linux-arm64.yaml b/.github/workflows/build-wheel-linux-arm64.yaml index 304b23ee5d..9e6d70465a 100644 --- a/.github/workflows/build-wheel-linux-arm64.yaml +++ b/.github/workflows/build-wheel-linux-arm64.yaml @@ -63,7 +63,7 @@ jobs: uses: actions/cache@v4 with: path: mlir/llvm-project - key: llvm-${{ needs.constants.outputs.llvm_version }}-default-source + key: llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-source enableCrossOsArchive: True - name: Cache MHLO Source @@ -71,7 +71,7 @@ jobs: uses: actions/cache@v4 with: path: mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-default-source + key: mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-source enableCrossOsArchive: True - name: Cache Enzyme Source @@ -112,14 +112,14 @@ jobs: uses: actions/cache/restore@v4 with: path: llvm-build - key: ${{ matrix.container_name }}-llvm-${{ needs.constants.outputs.llvm_version }}-wheel-build + key: ${{ matrix.container_name }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build - name: Restore MHLO Build id: cache-mhlo-build uses: actions/cache/restore@v4 with: path: mhlo-build - key: ${{ matrix.container_name }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + key: ${{ matrix.container_name }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build lookup-only: True - name: Restore Enzyme Build @@ -160,7 +160,7 @@ jobs: uses: actions/cache/save@v4 with: path: llvm-build - key: ${{ matrix.container_name }}-llvm-${{ needs.constants.outputs.llvm_version }}-wheel-build + key: ${{ matrix.container_name }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build - name: Build MHLO Dialect if: steps.cache-mhlo-build.outputs.cache-hit != 'true' @@ -179,7 +179,7 @@ jobs: uses: actions/cache/save@v4 with: path: mhlo-build - key: ${{ matrix.container_name }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + key: ${{ matrix.container_name }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build - name: Build Enzyme if: steps.cache-enzyme-build.outputs.cache-hit != 'true' @@ -240,7 +240,7 @@ jobs: uses: actions/cache/restore@v4 with: path: llvm-build - key: ${{ matrix.container_name }}-llvm-${{ needs.constants.outputs.llvm_version }}-wheel-build + key: ${{ matrix.container_name }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build fail-on-cache-miss: True - name: Get Cached MHLO Source @@ -257,7 +257,7 @@ jobs: uses: actions/cache/restore@v4 with: path: mhlo-build - key: ${{ matrix.container_name }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + key: ${{ matrix.container_name }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build fail-on-cache-miss: True - name: Get Cached Enzyme Source @@ -334,7 +334,7 @@ jobs: uses: actions/cache@v4 with: path: llvm-build - key: ${{ matrix.container_name }}-llvm-${{ needs.constants.outputs.llvm_version }}-wheel-build + key: ${{ matrix.container_name }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build fail-on-cache-miss: True - name: Run Python Pytest Tests diff --git a/.github/workflows/build-wheel-linux-x86_64.yaml b/.github/workflows/build-wheel-linux-x86_64.yaml index 825ce2f65e..f151c3832f 100644 --- a/.github/workflows/build-wheel-linux-x86_64.yaml +++ b/.github/workflows/build-wheel-linux-x86_64.yaml @@ -118,14 +118,14 @@ jobs: uses: actions/cache/restore@v4 with: path: llvm-build - key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{matrix.python_version}}-wheel-build + key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-${{matrix.python_version}}-wheel-build - name: Restore MHLO Build id: cache-mhlo-build uses: actions/cache/restore@v4 with: path: mhlo-build - key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build lookup-only: True - name: Restore Enzyme Build @@ -172,6 +172,11 @@ jobs: if: steps.cache-llvm-build.outputs.cache-hit != 'true' run: | export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH + + # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher). + if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch; fi + if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch; fi + cmake -S mlir/llvm-project/llvm -B llvm-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_BUILD_EXAMPLES=OFF \ @@ -198,17 +203,18 @@ jobs: uses: actions/cache/save@v4 with: path: llvm-build - key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{matrix.python_version}}-wheel-build + key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-${{matrix.python_version}}-wheel-build - name: Build MHLO Dialect if: steps.cache-mhlo-build.outputs.cache-hit != 'true' # building with LLD is a strong requirement for mhlo run: | + # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher). export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH - export TARGET_FILE=mlir/mlir-hlo/mhlo/transforms/CMakeLists.txt export PATCH_FILE=mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi + if patch --dry-run -p1 -N --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch; fi cmake -S mlir/mlir-hlo -B mhlo-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ @@ -228,7 +234,7 @@ jobs: uses: actions/cache/save@v4 with: path: mhlo-build - key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build - name: Build Enzyme if: steps.cache-enzyme-build.outputs.cache-hit != 'true' @@ -296,7 +302,7 @@ jobs: uses: actions/cache/restore@v4 with: path: llvm-build - key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-3.10-wheel-build + key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-3.10-wheel-build fail-on-cache-miss: True - name: Get Cached MHLO Source @@ -313,7 +319,7 @@ jobs: uses: actions/cache/restore@v4 with: path: mhlo-build - key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build fail-on-cache-miss: True - name: Get Cached Enzyme Source @@ -363,7 +369,12 @@ jobs: # Build Quantum and Gradient Dialects - name: Build MLIR Dialects run: | + # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher). export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH + if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch; fi + if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch; fi + if patch --dry-run -p1 -N --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch; fi + cmake -S mlir -B quantum-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_ENABLE_ASSERTIONS=ON \ diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml index 5d6be17222..307feb8b05 100644 --- a/.github/workflows/build-wheel-macos-arm64.yaml +++ b/.github/workflows/build-wheel-macos-arm64.yaml @@ -105,14 +105,14 @@ jobs: uses: actions/cache/restore@v4 with: path: llvm-build - key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{matrix.python_version}}-wheel-build + key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-${{matrix.python_version}}-wheel-build - name: Restore MHLO Build id: cache-mhlo-build uses: actions/cache/restore@v4 with: path: mhlo-build - key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build lookup-only: True - name: Restore Enzyme Build @@ -137,6 +137,10 @@ jobs: - name: Build LLVM / MLIR if: steps.cache-llvm-build.outputs.cache-hit != 'true' run: | + # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher). + if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch; fi + if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch; fi + cmake -S mlir/llvm-project/llvm -B llvm-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_BUILD_EXAMPLES=OFF \ @@ -163,16 +167,18 @@ jobs: uses: actions/cache/save@v4 with: path: llvm-build - key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{matrix.python_version}}-wheel-build + key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-${{matrix.python_version}}-wheel-build - name: Build MHLO Dialect if: steps.cache-mhlo-build.outputs.cache-hit != 'true' run: | export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH + # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher). export TARGET_FILE=mlir/mlir-hlo/mhlo/transforms/CMakeLists.txt export PATCH_FILE=mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi + if patch --dry-run -p1 -N --directory=mlir/mlir-hlo < mlir/patches/moduleOp-mhlo.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/mlir-hlo < mlir/patches/moduleOp-mhlo.patch; fi cmake -S mlir/mlir-hlo -B mhlo-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ @@ -192,7 +198,7 @@ jobs: uses: actions/cache/save@v4 with: path: mhlo-build - key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build - name: Build Enzyme if: steps.cache-enzyme-build.outputs.cache-hit != 'true' @@ -256,7 +262,7 @@ jobs: uses: actions/cache/restore@v4 with: path: llvm-build - key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-3.10-wheel-build + key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-3.10-wheel-build fail-on-cache-miss: True - name: Get Cached MHLO Source @@ -273,7 +279,7 @@ jobs: uses: actions/cache/restore@v4 with: path: mhlo-build - key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build fail-on-cache-miss: True - name: Get Cached Enzyme Source @@ -328,6 +334,11 @@ jobs: # Build Quantum and Gradient Dialects - name: Build MLIR Dialects run: | + # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher). + if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch; fi + if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch; fi + if patch --dry-run -p1 -N --directory=mlir/mlir-hlo < mlir/patches/moduleOp-mhlo.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/mlir-hlo < mlir/patches/moduleOp-mhlo.patch; fi + cmake -S mlir -B quantum-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_ENABLE_ASSERTIONS=ON \ diff --git a/.github/workflows/build-wheel-macos-x86_64.yaml b/.github/workflows/build-wheel-macos-x86_64.yaml index eebc346fd4..70d7dfeab2 100644 --- a/.github/workflows/build-wheel-macos-x86_64.yaml +++ b/.github/workflows/build-wheel-macos-x86_64.yaml @@ -103,14 +103,14 @@ jobs: uses: actions/cache/restore@v4 with: path: llvm-build - key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{matrix.python_version}}-wheel-build + key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-${{matrix.python_version}}-wheel-build - name: Restore MHLO Build id: cache-mhlo-build uses: actions/cache/restore@v4 with: path: mhlo-build - key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build lookup-only: True - name: Restore Enzyme Build @@ -133,6 +133,10 @@ jobs: - name: Build LLVM / MLIR if: steps.cache-llvm-build.outputs.cache-hit != 'true' run: | + # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher). + if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch; fi + if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch; fi + cmake -S mlir/llvm-project/llvm -B llvm-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_BUILD_EXAMPLES=OFF \ @@ -159,16 +163,18 @@ jobs: uses: actions/cache/save@v4 with: path: llvm-build - key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{matrix.python_version}}-wheel-build + key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-${{matrix.python_version}}-wheel-build - name: Build MHLO Dialect if: steps.cache-mhlo-build.outputs.cache-hit != 'true' run: | export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH + # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher). export TARGET_FILE=mlir/mlir-hlo/mhlo/transforms/CMakeLists.txt export PATCH_FILE=mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi + if patch --dry-run -p1 -N --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch; fi cmake -S mlir/mlir-hlo -B mhlo-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ @@ -188,7 +194,7 @@ jobs: uses: actions/cache/save@v4 with: path: mhlo-build - key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build - name: Build Enzyme if: steps.cache-enzyme-build.outputs.cache-hit != 'true' @@ -246,7 +252,7 @@ jobs: uses: actions/cache/restore@v4 with: path: llvm-build - key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-3.10-wheel-build + key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-3.10-wheel-build fail-on-cache-miss: True - name: Get Cached MHLO Source @@ -263,7 +269,7 @@ jobs: uses: actions/cache/restore@v4 with: path: mhlo-build - key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build + key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build fail-on-cache-miss: True - name: Get Cached Enzyme Source @@ -319,6 +325,11 @@ jobs: # Build Quantum and Gradient Dialects - name: Build MLIR Dialects run: | + # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher). + if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch; fi + if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch; fi + if patch --dry-run -p1 -N --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch; fi + cmake -S mlir -B quantum-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_ENABLE_ASSERTIONS=ON \ diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 73f7becf16..6981040958 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -119,7 +119,7 @@ jobs: uses: actions/cache@v4 with: path: llvm-build - key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-default-build-${{ matrix.compiler }} + key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-build-${{ matrix.compiler }} - name: Install Deps if: steps.cache-llvm-build.outputs.cache-hit != 'true' @@ -175,7 +175,7 @@ jobs: uses: actions/cache@v4 with: path: mhlo-build - key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-default-build-${{ matrix.compiler }} + key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-build-${{ matrix.compiler }} - name: Get Cached LLVM Source id: cache-llvm-source @@ -193,7 +193,7 @@ jobs: uses: actions/cache@v4 with: path: llvm-build - key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-default-build-${{ matrix.compiler }} + key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-build-${{ matrix.compiler }} fail-on-cache-miss: true - name: Install Deps @@ -263,7 +263,7 @@ jobs: uses: actions/cache@v4 with: path: llvm-build - key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-default-build-${{ matrix.compiler }} + key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-build-${{ matrix.compiler }} fail-on-cache-miss: true - name: Install Deps @@ -315,7 +315,7 @@ jobs: uses: actions/cache/restore@v4 with: path: llvm-build - key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-default-build-${{ matrix.compiler }} + key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-build-${{ matrix.compiler }} fail-on-cache-miss: true - name: Get Cached MHLO Source @@ -332,7 +332,7 @@ jobs: uses: actions/cache/restore@v4 with: path: mhlo-build - key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-default-build-${{ matrix.compiler }} + key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-build-${{ matrix.compiler }} fail-on-cache-miss: true - name: Get Cached Enzyme Source @@ -363,6 +363,11 @@ jobs: - name: Build MLIR Dialects run: | + # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher). + if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch; fi + if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch; fi + if patch --dry-run -p1 -N --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch; fi + CCACHE_DIR="$(pwd)/.ccache" \ C_COMPILER=$(which ${{ needs.constants.outputs[format('c_compiler.{0}', matrix.compiler)] }}) \ CXX_COMPILER=$(which ${{ needs.constants.outputs[format('cxx_compiler.{0}', matrix.compiler)] }}) \ @@ -419,7 +424,7 @@ jobs: uses: actions/cache@v4 with: path: llvm-build - key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-default-build-${{ matrix.compiler }} + key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-build-${{ matrix.compiler }} fail-on-cache-miss: true - name: Download Quantum Build Artifact @@ -491,7 +496,7 @@ jobs: uses: actions/cache@v4 with: path: llvm-build - key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-default-build-${{ matrix.compiler }} + key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-build-${{ matrix.compiler }} fail-on-cache-miss: true - name: Download Quantum Build Artifact @@ -546,7 +551,7 @@ jobs: uses: actions/cache@v4 with: path: llvm-build - key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-default-build-${{ matrix.compiler }} + key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-build-${{ matrix.compiler }} fail-on-cache-miss: true - name: Download Quantum Build Artifact diff --git a/.github/workflows/scripts/linux_arm64/rh8/build_catalyst.sh b/.github/workflows/scripts/linux_arm64/rh8/build_catalyst.sh index bb517160f6..e666e5ecb9 100644 --- a/.github/workflows/scripts/linux_arm64/rh8/build_catalyst.sh +++ b/.github/workflows/scripts/linux_arm64/rh8/build_catalyst.sh @@ -37,6 +37,11 @@ export PATH=/catalyst/llvm-build/bin:/opt/_internal/cpython-${PYTHON_VERSION}.${ # Install python dependencies /usr/bin/python3 -m pip install pennylane pybind11 PyYAML cmake ninja delocate 'amazon-braket-pennylane-plugin>1.27.1' +# Patch LLVM and MHLO. TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher). +if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch; fi +if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch; fi +if patch --dry-run -p1 -N --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch; fi + # Build Catalyst runtime cmake -S runtime -B runtime-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ diff --git a/.github/workflows/scripts/linux_arm64/rh8/build_llvm.sh b/.github/workflows/scripts/linux_arm64/rh8/build_llvm.sh index 78bb6aadb8..b4ee206580 100644 --- a/.github/workflows/scripts/linux_arm64/rh8/build_llvm.sh +++ b/.github/workflows/scripts/linux_arm64/rh8/build_llvm.sh @@ -33,6 +33,10 @@ export PATH=/opt/_internal/cpython-${PYTHON_VERSION}.${PYTHON_SUBVERSION}/bin:/o # Install python dependencies /usr/bin/python3 -m pip install pennylane pybind11 PyYAML cmake ninja +# TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher). +if patch --dry-run -p1 -N --directory=/catalyst/mlir/llvm-project < /catalyst/mlir/patches/FunctionOpInterface-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=/catalyst/mlir/llvm-project < /catalyst/mlir/patches/FunctionOpInterface-bufferization.patch; fi +if patch --dry-run -p1 -N --directory=/catalyst/mlir/llvm-project < /catalyst/mlir/patches/callOp-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=/catalyst/mlir/llvm-project < /catalyst/mlir/patches/callOp-bufferization.patch; fi + # Build LLVM cmake -S /catalyst/mlir/llvm-project/llvm -B /catalyst/llvm-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ diff --git a/.github/workflows/scripts/linux_arm64/rh8/build_mhlo.sh b/.github/workflows/scripts/linux_arm64/rh8/build_mhlo.sh index 2a5b2e4fa7..e452ee22d4 100644 --- a/.github/workflows/scripts/linux_arm64/rh8/build_mhlo.sh +++ b/.github/workflows/scripts/linux_arm64/rh8/build_mhlo.sh @@ -38,7 +38,10 @@ sed -i -e 's/LINK_LIBS PUBLIC/LINK_LIBS PUBLIC MLIRDeallocationUtils/g' mlir/mli export TARGET_FILE=mlir/mlir-hlo/mhlo/transforms/CMakeLists.txt export PATCH_FILE=mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch +# TODO: Jax has merged this fix. Remove after JAX upgrade. if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi +# TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher). +if patch --dry-run -p1 -N --directory=/catalyst/mlir/mlir-hlo < /catalyst/mlir/patches/FunctionOpInterface-mhlo.patch > /dev/null 2>&1; then patch -p1 --directory=/catalyst/mlir/mlir-hlo < /catalyst/mlir/patches/FunctionOpInterface-mhlo.patch; fi # Build MHLO cmake -S /catalyst/mlir/mlir-hlo -B /catalyst/mhlo-build -G Ninja \ diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 3412e495bb..6ba963c728 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -6,6 +6,8 @@

Improvements 🛠

+* [(#1027)](https://github.com/PennyLaneAI/catalyst/pull/1027) Catalyst now supports `one-shot bufferize` from MLIR, which is required for JAX v0.4.29 or higher. +

Breaking changes 💔

Deprecations 👋

@@ -16,4 +18,7 @@

Contributors ✍️

-This release contains contributions from (in alphabetical order): \ No newline at end of file +This release contains contributions from (in alphabetical order): + +Tzung-Han Juang, +Erick Ochoa Lopez, diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index 5d3d117469..8a15985b3b 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -224,41 +224,177 @@ def run_writing_command(command: List[str], compile_options: Optional[CompileOpt ], ) +# From: https://mlir.llvm.org/docs/Bufferization/#overview +# +# Preprocessing +# | rewrite_in_destination_passing_style +# | -eliminate-empty-tensors +# Bufferization +# | -one-shot-bufferize +# Buffer-Level +# Optimizations +# | -buffer-hoisting +# | -buffer-loop-hoisting +# | -buffer-results-to-out-params +# | -drop-equivalent-buffer-results +# | -promote-buffers-to-stack +# Deallocation +# | -buffer-deallocation-pipeline + BUFFERIZATION_PASS = ( "BufferizationPass", [ - "one-shot-bufferize{dialect-filter=memref}", "inline", "gradient-preprocess", - "gradient-bufferize", - "scf-bufferize", - "convert-tensor-to-linalg", # tensor.pad - "convert-elementwise-to-linalg", # Must be run before --arith-bufferize - "arith-bufferize", - "empty-tensor-to-alloc-tensor", - "func.func(bufferization-bufferize)", - "func.func(tensor-bufferize)", - "catalyst-bufferize", # Must be run before -- func.func(linalg-bufferize) - "func.func(linalg-bufferize)", - "func.func(tensor-bufferize)", - "quantum-bufferize", - "func-bufferize", - "func.func(finalizing-bufferize)", - "canonicalize", # Remove dead memrefToTensorOp's - "gradient-postprocess", + "convert-elementwise-to-linalg", + "canonicalize", + # Preprocessing: + # rewrite_in_destination_passing_style + # + # We are not rewriting everything in DPS before -one-shot-bufferize + # This was discussed with the main author of the -one-shot-bufferize + # pass and he stated the following: + # + # One-Shot Bufferize was designed for ops that are in DPS (destination-passing style). + # Ops that are not in DPS can still be bufferized, + # but a new buffer will be allocated for every tensor result. + # That’s functionally correct but inefficient. + # + # I’m not sure whether it’s better to first migrate to the new bufferization, + # then turn the ops into DPS ops, or do it the other way around. + # One benefit of implementing the bufferization first is that + # it’s a smaller step that you can already run end-to-end. + # And you can think of the DPS of a performance improvement on top of it. + # + # https://discourse.llvm.org/t/steps-of-migrating-to-one-shot-bufferization/81062/2 + # + # Here, please note that gradient-preprocessing is different than rewriting in DPS. + # So, overall, we are skipping this section while we first focus on migrating to the + # new -one-shot-bufferize + "eliminate-empty-tensors", + ( + # Before we enter one-shot-bufferize, here is what we expect: + # * Given + # + # One-Shot Bufferize was designed for ops that are in DPS + # (destination-passing style). + # Ops that are not in DPS can still be bufferized, + # but a new buffer will be allocated for every tensor result. + # That’s functionally correct but inefficient. + # + # https://discourse.llvm.org/t/steps-of-migrating-to-one-shot-bufferization/81062/2 + # + # we expect that results will be (automatically?) converted into new buffers. And it + # is up to us to just define the bufferization for the operands. + # + # So what is the state of the catalyst, gradient, quantum dialects at this point? + # + # Let's start with quantum: + # + # |-------------------------|--------------------| + # | operation | has result tensor | + # |-------------------------|--------------------| + # | quantum.set_state | | + # | quantum.set_basis_state | | + # | quantum.unitary | | + # | quantum.hermitian | | + # | quantum.hamiltonian | | + # | quantum.sample_op | YES | + # | quantum.counts_op | YES | + # | quantum.probs_op | YES | + # | quantum.state_op | YES | + # |-------------------------|--------------------| + # | catalyst.print_op | | + # | catalyst.custom_call | YES | + # | catalyst.callback | | + # | catalyst.callback_call | YES | + # | catalyst.launch_kernel | YES | + # |-------------------------|--------------------| + # | gradient.grad | YES | + # | gradient.value_and_grad | YES | + # | gradient.adjoint | YES | + # | gradient.backprop | YES | + # | gradient.jvp | YES | + # | gradient.vjp | YES | + # | gradient.forward | YES | + # | gradient.reverse | YES | + # |-------------------------|--------------------| + # + # So what this means is that for the operands, all the ones that have the YES + # means that no operands are written to. They are only read. + "one-shot-bufferize" + "{" + "bufferize-function-boundaries " + # - Bufferize function boundaries (experimental). + # + # By default, function boundaries are not bufferized. + # This is because there are currently limitations around function graph + # bufferization: + # recursive calls are not supported. + # As long as there are no recursive calls, function boundary bufferization can be + # enabled with bufferize-function-boundaries. + # Each tensor function argument and tensor function result is then turned into a memref. + # The layout map of the memref type can be controlled with function-boundary-type-conversion. + # + # https://mlir.llvm.org/docs/Bufferization/#using-one-shot-bufferize + "allow-return-allocs-from-loops " + # - Allows returning/yielding new allocations from a loop. + # https://github.com/llvm/llvm-project/pull/83964 + # https://github.com/llvm/llvm-project/pull/87594 + "function-boundary-type-conversion=identity-layout-map" + # - Controls layout maps when bufferizing function signatures. + # You can control the memref types at the function boundary with + # function-boundary-type-conversion. E.g., if you set it to identity-layout-map, + # you should get the same type as with --func-bufferize. + # By default, we put a fully dynamic layout map strided<[?, ?], offset: ?> + # because that works best if you don't know what layout map the buffers at + # the call site have -- you can always cast a buffer to a type with + # fully dynamic layout map. (But not the other way around. That may require a + # reallocation.) + # + # https://discord.com/channels/636084430946959380/642426447167881246/1212338527824515102 + "}" + ), + # Remove dead memrefToTensorOp's # introduced during gradient-bufferize of callbacks + # TODO: Figure out how to remove this. + "gradient-postprocess", "func.func(buffer-hoisting)", "func.func(buffer-loop-hoisting)", + # TODO: Figure out how to include the other buffer-level optimizations. + # -buffer-results-to-out-params, + # -drop-equivalent-buffer-results, + # -promote-buffers-to-stack + # Deallocation + # The buffer deallocation pass has been deprecated in favor of the + # ownership-based buffer deallocation pipeline. + # The deprecated pass has some limitations that may cause memory leaks in the resulting IR. + # TODO: Switch to one-shot-bufferization once it is merged. "func.func(buffer-deallocation)", + # catalyst.list_* operations are not bufferized through + # the bufferization interface + # This is because they store a memref inside of a memref + # which is incompatible with the bufferization pipeline. "convert-arraylist-to-memref", "convert-bufferization-to-memref", - "canonicalize", # Must be after convert-bufferization-to-memref + # Must be after convert-bufferization-to-memref # otherwise there are issues in lowering of dynamic tensors. + "canonicalize", # "cse", "cp-global-memref", ], ) +BUFFERIZATION_ASYNC_PASS = ( + "BufferizationPass", + [ + # TODO: Can we remove copy-before-write? + # copy-before-write: + # Skip the analysis. Make a buffer copy on every write. + s.replace("}", " copy-before-write}") if s.startswith("one-shot-bufferize") else s + for s in BUFFERIZATION_PASS[1] + ], +) MLIR_TO_LLVM_PASS = ( "MLIRToLLVMDialect", @@ -328,7 +464,7 @@ def run_writing_command(command: List[str], compile_options: Optional[CompileOpt ENFORCE_RUNTIME_INVARIANTS_PASS, HLO_LOWERING_PASS, QUANTUM_COMPILATION_PASS, - BUFFERIZATION_PASS, + BUFFERIZATION_ASYNC_PASS, MLIR_TO_LLVM_ASYNC_PASS, ] diff --git a/mlir/Makefile b/mlir/Makefile index 2cb8de388a..3d4e737f89 100644 --- a/mlir/Makefile +++ b/mlir/Makefile @@ -12,6 +12,12 @@ ENZYME_BUILD_DIR?=$(MK_DIR)/Enzyme/build RT_BUILD_DIR?=$(MK_DIR)/../runtime/build ENABLE_ASAN?=OFF BUILD_TYPE?=Release +# TODO: remove after JAX upgrade +LLVM_ROOT=$(MK_DIR)/llvm-project +LLVM_FUNCOP_PATCH_FILE=$(MK_DIR)/patches/FunctionOpInterface-bufferization.patch +LLVM_FUNC_CALL_PATCH_FILE=$(MK_DIR)/patches/callOp-bufferization.patch +MHLO_ROOT?=$(MK_DIR)/mlir-hlo +MHLO_MODULE_PATCH_FILE=$(MK_DIR)/patches/FunctionOpInterface-mhlo.patch TARGET_FILE=$(MK_DIR)/mlir-hlo/mhlo/transforms/CMakeLists.txt PATCH_FILE=$(MK_DIR)/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch @@ -53,7 +59,14 @@ all: llvm mhlo enzyme dialects .PHONY: llvm llvm: + # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher). @echo "build LLVM and MLIR enabling Python bindings" + @if patch --dry-run -p1 -N --directory=$(LLVM_ROOT) < $(LLVM_FUNCOP_PATCH_FILE) > /dev/null 2>&1; then \ + patch -p1 --directory=$(LLVM_ROOT) < $(LLVM_FUNCOP_PATCH_FILE); \ + fi + @if patch --dry-run -p1 -N --directory=$(LLVM_ROOT) < $(LLVM_FUNC_CALL_PATCH_FILE) > /dev/null 2>&1; then \ + patch -p1 --directory=$(LLVM_ROOT) < $(LLVM_FUNC_CALL_PATCH_FILE); \ + fi cmake -G Ninja -S llvm-project/llvm -B $(LLVM_BUILD_DIR) \ -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ -DLLVM_BUILD_EXAMPLES=OFF \ @@ -85,6 +98,10 @@ mhlo: @if patch --dry-run -p1 -N $(TARGET_FILE) $(PATCH_FILE) > /dev/null 2>&1; then \ patch -p1 $(TARGET_FILE) $(PATCH_FILE); \ fi + # TODO: Remove this patch after upgrading Jax (potentailly for 0.4.34 or higher). + @if patch --dry-run -p1 -N --directory=$(MHLO_ROOT) < $(MHLO_MODULE_PATCH_FILE) > /dev/null 2>&1; then \ + patch -p1 --directory=$(MHLO_ROOT) < $(MHLO_MODULE_PATCH_FILE); \ + fi cmake -G Ninja -S mlir-hlo -B $(MHLO_BUILD_DIR) \ -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ -DLLVM_ENABLE_ASSERTIONS=ON \ diff --git a/mlir/include/Catalyst/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/Catalyst/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 index 0000000000..ec20d6f6c9 --- /dev/null +++ b/mlir/include/Catalyst/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,9 @@ +#pragma once + +using namespace mlir; + +namespace catalyst { + +void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry ®istry); + +} // namespace catalyst diff --git a/mlir/include/Gradient/IR/GradientOps.h b/mlir/include/Gradient/IR/GradientOps.h index c6f6afadfe..a54e110043 100644 --- a/mlir/include/Gradient/IR/GradientOps.h +++ b/mlir/include/Gradient/IR/GradientOps.h @@ -21,6 +21,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "Gradient/IR/GradientInterfaces.h" diff --git a/mlir/include/Gradient/IR/GradientOps.td b/mlir/include/Gradient/IR/GradientOps.td index e2d7660e38..3a73403b6d 100644 --- a/mlir/include/Gradient/IR/GradientOps.td +++ b/mlir/include/Gradient/IR/GradientOps.td @@ -17,6 +17,7 @@ include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/OpBase.td" @@ -388,7 +389,7 @@ def ReverseOp : Gradient_Op<"reverse", } def ReturnOp : Gradient_Op<"return", - [Terminator, ParentOneOf<["ForwardOp", "ReverseOp"]>]> { + [ReturnLike, Terminator, ParentOneOf<["ForwardOp", "ReverseOp"]>]> { let summary = "Return tapes or nothing"; diff --git a/mlir/include/Gradient/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/Gradient/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 index 0000000000..ae5096eb39 --- /dev/null +++ b/mlir/include/Gradient/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,13 @@ +#pragma once + +using namespace mlir; + +namespace catalyst { + +namespace gradient { + +void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry ®istry); + +} + +} // namespace catalyst diff --git a/mlir/include/Quantum/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/Quantum/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 index 0000000000..bf60013f70 --- /dev/null +++ b/mlir/include/Quantum/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,13 @@ +#pragma once + +using namespace mlir; + +namespace catalyst { + +namespace quantum { + +void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry ®istry); + +} + +} // namespace catalyst diff --git a/mlir/lib/Bufferization.md b/mlir/lib/Bufferization.md new file mode 100644 index 0000000000..14eb74599b --- /dev/null +++ b/mlir/lib/Bufferization.md @@ -0,0 +1,23 @@ +**Bufferization Interfaces:** + +| Bufferizable Operations | PrintOp | CustomCallOp | CallbackOp | CallbackCallOp | AdjointOp | BackpropOp | ForwardOp | ReverseOp | QubitUnitaryOp | HermitianOp | HamiltonianOp | SampleOp | StateOp | ProbsOp | CountsOp | SetStateOp | SetBasisStateOp | +| --------------------------------| ---------| ------------ | ------------ | -------------- | --------- | ---------- | --------- | --------- | -------------- | ----------- | ------------- | -------- | ------- | ------- | -------- | ---------- | --------------- | +| Catagory | catalyst | catalyst | catalyst | catalyst | gradient | gradient | gradient | gradient | quantum | quantum | quantum | quantum | quantum | quantum | quantum | quantum | quantum | +| bufferizesToAllocation | | true | | true | | | | | | | | | | | | | | +| bufferizesToMemoryRead | true | true | | true | true | true | | | true | true | true | false | false | false | false | false | false | +| bufferizesToMemoryWrite | false | false | | false | false | true | | | false | false | false | false | false | false | false | false | false | +| bufferizesToElementwiseAccess | | | | | | | | | | | | | | | | | | +| resultBufferizesToMemoryWrite | | | | | | | | | | | | | | | | | | +| mustBufferizeInPlace | | | | | | | | | | | | | | | | | | +| getAliasingValues | {} | {} | | {} | {} | {} | | | {} | {} | {} | {} | {} | {} | {} | {} | {} | +| getAliasingOpOperands | | | {} | | | | v | v | | | | | | | | | | +| resolveConflicts | | | | | | | | | | | | | | | | | | +| bufferize | v | v | v | v | v | v | v | v | v | v | v | v | v | v | v | v | v | +| isWritable | | | | | | | | | | | | | | | | | | +| isNotConflicting | | | | | | | | | | | | | | | | | | +| verifyAnalysis | | | | | | | v | v | | | | | | | | | | +| getBufferType | | | | | | | v | v | | | | | | | | | | +| isRepetitiveRegion | | | | | | | | | | | | | | | | | | +| isParallelRegion | | | | | | | | | | | | | | | | | | +| hasTensorSemantics | | | v | | | | v | v | | | | | | | | | | +| supportsUnstructuredControlFlow | | | false | | | | true | true | | | | | | | | | | \ No newline at end of file diff --git a/mlir/lib/Catalyst/IR/CatalystDialect.cpp b/mlir/lib/Catalyst/IR/CatalystDialect.cpp index 1dce30c4cc..01419fddcb 100644 --- a/mlir/lib/Catalyst/IR/CatalystDialect.cpp +++ b/mlir/lib/Catalyst/IR/CatalystDialect.cpp @@ -14,6 +14,7 @@ #include "Catalyst/IR/CatalystDialect.h" #include "Catalyst/IR/CatalystOps.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" // needed for generated type parser #include "mlir/Interfaces/FunctionImplementation.h" @@ -40,6 +41,8 @@ void CatalystDialect::initialize() #define GET_OP_LIST #include "Catalyst/IR/CatalystOps.cpp.inc" >(); + declarePromisedInterfaces(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp b/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp index a0c27129cb..e811668e89 100644 --- a/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp +++ b/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp @@ -215,7 +215,7 @@ std::optional AsyncUtils::getCalleeSafe(LLVM::CallOp callOp) bool AsyncUtils::isFunctionNamed(LLVM::LLVMFuncOp funcOp, llvm::StringRef expectedName) { llvm::StringRef observedName = funcOp.getSymName(); - return observedName.equals(expectedName); + return observedName.compare(expectedName) == 0; } bool AsyncUtils::isMlirAsyncRuntimeCreateValue(LLVM::LLVMFuncOp funcOp) diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp new file mode 100644 index 0000000000..c9033716d5 --- /dev/null +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,293 @@ +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "Catalyst/IR/CatalystOps.h" +#include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h" + +using namespace mlir; +using namespace catalyst; + +namespace { +/** + * The new bufferization interface requires `bufferizesToMemoryRead`, `bufferizesToMemoryWrite`, + * and `getAliasingValues`. + * + * `bufferizesToMemoryRead`: Return `true` if the buffer of the given tensor OpOperand is read. + * + * `bufferizesToMemoryWrite`: Return `true` if the buffer of the given tensor OpOperand is written + * (if bufferizing in-place). + * + * `getAliasingOpOperands`: Return the OpResults that may share the same buffer as the given + * OpOperand. + * + * Link: https://mlir.llvm.org/docs/Bufferization/#extending-one-shot-bufferize + */ + +/// Bufferization of catalyst.print. Get memref of printOp.val. +struct PrintOpInterface + // PrintOp will never write to the buffers + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto printOp = cast(op); + if (printOp.getVal()) { + FailureOr source = getBuffer(rewriter, printOp.getVal(), options); + if (failed(source)) + return failure(); + bufferization::replaceOpWithNewBufferizedOp( + rewriter, op, *source, printOp.getConstValAttr(), printOp.getPrintDescriptorAttr()); + } + return success(); + } +}; + +/// Bufferization of catalyst.custom_call. Mainly get buffers for arguments. +struct CustomCallOpInterface + // CustomCallOp will interface with BLAS functions. + // This operations is not in DPS form. This means that + // if we can guarantee operands are never written to, then we can set + // bufferizesToMemoryWrite as false. + // Results will be allocated a new buffer. + // TODO: Double check BLAS and others. Until then, it should be safe to keep + // bufferizesToMemoryWrite as True. + : public bufferization::BufferizableOpInterface::ExternalModel { + + bool bufferizesToAllocation(Operation *op, Value value) const { return true; } + + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto customCallOp = cast(op); + + // Add bufferized arguments + SmallVector bufferArgs; + ValueRange operands = customCallOp.getOperands(); + for (Value operand : operands) { + FailureOr opBuffer = getBuffer(rewriter, operand, options); + if (failed(opBuffer)) + bufferArgs.push_back(operand); + else + bufferArgs.push_back(*opBuffer); + } + + // Add bufferized return values to the arguments + ValueRange results = customCallOp.getResults(); + for (Value result : results) { + Type resultType = result.getType(); + RankedTensorType tensorType = dyn_cast(resultType); + if (!tensorType) { + bufferArgs.push_back(result); + continue; + } + auto options = bufferization::BufferizationOptions(); + FailureOr tensorAlloc = bufferization::allocateTensorForShapedValue( + rewriter, op->getLoc(), result, options, false); + MemRefType memrefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + auto newBuffer = + rewriter.create(op->getLoc(), memrefType, *tensorAlloc); + bufferArgs.push_back(newBuffer); + } + + // Add the initial number of arguments + int32_t numArguments = static_cast(customCallOp.getNumOperands()); + DenseI32ArrayAttr numArgumentsDenseAttr = rewriter.getDenseI32ArrayAttr({numArguments}); + + // Create an updated custom call operation + rewriter.create(op->getLoc(), TypeRange{}, bufferArgs, + customCallOp.getCallTargetName(), numArgumentsDenseAttr); + size_t startIndex = bufferArgs.size() - customCallOp.getNumResults(); + SmallVector bufferResults(bufferArgs.begin() + startIndex, bufferArgs.end()); + bufferization::replaceOpWithBufferizedValues(rewriter, op, bufferResults); + + return success(); + } +}; + +struct CallbackOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool hasTensorSemantics(Operation *op) const + { + auto isaTensor = llvm::IsaPred; + + // A function has tensor semantics if it has tensor arguments/results. + auto callbackOp = cast(op); + bool hasTensorArg = any_of(callbackOp.getArgumentTypes(), isaTensor); + bool hasTensorResult = any_of(callbackOp.getResultTypes(), isaTensor); + if (hasTensorArg || hasTensorResult) + return true; + + return false; + } + + bufferization::AliasingOpOperandList + getAliasingOpOperands(Operation *op, Value value, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto callbackOp = cast(op); + + auto argTys = callbackOp.getArgumentTypes(); + auto retTys = callbackOp.getResultTypes(); + SmallVector emptyRets; + SmallVector args(argTys.begin(), argTys.end()); + args.insert(args.end(), retTys.begin(), retTys.end()); + SmallVector bufferArgs; + for (Type ty : args) { + auto tensorType = dyn_cast(ty); + if (!tensorType) + bufferArgs.push_back(ty); + else + bufferArgs.push_back( + MemRefType::get(tensorType.getShape(), tensorType.getElementType())); + } + auto callbackTy = rewriter.getFunctionType(bufferArgs, emptyRets); + rewriter.modifyOpInPlace(op, [&] { callbackOp.setFunctionType(callbackTy); }); + + return success(); + } +}; + +struct CallbackCallOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToAllocation(Operation *op, Value value) const { return true; } + + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + // We can safely say false because CallbackCallOp's memrefs + // will be put in a JAX array and JAX arrays are immutable. + // + // Unlike NumPy arrays, JAX arrays are always immutable. + // + // https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto callOp = cast(op); + + bufferization::BufferizeTypeConverter typeConverter; + + SmallVector convertedResults; + if (failed(typeConverter.convertTypes(callOp.getResultTypes(), convertedResults))) + return failure(); + + if (callOp->getNumResults() != convertedResults.size()) + return failure(); + + SmallVector newInputs; + auto operands = callOp.getOperands(); + for (Value operand : operands) { + FailureOr opBuffer = getBuffer(rewriter, operand, options); + if (failed(opBuffer)) + return failure(); + newInputs.push_back(*opBuffer); + } + + auto results = callOp.getResults(); + auto loc = callOp->getLoc(); + SmallVector outmemrefs; + for (auto result : results) { + FailureOr tensorAlloc = + bufferization::allocateTensorForShapedValue(rewriter, loc, result, options, false); + if (failed(tensorAlloc)) + return failure(); + + auto tensor = *tensorAlloc; + RankedTensorType tensorTy = cast(tensor.getType()); + auto shape = tensorTy.getShape(); + auto elementTy = tensorTy.getElementType(); + auto memrefType = MemRefType::get(shape, elementTy); + auto toMemrefOp = rewriter.create(loc, memrefType, tensor); + auto memref = toMemrefOp.getResult(); + outmemrefs.push_back(memref); + newInputs.push_back(memref); + } + + SmallVector emptyRets; + rewriter.create(loc, emptyRets, callOp.getCallee(), newInputs); + bufferization::replaceOpWithBufferizedValues(rewriter, op, outmemrefs); + return success(); + } +}; + +} // namespace + +void catalyst::registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) +{ + registry.addExtension(+[](MLIRContext *ctx, CatalystDialect *dialect) { + CustomCallOp::attachInterface(*ctx); + PrintOp::attachInterface(*ctx); + CallbackOp::attachInterface(*ctx); + CallbackCallOp::attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index 77fb4d64b5..07f221469f 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -4,6 +4,12 @@ file(GLOB SRC ApplyTransformSequencePass.cpp ArrayListToMemRefPass.cpp + scatter_lowering.cpp + ScatterPatterns.cpp + qnode_to_async_lowering.cpp + QnodeToAsyncPatterns.cpp + RegisterAllPasses.cpp + BufferizableOpInterfaceImpl.cpp AsyncUtils.cpp BufferizationPatterns.cpp catalyst_bufferize.cpp diff --git a/mlir/lib/Driver/CompilerDriver.cpp b/mlir/lib/Driver/CompilerDriver.cpp index 9fc3e6f9bf..a5f275adbc 100644 --- a/mlir/lib/Driver/CompilerDriver.cpp +++ b/mlir/lib/Driver/CompilerDriver.cpp @@ -56,6 +56,7 @@ #include "llvm/Transforms/IPO/GlobalDCE.h" #include "Catalyst/IR/CatalystDialect.h" +#include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h" #include "Catalyst/Transforms/Passes.h" #include "Driver/CatalystLLVMTarget.h" #include "Driver/CompilerDriver.h" @@ -63,10 +64,12 @@ #include "Driver/Support.h" #include "Gradient/IR/GradientDialect.h" #include "Gradient/IR/GradientInterfaces.h" +#include "Gradient/Transforms/BufferizableOpInterfaceImpl.h" #include "Gradient/Transforms/Passes.h" #include "Mitigation/IR/MitigationDialect.h" #include "Mitigation/Transforms/Passes.h" #include "Quantum/IR/QuantumDialect.h" +#include "Quantum/Transforms/BufferizableOpInterfaceImpl.h" #include "Quantum/Transforms/Passes.h" #include "Enzyme.h" @@ -75,6 +78,7 @@ using namespace mlir; using namespace catalyst; using namespace catalyst::driver; +using namespace catalyst::quantum; namespace cl = llvm::cl; namespace catalyst::utils { @@ -296,6 +300,11 @@ void registerAllCatalystDialects(DialectRegistry ®istry) registry.insert(); registry.insert(); registry.insert(); + + // Extend one-shot bufferization pass. + catalyst::registerBufferizableOpInterfaceExternalModels(registry); + catalyst::quantum::registerBufferizableOpInterfaceExternalModels(registry); + catalyst::gradient::registerBufferizableOpInterfaceExternalModels(registry); } } // namespace diff --git a/mlir/lib/Gradient/IR/GradientDialect.cpp b/mlir/lib/Gradient/IR/GradientDialect.cpp index 4d9cfddb00..068079b99f 100644 --- a/mlir/lib/Gradient/IR/GradientDialect.cpp +++ b/mlir/lib/Gradient/IR/GradientDialect.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "Gradient/IR/GradientDialect.h" @@ -50,6 +51,8 @@ void GradientDialect::initialize() #include "Gradient/IR/GradientOps.cpp.inc" >(); addInterface(); + declarePromisedInterfaces(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp new file mode 100644 index 0000000000..99b1b4396b --- /dev/null +++ b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,614 @@ +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "Gradient/IR/GradientOps.h" +#include "Gradient/Transforms/BufferizableOpInterfaceImpl.h" +#include "Gradient/Utils/GradientShape.h" +#include "Quantum/IR/QuantumOps.h" +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; +using namespace catalyst::gradient; + +namespace { +/** + * The new bufferization interface requires `bufferizesToMemoryRead`, `bufferizesToMemoryWrite`, + * and `getAliasingValues`. + * + * `bufferizesToMemoryRead`: Return `true` if the buffer of the given tensor OpOperand is read. + * + * `bufferizesToMemoryWrite`: Return `true` if the buffer of the given tensor OpOperand is written + * (if bufferizing in-place). + * + * `getAliasingOpOperands`: Return the OpResults that may share the same buffer as the given + * OpOperand. Note that MLIR documentation does not mention `getAliasingValues` but it seems to + * serve the same purpose. + * + * Bufferizing FunctionOpInterface is also not documented by MLIR. It requires + * `OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel`, which requires the + * implementation of `supportsUnstructuredControlFlow`, `hasTensorSemantics`, and + * `getAliasingOpOperands`. + * + * Link: https://mlir.llvm.org/docs/Bufferization/#extending-one-shot-bufferize + */ + +static BaseMemRefType +getBufferizedFunctionArgType(FunctionOpInterface funcOp, int64_t index, + const bufferization::BufferizationOptions &options) +{ + auto tensorType = dyn_cast(funcOp.getArgument(index).getType()); + assert(tensorType && "expected TensorType"); + + BaseMemRefType memrefType = options.functionArgTypeConverterFn( + tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options); + + auto layoutAttr = funcOp.getArgAttrOfType( + index, bufferization::BufferizationDialect::kBufferLayoutAttrName); + if (!layoutAttr) + return memrefType; + + auto rankedMemrefType = dyn_cast(memrefType); + assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); + return MemRefType::get(rankedMemrefType.getShape(), rankedMemrefType.getElementType(), + layoutAttr.getValue(), rankedMemrefType.getMemorySpace()); +} + +static ReturnOp getAssumedUniqueReturnOp(FunctionOpInterface funcOp) +{ + ReturnOp returnOp; + for (Block &b : funcOp.getFunctionBody()) { + if (auto candidateOp = dyn_cast(b.getTerminator())) { + if (returnOp) + return nullptr; + returnOp = candidateOp; + } + } + return returnOp; +} + +Value generateAllocation(OpBuilder &builder, Location loc, Value reference) +{ + auto origMemrefType = cast(reference.getType()); + // TODO: Investigate how to get rid of identity-layout-map + // + // Hi all. For one-shot-bufferization, is there any automatic way to pass all memref symbols + // to AllocOp? we have an example below that triggers error: 'memref.alloc' op symbol + // operand count does not equal memref symbol count: expected 1, got 0 . We think we have + // to pass the offset symbol to AllocOp. + // + // %0 = "bufferization.to_memref"(%arg0) : (tensor) -> memref> %1 = "memref.alloc"() <{operandSegmentSizes = array}> : () -> + // memref> + // + // We know we can set function-signature-type-conversion=identity-layout-map to get rid of + // it. But according to the document, identity-layout-map could be less efficient, we still + // want to stick with the default setting. + // + // https://discord.com/channels/636084430946959380/642426447167881246/1281620504859512914 + // + // Something looks odd here. + // The result of a `memref.alloc` should be a memref without identity layout. + // I know that the op supports operands for dims/symbols in the memref type, + // but I never understood why. + // Imo, a `memref.alloc() : memref` should have been generated. + // The result value can then be casted to `memref>`. + // + // https://discord.com/channels/636084430946959380/642426447167881246/1281710682160627785 + // + // What I find interesting is that the comment says that + // + // "we know we can set function-signature-type-conversion=identity-layout-map to get rid of + // it" + // + // and that is what we are using, however we still have this rebuilding a memref without the + // layout. If that were true, then we could uncomment the following line and it should work. + // auto memrefType = origMemrefType; + // I can confirm that having + // function-signature-type-conversion=identity-layout-map makes the line above succed while the + // line below fail: + // + // Get dynamic dimension sizes from the provided reference value if necessary. + auto memrefType = MemRefType::get(origMemrefType.getShape(), origMemrefType.getElementType()); + // + // Looking at this a little bit deeper, I can say that the variable reference + // appears to come from a function parameter. + // and since it is not the identity layout, then we see the following generic MLIR when not + // using identity layout + // + // "func.func"() <{function_type = (memref>) -> memref> + // + // and we see this when using the identity layout: + // + // func.func public @jit_fn(%arg0: memref) -> memref + // + // When not using identity layout but also not removing the layout in the alloca, there are + // errors in some cases but not in others. I believe we have to do some casts in other places as + // well, whenever we use allocas and the types come from the arguments. + // + // My recommendation: at some point it would be good to remove the identity-layout-map from the + // frontend but until we have some more resources, let's keep it along with the origMemrefType. + + SmallVector dynamicDims; + if (!memrefType.hasStaticShape()) { + for (int64_t dim = 0; dim < memrefType.getRank(); dim++) { + if (memrefType.isDynamicDim(dim)) { + Value dimIndex = builder.create(loc, dim); + dynamicDims.push_back(builder.create(loc, reference, dimIndex)); + } + } + } + + return builder.create(loc, memrefType, dynamicDims); + // Uncomment below to follow Matthias suggestion of placing a CastOp after AllocOp + // some more tests will pass. + // return builder.create(loc, origMemrefType, alloc_uncasted); +} + +/// Helper function to generate a set of memref allocations. +/// +/// The allocation size and shape is deduced from a list of existing memref values. +/// +void generateAllocations(RewriterBase &rewriter, Location loc, SmallVectorImpl &allocations, + ValueRange referenceValues) +{ + for (Value memref : referenceValues) { + allocations.push_back( + generateAllocation(rewriter, loc, cast>(memref))); + } +} + +struct AdjointOpInterface + // This operation is not in DPS style and I believe that operands will only be read. + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto adjointOp = cast(op); + + bufferization::BufferizeTypeConverter typeConverter; + + SmallVector resTypes; + if (failed(typeConverter.convertTypes(adjointOp.getResultTypes(), resTypes))) + return failure(); + + Location loc = adjointOp.getLoc(); + Value gradSize = adjointOp.getGradSize(); + SmallVector memrefValues; + for (Type resType : resTypes) { + MemRefType memrefType = cast(resType); + Value memrefValue = rewriter.create(loc, memrefType, gradSize); + memrefValues.push_back(memrefValue); + } + + SmallVector bufferArgs; + ValueRange operands = adjointOp.getArgs(); + for (Value operand : operands) { + FailureOr opBuffer = getBuffer(rewriter, operand, options); + if (failed(opBuffer)) + return failure(); + bufferArgs.push_back(*opBuffer); + } + + rewriter.create(loc, TypeRange{}, adjointOp.getCalleeAttr(), + adjointOp.getGradSize(), bufferArgs, memrefValues); + bufferization::replaceOpWithBufferizedValues(rewriter, op, memrefValues); + return success(); + } +}; + +struct BackpropOpInterface + // This operation is not in DPS style + // but it has a lot of parameters, notably: + // Variadic: $args + // Variadic<...RankedTensorOf<[AnyFloat]>>: $cotangents + // I think we don't write to the cotangents. And also not to the arguments + // so we can set bufferizesToMemoryWrite as false. + // The safe assumption is that it should be true. + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto backpropOp = cast(op); + + Location loc = backpropOp.getLoc(); + SmallVector gradients; + SmallVector argShadows; + // Conceptually a map from scalar result indices (w.r.t. other scalars) to the position in + // the overall list of returned gradients. + // For instance, a backprop op that returns (tensor, f64, tensor, f64, f64) will have + // scalarIndices = {1, 3, 4}. + SmallVector scalarIndices; + SmallVector scalarReturnTypes; + + SmallVector bufferArgs; + ValueRange operands = backpropOp.getArgs(); + for (Value operand : operands) { + if (isa(operand.getType())) { + FailureOr opBuffer = getBuffer(rewriter, operand, options); + if (failed(opBuffer)) + return failure(); + bufferArgs.push_back(*opBuffer); + } + else { + bufferArgs.push_back(operand); + } + } + + std::vector diffArgs = + computeDiffArgs(bufferArgs, backpropOp.getDiffArgIndicesAttr()); + + for (const auto &[idx, diffArg] : llvm::enumerate(diffArgs)) { + // Allocate buffers to place the differentiation results (gradients) into. Enzyme refers + // to these as shadow arguments. There is one result for each differentiable MemRef + // argument, with a matching shape and type. + if (isa(diffArg.getType())) { + Value shadow = generateAllocation(rewriter, loc, diffArg); + gradients.push_back(shadow); + argShadows.push_back(shadow); + } + else if (isa(diffArg.getType())) { + scalarReturnTypes.push_back(diffArg.getType()); + scalarIndices.push_back(idx); + // Put a null placeholder value that will be filled in with the result of the + // bufferized BackpropOp. + gradients.push_back(Value()); + } + } + + // Enzyme requires buffers for the primal outputs as well, even though we don't need their + // values. We'll mark them dupNoNeed later on to allow Enzyme to optimize away their + // computation. + SmallVector calleeResults, resShadows; + ValueRange cotangents = backpropOp.getCotangents(); + SmallVector bufferCotangentsList; + for (Value operand : cotangents) { + FailureOr opBuffer = getBuffer(rewriter, operand, options); + if (failed(opBuffer)) + return failure(); + bufferCotangentsList.push_back(*opBuffer); + } + mlir::ValueRange bufferCotangents(bufferCotangentsList); + + generateAllocations(rewriter, loc, calleeResults, bufferCotangents); + // Enzyme mutates the result shadows but the cotangent tensors must be immutable, so we + // create copies to pass into Enzyme. Concretely, this issue pops up with multiple + // BackpropOps that have the same cotangent tensor due to a CSE effect from one-shot + // bufferization. + generateAllocations(rewriter, loc, resShadows, bufferCotangents); + for (const auto &[cotangent, resShadow] : llvm::zip(bufferCotangents, resShadows)) { + rewriter.create(loc, cotangent, resShadow); + } + + DenseIntElementsAttr diffArgIndicesAttr = backpropOp.getDiffArgIndices().value_or(nullptr); + auto bufferizedBackpropOp = rewriter.create( + loc, TypeRange{}, scalarReturnTypes, backpropOp.getCalleeAttr(), bufferArgs, argShadows, + calleeResults, resShadows, diffArgIndicesAttr, backpropOp.getKeepValueResultsAttr()); + // Fill in the null placeholders. + for (const auto &[idx, scalarResult] : + llvm::enumerate(bufferizedBackpropOp.getGradients())) { + gradients[scalarIndices[idx]] = scalarResult; + } + + // BackpropOp can return two results for value_and_grad: values and gradients + // or only one for grad: gradients + SmallVector results; + { + // If we are lowering a value_and_grad operation, then take values from the + // calleeResults + if (!backpropOp.getVals().empty()) { + results.insert(results.end(), calleeResults.begin(), calleeResults.end()); + } + results.insert(results.end(), gradients.begin(), gradients.end()); + } + + bufferization::replaceOpWithBufferizedValues(rewriter, op, results); + return success(); + } +}; + +struct ForwardOpInterface + : public bufferization::OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel< + ForwardOpInterface, ForwardOp> { + static bool supportsUnstructuredControlFlow() { return false; } + + bool hasTensorSemantics(Operation *op) const + { + auto isaTensor = llvm::IsaPred; + + // A function has tensor semantics if it has tensor arguments/results. + auto forwardOp = cast(op); + bool hasTensorArg = any_of(forwardOp.getArgumentTypes(), isaTensor); + bool hasTensorResult = any_of(forwardOp.getResultTypes(), isaTensor); + bool hasTensorFuncInType = any_of(forwardOp.getFunctionType().getInputs(), isaTensor); + bool hasTensorFuncOutType = any_of(forwardOp.getFunctionType().getResults(), isaTensor); + if (hasTensorArg || hasTensorResult || hasTensorFuncInType || hasTensorFuncOutType) + return true; + + return false; + } + + bufferization::AliasingOpOperandList + getAliasingOpOperands(Operation *op, Value value, + const bufferization::AnalysisState &state) const + { + return {}; + } + + FailureOr getBufferType(Operation *op, Value value, + const bufferization::BufferizationOptions &options, + SmallVector &invocationStack) const + { + auto forwardOp = cast(op); + auto bbArg = cast(value); + + // Function arguments are special. + if (bbArg.getOwner() == &forwardOp.getBody().front()) + return getBufferizedFunctionArgType(forwardOp, bbArg.getArgNumber(), options); + + return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::getBufferType( + op, value, options, invocationStack); + } + + LogicalResult verifyAnalysis(Operation *op, const bufferization::AnalysisState &state) const + { + auto forwardOp = cast(op); + // TODO: func.func with multiple returns are not supported. + if (!getAssumedUniqueReturnOp(forwardOp)) + return op->emitOpError("op without unique func.return is not supported"); + return success(); + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto forwardOp = cast(op); + FunctionType funcType = forwardOp.getFunctionType(); + + // Construct the bufferized function type. + SmallVector argTypes; + for (const auto &it : llvm::enumerate(funcType.getInputs())) { + Type argType = it.value(); + if (dyn_cast(argType)) { + argTypes.push_back(getBufferizedFunctionArgType(forwardOp, it.index(), options)); + continue; + } + argTypes.push_back(argType); + } + + ReturnOp returnOp = getAssumedUniqueReturnOp(forwardOp); + assert(returnOp && "expected func with single return op"); + Location loc = returnOp.getLoc(); + + // 1. Bufferize every block. + for (Block &block : forwardOp.getBody()) + if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options))) + return failure(); + + // 2. For each result, keep track of which inplace argument it reuses. + SmallVector returnValues; + for (OpOperand &returnOperand : returnOp->getOpOperands()) { + Value returnVal = returnOperand.get(); + auto tensorType = dyn_cast(returnVal.getType()); + rewriter.setInsertionPoint(returnOp); + + // If not a tensor type just forward it. + if (!tensorType) { + returnValues.push_back(returnVal); + continue; + } + + // Note: If `inferFunctionResultLayout = true`, cast are later folded + // away. + BaseMemRefType resultType = options.functionArgTypeConverterFn( + tensorType, *options.defaultMemorySpaceFn(tensorType), forwardOp, options); + Value toMemrefOp = + rewriter.create(loc, resultType, returnVal); + returnValues.push_back(toMemrefOp); + } + + // 3. Rewrite the terminator. + forwardOp.walk([&](ReturnOp returnOp) { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(returnOp); + rewriter.replaceOpWithNewOp(returnOp, returnValues, returnOp.getEmpty()); + }); + + // 4. Rewrite the FuncOp type to buffer form. Also preserve unused return types. + SmallVector returnTypes; + for (auto retTy : forwardOp.getResultTypes()) { + auto tensorType = dyn_cast(retTy); + BaseMemRefType resultType = options.functionArgTypeConverterFn( + tensorType, *options.defaultMemorySpaceFn(tensorType), forwardOp, options); + returnTypes.push_back(resultType); + } + forwardOp.setType(FunctionType::get(op->getContext(), argTypes, returnTypes)); + + return success(); + } +}; + +struct ReverseOpInterface + : public bufferization::OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel< + ReverseOpInterface, ReverseOp> { + static bool supportsUnstructuredControlFlow() { return false; } + + bool hasTensorSemantics(Operation *op) const + { + auto isaTensor = llvm::IsaPred; + + // A function has tensor semantics if it has tensor arguments/results. + auto reverseOp = cast(op); + bool hasTensorArg = any_of(reverseOp.getArgumentTypes(), isaTensor); + bool hasTensorResult = any_of(reverseOp.getResultTypes(), isaTensor); + bool hasTensorFuncInType = any_of(reverseOp.getFunctionType().getInputs(), isaTensor); + bool hasTensorFuncOutType = any_of(reverseOp.getFunctionType().getResults(), isaTensor); + if (hasTensorArg || hasTensorResult || hasTensorFuncInType || hasTensorFuncOutType) + return true; + + return false; + } + + bufferization::AliasingOpOperandList + getAliasingOpOperands(Operation *op, Value value, + const bufferization::AnalysisState &state) const + { + return {}; + } + + FailureOr getBufferType(Operation *op, Value value, + const bufferization::BufferizationOptions &options, + SmallVector &invocationStack) const + { + auto reverseOp = cast(op); + auto bbArg = cast(value); + + // Function arguments are special. + if (bbArg.getOwner() == &reverseOp.getBody().front()) + return getBufferizedFunctionArgType(reverseOp, bbArg.getArgNumber(), options); + + return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::getBufferType( + op, value, options, invocationStack); + } + + LogicalResult verifyAnalysis(Operation *op, const bufferization::AnalysisState &state) const + { + auto reverseOp = cast(op); + // TODO: func.func with multiple returns are not supported. + if (!getAssumedUniqueReturnOp(reverseOp)) + return op->emitOpError("op without unique func.return is not supported"); + return success(); + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto reverseOp = cast(op); + FunctionType funcType = reverseOp.getFunctionType(); + + // Construct the bufferized function type. + SmallVector argTypes; + for (const auto &it : llvm::enumerate(funcType.getInputs())) { + Type argType = it.value(); + if (dyn_cast(argType)) { + argTypes.push_back(getBufferizedFunctionArgType(reverseOp, it.index(), options)); + continue; + } + argTypes.push_back(argType); + } + + ReturnOp returnOp = getAssumedUniqueReturnOp(reverseOp); + assert(returnOp && "expected func with single return op"); + Location loc = returnOp.getLoc(); + + // 1. Bufferize every block. + for (Block &block : reverseOp.getBody()) + if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options))) + return failure(); + + // 2. For each result, keep track of which inplace argument it reuses. + SmallVector returnValues; + for (OpOperand &returnOperand : returnOp->getOpOperands()) { + Value returnVal = returnOperand.get(); + auto tensorType = dyn_cast(returnVal.getType()); + rewriter.setInsertionPoint(returnOp); + + // If not a tensor type just forward it. + if (!tensorType) { + returnValues.push_back(returnVal); + continue; + } + + // Note: If `inferFunctionResultLayout = true`, cast are later folded + // away. + BaseMemRefType resultType = options.functionArgTypeConverterFn( + tensorType, *options.defaultMemorySpaceFn(tensorType), reverseOp, options); + Value toMemrefOp = + rewriter.create(loc, resultType, returnVal); + returnValues.push_back(toMemrefOp); + } + + // 3. Rewrite the terminator. + reverseOp.walk([&](ReturnOp returnOp) { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(returnOp); + rewriter.replaceOpWithNewOp(returnOp, returnValues, returnOp.getEmpty()); + }); + + // 4. Rewrite the FuncOp type to buffer form. Also preserve unused return types. + SmallVector returnTypes; + for (auto retTy : reverseOp.getResultTypes()) { + auto tensorType = dyn_cast(retTy); + BaseMemRefType resultType = options.functionArgTypeConverterFn( + tensorType, *options.defaultMemorySpaceFn(tensorType), reverseOp, options); + returnTypes.push_back(resultType); + } + reverseOp.setType(FunctionType::get(op->getContext(), argTypes, returnTypes)); + + return success(); + } +}; + +} // namespace + +void catalyst::gradient::registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) +{ + registry.addExtension(+[](MLIRContext *ctx, GradientDialect *dialect) { + AdjointOp::attachInterface(*ctx); + BackpropOp::attachInterface(*ctx); + ForwardOp::attachInterface(*ctx); + ReverseOp::attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Gradient/Transforms/CMakeLists.txt b/mlir/lib/Gradient/Transforms/CMakeLists.txt index 6495dc6cba..6442cfd073 100644 --- a/mlir/lib/Gradient/Transforms/CMakeLists.txt +++ b/mlir/lib/Gradient/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ set(LIBRARY_NAME gradient-transforms) file(GLOB SRC GradMethods/*.cpp + BufferizableOpInterfaceImpl.cpp BufferizationPatterns.cpp gradient_bufferize.cpp PreprocessingPatterns.cpp diff --git a/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp b/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp index 46abee67ce..30baf98952 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp @@ -147,7 +147,8 @@ func::FuncOp genSplitPreprocessed(PatternRewriter &rewriter, Location loc, func: PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(&splitFn.getBody().front()); Value paramsBuffer = rewriter.create(loc, paramsBufferType, paramCount); - Value paramsTensor = rewriter.create(loc, paramsBuffer); + Value paramsTensor = + rewriter.create(loc, paramsBuffer, /*restrict=*/true); qnodeQuantumArgs.push_back(paramsTensor); MemRefType paramsProcessedType = MemRefType::get({}, rewriter.getIndexType()); @@ -289,8 +290,8 @@ func::FuncOp genArgMapFunction(PatternRewriter &rewriter, Location loc, func::Fu else if (auto returnOp = dyn_cast(op)) { PatternRewriter::InsertionGuard insertionGuard(rewriter); rewriter.setInsertionPoint(returnOp); - Value paramsVector = - rewriter.create(loc, paramsVectorType, paramsBuffer); + Value paramsVector = rewriter.create( + loc, paramsVectorType, paramsBuffer, /*restrict=*/true); returnOp.getOperandsMutable().assign(paramsVector); } }); diff --git a/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp b/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp index b0d48aaf01..19100ea8eb 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp @@ -58,7 +58,8 @@ static std::vector computePartialDerivative(PatternRewriter &rewriter, Lo { constexpr double shift = llvm::numbers::pi / 2; ShapedType shiftVectorType = RankedTensorType::get({numShifts}, rewriter.getF64Type()); - Value selectorVector = rewriter.create(loc, selectorBuffer); + Value selectorVector = + rewriter.create(loc, selectorBuffer, /*restrict=*/true); // Define the shift vectors (pos/neg) as sparse tensor constants. DenseElementsAttr nonZeroIndices = rewriter.getI64TensorAttr(currentShift); @@ -284,8 +285,8 @@ func::FuncOp ParameterShiftLowering::genQGradFunction(PatternRewriter &rewriter, std::vector gradientTensors; gradientTensors.reserve(gradResTypes.size()); for (Value gradientBuffer : gradientBuffers) { - gradientTensors.push_back( - rewriter.create(loc, gradientBuffer)); + gradientTensors.push_back(rewriter.create( + loc, gradientBuffer, /*restrict=*/true)); } op->setOperands(gradientTensors); } diff --git a/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp b/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp index 5c79663f0d..16c0243352 100644 --- a/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp +++ b/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp @@ -14,12 +14,14 @@ #include "iostream" #include "llvm/Support/raw_ostream.h" +#include #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" @@ -41,10 +43,11 @@ struct PostprocessForwardOp : public OpRewritePattern { // Check if the numbers of args and returns match Enzyme's format. auto argc = op.getArgc(); auto resc = op.getResc(); - auto tapeCount = op.getTape(); + auto tape = op.getTape(); - if (op.getFunctionType().getNumInputs() == (argc + resc) * 2 && - op.getFunctionType().getNumResults() == tapeCount) + // If function signature is modified, this pass cannot be processed. + if (op.getFunctionType().getNumInputs() != argc || + op.getFunctionType().getNumResults() != (resc + tape)) return failure(); auto argTys = op.getArgumentTypes(); @@ -127,7 +130,9 @@ struct PostprocessReverseOp : public OpRewritePattern { auto forwardResc = op.getResc(); auto tape = op.getTape(); - if (op.getFunctionType().getNumInputs() == (forwardArgc + forwardResc) * 2 + tape) + // If function signature is modified, this pass cannot be processed. + if (op.getFunctionType().getNumInputs() != (forwardResc + tape) || + op.getFunctionType().getNumResults() != forwardArgc) return failure(); auto argTys = op.getArgumentTypes(); @@ -212,4 +217,4 @@ void populatePostprocessingPatterns(RewritePatternSet &patterns) } } // namespace gradient -} // namespace catalyst \ No newline at end of file +} // namespace catalyst diff --git a/mlir/lib/Gradient/Transforms/PreprocessingPatterns.cpp b/mlir/lib/Gradient/Transforms/PreprocessingPatterns.cpp index c41fb2d46b..ec6820b001 100644 --- a/mlir/lib/Gradient/Transforms/PreprocessingPatterns.cpp +++ b/mlir/lib/Gradient/Transforms/PreprocessingPatterns.cpp @@ -70,7 +70,6 @@ struct PreprocessReverseOp : public OpRewritePattern { { if (!op.getBody().empty()) return failure(); - Block *block; rewriter.modifyOpInPlace(op, [&] { block = op.addEntryBlock(); }); @@ -117,4 +116,4 @@ void populatePreprocessingPatterns(RewritePatternSet &patterns) } } // namespace gradient -} // namespace catalyst \ No newline at end of file +} // namespace catalyst diff --git a/mlir/lib/Gradient/Transforms/gradient_preprocess.cpp b/mlir/lib/Gradient/Transforms/gradient_preprocess.cpp index a0a003991c..f470d5f8d9 100644 --- a/mlir/lib/Gradient/Transforms/gradient_preprocess.cpp +++ b/mlir/lib/Gradient/Transforms/gradient_preprocess.cpp @@ -55,4 +55,4 @@ std::unique_ptr createGradientPreprocessingPass() return std::make_unique(); } -} // namespace catalyst \ No newline at end of file +} // namespace catalyst diff --git a/mlir/lib/Quantum/IR/QuantumDialect.cpp b/mlir/lib/Quantum/IR/QuantumDialect.cpp index 385f4e0ae5..d4d820326f 100644 --- a/mlir/lib/Quantum/IR/QuantumDialect.cpp +++ b/mlir/lib/Quantum/IR/QuantumDialect.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/IR/DialectImplementation.h" // needed for generated type parser #include "llvm/ADT/TypeSwitch.h" // needed for generated type parser @@ -43,6 +44,9 @@ void QuantumDialect::initialize() #define GET_OP_LIST #include "Quantum/IR/QuantumOps.cpp.inc" >(); + declarePromisedInterfaces(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp new file mode 100644 index 0000000000..2e61dfe579 --- /dev/null +++ b/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,414 @@ +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "Quantum/IR/QuantumOps.h" +#include "Quantum/Transforms/BufferizableOpInterfaceImpl.h" + +using namespace mlir; +using namespace catalyst::quantum; + +namespace { +/** + * The new bufferization interface requires `bufferizesToMemoryRead`, `bufferizesToMemoryWrite`, + * and `getAliasingValues`. + * + * `bufferizesToMemoryRead`: Return `true` if the buffer of the given tensor OpOperand is read. + * + * `bufferizesToMemoryWrite`: Return `true` if the buffer of the given tensor OpOperand is written + * (if bufferizing in-place). + * + * `getAliasingOpOperands`: Return the OpResults that may share the same buffer as the given + * OpOperand. Note that MLIR documentation does not mention `getAliasingValues` but it seems to + * serve the same purpose. + * + * Link: https://mlir.llvm.org/docs/Bufferization/#extending-one-shot-bufferize + */ + +/// Bufferization of catalyst.quantum.unitary. Convert Matrix into memref. +struct QubitUnitaryOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto qubitUnitaryOp = cast(op); + Location loc = op->getLoc(); + auto tensorType = cast(qubitUnitaryOp.getMatrix().getType()); + MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + auto toMemrefOp = + rewriter.create(loc, memrefType, qubitUnitaryOp.getMatrix()); + auto memref = toMemrefOp.getResult(); + bufferization::replaceOpWithNewBufferizedOp( + rewriter, op, qubitUnitaryOp.getOutQubits().getTypes(), + qubitUnitaryOp.getOutCtrlQubits().getTypes(), memref, qubitUnitaryOp.getInQubits(), + qubitUnitaryOp.getAdjointAttr(), qubitUnitaryOp.getInCtrlQubits(), + qubitUnitaryOp.getInCtrlValues()); + return success(); + } +}; + +/// Bufferization of catalyst.quantum.hermitian. Convert Matrix into memref. +struct HermitianOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto hermitianOp = cast(op); + Location loc = op->getLoc(); + auto tensorType = cast(hermitianOp.getMatrix().getType()); + MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + auto toMemrefOp = + rewriter.create(loc, memrefType, hermitianOp.getMatrix()); + auto memref = toMemrefOp.getResult(); + auto newHermitianOp = rewriter.create(loc, hermitianOp.getType(), memref, + hermitianOp.getQubits()); + bufferization::replaceOpWithBufferizedValues(rewriter, op, newHermitianOp.getObs()); + + return success(); + } +}; + +/// Bufferization of catalyst.quantum.hamiltonian. Convert Matrix into memref. +struct HamiltonianOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto hamiltonianOp = cast(op); + Location loc = op->getLoc(); + auto tensorType = cast(hamiltonianOp.getCoeffs().getType()); + MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + auto toMemrefOp = + rewriter.create(loc, memrefType, hamiltonianOp.getCoeffs()); + auto memref = toMemrefOp.getResult(); + auto newHamiltonianOp = rewriter.create(loc, hamiltonianOp.getType(), memref, + hamiltonianOp.getTerms()); + bufferization::replaceOpWithBufferizedValues(rewriter, op, newHamiltonianOp.getObs()); + + return success(); + } +}; + +/// Bufferization of catalyst.quantum.sample. Replace with memref.alloc and a new +/// catalyst.quantum.sample that uses the memory allocated by memref.alloc. +struct SampleOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto sampleOp = cast(op); + Location loc = op->getLoc(); + auto tensorType = cast(sampleOp.getSamples().getType()); + MemRefType resultType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + + Value allocVal = rewriter.create(loc, resultType); + rewriter.create(loc, TypeRange{}, ValueRange{sampleOp.getObs(), allocVal}, + sampleOp->getAttrs()); + bufferization::replaceOpWithBufferizedValues(rewriter, op, allocVal); + + return success(); + } +}; + +/// Bufferization of catalyst.quantum.state. Replace with memref.alloc and a new +/// catalyst.quantum.state that uses the memory allocated by memref.alloc. +struct StateOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto stateOp = cast(op); + Location loc = op->getLoc(); + auto tensorType = cast(stateOp.getState().getType()); + MemRefType resultType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + + Value allocVal = rewriter.create(loc, resultType); + rewriter.create(loc, TypeRange{}, ValueRange{stateOp.getObs(), allocVal}); + bufferization::replaceOpWithBufferizedValues(rewriter, op, allocVal); + + return success(); + } +}; + +/// Bufferization of catalyst.quantum.probs. Replace with memref.alloc and a new +/// catalyst.quantum.probs that uses the memory allocated by memref.alloc. +struct ProbsOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto probsOp = cast(op); + Location loc = op->getLoc(); + auto tensorType = cast(probsOp.getProbabilities().getType()); + MemRefType resultType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + + Value allocVal = rewriter.create(loc, resultType); + rewriter.create(loc, TypeRange{}, ValueRange{probsOp.getObs(), allocVal}); + bufferization::replaceOpWithBufferizedValues(rewriter, op, allocVal); + + return success(); + } +}; + +/// Bufferization of catalyst.quantum.counts. Replace with memref.allocs and a new +/// catalyst.quantum.counts that uses the memory allocated by memref.allocs. +struct CountsOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto countsOp = cast(op); + Location loc = op->getLoc(); + auto tensorType0 = cast(countsOp.getEigvals().getType()); + auto tensorType1 = cast(countsOp.getCounts().getType()); + MemRefType resultType0 = + MemRefType::get(tensorType0.getShape(), tensorType0.getElementType()); + MemRefType resultType1 = + MemRefType::get(tensorType1.getShape(), tensorType1.getElementType()); + + Value allocVal0 = rewriter.create(loc, resultType0); + Value allocVal1 = rewriter.create(loc, resultType1); + rewriter.create(loc, nullptr, nullptr, countsOp.getObs(), allocVal0, allocVal1, + countsOp.getShotsAttr()); + bufferization::replaceOpWithBufferizedValues(rewriter, op, + ValueRange{allocVal0, allocVal1}); + + return success(); + } +}; + +/// Bufferization of catalyst.quantum.set_state. Convert InState into memref. +struct SetStateOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto setStateOp = cast(op); + Location loc = op->getLoc(); + auto tensorType = cast(setStateOp.getInState().getType()); + MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + + auto toMemrefOp = + rewriter.create(loc, memrefType, setStateOp.getInState()); + auto memref = toMemrefOp.getResult(); + auto newSetStateOp = rewriter.create(loc, setStateOp.getOutQubits().getTypes(), + memref, setStateOp.getInQubits()); + bufferization::replaceOpWithBufferizedValues(rewriter, op, newSetStateOp.getOutQubits()); + return success(); + } +}; + +/// Bufferization of catalyst.quantum.set_basic_state. Convert BasisState into memref. +struct SetBasisStateOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto setBasisStateOp = cast(op); + Location loc = op->getLoc(); + auto tensorType = cast(setBasisStateOp.getBasisState().getType()); + MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + + auto toMemrefOp = rewriter.create( + loc, memrefType, setBasisStateOp.getBasisState()); + auto memref = toMemrefOp.getResult(); + auto newSetStateOp = rewriter.create( + loc, setBasisStateOp.getOutQubits().getTypes(), memref, setBasisStateOp.getInQubits()); + bufferization::replaceOpWithBufferizedValues(rewriter, op, newSetStateOp.getOutQubits()); + return success(); + } +}; + +} // namespace + +void catalyst::quantum::registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) +{ + registry.addExtension(+[](MLIRContext *ctx, QuantumDialect *dialect) { + QubitUnitaryOp::attachInterface(*ctx); + HermitianOp::attachInterface(*ctx); + HamiltonianOp::attachInterface(*ctx); + SampleOp::attachInterface(*ctx); + StateOp::attachInterface(*ctx); + ProbsOp::attachInterface(*ctx); + CountsOp::attachInterface(*ctx); + SetStateOp::attachInterface(*ctx); + SetBasisStateOp::attachInterface(*ctx); + }); +} \ No newline at end of file diff --git a/mlir/lib/Quantum/Transforms/CMakeLists.txt b/mlir/lib/Quantum/Transforms/CMakeLists.txt index 96ba30d23e..2504275410 100644 --- a/mlir/lib/Quantum/Transforms/CMakeLists.txt +++ b/mlir/lib/Quantum/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ set(LIBRARY_NAME quantum-transforms) file(GLOB SRC + BufferizableOpInterfaceImpl.cpp BufferizationPatterns.cpp quantum_bufferize.cpp ConversionPatterns.cpp diff --git a/mlir/lib/Quantum/Transforms/cp_global_buffers.cpp b/mlir/lib/Quantum/Transforms/cp_global_buffers.cpp index b461dc8a60..160adb70d6 100644 --- a/mlir/lib/Quantum/Transforms/cp_global_buffers.cpp +++ b/mlir/lib/Quantum/Transforms/cp_global_buffers.cpp @@ -87,12 +87,16 @@ llvm::SmallVector getReturnMemRefs(func::ReturnOp op) */ Value allocCopyMemrefDyn(Location loc, Value memref, PatternRewriter &rewriter) { - auto memrefType = cast(memref.getType()); + auto origMemrefType = cast(memref.getType()); + // Rebuild MemRefType without memory layout. + auto newMemrefType = + MemRefType::get(origMemrefType.getShape(), origMemrefType.getElementType()); + llvm::SmallVector dynDims; { llvm::SmallVector dynIndices; int64_t ndim = 0; - for (auto dim : memrefType.getShape()) { + for (auto dim : newMemrefType.getShape()) { if (dim < 0) { Value dynValue = rewriter.create(loc, memref, ndim); dynDims.push_back(dynValue); @@ -101,9 +105,11 @@ Value allocCopyMemrefDyn(Location loc, Value memref, PatternRewriter &rewriter) } } - Value newMemRef = rewriter.create(loc, memrefType, dynDims); + Value newMemRef = rewriter.create(loc, newMemrefType, dynDims); + // Cast memrefType back to maintain memory layout. + Value castMemRef = rewriter.create(loc, origMemrefType, newMemRef); rewriter.create(loc, memref, newMemRef); - return newMemRef; + return castMemRef; } /** diff --git a/mlir/patches/FunctionOpInterface-bufferization.patch b/mlir/patches/FunctionOpInterface-bufferization.patch new file mode 100644 index 0000000000..60a2e9b93f --- /dev/null +++ b/mlir/patches/FunctionOpInterface-bufferization.patch @@ -0,0 +1,993 @@ +diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +index 2fda091e412a..eb0df1d92d6a 100644 +--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h ++++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +@@ -11,6 +11,7 @@ + + #include "mlir/IR/Operation.h" + #include "mlir/IR/PatternMatch.h" ++#include "mlir/Interfaces/FunctionInterfaces.h" + #include "mlir/Support/LLVM.h" + #include "llvm/ADT/DenseMapInfoVariant.h" + #include "llvm/ADT/SetVector.h" +@@ -260,9 +261,9 @@ struct BufferizationOptions { + using AnalysisStateInitFn = std::function; + /// Tensor -> MemRef type converter. + /// Parameters: Value, memory space, func op, bufferization options +- using FunctionArgTypeConverterFn = +- std::function; ++ using FunctionArgTypeConverterFn = std::function; + /// Tensor -> MemRef type converter. + /// Parameters: Value, memory space, bufferization options + using UnknownTypeConverterFn = std::function equivalentFuncArgs; ++ DenseMap equivalentFuncArgs; + + /// A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices. +- DenseMap aliasingReturnVals; ++ DenseMap aliasingReturnVals; + + /// A set of all read BlockArguments of FuncOps. +- DenseMap readBbArgs; ++ DenseMap readBbArgs; + + /// A set of all written-to BlockArguments of FuncOps. +- DenseMap writtenBbArgs; ++ DenseMap writtenBbArgs; + + /// Keep track of which FuncOps are fully analyzed or currently being + /// analyzed. +- DenseMap analyzedFuncOps; ++ DenseMap analyzedFuncOps; + + /// This function is called right before analyzing the given FuncOp. It + /// initializes the data structures for the FuncOp in this state object. +- void startFunctionAnalysis(FuncOp funcOp); ++ void startFunctionAnalysis(FunctionOpInterface funcOp); + }; + + void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +index d51d63f243ea..c4201698468c 100644 +--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp ++++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +@@ -18,6 +18,7 @@ + #include "mlir/IR/TypeUtilities.h" + #include "mlir/IR/Value.h" + #include "mlir/Interfaces/ControlFlowInterfaces.h" ++#include "mlir/Interfaces/FunctionInterfaces.h" + #include "llvm/ADT/ScopeExit.h" + #include "llvm/Support/Debug.h" + +@@ -314,7 +315,7 @@ namespace { + /// Default function arg type converter: Use a fully dynamic layout map. + BaseMemRefType + defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace, +- func::FuncOp funcOp, ++ FunctionOpInterface funcOp, + const BufferizationOptions &options) { + return getMemRefTypeWithFullyDynamicLayout(type, memorySpace); + } +@@ -361,7 +362,7 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const { + void BufferizationOptions::setFunctionBoundaryTypeConversion( + LayoutMapOption layoutMapOption) { + functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace, +- func::FuncOp funcOp, ++ FunctionOpInterface funcOp, + const BufferizationOptions &options) { + if (layoutMapOption == LayoutMapOption::IdentityLayoutMap) + return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, +diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +index 9fbe574ec392..9749a71f3514 100644 +--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp ++++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +@@ -22,7 +22,7 @@ namespace mlir { + namespace bufferization { + namespace func_ext { + +-void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { ++void FuncAnalysisState::startFunctionAnalysis(FunctionOpInterface funcOp) { + analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress; + auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping()); + auto createdAliasingResults = +diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +index 0a4072605c26..a0e5c7fff769 100644 +--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp ++++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +@@ -75,7 +75,7 @@ using namespace mlir::bufferization; + using namespace mlir::bufferization::func_ext; + + /// A mapping of FuncOps to their callers. +-using FuncCallerMap = DenseMap>; ++using FuncCallerMap = DenseMap>; + + /// Get or create FuncAnalysisState. + static FuncAnalysisState & +@@ -88,10 +88,11 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) { + + /// Return the unique ReturnOp that terminates `funcOp`. + /// Return nullptr if there is no such unique ReturnOp. +-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { +- func::ReturnOp returnOp; +- for (Block &b : funcOp.getBody()) { +- if (auto candidateOp = dyn_cast(b.getTerminator())) { ++static Operation *getAssumedUniqueReturnOp(FunctionOpInterface funcOp) { ++ Operation *returnOp = nullptr; ++ for (Block &b : funcOp.getFunctionBody()) { ++ auto candidateOp = b.getTerminator(); ++ if (candidateOp && candidateOp->hasTrait()) { + if (returnOp) + return nullptr; + returnOp = candidateOp; +@@ -126,16 +127,16 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal, + /// Store function BlockArguments that are equivalent to/aliasing a returned + /// value in FuncAnalysisState. + static LogicalResult +-aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, ++aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp, ++ OneShotAnalysisState &state, + FuncAnalysisState &funcState) { +- if (funcOp.getBody().empty()) { ++ if (funcOp.getFunctionBody().empty()) { + // No function body available. Conservatively assume that every tensor + // return value may alias with any tensor bbArg. +- FunctionType type = funcOp.getFunctionType(); +- for (const auto &inputIt : llvm::enumerate(type.getInputs())) { ++ for (const auto &inputIt : llvm::enumerate(funcOp.getArgumentTypes())) { + if (!isa(inputIt.value())) + continue; +- for (const auto &resultIt : llvm::enumerate(type.getResults())) { ++ for (const auto &resultIt : llvm::enumerate(funcOp.getResultTypes())) { + if (!isa(resultIt.value())) + continue; + int64_t returnIdx = resultIt.index(); +@@ -147,7 +148,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, + } + + // Support only single return-terminated block in the function. +- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); ++ Operation *returnOp = getAssumedUniqueReturnOp(funcOp); + assert(returnOp && "expected func with single return op"); + + for (OpOperand &returnVal : returnOp->getOpOperands()) +@@ -168,8 +169,8 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, + return success(); + } + +-static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead, +- bool isWritten) { ++static void annotateFuncArgAccess(FunctionOpInterface funcOp, int64_t idx, ++ bool isRead, bool isWritten) { + OpBuilder b(funcOp.getContext()); + Attribute accessType; + if (isRead && isWritten) { +@@ -189,12 +190,12 @@ static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead, + /// function with unknown ops, we conservatively assume that such ops bufferize + /// to a read + write. + static LogicalResult +-funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state, ++funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp, ++ OneShotAnalysisState &state, + FuncAnalysisState &funcState) { +- for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e; +- ++idx) { ++ for (int64_t idx = 0, e = funcOp.getNumArguments(); idx < e; ++idx) { + // Skip non-tensor arguments. +- if (!isa(funcOp.getFunctionType().getInput(idx))) ++ if (!isa(funcOp.getArgumentTypes()[idx])) + continue; + bool isRead; + bool isWritten; +@@ -204,7 +205,7 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state, + StringRef str = accessAttr.getValue(); + isRead = str == "read" || str == "read-write"; + isWritten = str == "write" || str == "read-write"; +- } else if (funcOp.getBody().empty()) { ++ } else if (funcOp.getFunctionBody().empty()) { + // If the function has no body, conservatively assume that all args are + // read + written. + isRead = true; +@@ -230,20 +231,19 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state, + + /// Remove bufferization attributes on FuncOp arguments. + static void removeBufferizationAttributes(BlockArgument bbArg) { +- auto funcOp = cast(bbArg.getOwner()->getParentOp()); ++ auto funcOp = cast(bbArg.getOwner()->getParentOp()); + funcOp.removeArgAttr(bbArg.getArgNumber(), + BufferizationDialect::kBufferLayoutAttrName); + funcOp.removeArgAttr(bbArg.getArgNumber(), + BufferizationDialect::kWritableAttrName); + } + +-/// Return the func::FuncOp called by `callOp`. +-static func::FuncOp getCalledFunction(func::CallOp callOp) { ++static FunctionOpInterface getCalledFunction(CallOpInterface callOp) { + SymbolRefAttr sym = + llvm::dyn_cast_if_present(callOp.getCallableForCallee()); + if (!sym) + return nullptr; +- return dyn_cast_or_null( ++ return dyn_cast_or_null( + SymbolTable::lookupNearestSymbolFrom(callOp, sym)); + } + +@@ -251,12 +251,13 @@ static func::FuncOp getCalledFunction(func::CallOp callOp) { + /// Note: This only adds new equivalence info if the called function was already + /// analyzed. + // TODO: This does not handle cyclic function call graphs etc. +-static void equivalenceAnalysis(func::FuncOp funcOp, ++static void equivalenceAnalysis(FunctionOpInterface funcOp, + OneShotAnalysisState &state, + FuncAnalysisState &funcState) { +- funcOp->walk([&](func::CallOp callOp) { +- func::FuncOp calledFunction = getCalledFunction(callOp); +- assert(calledFunction && "could not retrieved called func::FuncOp"); ++ funcOp->walk([&](CallOpInterface callOp) { ++ FunctionOpInterface calledFunction = getCalledFunction(callOp); ++ if (!calledFunction) ++ return WalkResult::skip(); + + // No equivalence info available for the called function. + if (!funcState.equivalentFuncArgs.count(calledFunction)) +@@ -267,7 +268,7 @@ static void equivalenceAnalysis(func::FuncOp funcOp, + int64_t bbargIdx = it.second; + if (!state.isInPlace(callOp->getOpOperand(bbargIdx))) + continue; +- Value returnVal = callOp.getResult(returnIdx); ++ Value returnVal = callOp->getResult(returnIdx); + Value argVal = callOp->getOperand(bbargIdx); + state.unionEquivalenceClasses(returnVal, argVal); + } +@@ -277,11 +278,9 @@ static void equivalenceAnalysis(func::FuncOp funcOp, + } + + /// Return "true" if the given function signature has tensor semantics. +-static bool hasTensorSignature(func::FuncOp funcOp) { +- return llvm::any_of(funcOp.getFunctionType().getInputs(), +- llvm::IsaPred) || +- llvm::any_of(funcOp.getFunctionType().getResults(), +- llvm::IsaPred); ++static bool hasTensorSignature(FunctionOpInterface funcOp) { ++ return llvm::any_of(funcOp.getArgumentTypes(), llvm::IsaPred) || ++ llvm::any_of(funcOp.getResultTypes(), llvm::IsaPred); + } + + /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by +@@ -291,16 +290,16 @@ static bool hasTensorSignature(func::FuncOp funcOp) { + /// retrieve the called FuncOp from any func::CallOp. + static LogicalResult + getFuncOpsOrderedByCalls(ModuleOp moduleOp, +- SmallVectorImpl &orderedFuncOps, ++ SmallVectorImpl &orderedFuncOps, + FuncCallerMap &callerMap) { + // For each FuncOp, the set of functions called by it (i.e. the union of + // symbols of all nested func::CallOp). +- DenseMap> calledBy; ++ DenseMap> calledBy; + // For each FuncOp, the number of func::CallOp it contains. +- DenseMap numberCallOpsContainedInFuncOp; +- WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult { +- if (!funcOp.getBody().empty()) { +- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); ++ DenseMap numberCallOpsContainedInFuncOp; ++ WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult { ++ if (!funcOp.getFunctionBody().empty()) { ++ Operation *returnOp = getAssumedUniqueReturnOp(funcOp); + if (!returnOp) + return funcOp->emitError() + << "cannot bufferize a FuncOp with tensors and " +@@ -309,9 +308,10 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, + + // Collect function calls and populate the caller map. + numberCallOpsContainedInFuncOp[funcOp] = 0; +- return funcOp.walk([&](func::CallOp callOp) -> WalkResult { +- func::FuncOp calledFunction = getCalledFunction(callOp); +- assert(calledFunction && "could not retrieved called func::FuncOp"); ++ return funcOp.walk([&](CallOpInterface callOp) -> WalkResult { ++ FunctionOpInterface calledFunction = getCalledFunction(callOp); ++ if (!calledFunction) ++ return WalkResult::skip(); + // If the called function does not have any tensors in its signature, then + // it is not necessary to bufferize the callee before the caller. + if (!hasTensorSignature(calledFunction)) +@@ -349,11 +349,11 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, + /// most generic layout map as function return types. After bufferizing the + /// entire function body, a more concise memref type can potentially be used for + /// the return type of the function. +-static void foldMemRefCasts(func::FuncOp funcOp) { +- if (funcOp.getBody().empty()) ++static void foldMemRefCasts(FunctionOpInterface funcOp) { ++ if (funcOp.getFunctionBody().empty()) + return; + +- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); ++ Operation *returnOp = getAssumedUniqueReturnOp(funcOp); + SmallVector resultTypes; + + for (OpOperand &operand : returnOp->getOpOperands()) { +@@ -365,8 +365,8 @@ static void foldMemRefCasts(func::FuncOp funcOp) { + } + } + +- auto newFuncType = FunctionType::get( +- funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes); ++ auto newFuncType = FunctionType::get(funcOp.getContext(), ++ funcOp.getArgumentTypes(), resultTypes); + funcOp.setType(newFuncType); + } + +@@ -379,7 +379,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, + FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state); + + // A list of functions in the order in which they are analyzed + bufferized. +- SmallVector orderedFuncOps; ++ SmallVector orderedFuncOps; + + // A mapping of FuncOps to their callers. + FuncCallerMap callerMap; +@@ -388,7 +388,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, + return failure(); + + // Analyze ops. +- for (func::FuncOp funcOp : orderedFuncOps) { ++ for (FunctionOpInterface funcOp : orderedFuncOps) { + if (!state.getOptions().isOpAllowed(funcOp)) + continue; + +@@ -416,7 +416,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, + + void mlir::bufferization::removeBufferizationAttributesInModule( + ModuleOp moduleOp) { +- moduleOp.walk([&](func::FuncOp op) { ++ moduleOp.walk([&](FunctionOpInterface op) { + for (BlockArgument bbArg : op.getArguments()) + removeBufferizationAttributes(bbArg); + }); +@@ -430,7 +430,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( + IRRewriter rewriter(moduleOp.getContext()); + + // A list of functions in the order in which they are analyzed + bufferized. +- SmallVector orderedFuncOps; ++ SmallVector orderedFuncOps; + + // A mapping of FuncOps to their callers. + FuncCallerMap callerMap; +@@ -439,11 +439,11 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( + return failure(); + + // Bufferize functions. +- for (func::FuncOp funcOp : orderedFuncOps) { ++ for (FunctionOpInterface funcOp : orderedFuncOps) { + // Note: It would be good to apply cleanups here but we cannot as aliasInfo + // would be invalidated. + +- if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) { ++ if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getName())) { + // This function was not analyzed and RaW conflicts were not resolved. + // Buffer copies must be inserted before every write. + OneShotBufferizationOptions updatedOptions = options; +@@ -463,7 +463,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( + // Bufferize all other ops. + for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) { + // Functions were already bufferized. +- if (isa(&op)) ++ if (isa(&op)) + continue; + if (failed(bufferizeOp(&op, options, statistics))) + return failure(); +@@ -490,12 +490,12 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize( + // FuncOps whose names are specified in options.noAnalysisFuncFilter will + // not be analyzed. Ops in these FuncOps will not be analyzed as well. + OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) { +- auto func = dyn_cast(op); ++ auto func = dyn_cast(op); + if (!func) +- func = op->getParentOfType(); ++ func = op->getParentOfType(); + if (func) + return llvm::is_contained(options.noAnalysisFuncFilter, +- func.getSymName()); ++ func.getName()); + return false; + }; + OneShotBufferizationOptions updatedOptions(options); +diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir +index 3c50a9e72d9d..588aa8a85a84 100644 +--- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir ++++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir +@@ -1,4 +1,4 @@ +-// RUN: mlir-opt --transform-interpreter %s -split-input-file -verify-diagnostics | FileCheck %s ++// RUN: mlir-opt --transform-interpreter="debug-payload-root-tag=payload" %s -split-input-file -verify-diagnostics | FileCheck %s + + // Test One-Shot Bufferize. + +@@ -12,19 +12,21 @@ module attributes {transform.with_named_sequence} { + + // CHECK-LABEL: func @test_function( + // CHECK-SAME: %[[A:.*]]: tensor +-func.func @test_function(%A : tensor, %v : vector<4xf32>) -> (tensor) { +- %c0 = arith.constant 0 : index ++module @payload attributes { transform.target_tag = "payload" } { ++ func.func @test_function(%A : tensor, %v : vector<4xf32>) -> (tensor) { ++ %c0 = arith.constant 0 : index + +- // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] +- // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]] +- // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) +- // CHECK: memref.copy %[[A_memref]], %[[alloc]] +- // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] +- // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]] +- %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor ++ // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] ++ // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]] ++ // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) ++ // CHECK: memref.copy %[[A_memref]], %[[alloc]] ++ // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] ++ // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]] ++ %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor + +- // CHECK: return %[[res_tensor]] +- return %0 : tensor ++ // CHECK: return %[[res_tensor]] ++ return %0 : tensor ++ } + } + + // ----- +@@ -42,19 +44,21 @@ module attributes {transform.with_named_sequence} { + // CHECK-LABEL: func @test_function( + // CHECK-SAME: %[[A:.*]]: tensor + // CHECK-NOT: memref.copy +-func.func @test_function(%A : tensor, %v : vector<4xf32>) -> (tensor) { +- %c0 = arith.constant 0 : index ++module @payload attributes { transform.target_tag = "payload" } { ++ func.func @test_function(%A : tensor, %v : vector<4xf32>) -> (tensor) { ++ %c0 = arith.constant 0 : index + +- // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] +- // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]] +- // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) +- // CHECK: linalg.copy ins(%[[A_memref]] : memref<{{.*}}>) outs(%[[alloc]] +- // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] +- // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]] +- %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor ++ // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] ++ // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]] ++ // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) ++ // CHECK: linalg.copy ins(%[[A_memref]] : memref<{{.*}}>) outs(%[[alloc]] ++ // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] ++ // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]] ++ %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor + +- // CHECK: return %[[res_tensor]] +- return %0 : tensor ++ // CHECK: return %[[res_tensor]] ++ return %0 : tensor ++ } + } + + // ----- +@@ -72,13 +76,15 @@ module attributes {transform.with_named_sequence} { + + // CHECK-LABEL: func @test_function_analysis( + // CHECK-SAME: %[[A:.*]]: tensor +-func.func @test_function_analysis(%A : tensor, %v : vector<4xf32>) -> (tensor) { +- %c0 = arith.constant 0 : index +- // CHECK: vector.transfer_write +- // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "none"]} +- // CHECK-SAME: tensor +- %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor +- return %0 : tensor ++module @payload attributes { transform.target_tag = "payload" } { ++ func.func @test_function_analysis(%A : tensor, %v : vector<4xf32>) -> (tensor) { ++ %c0 = arith.constant 0 : index ++ // CHECK: vector.transfer_write ++ // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "none"]} ++ // CHECK-SAME: tensor ++ %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor ++ return %0 : tensor ++ } + } + + // ----- +@@ -95,10 +101,12 @@ module attributes {transform.with_named_sequence} { + } + } + +-func.func @test_unknown_op_failure() -> (tensor) { +- // expected-error @+1 {{op was not bufferized}} +- %0 = "test.dummy_op"() : () -> (tensor) +- return %0 : tensor ++module @payload attributes { transform.target_tag = "payload" } { ++ func.func @test_unknown_op_failure() -> (tensor) { ++ // expected-error @+1 {{op was not bufferized}} ++ %0 = "test.dummy_op"() : () -> (tensor) ++ return %0 : tensor ++ } + } + + // ----- +@@ -111,7 +119,7 @@ module attributes {transform.with_named_sequence} { + } + } + +-module { ++module @payload attributes { transform.target_tag = "payload" } { + // CHECK-LABEL: func @test_function( + // CHECK-SAME: %[[A:.*]]: tensor + func.func @test_function(%A : tensor, %v : vector<4xf32>) -> (tensor) { +@@ -146,11 +154,13 @@ module attributes {transform.with_named_sequence} { + // CHECK-SAME: %[[A:.*]]: memref<12x9xf32>, + // CHECK-SAME: %[[B:.*]]: memref<9x6xf32>, + // CHECK-SAME: %[[C:.*]]: memref<12x6xf32>) -> memref<12x6xf32> { +-func.func @matmul(%A: tensor<12x9xf32>, %B: tensor<9x6xf32>, %C: tensor<12x6xf32>) -> tensor<12x6xf32> { +- // CHECK: linalg.matmul ins(%[[A]], %[[B]] : memref<12x9xf32>, memref<9x6xf32>) outs(%[[C]] : memref<12x6xf32>) +- %D = linalg.matmul ins(%A, %B: tensor<12x9xf32>, tensor<9x6xf32>) outs(%C: tensor<12x6xf32>) -> tensor<12x6xf32> +- // CHECK: return %[[C]] : memref<12x6xf32> +- return %D : tensor<12x6xf32> ++module @payload attributes { transform.target_tag = "payload" } { ++ func.func @matmul(%A: tensor<12x9xf32>, %B: tensor<9x6xf32>, %C: tensor<12x6xf32>) -> tensor<12x6xf32> { ++ // CHECK: linalg.matmul ins(%[[A]], %[[B]] : memref<12x9xf32>, memref<9x6xf32>) outs(%[[C]] : memref<12x6xf32>) ++ %D = linalg.matmul ins(%A, %B: tensor<12x9xf32>, tensor<9x6xf32>) outs(%C: tensor<12x6xf32>) -> tensor<12x6xf32> ++ // CHECK: return %[[C]] : memref<12x6xf32> ++ return %D : tensor<12x6xf32> ++ } + } + + // ----- +@@ -165,10 +175,12 @@ module attributes {transform.with_named_sequence} { + } + + // Expect `bufferization.empty_tensor_to_alloc_tensor` to replace the tensor.empty. +-func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> { +- // CHECK: bufferization.alloc_tensor +- %0 = tensor.empty() : tensor<2x2xf32> +- return %0 : tensor<2x2xf32> ++module @payload attributes { transform.target_tag = "payload" } { ++ func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> { ++ // CHECK: bufferization.alloc_tensor ++ %0 = tensor.empty() : tensor<2x2xf32> ++ return %0 : tensor<2x2xf32> ++ } + } + + // ----- +@@ -185,13 +197,15 @@ module attributes {transform.with_named_sequence} { + // CHECK: tensor.extract_slice + // CHECK: linalg.fill + // CHECK: tensor.insert_slice +-func.func @empty_tensor_elimination( +- %t: tensor<10xf32>, %f: f32) -> tensor<10xf32> { +- %0 = tensor.empty() : tensor<5xf32> +- %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> +- %2 = tensor.insert_slice %1 into %t [1][5][1] +- : tensor<5xf32> into tensor<10xf32> +- return %2 : tensor<10xf32> ++module @payload attributes { transform.target_tag = "payload" } { ++ func.func @empty_tensor_elimination( ++ %t: tensor<10xf32>, %f: f32) -> tensor<10xf32> { ++ %0 = tensor.empty() : tensor<5xf32> ++ %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> ++ %2 = tensor.insert_slice %1 into %t [1][5][1] ++ : tensor<5xf32> into tensor<10xf32> ++ return %2 : tensor<10xf32> ++ } + } + + // ----- +@@ -208,12 +222,14 @@ module attributes {transform.with_named_sequence} { + // CHECK: memref.alloca + // CHECK: scf.for + // CHECK: memref.store +-func.func @buffer_loop_hoisting(%lb: index, %ub: index, %step: index, %f: f32, %pos: index) { +- scf.for %iv = %lb to %ub step %step { +- %0 = memref.alloca() : memref<5xf32> +- memref.store %f, %0[%pos] : memref<5xf32> ++module @payload attributes { transform.target_tag = "payload" } { ++ func.func @buffer_loop_hoisting(%lb: index, %ub: index, %step: index, %f: f32, %pos: index) { ++ scf.for %iv = %lb to %ub step %step { ++ %0 = memref.alloca() : memref<5xf32> ++ memref.store %f, %0[%pos] : memref<5xf32> ++ } ++ return + } +- return + } + + // ----- +@@ -231,10 +247,12 @@ module attributes {transform.with_named_sequence} { + + // Expect `bufferization.bufferize_to_allocation` to create an alloc. + // CHECK-LABEL: func.func @empty_to_tensor_alloc() +-func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> { +- // CHECK-NEXT: %[[alloca:.*]] = memref.alloca() : memref<2x2xf32> +- // CHECK-NEXT: %[[tensor:.*]] = bufferization.to_tensor %[[alloca]] restrict writable : memref<2x2xf32> +- // CHECK-NEXT: return %[[tensor]] : tensor<2x2xf32> +- %0 = bufferization.alloc_tensor() : tensor<2x2xf32> +- return %0 : tensor<2x2xf32> ++module @payload attributes { transform.target_tag = "payload" } { ++ func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> { ++ // CHECK-NEXT: %[[alloca:.*]] = memref.alloca() : memref<2x2xf32> ++ // CHECK-NEXT: %[[tensor:.*]] = bufferization.to_tensor %[[alloca]] restrict writable : memref<2x2xf32> ++ // CHECK-NEXT: return %[[tensor]] : tensor<2x2xf32> ++ %0 = bufferization.alloc_tensor() : tensor<2x2xf32> ++ return %0 : tensor<2x2xf32> ++ } + } +diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir +index c00b47fb936e..3e637a3ec49a 100644 +--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir ++++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir +@@ -1,15 +1,17 @@ +-// RUN: mlir-opt %s --transform-interpreter -test-transform-dialect-erase-schedule --test-lower-to-llvm --split-input-file | FileCheck %s ++// RUN: mlir-opt %s --transform-interpreter="debug-payload-root-tag=payload" -test-transform-dialect-erase-schedule --test-lower-to-llvm --split-input-file | FileCheck %s + + // CHECK-LABEL: llvm.func @matmul_tensors +-func.func @matmul_tensors( +- %arg0: tensor<2x4xf32>, %arg1: tensor<4x6xf32>, %arg2: tensor<2x6xf32>) +- -> tensor<2x6xf32> { +-// CHECK-NOT: linalg +-// CHECK: llvm.intr.fmuladd{{.*}} +- %0 = linalg.matmul ins(%arg0, %arg1: tensor<2x4xf32>, tensor<4x6xf32>) +- outs(%arg2: tensor<2x6xf32>) +- -> tensor<2x6xf32> +- return %0 : tensor<2x6xf32> ++module @payload attributes { transform.target_tag = "payload" } { ++ func.func @matmul_tensors( ++ %arg0: tensor<2x4xf32>, %arg1: tensor<4x6xf32>, %arg2: tensor<2x6xf32>) ++ -> tensor<2x6xf32> { ++ // CHECK-NOT: linalg ++ // CHECK: llvm.intr.fmuladd{{.*}} ++ %0 = linalg.matmul ins(%arg0, %arg1: tensor<2x4xf32>, tensor<4x6xf32>) ++ outs(%arg2: tensor<2x6xf32>) ++ -> tensor<2x6xf32> ++ return %0 : tensor<2x6xf32> ++ } + } + + module attributes {transform.with_named_sequence} { +diff --git a/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir b/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir +index 3f8d2ea06641..9c223737750a 100644 +--- a/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir ++++ b/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir +@@ -1,4 +1,4 @@ +-// RUN: mlir-opt --split-input-file --transform-interpreter %s | FileCheck %s ++// RUN: mlir-opt --split-input-file --transform-interpreter="debug-payload-root-tag=payload" %s | FileCheck %s + + // CHECK-LABEL: func @matmul_divisible + // CHECK: scf.forall +@@ -24,19 +24,21 @@ + // CHECK: scf.forall + // CHECK: vector.transfer_read + // CHECK: vector.transfer_write +-func.func @matmul_divisible(%A: tensor<1024x1024xf32>, +- %B: tensor<1024x1024xf32>, +- %C: tensor<1024x1024xf32>) +- -> tensor<1024x1024xf32> +-{ +- %cst = arith.constant 0.000000e+00 : f32 +- %0 = linalg.fill ins(%cst : f32) +- outs(%C : tensor<1024x1024xf32>) ++module @payload attributes { transform.target_tag = "payload" } { ++ func.func @matmul_divisible(%A: tensor<1024x1024xf32>, ++ %B: tensor<1024x1024xf32>, ++ %C: tensor<1024x1024xf32>) + -> tensor<1024x1024xf32> +- %1 = linalg.matmul ins(%A, %B : tensor<1024x1024xf32>, tensor<1024x1024xf32>) +- outs(%0 : tensor<1024x1024xf32>) +- -> tensor<1024x1024xf32> +- return %1 : tensor<1024x1024xf32> ++ { ++ %cst = arith.constant 0.000000e+00 : f32 ++ %0 = linalg.fill ins(%cst : f32) ++ outs(%C : tensor<1024x1024xf32>) ++ -> tensor<1024x1024xf32> ++ %1 = linalg.matmul ins(%A, %B : tensor<1024x1024xf32>, tensor<1024x1024xf32>) ++ outs(%0 : tensor<1024x1024xf32>) ++ -> tensor<1024x1024xf32> ++ return %1 : tensor<1024x1024xf32> ++ } + } + + module attributes {transform.with_named_sequence} { +@@ -143,19 +145,21 @@ module attributes {transform.with_named_sequence} { + // CHECK: linalg.matmul + // CHECK: vector.transfer_read + // CHECK: vector.transfer_write ++module @payload attributes { transform.target_tag = "payload" } { + func.func @matmul_not_divisible(%A: tensor<1023x1023xf32>, +- %B: tensor<1023x1023xf32>, +- %C: tensor<1023x1023xf32>) +- -> tensor<1023x1023xf32> +-{ +- %cst = arith.constant 0.000000e+00 : f32 +- %0 = linalg.fill ins(%cst : f32) +- outs(%C : tensor<1023x1023xf32>) ++ %B: tensor<1023x1023xf32>, ++ %C: tensor<1023x1023xf32>) + -> tensor<1023x1023xf32> +- %1 = linalg.matmul ins(%A, %B : tensor<1023x1023xf32>, tensor<1023x1023xf32>) +- outs(%0 : tensor<1023x1023xf32>) +- -> tensor<1023x1023xf32> +- return %1 : tensor<1023x1023xf32> ++ { ++ %cst = arith.constant 0.000000e+00 : f32 ++ %0 = linalg.fill ins(%cst : f32) ++ outs(%C : tensor<1023x1023xf32>) ++ -> tensor<1023x1023xf32> ++ %1 = linalg.matmul ins(%A, %B : tensor<1023x1023xf32>, tensor<1023x1023xf32>) ++ outs(%0 : tensor<1023x1023xf32>) ++ -> tensor<1023x1023xf32> ++ return %1 : tensor<1023x1023xf32> ++ } + } + + module attributes {transform.with_named_sequence} { +diff --git a/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir b/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir +index f2e9e839b7c4..5e5657980ba1 100644 +--- a/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir ++++ b/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir +@@ -1,5 +1,5 @@ + +-// RUN: mlir-opt --transform-interpreter -cse -canonicalize -split-input-file -verify-diagnostics %s | FileCheck %s ++// RUN: mlir-opt --transform-interpreter="debug-payload-root-tag=payload" -cse -canonicalize -split-input-file -verify-diagnostics %s | FileCheck %s + + #map = affine_map<()[s0] -> (-s0 + 12, 7)> + +@@ -7,43 +7,45 @@ + // CHECK-SAME: %[[arg0:.*]]: memref<24x12xf32, strided<[?, ?], offset: ?>>, + // CHECK-SAME: %[[arg1:.*]]: memref<12x25xf32, strided<[?, ?], offset: ?>>, + // CHECK-SAME: %[[arg2:.*]]: memref<24x25xf32, strided<[?, ?], offset: ?>>, +-func.func @pad_to_memory_space(%arg0: tensor<24x12xf32>, +- %arg1: tensor<12x25xf32>, +- %arg2: tensor<24x25xf32>, +- %iv0 : index, %iv1 : index, +- %iv2 : index) -> tensor<24x25xf32> { +- %0 = affine.min #map()[%iv2] +- +- // CHECK: %[[s0:.*]] = memref.subview %[[arg0]] +- %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> +- // CHECK: %[[s1:.*]] = memref.subview %[[arg1]] +- %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor +- // CHECK: %[[s2:.*]] = memref.subview %[[arg2]] +- %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> +- +- // CHECK: %[[alloc0:.*]] = memref.alloc() : memref<4x7xf32, 3> +- // CHECK: linalg.fill {{.*}} outs(%[[alloc0]] +- // CHECK: %[[alloc0_view:.*]] = memref.subview %[[alloc0]][0, 0] [4, %{{.*}}] [1, 1] +- // CHECK: memref.copy %[[s0]], %[[alloc0_view]] +- +- // CHECK: %[[alloc1:.*]] = memref.alloc() : memref<7x5xf32, 3> +- // CHECK: linalg.fill {{.*}} outs(%[[alloc1]] +- // CHECK: %[[alloc1_view:.*]] = memref.subview %[[alloc1]][0, 0] [%{{.*}}, 5] [1, 1] +- // CHECK: memref.copy %[[s1]], %[[alloc1_view]] +- +- // CHECK: %[[alloc2:.*]] = memref.alloc() : memref<4x5xf32, 3> +- // CHECK-NOT: linalg.fill {{.*}} outs(%[[alloc2]] +- // No subview because there is 0 padding +- // CHECK: memref.copy %[[s2]], %[[alloc2]] +- +- // CHECK: linalg.matmul ins(%[[alloc0]], %[[alloc1]] : {{.*}}) outs(%[[alloc2]] : {{.*}}) +- // Copy back result. +- // CHECK: memref.copy %[[alloc2]], %[[s2]] +- %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> +- +- // insert_slice bufferizes to a no-op. +- %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> +- func.return %5 : tensor<24x25xf32> ++module @payload attributes { transform.target_tag = "payload" } { ++ func.func @pad_to_memory_space(%arg0: tensor<24x12xf32>, ++ %arg1: tensor<12x25xf32>, ++ %arg2: tensor<24x25xf32>, ++ %iv0 : index, %iv1 : index, ++ %iv2 : index) -> tensor<24x25xf32> { ++ %0 = affine.min #map()[%iv2] ++ ++ // CHECK: %[[s0:.*]] = memref.subview %[[arg0]] ++ %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> ++ // CHECK: %[[s1:.*]] = memref.subview %[[arg1]] ++ %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor ++ // CHECK: %[[s2:.*]] = memref.subview %[[arg2]] ++ %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> ++ ++ // CHECK: %[[alloc0:.*]] = memref.alloc() : memref<4x7xf32, 3> ++ // CHECK: linalg.fill {{.*}} outs(%[[alloc0]] ++ // CHECK: %[[alloc0_view:.*]] = memref.subview %[[alloc0]][0, 0] [4, %{{.*}}] [1, 1] ++ // CHECK: memref.copy %[[s0]], %[[alloc0_view]] ++ ++ // CHECK: %[[alloc1:.*]] = memref.alloc() : memref<7x5xf32, 3> ++ // CHECK: linalg.fill {{.*}} outs(%[[alloc1]] ++ // CHECK: %[[alloc1_view:.*]] = memref.subview %[[alloc1]][0, 0] [%{{.*}}, 5] [1, 1] ++ // CHECK: memref.copy %[[s1]], %[[alloc1_view]] ++ ++ // CHECK: %[[alloc2:.*]] = memref.alloc() : memref<4x5xf32, 3> ++ // CHECK-NOT: linalg.fill {{.*}} outs(%[[alloc2]] ++ // No subview because there is 0 padding ++ // CHECK: memref.copy %[[s2]], %[[alloc2]] ++ ++ // CHECK: linalg.matmul ins(%[[alloc0]], %[[alloc1]] : {{.*}}) outs(%[[alloc2]] : {{.*}}) ++ // Copy back result. ++ // CHECK: memref.copy %[[alloc2]], %[[s2]] ++ %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> ++ ++ // insert_slice bufferizes to a no-op. ++ %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> ++ func.return %5 : tensor<24x25xf32> ++ } + } + + module attributes {transform.with_named_sequence} { +@@ -69,40 +71,42 @@ module attributes {transform.with_named_sequence} { + // CHECK-SAME: %[[arg0:.*]]: memref<24x12xf32, strided<[?, ?], offset: ?>>, + // CHECK-SAME: %[[arg1:.*]]: memref<12x25xf32, strided<[?, ?], offset: ?>>, + // CHECK-SAME: %[[arg2:.*]]: memref<24x25xf32, strided<[?, ?], offset: ?>>, +-func.func @vectorize_and_bufferize_pad(%arg0: tensor<24x12xf32>, +- %arg1: tensor<12x25xf32>, +- %arg2: tensor<24x25xf32>, +- %iv0 : index, %iv1 : index, +- %iv2 : index) -> tensor<24x25xf32> { +- %0 = affine.min #map()[%iv2] +- +- // CHECK: %[[s0:.*]] = memref.subview %[[arg0]] +- %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> +- // CHECK: %[[s1:.*]] = memref.subview %[[arg1]] +- %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor +- // CHECK: %[[s2:.*]] = memref.subview %[[arg2]] +- %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> +- +- // CHECK: %[[v0:.*]] = vector.mask {{.*}} { vector.transfer_read %[[s0]] +- // CHECK: %[[alloc0:.*]] = memref.alloc() : memref<4x7xf32, 3> +- // CHECK: vector.mask {{.*}} { vector.transfer_write %[[v0]], %[[alloc0]] +- +- // CHECK: %[[v1:.*]] = vector.mask {{.*}} { vector.transfer_read %[[s1]] +- // CHECK: %[[alloc1:.*]] = memref.alloc() : memref<7x5xf32, 3> +- // CHECK: vector.mask {{.*}} { vector.transfer_write %[[v1]], %[[alloc1]] +- +- // CHECK: %[[v2:.*]] = vector.mask {{.*}} { vector.transfer_read %[[s2]] +- // CHECK: %[[alloc2:.*]] = memref.alloc() : memref<4x5xf32, 3> +- // CHECK: vector.mask {{.*}} { vector.transfer_write %[[v2]], %[[alloc0]] +- +- // CHECK: linalg.matmul ins(%[[alloc0]], %[[alloc1]] : {{.*}}) outs(%[[alloc2]] : {{.*}}) +- // Copy back result. +- // CHECK: memref.copy %[[alloc2]], %[[s2]] +- %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> +- +- // insert_slice bufferizes to a no-op. +- %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> +- func.return %5 : tensor<24x25xf32> ++module @payload attributes { transform.target_tag = "payload" } { ++ func.func @vectorize_and_bufferize_pad(%arg0: tensor<24x12xf32>, ++ %arg1: tensor<12x25xf32>, ++ %arg2: tensor<24x25xf32>, ++ %iv0 : index, %iv1 : index, ++ %iv2 : index) -> tensor<24x25xf32> { ++ %0 = affine.min #map()[%iv2] ++ ++ // CHECK: %[[s0:.*]] = memref.subview %[[arg0]] ++ %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> ++ // CHECK: %[[s1:.*]] = memref.subview %[[arg1]] ++ %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor ++ // CHECK: %[[s2:.*]] = memref.subview %[[arg2]] ++ %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> ++ ++ // CHECK: %[[v0:.*]] = vector.mask {{.*}} { vector.transfer_read %[[s0]] ++ // CHECK: %[[alloc0:.*]] = memref.alloc() : memref<4x7xf32, 3> ++ // CHECK: vector.mask {{.*}} { vector.transfer_write %[[v0]], %[[alloc0]] ++ ++ // CHECK: %[[v1:.*]] = vector.mask {{.*}} { vector.transfer_read %[[s1]] ++ // CHECK: %[[alloc1:.*]] = memref.alloc() : memref<7x5xf32, 3> ++ // CHECK: vector.mask {{.*}} { vector.transfer_write %[[v1]], %[[alloc1]] ++ ++ // CHECK: %[[v2:.*]] = vector.mask {{.*}} { vector.transfer_read %[[s2]] ++ // CHECK: %[[alloc2:.*]] = memref.alloc() : memref<4x5xf32, 3> ++ // CHECK: vector.mask {{.*}} { vector.transfer_write %[[v2]], %[[alloc0]] ++ ++ // CHECK: linalg.matmul ins(%[[alloc0]], %[[alloc1]] : {{.*}}) outs(%[[alloc2]] : {{.*}}) ++ // Copy back result. ++ // CHECK: memref.copy %[[alloc2]], %[[s2]] ++ %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> ++ ++ // insert_slice bufferizes to a no-op. ++ %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> ++ func.return %5 : tensor<24x25xf32> ++ } + } + + module attributes {transform.with_named_sequence} { +diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir +index 4b38db79bff3..0439844dc66c 100644 +--- a/mlir/test/Dialect/Vector/transform-vector.mlir ++++ b/mlir/test/Dialect/Vector/transform-vector.mlir +@@ -1,16 +1,18 @@ +-// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s ++// RUN: mlir-opt --transform-interpreter="debug-payload-root-tag=payload" %s --split-input-file | FileCheck %s + + // CHECK-LABEL: func @matmul_tensors +-func.func @matmul_tensors( +- %arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>, %arg2: tensor<8x32xf32>) +- -> tensor<8x32xf32> { +-// CHECK-NOT: linalg +-// CHECK: vector.extract {{.*}} : vector<4xf32> from vector<8x4xf32> +-// CHECK: vector.store {{.*}} : memref<8x32xf32>, vector<4xf32> +- %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x16xf32>, tensor<16x32xf32>) +- outs(%arg2: tensor<8x32xf32>) +- -> tensor<8x32xf32> +- return %0 : tensor<8x32xf32> ++module @payload attributes { transform.target_tag = "payload" } { ++ func.func @matmul_tensors( ++ %arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>, %arg2: tensor<8x32xf32>) ++ -> tensor<8x32xf32> { ++ // CHECK-NOT: linalg ++ // CHECK: vector.extract {{.*}} : vector<4xf32> from vector<8x4xf32> ++ // CHECK: vector.store {{.*}} : memref<8x32xf32>, vector<4xf32> ++ %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x16xf32>, tensor<16x32xf32>) ++ outs(%arg2: tensor<8x32xf32>) ++ -> tensor<8x32xf32> ++ return %0 : tensor<8x32xf32> ++ } + } + + module attributes {transform.with_named_sequence} { +@@ -76,11 +78,13 @@ module attributes {transform.with_named_sequence} { + // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} + // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xf16>, vector<64x64xf16> into vector<64x64xf32> + // CHECK-NEXT: return %[[R]] : vector<64x64xf32> +-func.func @fold_arith_extf_into_contract(%arg0: vector<64x64xf16>, %arg1: vector<64x64xf16>, %arg2: vector<64x64xf32>) -> vector<64x64xf32> { +- %lhs_f32 = arith.extf %arg0 : vector<64x64xf16> to vector<64x64xf32> +- %rhs_f32 = arith.extf %arg1 : vector<64x64xf16> to vector<64x64xf32> +- %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs_f32, %rhs_f32, %arg2 : vector<64x64xf32>, vector<64x64xf32> into vector<64x64xf32> +- return %result : vector<64x64xf32> ++module @payload attributes { transform.target_tag = "payload" } { ++ func.func @fold_arith_extf_into_contract(%arg0: vector<64x64xf16>, %arg1: vector<64x64xf16>, %arg2: vector<64x64xf32>) -> vector<64x64xf32> { ++ %lhs_f32 = arith.extf %arg0 : vector<64x64xf16> to vector<64x64xf32> ++ %rhs_f32 = arith.extf %arg1 : vector<64x64xf16> to vector<64x64xf32> ++ %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs_f32, %rhs_f32, %arg2 : vector<64x64xf32>, vector<64x64xf32> into vector<64x64xf32> ++ return %result : vector<64x64xf32> ++ } + } + + module attributes {transform.with_named_sequence} { +diff --git a/mlir/test/Examples/transform/ChH/full.mlir b/mlir/test/Examples/transform/ChH/full.mlir +index 259475ebdbf4..85dbf6702332 100644 +--- a/mlir/test/Examples/transform/ChH/full.mlir ++++ b/mlir/test/Examples/transform/ChH/full.mlir +@@ -1,8 +1,6 @@ +-// RUN: mlir-opt %s --transform-interpreter \ +-// RUN: --test-transform-dialect-erase-schedule \ +-// RUN: --math-uplift-to-fma \ +-// RUN: --convert-bufferization-to-memref \ +-// RUN: --test-lower-to-llvm |\ ++// RUN: mlir-opt %s --transform-interpreter="debug-payload-root-tag=payload" \ ++// RUN: --test-transform-dialect-erase-schedule |\ ++// RUN: mlir-opt -pass-pipeline='builtin.module(builtin.module(math-uplift-to-fma,convert-bufferization-to-memref,test-lower-to-llvm))' - |\ + // RUN: FileCheck %s + + // Fixed-size tensor types to be used in convolution. +@@ -19,6 +17,7 @@ + // tensors annotated with attributes from the `bufferization` dialect. These + // attributes hint the bufferization pass to assume buffers can be directly + // used for these tensors without reshaping. ++module @payload attributes { transform.target_tag = "payload" } { + func.func @conv( + %input: !tinput {bufferization.writable = false, + bufferization.access = "read", +@@ -84,7 +83,7 @@ func.func @conv( + + return %relued : !toutput + } +- ++} + // Module containing the transformation script to be applied. The attribute + // is required to correctly verify the use of named (macro-like) sequences. + module attributes { transform.with_named_sequence } { diff --git a/mlir/patches/FunctionOpInterface-mhlo.patch b/mlir/patches/FunctionOpInterface-mhlo.patch new file mode 100644 index 0000000000..74fe5f3dd9 --- /dev/null +++ b/mlir/patches/FunctionOpInterface-mhlo.patch @@ -0,0 +1,21 @@ +diff --git a/transforms/bufferize_pass.cc b/transforms/bufferize_pass.cc +index 1e810cff2..c91c49710 100644 +--- a/transforms/bufferize_pass.cc ++++ b/transforms/bufferize_pass.cc +@@ -66,6 +66,7 @@ limitations under the License. + #include "mlir/IR/Operation.h" + #include "mlir/IR/PatternMatch.h" + #include "mlir/IR/Visitors.h" ++#include "mlir/Interfaces/FunctionInterfaces.h" + #include "mlir/Support/LLVM.h" + #include "mlir/Transforms/DialectConversion.h" + #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +@@ -235,7 +236,7 @@ struct OneShotBufferizePass + opts.allowReturnAllocsFromLoops = true; + opts.bufferizeFunctionBoundaries = true; + opts.functionArgTypeConverterFn = +- [=](TensorType tensorType, Attribute memorySpace, func::FuncOp funcOp, ++ [=](TensorType tensorType, Attribute memorySpace, FunctionOpInterface funcOp, + const bufferization::BufferizationOptions& options) { + // Functions created by fusion outlining should have fully dynamic + // layout. All other functions (for now only "main") gets static diff --git a/mlir/patches/callOp-bufferization.patch b/mlir/patches/callOp-bufferization.patch new file mode 100644 index 0000000000..2b7180fdd7 --- /dev/null +++ b/mlir/patches/callOp-bufferization.patch @@ -0,0 +1,71 @@ +diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +index 053ea7935260a2..9fbe574ec392dc 100644 +--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp ++++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +@@ -258,20 +258,23 @@ struct CallOpInterface + return failure(); + Value buffer = *maybeBuffer; + +- // Caller / callee type mismatch is handled with a CastOp. ++ // Caller / callee type mismatch is handled with castOrReallocMemRefValue. + auto memRefType = funcType.getInput(opOperand.getOperandNumber()); + // Since we don't yet have a clear layout story, to_memref may + // conservatively turn tensors into more dynamic memref than necessary. + // If the memref type of the callee fails, introduce an extra memref.cast + // that will either canonicalize away or fail compilation until we can do +- // something better. ++ // something better. Insert a reallocation + copy if it cannot be ++ // statically guaranteed that a direct cast would be valid. + if (buffer.getType() != memRefType) { +- assert( +- memref::CastOp::areCastCompatible(buffer.getType(), memRefType) && +- "CallOp::bufferize: cast incompatible"); +- Value castBuffer = rewriter.create(callOp.getLoc(), +- memRefType, buffer); +- buffer = castBuffer; ++ auto memrefDstType = dyn_cast(memRefType); ++ assert(memrefDstType && ++ "buffer layout not supported on unranked tensors"); ++ FailureOr replacement = bufferization::castOrReallocMemRefValue( ++ rewriter, buffer, memrefDstType, options); ++ if (failed(replacement)) ++ return failure(); ++ buffer = *replacement; + } + newOperands.push_back(buffer); + } +diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +index 0248afb11f1672..0d5224514e3a02 100644 +--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir ++++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +@@ -71,6 +71,30 @@ func.func @return_extract_slice(%idx: index, %sz: index) -> (tensor<2x?xf32>) + + // ----- + ++// CHECK-NO-LAYOUT-MAP-LABEL: func.func @foo( ++// CHECK-NO-LAYOUT-MAP-SAME: %[[VAL_0:.*]]: memref<3x8xf16>) -> memref<3x8xf16> { ++// CHECK-NO-LAYOUT-MAP: return %[[VAL_0]] : memref<3x8xf16> ++// CHECK-NO-LAYOUT-MAP: } ++func.func @foo(%arg0: tensor<3x8xf16>) -> tensor<3x8xf16> { ++ return %arg0 : tensor<3x8xf16> ++} ++ ++// CHECK-NO-LAYOUT-MAP-LABEL: func.func @call_extract_slice( ++// CHECK-NO-LAYOUT-MAP-SAME: %[[VAL_0:.*]]: memref<4x8xf16>) -> memref<3x8xf16> { ++// CHECK-NO-LAYOUT-MAP: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][1, 0] [3, 8] [1, 1] : memref<4x8xf16> to memref<3x8xf16, strided<[8, 1], offset: 8>> ++// CHECK-NO-LAYOUT-MAP: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3x8xf16> ++// CHECK-NO-LAYOUT-MAP: memref.copy %[[VAL_1]], %[[VAL_2]] : memref<3x8xf16, strided<[8, 1], offset: 8>> to memref<3x8xf16> ++// CHECK-NO-LAYOUT-MAP: %[[VAL_3:.*]] = call @foo(%[[VAL_2]]) : (memref<3x8xf16>) -> memref<3x8xf16> ++// CHECK-NO-LAYOUT-MAP: return %[[VAL_3]] : memref<3x8xf16> ++// CHECK-NO-LAYOUT-MAP: } ++func.func @call_extract_slice(%arg0: tensor<4x8xf16>) -> (tensor<3x8xf16>) { ++ %0 = tensor.extract_slice %arg0[1, 0] [3, 8] [1, 1] : tensor<4x8xf16> to tensor<3x8xf16> ++ %1 = call @foo(%0) : (tensor<3x8xf16>) -> tensor<3x8xf16> ++ return %1 : tensor<3x8xf16> ++} ++ ++// ----- ++ + // CHECK-LABEL: func private @private_func + // CHECK-NO-LAYOUT-MAP-LABEL: func private @private_func(memref) -> f32 + func.func private @private_func(tensor) -> (f32) diff --git a/mlir/test/Gradient/PS_QuantumGradientTest.mlir b/mlir/test/Gradient/PS_QuantumGradientTest.mlir index c13ca339d4..8c8034981b 100644 --- a/mlir/test/Gradient/PS_QuantumGradientTest.mlir +++ b/mlir/test/Gradient/PS_QuantumGradientTest.mlir @@ -425,7 +425,7 @@ func.func @multi_res_circuit(%arg0: f64) -> (f64, tensor<2xf64>) attributes {qno %r = quantum.alloc(1) : !quantum.reg %q_0 = quantum.extract %r[%idx] : !quantum.reg -> !quantum.bit - // CHECK: [[SEL:%.+]] = bufferization.to_tensor [[SELBUFF]] : memref<0xindex> + // CHECK: [[SEL:%.+]] = bufferization.to_tensor [[SELBUFF]] restrict : memref<0xindex> // CHECK: [[EVALPOS:%.+]]:2 = call @multi_res_circuit.shifted(%arg0, [[SHIFTPOS]], [[SEL]]) : {{.+}} -> (f64, tensor<2xf64>) // CHECK: [[EVALNEG:%.+]]:2 = call @multi_res_circuit.shifted(%arg0, [[SHIFTNEG]], [[SEL]]) : {{.+}} -> (f64, tensor<2xf64>) // CHECK: [[DIFF0:%.+]] = arith.subf [[EVALPOS]]#0, [[EVALNEG]]#0 diff --git a/mlir/tools/quantum-opt/quantum-opt.cpp b/mlir/tools/quantum-opt/quantum-opt.cpp index 01365f7162..26e50520cb 100644 --- a/mlir/tools/quantum-opt/quantum-opt.cpp +++ b/mlir/tools/quantum-opt/quantum-opt.cpp @@ -25,12 +25,15 @@ #include "mhlo/IR/hlo_ops.h" #include "Catalyst/IR/CatalystDialect.h" +#include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h" #include "Catalyst/Transforms/Passes.h" #include "Gradient/IR/GradientDialect.h" +#include "Gradient/Transforms/BufferizableOpInterfaceImpl.h" #include "Gradient/Transforms/Passes.h" #include "Mitigation/IR/MitigationDialect.h" #include "Mitigation/Transforms/Passes.h" #include "Quantum/IR/QuantumDialect.h" +#include "Quantum/Transforms/BufferizableOpInterfaceImpl.h" #include "Quantum/Transforms/Passes.h" namespace test { @@ -55,6 +58,10 @@ int main(int argc, char **argv) registry.insert(); registry.insert(); + catalyst::registerBufferizableOpInterfaceExternalModels(registry); + catalyst::quantum::registerBufferizableOpInterfaceExternalModels(registry); + catalyst::gradient::registerBufferizableOpInterfaceExternalModels(registry); + return mlir::asMainReturnCode( mlir::MlirOptMain(argc, argv, "Quantum optimizer driver\n", registry)); }