diff --git a/.github/workflows/build-wheel-linux-arm64.yaml b/.github/workflows/build-wheel-linux-arm64.yaml
index 304b23ee5d..9e6d70465a 100644
--- a/.github/workflows/build-wheel-linux-arm64.yaml
+++ b/.github/workflows/build-wheel-linux-arm64.yaml
@@ -63,7 +63,7 @@ jobs:
uses: actions/cache@v4
with:
path: mlir/llvm-project
- key: llvm-${{ needs.constants.outputs.llvm_version }}-default-source
+ key: llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-source
enableCrossOsArchive: True
- name: Cache MHLO Source
@@ -71,7 +71,7 @@ jobs:
uses: actions/cache@v4
with:
path: mlir/mlir-hlo
- key: mhlo-${{ needs.constants.outputs.mhlo_version }}-default-source
+ key: mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-source
enableCrossOsArchive: True
- name: Cache Enzyme Source
@@ -112,14 +112,14 @@ jobs:
uses: actions/cache/restore@v4
with:
path: llvm-build
- key: ${{ matrix.container_name }}-llvm-${{ needs.constants.outputs.llvm_version }}-wheel-build
+ key: ${{ matrix.container_name }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build
- name: Restore MHLO Build
id: cache-mhlo-build
uses: actions/cache/restore@v4
with:
path: mhlo-build
- key: ${{ matrix.container_name }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build
+ key: ${{ matrix.container_name }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build
lookup-only: True
- name: Restore Enzyme Build
@@ -160,7 +160,7 @@ jobs:
uses: actions/cache/save@v4
with:
path: llvm-build
- key: ${{ matrix.container_name }}-llvm-${{ needs.constants.outputs.llvm_version }}-wheel-build
+ key: ${{ matrix.container_name }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build
- name: Build MHLO Dialect
if: steps.cache-mhlo-build.outputs.cache-hit != 'true'
@@ -179,7 +179,7 @@ jobs:
uses: actions/cache/save@v4
with:
path: mhlo-build
- key: ${{ matrix.container_name }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build
+ key: ${{ matrix.container_name }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build
- name: Build Enzyme
if: steps.cache-enzyme-build.outputs.cache-hit != 'true'
@@ -240,7 +240,7 @@ jobs:
uses: actions/cache/restore@v4
with:
path: llvm-build
- key: ${{ matrix.container_name }}-llvm-${{ needs.constants.outputs.llvm_version }}-wheel-build
+ key: ${{ matrix.container_name }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build
fail-on-cache-miss: True
- name: Get Cached MHLO Source
@@ -257,7 +257,7 @@ jobs:
uses: actions/cache/restore@v4
with:
path: mhlo-build
- key: ${{ matrix.container_name }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build
+ key: ${{ matrix.container_name }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build
fail-on-cache-miss: True
- name: Get Cached Enzyme Source
@@ -334,7 +334,7 @@ jobs:
uses: actions/cache@v4
with:
path: llvm-build
- key: ${{ matrix.container_name }}-llvm-${{ needs.constants.outputs.llvm_version }}-wheel-build
+ key: ${{ matrix.container_name }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build
fail-on-cache-miss: True
- name: Run Python Pytest Tests
diff --git a/.github/workflows/build-wheel-linux-x86_64.yaml b/.github/workflows/build-wheel-linux-x86_64.yaml
index 825ce2f65e..f151c3832f 100644
--- a/.github/workflows/build-wheel-linux-x86_64.yaml
+++ b/.github/workflows/build-wheel-linux-x86_64.yaml
@@ -118,14 +118,14 @@ jobs:
uses: actions/cache/restore@v4
with:
path: llvm-build
- key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{matrix.python_version}}-wheel-build
+ key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-${{matrix.python_version}}-wheel-build
- name: Restore MHLO Build
id: cache-mhlo-build
uses: actions/cache/restore@v4
with:
path: mhlo-build
- key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build
+ key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build
lookup-only: True
- name: Restore Enzyme Build
@@ -172,6 +172,11 @@ jobs:
if: steps.cache-llvm-build.outputs.cache-hit != 'true'
run: |
export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH
+
+ # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher).
+ if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch; fi
+ if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch; fi
+
cmake -S mlir/llvm-project/llvm -B llvm-build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_BUILD_EXAMPLES=OFF \
@@ -198,17 +203,18 @@ jobs:
uses: actions/cache/save@v4
with:
path: llvm-build
- key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{matrix.python_version}}-wheel-build
+ key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-${{matrix.python_version}}-wheel-build
- name: Build MHLO Dialect
if: steps.cache-mhlo-build.outputs.cache-hit != 'true'
# building with LLD is a strong requirement for mhlo
run: |
+ # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher).
export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH
-
export TARGET_FILE=mlir/mlir-hlo/mhlo/transforms/CMakeLists.txt
export PATCH_FILE=mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch
if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi
+ if patch --dry-run -p1 -N --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch; fi
cmake -S mlir/mlir-hlo -B mhlo-build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
@@ -228,7 +234,7 @@ jobs:
uses: actions/cache/save@v4
with:
path: mhlo-build
- key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build
+ key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build
- name: Build Enzyme
if: steps.cache-enzyme-build.outputs.cache-hit != 'true'
@@ -296,7 +302,7 @@ jobs:
uses: actions/cache/restore@v4
with:
path: llvm-build
- key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-3.10-wheel-build
+ key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-3.10-wheel-build
fail-on-cache-miss: True
- name: Get Cached MHLO Source
@@ -313,7 +319,7 @@ jobs:
uses: actions/cache/restore@v4
with:
path: mhlo-build
- key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build
+ key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build
fail-on-cache-miss: True
- name: Get Cached Enzyme Source
@@ -363,7 +369,12 @@ jobs:
# Build Quantum and Gradient Dialects
- name: Build MLIR Dialects
run: |
+ # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher).
export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH
+ if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch; fi
+ if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch; fi
+ if patch --dry-run -p1 -N --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch; fi
+
cmake -S mlir -B quantum-build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml
index 5d6be17222..307feb8b05 100644
--- a/.github/workflows/build-wheel-macos-arm64.yaml
+++ b/.github/workflows/build-wheel-macos-arm64.yaml
@@ -105,14 +105,14 @@ jobs:
uses: actions/cache/restore@v4
with:
path: llvm-build
- key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{matrix.python_version}}-wheel-build
+ key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-${{matrix.python_version}}-wheel-build
- name: Restore MHLO Build
id: cache-mhlo-build
uses: actions/cache/restore@v4
with:
path: mhlo-build
- key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build
+ key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build
lookup-only: True
- name: Restore Enzyme Build
@@ -137,6 +137,10 @@ jobs:
- name: Build LLVM / MLIR
if: steps.cache-llvm-build.outputs.cache-hit != 'true'
run: |
+ # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher).
+ if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch; fi
+ if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch; fi
+
cmake -S mlir/llvm-project/llvm -B llvm-build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_BUILD_EXAMPLES=OFF \
@@ -163,16 +167,18 @@ jobs:
uses: actions/cache/save@v4
with:
path: llvm-build
- key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{matrix.python_version}}-wheel-build
+ key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-${{matrix.python_version}}-wheel-build
- name: Build MHLO Dialect
if: steps.cache-mhlo-build.outputs.cache-hit != 'true'
run: |
export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH
+ # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher).
export TARGET_FILE=mlir/mlir-hlo/mhlo/transforms/CMakeLists.txt
export PATCH_FILE=mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch
if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi
+ if patch --dry-run -p1 -N --directory=mlir/mlir-hlo < mlir/patches/moduleOp-mhlo.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/mlir-hlo < mlir/patches/moduleOp-mhlo.patch; fi
cmake -S mlir/mlir-hlo -B mhlo-build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
@@ -192,7 +198,7 @@ jobs:
uses: actions/cache/save@v4
with:
path: mhlo-build
- key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build
+ key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build
- name: Build Enzyme
if: steps.cache-enzyme-build.outputs.cache-hit != 'true'
@@ -256,7 +262,7 @@ jobs:
uses: actions/cache/restore@v4
with:
path: llvm-build
- key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-3.10-wheel-build
+ key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-3.10-wheel-build
fail-on-cache-miss: True
- name: Get Cached MHLO Source
@@ -273,7 +279,7 @@ jobs:
uses: actions/cache/restore@v4
with:
path: mhlo-build
- key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build
+ key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build
fail-on-cache-miss: True
- name: Get Cached Enzyme Source
@@ -328,6 +334,11 @@ jobs:
# Build Quantum and Gradient Dialects
- name: Build MLIR Dialects
run: |
+ # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher).
+ if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch; fi
+ if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch; fi
+ if patch --dry-run -p1 -N --directory=mlir/mlir-hlo < mlir/patches/moduleOp-mhlo.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/mlir-hlo < mlir/patches/moduleOp-mhlo.patch; fi
+
cmake -S mlir -B quantum-build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
diff --git a/.github/workflows/build-wheel-macos-x86_64.yaml b/.github/workflows/build-wheel-macos-x86_64.yaml
index eebc346fd4..70d7dfeab2 100644
--- a/.github/workflows/build-wheel-macos-x86_64.yaml
+++ b/.github/workflows/build-wheel-macos-x86_64.yaml
@@ -103,14 +103,14 @@ jobs:
uses: actions/cache/restore@v4
with:
path: llvm-build
- key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{matrix.python_version}}-wheel-build
+ key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-${{matrix.python_version}}-wheel-build
- name: Restore MHLO Build
id: cache-mhlo-build
uses: actions/cache/restore@v4
with:
path: mhlo-build
- key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build
+ key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build
lookup-only: True
- name: Restore Enzyme Build
@@ -133,6 +133,10 @@ jobs:
- name: Build LLVM / MLIR
if: steps.cache-llvm-build.outputs.cache-hit != 'true'
run: |
+ # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher).
+ if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch; fi
+ if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch; fi
+
cmake -S mlir/llvm-project/llvm -B llvm-build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_BUILD_EXAMPLES=OFF \
@@ -159,16 +163,18 @@ jobs:
uses: actions/cache/save@v4
with:
path: llvm-build
- key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{matrix.python_version}}-wheel-build
+ key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-${{matrix.python_version}}-wheel-build
- name: Build MHLO Dialect
if: steps.cache-mhlo-build.outputs.cache-hit != 'true'
run: |
export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH
+ # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher).
export TARGET_FILE=mlir/mlir-hlo/mhlo/transforms/CMakeLists.txt
export PATCH_FILE=mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch
if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi
+ if patch --dry-run -p1 -N --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch; fi
cmake -S mlir/mlir-hlo -B mhlo-build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
@@ -188,7 +194,7 @@ jobs:
uses: actions/cache/save@v4
with:
path: mhlo-build
- key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build
+ key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build
- name: Build Enzyme
if: steps.cache-enzyme-build.outputs.cache-hit != 'true'
@@ -246,7 +252,7 @@ jobs:
uses: actions/cache/restore@v4
with:
path: llvm-build
- key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-3.10-wheel-build
+ key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-3.10-wheel-build
fail-on-cache-miss: True
- name: Get Cached MHLO Source
@@ -263,7 +269,7 @@ jobs:
uses: actions/cache/restore@v4
with:
path: mhlo-build
- key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build
+ key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-wheel-build
fail-on-cache-miss: True
- name: Get Cached Enzyme Source
@@ -319,6 +325,11 @@ jobs:
# Build Quantum and Gradient Dialects
- name: Build MLIR Dialects
run: |
+ # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher).
+ if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch; fi
+ if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch; fi
+ if patch --dry-run -p1 -N --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch; fi
+
cmake -S mlir -B quantum-build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml
index 73f7becf16..6981040958 100644
--- a/.github/workflows/check-catalyst.yaml
+++ b/.github/workflows/check-catalyst.yaml
@@ -119,7 +119,7 @@ jobs:
uses: actions/cache@v4
with:
path: llvm-build
- key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-default-build-${{ matrix.compiler }}
+ key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-build-${{ matrix.compiler }}
- name: Install Deps
if: steps.cache-llvm-build.outputs.cache-hit != 'true'
@@ -175,7 +175,7 @@ jobs:
uses: actions/cache@v4
with:
path: mhlo-build
- key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-default-build-${{ matrix.compiler }}
+ key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-build-${{ matrix.compiler }}
- name: Get Cached LLVM Source
id: cache-llvm-source
@@ -193,7 +193,7 @@ jobs:
uses: actions/cache@v4
with:
path: llvm-build
- key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-default-build-${{ matrix.compiler }}
+ key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-build-${{ matrix.compiler }}
fail-on-cache-miss: true
- name: Install Deps
@@ -263,7 +263,7 @@ jobs:
uses: actions/cache@v4
with:
path: llvm-build
- key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-default-build-${{ matrix.compiler }}
+ key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-build-${{ matrix.compiler }}
fail-on-cache-miss: true
- name: Install Deps
@@ -315,7 +315,7 @@ jobs:
uses: actions/cache/restore@v4
with:
path: llvm-build
- key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-default-build-${{ matrix.compiler }}
+ key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-build-${{ matrix.compiler }}
fail-on-cache-miss: true
- name: Get Cached MHLO Source
@@ -332,7 +332,7 @@ jobs:
uses: actions/cache/restore@v4
with:
path: mhlo-build
- key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-default-build-${{ matrix.compiler }}
+ key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-build-${{ matrix.compiler }}
fail-on-cache-miss: true
- name: Get Cached Enzyme Source
@@ -363,6 +363,11 @@ jobs:
- name: Build MLIR Dialects
run: |
+ # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher).
+ if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch; fi
+ if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch; fi
+ if patch --dry-run -p1 -N --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch; fi
+
CCACHE_DIR="$(pwd)/.ccache" \
C_COMPILER=$(which ${{ needs.constants.outputs[format('c_compiler.{0}', matrix.compiler)] }}) \
CXX_COMPILER=$(which ${{ needs.constants.outputs[format('cxx_compiler.{0}', matrix.compiler)] }}) \
@@ -419,7 +424,7 @@ jobs:
uses: actions/cache@v4
with:
path: llvm-build
- key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-default-build-${{ matrix.compiler }}
+ key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-build-${{ matrix.compiler }}
fail-on-cache-miss: true
- name: Download Quantum Build Artifact
@@ -491,7 +496,7 @@ jobs:
uses: actions/cache@v4
with:
path: llvm-build
- key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-default-build-${{ matrix.compiler }}
+ key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-build-${{ matrix.compiler }}
fail-on-cache-miss: true
- name: Download Quantum Build Artifact
@@ -546,7 +551,7 @@ jobs:
uses: actions/cache@v4
with:
path: llvm-build
- key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-default-build-${{ matrix.compiler }}
+ key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-patch-${{ hashFiles('mlir/patches/**') }}-default-build-${{ matrix.compiler }}
fail-on-cache-miss: true
- name: Download Quantum Build Artifact
diff --git a/.github/workflows/scripts/linux_arm64/rh8/build_catalyst.sh b/.github/workflows/scripts/linux_arm64/rh8/build_catalyst.sh
index bb517160f6..e666e5ecb9 100644
--- a/.github/workflows/scripts/linux_arm64/rh8/build_catalyst.sh
+++ b/.github/workflows/scripts/linux_arm64/rh8/build_catalyst.sh
@@ -37,6 +37,11 @@ export PATH=/catalyst/llvm-build/bin:/opt/_internal/cpython-${PYTHON_VERSION}.${
# Install python dependencies
/usr/bin/python3 -m pip install pennylane pybind11 PyYAML cmake ninja delocate 'amazon-braket-pennylane-plugin>1.27.1'
+# Patch LLVM and MHLO. TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher).
+if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/FunctionOpInterface-bufferization.patch; fi
+if patch --dry-run -p1 -N --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/llvm-project < mlir/patches/callOp-bufferization.patch; fi
+if patch --dry-run -p1 -N --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch > /dev/null 2>&1; then patch -p1 --directory=mlir/mlir-hlo < mlir/patches/FunctionOpInterface-mhlo.patch; fi
+
# Build Catalyst runtime
cmake -S runtime -B runtime-build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
diff --git a/.github/workflows/scripts/linux_arm64/rh8/build_llvm.sh b/.github/workflows/scripts/linux_arm64/rh8/build_llvm.sh
index 78bb6aadb8..b4ee206580 100644
--- a/.github/workflows/scripts/linux_arm64/rh8/build_llvm.sh
+++ b/.github/workflows/scripts/linux_arm64/rh8/build_llvm.sh
@@ -33,6 +33,10 @@ export PATH=/opt/_internal/cpython-${PYTHON_VERSION}.${PYTHON_SUBVERSION}/bin:/o
# Install python dependencies
/usr/bin/python3 -m pip install pennylane pybind11 PyYAML cmake ninja
+# TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher).
+if patch --dry-run -p1 -N --directory=/catalyst/mlir/llvm-project < /catalyst/mlir/patches/FunctionOpInterface-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=/catalyst/mlir/llvm-project < /catalyst/mlir/patches/FunctionOpInterface-bufferization.patch; fi
+if patch --dry-run -p1 -N --directory=/catalyst/mlir/llvm-project < /catalyst/mlir/patches/callOp-bufferization.patch > /dev/null 2>&1; then patch -p1 --directory=/catalyst/mlir/llvm-project < /catalyst/mlir/patches/callOp-bufferization.patch; fi
+
# Build LLVM
cmake -S /catalyst/mlir/llvm-project/llvm -B /catalyst/llvm-build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
diff --git a/.github/workflows/scripts/linux_arm64/rh8/build_mhlo.sh b/.github/workflows/scripts/linux_arm64/rh8/build_mhlo.sh
index 2a5b2e4fa7..e452ee22d4 100644
--- a/.github/workflows/scripts/linux_arm64/rh8/build_mhlo.sh
+++ b/.github/workflows/scripts/linux_arm64/rh8/build_mhlo.sh
@@ -38,7 +38,10 @@ sed -i -e 's/LINK_LIBS PUBLIC/LINK_LIBS PUBLIC MLIRDeallocationUtils/g' mlir/mli
export TARGET_FILE=mlir/mlir-hlo/mhlo/transforms/CMakeLists.txt
export PATCH_FILE=mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch
+# TODO: Jax has merged this fix. Remove after JAX upgrade.
if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi
+# TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher).
+if patch --dry-run -p1 -N --directory=/catalyst/mlir/mlir-hlo < /catalyst/mlir/patches/FunctionOpInterface-mhlo.patch > /dev/null 2>&1; then patch -p1 --directory=/catalyst/mlir/mlir-hlo < /catalyst/mlir/patches/FunctionOpInterface-mhlo.patch; fi
# Build MHLO
cmake -S /catalyst/mlir/mlir-hlo -B /catalyst/mhlo-build -G Ninja \
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 3412e495bb..6ba963c728 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -6,6 +6,8 @@
Improvements 🛠
+* [(#1027)](https://github.com/PennyLaneAI/catalyst/pull/1027) Catalyst now supports `one-shot bufferize` from MLIR, which is required for JAX v0.4.29 or higher.
+
Breaking changes 💔
Deprecations 👋
@@ -16,4 +18,7 @@
Contributors ✍️
-This release contains contributions from (in alphabetical order):
\ No newline at end of file
+This release contains contributions from (in alphabetical order):
+
+Tzung-Han Juang,
+Erick Ochoa Lopez,
diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py
index 5d3d117469..8a15985b3b 100644
--- a/frontend/catalyst/compiler.py
+++ b/frontend/catalyst/compiler.py
@@ -224,41 +224,177 @@ def run_writing_command(command: List[str], compile_options: Optional[CompileOpt
],
)
+# From: https://mlir.llvm.org/docs/Bufferization/#overview
+#
+# Preprocessing
+# | rewrite_in_destination_passing_style
+# | -eliminate-empty-tensors
+# Bufferization
+# | -one-shot-bufferize
+# Buffer-Level
+# Optimizations
+# | -buffer-hoisting
+# | -buffer-loop-hoisting
+# | -buffer-results-to-out-params
+# | -drop-equivalent-buffer-results
+# | -promote-buffers-to-stack
+# Deallocation
+# | -buffer-deallocation-pipeline
+
BUFFERIZATION_PASS = (
"BufferizationPass",
[
- "one-shot-bufferize{dialect-filter=memref}",
"inline",
"gradient-preprocess",
- "gradient-bufferize",
- "scf-bufferize",
- "convert-tensor-to-linalg", # tensor.pad
- "convert-elementwise-to-linalg", # Must be run before --arith-bufferize
- "arith-bufferize",
- "empty-tensor-to-alloc-tensor",
- "func.func(bufferization-bufferize)",
- "func.func(tensor-bufferize)",
- "catalyst-bufferize", # Must be run before -- func.func(linalg-bufferize)
- "func.func(linalg-bufferize)",
- "func.func(tensor-bufferize)",
- "quantum-bufferize",
- "func-bufferize",
- "func.func(finalizing-bufferize)",
- "canonicalize", # Remove dead memrefToTensorOp's
- "gradient-postprocess",
+ "convert-elementwise-to-linalg",
+ "canonicalize",
+ # Preprocessing:
+ # rewrite_in_destination_passing_style
+ #
+ # We are not rewriting everything in DPS before -one-shot-bufferize
+ # This was discussed with the main author of the -one-shot-bufferize
+ # pass and he stated the following:
+ #
+ # One-Shot Bufferize was designed for ops that are in DPS (destination-passing style).
+ # Ops that are not in DPS can still be bufferized,
+ # but a new buffer will be allocated for every tensor result.
+ # That’s functionally correct but inefficient.
+ #
+ # I’m not sure whether it’s better to first migrate to the new bufferization,
+ # then turn the ops into DPS ops, or do it the other way around.
+ # One benefit of implementing the bufferization first is that
+ # it’s a smaller step that you can already run end-to-end.
+ # And you can think of the DPS of a performance improvement on top of it.
+ #
+ # https://discourse.llvm.org/t/steps-of-migrating-to-one-shot-bufferization/81062/2
+ #
+ # Here, please note that gradient-preprocessing is different than rewriting in DPS.
+ # So, overall, we are skipping this section while we first focus on migrating to the
+ # new -one-shot-bufferize
+ "eliminate-empty-tensors",
+ (
+ # Before we enter one-shot-bufferize, here is what we expect:
+ # * Given
+ #
+ # One-Shot Bufferize was designed for ops that are in DPS
+ # (destination-passing style).
+ # Ops that are not in DPS can still be bufferized,
+ # but a new buffer will be allocated for every tensor result.
+ # That’s functionally correct but inefficient.
+ #
+ # https://discourse.llvm.org/t/steps-of-migrating-to-one-shot-bufferization/81062/2
+ #
+ # we expect that results will be (automatically?) converted into new buffers. And it
+ # is up to us to just define the bufferization for the operands.
+ #
+ # So what is the state of the catalyst, gradient, quantum dialects at this point?
+ #
+ # Let's start with quantum:
+ #
+ # |-------------------------|--------------------|
+ # | operation | has result tensor |
+ # |-------------------------|--------------------|
+ # | quantum.set_state | |
+ # | quantum.set_basis_state | |
+ # | quantum.unitary | |
+ # | quantum.hermitian | |
+ # | quantum.hamiltonian | |
+ # | quantum.sample_op | YES |
+ # | quantum.counts_op | YES |
+ # | quantum.probs_op | YES |
+ # | quantum.state_op | YES |
+ # |-------------------------|--------------------|
+ # | catalyst.print_op | |
+ # | catalyst.custom_call | YES |
+ # | catalyst.callback | |
+ # | catalyst.callback_call | YES |
+ # | catalyst.launch_kernel | YES |
+ # |-------------------------|--------------------|
+ # | gradient.grad | YES |
+ # | gradient.value_and_grad | YES |
+ # | gradient.adjoint | YES |
+ # | gradient.backprop | YES |
+ # | gradient.jvp | YES |
+ # | gradient.vjp | YES |
+ # | gradient.forward | YES |
+ # | gradient.reverse | YES |
+ # |-------------------------|--------------------|
+ #
+ # So what this means is that for the operands, all the ones that have the YES
+ # means that no operands are written to. They are only read.
+ "one-shot-bufferize"
+ "{"
+ "bufferize-function-boundaries "
+ # - Bufferize function boundaries (experimental).
+ #
+ # By default, function boundaries are not bufferized.
+ # This is because there are currently limitations around function graph
+ # bufferization:
+ # recursive calls are not supported.
+ # As long as there are no recursive calls, function boundary bufferization can be
+ # enabled with bufferize-function-boundaries.
+ # Each tensor function argument and tensor function result is then turned into a memref.
+ # The layout map of the memref type can be controlled with function-boundary-type-conversion.
+ #
+ # https://mlir.llvm.org/docs/Bufferization/#using-one-shot-bufferize
+ "allow-return-allocs-from-loops "
+ # - Allows returning/yielding new allocations from a loop.
+ # https://github.com/llvm/llvm-project/pull/83964
+ # https://github.com/llvm/llvm-project/pull/87594
+ "function-boundary-type-conversion=identity-layout-map"
+ # - Controls layout maps when bufferizing function signatures.
+ # You can control the memref types at the function boundary with
+ # function-boundary-type-conversion. E.g., if you set it to identity-layout-map,
+ # you should get the same type as with --func-bufferize.
+ # By default, we put a fully dynamic layout map strided<[?, ?], offset: ?>
+ # because that works best if you don't know what layout map the buffers at
+ # the call site have -- you can always cast a buffer to a type with
+ # fully dynamic layout map. (But not the other way around. That may require a
+ # reallocation.)
+ #
+ # https://discord.com/channels/636084430946959380/642426447167881246/1212338527824515102
+ "}"
+ ),
+ # Remove dead memrefToTensorOp's
# introduced during gradient-bufferize of callbacks
+ # TODO: Figure out how to remove this.
+ "gradient-postprocess",
"func.func(buffer-hoisting)",
"func.func(buffer-loop-hoisting)",
+ # TODO: Figure out how to include the other buffer-level optimizations.
+ # -buffer-results-to-out-params,
+ # -drop-equivalent-buffer-results,
+ # -promote-buffers-to-stack
+ # Deallocation
+ # The buffer deallocation pass has been deprecated in favor of the
+ # ownership-based buffer deallocation pipeline.
+ # The deprecated pass has some limitations that may cause memory leaks in the resulting IR.
+ # TODO: Switch to one-shot-bufferization once it is merged.
"func.func(buffer-deallocation)",
+ # catalyst.list_* operations are not bufferized through
+ # the bufferization interface
+ # This is because they store a memref inside of a memref
+ # which is incompatible with the bufferization pipeline.
"convert-arraylist-to-memref",
"convert-bufferization-to-memref",
- "canonicalize", # Must be after convert-bufferization-to-memref
+ # Must be after convert-bufferization-to-memref
# otherwise there are issues in lowering of dynamic tensors.
+ "canonicalize",
# "cse",
"cp-global-memref",
],
)
+BUFFERIZATION_ASYNC_PASS = (
+ "BufferizationPass",
+ [
+ # TODO: Can we remove copy-before-write?
+ # copy-before-write:
+ # Skip the analysis. Make a buffer copy on every write.
+ s.replace("}", " copy-before-write}") if s.startswith("one-shot-bufferize") else s
+ for s in BUFFERIZATION_PASS[1]
+ ],
+)
MLIR_TO_LLVM_PASS = (
"MLIRToLLVMDialect",
@@ -328,7 +464,7 @@ def run_writing_command(command: List[str], compile_options: Optional[CompileOpt
ENFORCE_RUNTIME_INVARIANTS_PASS,
HLO_LOWERING_PASS,
QUANTUM_COMPILATION_PASS,
- BUFFERIZATION_PASS,
+ BUFFERIZATION_ASYNC_PASS,
MLIR_TO_LLVM_ASYNC_PASS,
]
diff --git a/mlir/Makefile b/mlir/Makefile
index 2cb8de388a..3d4e737f89 100644
--- a/mlir/Makefile
+++ b/mlir/Makefile
@@ -12,6 +12,12 @@ ENZYME_BUILD_DIR?=$(MK_DIR)/Enzyme/build
RT_BUILD_DIR?=$(MK_DIR)/../runtime/build
ENABLE_ASAN?=OFF
BUILD_TYPE?=Release
+# TODO: remove after JAX upgrade
+LLVM_ROOT=$(MK_DIR)/llvm-project
+LLVM_FUNCOP_PATCH_FILE=$(MK_DIR)/patches/FunctionOpInterface-bufferization.patch
+LLVM_FUNC_CALL_PATCH_FILE=$(MK_DIR)/patches/callOp-bufferization.patch
+MHLO_ROOT?=$(MK_DIR)/mlir-hlo
+MHLO_MODULE_PATCH_FILE=$(MK_DIR)/patches/FunctionOpInterface-mhlo.patch
TARGET_FILE=$(MK_DIR)/mlir-hlo/mhlo/transforms/CMakeLists.txt
PATCH_FILE=$(MK_DIR)/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch
@@ -53,7 +59,14 @@ all: llvm mhlo enzyme dialects
.PHONY: llvm
llvm:
+ # TODO: Remove these patches after upgrading Jax (potentailly for 0.4.34 or higher).
@echo "build LLVM and MLIR enabling Python bindings"
+ @if patch --dry-run -p1 -N --directory=$(LLVM_ROOT) < $(LLVM_FUNCOP_PATCH_FILE) > /dev/null 2>&1; then \
+ patch -p1 --directory=$(LLVM_ROOT) < $(LLVM_FUNCOP_PATCH_FILE); \
+ fi
+ @if patch --dry-run -p1 -N --directory=$(LLVM_ROOT) < $(LLVM_FUNC_CALL_PATCH_FILE) > /dev/null 2>&1; then \
+ patch -p1 --directory=$(LLVM_ROOT) < $(LLVM_FUNC_CALL_PATCH_FILE); \
+ fi
cmake -G Ninja -S llvm-project/llvm -B $(LLVM_BUILD_DIR) \
-DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \
-DLLVM_BUILD_EXAMPLES=OFF \
@@ -85,6 +98,10 @@ mhlo:
@if patch --dry-run -p1 -N $(TARGET_FILE) $(PATCH_FILE) > /dev/null 2>&1; then \
patch -p1 $(TARGET_FILE) $(PATCH_FILE); \
fi
+ # TODO: Remove this patch after upgrading Jax (potentailly for 0.4.34 or higher).
+ @if patch --dry-run -p1 -N --directory=$(MHLO_ROOT) < $(MHLO_MODULE_PATCH_FILE) > /dev/null 2>&1; then \
+ patch -p1 --directory=$(MHLO_ROOT) < $(MHLO_MODULE_PATCH_FILE); \
+ fi
cmake -G Ninja -S mlir-hlo -B $(MHLO_BUILD_DIR) \
-DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \
-DLLVM_ENABLE_ASSERTIONS=ON \
diff --git a/mlir/include/Catalyst/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/Catalyst/Transforms/BufferizableOpInterfaceImpl.h
new file mode 100644
index 0000000000..ec20d6f6c9
--- /dev/null
+++ b/mlir/include/Catalyst/Transforms/BufferizableOpInterfaceImpl.h
@@ -0,0 +1,9 @@
+#pragma once
+
+using namespace mlir;
+
+namespace catalyst {
+
+void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry ®istry);
+
+} // namespace catalyst
diff --git a/mlir/include/Gradient/IR/GradientOps.h b/mlir/include/Gradient/IR/GradientOps.h
index c6f6afadfe..a54e110043 100644
--- a/mlir/include/Gradient/IR/GradientOps.h
+++ b/mlir/include/Gradient/IR/GradientOps.h
@@ -21,6 +21,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "Gradient/IR/GradientInterfaces.h"
diff --git a/mlir/include/Gradient/IR/GradientOps.td b/mlir/include/Gradient/IR/GradientOps.td
index e2d7660e38..3a73403b6d 100644
--- a/mlir/include/Gradient/IR/GradientOps.td
+++ b/mlir/include/Gradient/IR/GradientOps.td
@@ -17,6 +17,7 @@
include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/OpBase.td"
@@ -388,7 +389,7 @@ def ReverseOp : Gradient_Op<"reverse",
}
def ReturnOp : Gradient_Op<"return",
- [Terminator, ParentOneOf<["ForwardOp", "ReverseOp"]>]> {
+ [ReturnLike, Terminator, ParentOneOf<["ForwardOp", "ReverseOp"]>]> {
let summary = "Return tapes or nothing";
diff --git a/mlir/include/Gradient/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/Gradient/Transforms/BufferizableOpInterfaceImpl.h
new file mode 100644
index 0000000000..ae5096eb39
--- /dev/null
+++ b/mlir/include/Gradient/Transforms/BufferizableOpInterfaceImpl.h
@@ -0,0 +1,13 @@
+#pragma once
+
+using namespace mlir;
+
+namespace catalyst {
+
+namespace gradient {
+
+void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry ®istry);
+
+}
+
+} // namespace catalyst
diff --git a/mlir/include/Quantum/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/Quantum/Transforms/BufferizableOpInterfaceImpl.h
new file mode 100644
index 0000000000..bf60013f70
--- /dev/null
+++ b/mlir/include/Quantum/Transforms/BufferizableOpInterfaceImpl.h
@@ -0,0 +1,13 @@
+#pragma once
+
+using namespace mlir;
+
+namespace catalyst {
+
+namespace quantum {
+
+void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry ®istry);
+
+}
+
+} // namespace catalyst
diff --git a/mlir/lib/Bufferization.md b/mlir/lib/Bufferization.md
new file mode 100644
index 0000000000..14eb74599b
--- /dev/null
+++ b/mlir/lib/Bufferization.md
@@ -0,0 +1,23 @@
+**Bufferization Interfaces:**
+
+| Bufferizable Operations | PrintOp | CustomCallOp | CallbackOp | CallbackCallOp | AdjointOp | BackpropOp | ForwardOp | ReverseOp | QubitUnitaryOp | HermitianOp | HamiltonianOp | SampleOp | StateOp | ProbsOp | CountsOp | SetStateOp | SetBasisStateOp |
+| --------------------------------| ---------| ------------ | ------------ | -------------- | --------- | ---------- | --------- | --------- | -------------- | ----------- | ------------- | -------- | ------- | ------- | -------- | ---------- | --------------- |
+| Catagory | catalyst | catalyst | catalyst | catalyst | gradient | gradient | gradient | gradient | quantum | quantum | quantum | quantum | quantum | quantum | quantum | quantum | quantum |
+| bufferizesToAllocation | | true | | true | | | | | | | | | | | | | |
+| bufferizesToMemoryRead | true | true | | true | true | true | | | true | true | true | false | false | false | false | false | false |
+| bufferizesToMemoryWrite | false | false | | false | false | true | | | false | false | false | false | false | false | false | false | false |
+| bufferizesToElementwiseAccess | | | | | | | | | | | | | | | | | |
+| resultBufferizesToMemoryWrite | | | | | | | | | | | | | | | | | |
+| mustBufferizeInPlace | | | | | | | | | | | | | | | | | |
+| getAliasingValues | {} | {} | | {} | {} | {} | | | {} | {} | {} | {} | {} | {} | {} | {} | {} |
+| getAliasingOpOperands | | | {} | | | | v | v | | | | | | | | | |
+| resolveConflicts | | | | | | | | | | | | | | | | | |
+| bufferize | v | v | v | v | v | v | v | v | v | v | v | v | v | v | v | v | v |
+| isWritable | | | | | | | | | | | | | | | | | |
+| isNotConflicting | | | | | | | | | | | | | | | | | |
+| verifyAnalysis | | | | | | | v | v | | | | | | | | | |
+| getBufferType | | | | | | | v | v | | | | | | | | | |
+| isRepetitiveRegion | | | | | | | | | | | | | | | | | |
+| isParallelRegion | | | | | | | | | | | | | | | | | |
+| hasTensorSemantics | | | v | | | | v | v | | | | | | | | | |
+| supportsUnstructuredControlFlow | | | false | | | | true | true | | | | | | | | | |
\ No newline at end of file
diff --git a/mlir/lib/Catalyst/IR/CatalystDialect.cpp b/mlir/lib/Catalyst/IR/CatalystDialect.cpp
index 1dce30c4cc..01419fddcb 100644
--- a/mlir/lib/Catalyst/IR/CatalystDialect.cpp
+++ b/mlir/lib/Catalyst/IR/CatalystDialect.cpp
@@ -14,6 +14,7 @@
#include "Catalyst/IR/CatalystDialect.h"
#include "Catalyst/IR/CatalystOps.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h" // needed for generated type parser
#include "mlir/Interfaces/FunctionImplementation.h"
@@ -40,6 +41,8 @@ void CatalystDialect::initialize()
#define GET_OP_LIST
#include "Catalyst/IR/CatalystOps.cpp.inc"
>();
+ declarePromisedInterfaces();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp b/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp
index a0c27129cb..e811668e89 100644
--- a/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp
+++ b/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp
@@ -215,7 +215,7 @@ std::optional AsyncUtils::getCalleeSafe(LLVM::CallOp callOp)
bool AsyncUtils::isFunctionNamed(LLVM::LLVMFuncOp funcOp, llvm::StringRef expectedName)
{
llvm::StringRef observedName = funcOp.getSymName();
- return observedName.equals(expectedName);
+ return observedName.compare(expectedName) == 0;
}
bool AsyncUtils::isMlirAsyncRuntimeCreateValue(LLVM::LLVMFuncOp funcOp)
diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp
new file mode 100644
index 0000000000..c9033716d5
--- /dev/null
+++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -0,0 +1,293 @@
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "Catalyst/IR/CatalystOps.h"
+#include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h"
+
+using namespace mlir;
+using namespace catalyst;
+
+namespace {
+/**
+ * The new bufferization interface requires `bufferizesToMemoryRead`, `bufferizesToMemoryWrite`,
+ * and `getAliasingValues`.
+ *
+ * `bufferizesToMemoryRead`: Return `true` if the buffer of the given tensor OpOperand is read.
+ *
+ * `bufferizesToMemoryWrite`: Return `true` if the buffer of the given tensor OpOperand is written
+ * (if bufferizing in-place).
+ *
+ * `getAliasingOpOperands`: Return the OpResults that may share the same buffer as the given
+ * OpOperand.
+ *
+ * Link: https://mlir.llvm.org/docs/Bufferization/#extending-one-shot-bufferize
+ */
+
+/// Bufferization of catalyst.print. Get memref of printOp.val.
+struct PrintOpInterface
+ // PrintOp will never write to the buffers
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto printOp = cast(op);
+ if (printOp.getVal()) {
+ FailureOr source = getBuffer(rewriter, printOp.getVal(), options);
+ if (failed(source))
+ return failure();
+ bufferization::replaceOpWithNewBufferizedOp(
+ rewriter, op, *source, printOp.getConstValAttr(), printOp.getPrintDescriptorAttr());
+ }
+ return success();
+ }
+};
+
+/// Bufferization of catalyst.custom_call. Mainly get buffers for arguments.
+struct CustomCallOpInterface
+ // CustomCallOp will interface with BLAS functions.
+ // This operations is not in DPS form. This means that
+ // if we can guarantee operands are never written to, then we can set
+ // bufferizesToMemoryWrite as false.
+ // Results will be allocated a new buffer.
+ // TODO: Double check BLAS and others. Until then, it should be safe to keep
+ // bufferizesToMemoryWrite as True.
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+
+ bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
+
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return true;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto customCallOp = cast(op);
+
+ // Add bufferized arguments
+ SmallVector bufferArgs;
+ ValueRange operands = customCallOp.getOperands();
+ for (Value operand : operands) {
+ FailureOr opBuffer = getBuffer(rewriter, operand, options);
+ if (failed(opBuffer))
+ bufferArgs.push_back(operand);
+ else
+ bufferArgs.push_back(*opBuffer);
+ }
+
+ // Add bufferized return values to the arguments
+ ValueRange results = customCallOp.getResults();
+ for (Value result : results) {
+ Type resultType = result.getType();
+ RankedTensorType tensorType = dyn_cast(resultType);
+ if (!tensorType) {
+ bufferArgs.push_back(result);
+ continue;
+ }
+ auto options = bufferization::BufferizationOptions();
+ FailureOr tensorAlloc = bufferization::allocateTensorForShapedValue(
+ rewriter, op->getLoc(), result, options, false);
+ MemRefType memrefType =
+ MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ auto newBuffer =
+ rewriter.create(op->getLoc(), memrefType, *tensorAlloc);
+ bufferArgs.push_back(newBuffer);
+ }
+
+ // Add the initial number of arguments
+ int32_t numArguments = static_cast(customCallOp.getNumOperands());
+ DenseI32ArrayAttr numArgumentsDenseAttr = rewriter.getDenseI32ArrayAttr({numArguments});
+
+ // Create an updated custom call operation
+ rewriter.create(op->getLoc(), TypeRange{}, bufferArgs,
+ customCallOp.getCallTargetName(), numArgumentsDenseAttr);
+ size_t startIndex = bufferArgs.size() - customCallOp.getNumResults();
+ SmallVector bufferResults(bufferArgs.begin() + startIndex, bufferArgs.end());
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, bufferResults);
+
+ return success();
+ }
+};
+
+struct CallbackOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool hasTensorSemantics(Operation *op) const
+ {
+ auto isaTensor = llvm::IsaPred;
+
+ // A function has tensor semantics if it has tensor arguments/results.
+ auto callbackOp = cast(op);
+ bool hasTensorArg = any_of(callbackOp.getArgumentTypes(), isaTensor);
+ bool hasTensorResult = any_of(callbackOp.getResultTypes(), isaTensor);
+ if (hasTensorArg || hasTensorResult)
+ return true;
+
+ return false;
+ }
+
+ bufferization::AliasingOpOperandList
+ getAliasingOpOperands(Operation *op, Value value,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto callbackOp = cast(op);
+
+ auto argTys = callbackOp.getArgumentTypes();
+ auto retTys = callbackOp.getResultTypes();
+ SmallVector emptyRets;
+ SmallVector args(argTys.begin(), argTys.end());
+ args.insert(args.end(), retTys.begin(), retTys.end());
+ SmallVector bufferArgs;
+ for (Type ty : args) {
+ auto tensorType = dyn_cast(ty);
+ if (!tensorType)
+ bufferArgs.push_back(ty);
+ else
+ bufferArgs.push_back(
+ MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
+ }
+ auto callbackTy = rewriter.getFunctionType(bufferArgs, emptyRets);
+ rewriter.modifyOpInPlace(op, [&] { callbackOp.setFunctionType(callbackTy); });
+
+ return success();
+ }
+};
+
+struct CallbackCallOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
+
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ // We can safely say false because CallbackCallOp's memrefs
+ // will be put in a JAX array and JAX arrays are immutable.
+ //
+ // Unlike NumPy arrays, JAX arrays are always immutable.
+ //
+ // https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto callOp = cast(op);
+
+ bufferization::BufferizeTypeConverter typeConverter;
+
+ SmallVector convertedResults;
+ if (failed(typeConverter.convertTypes(callOp.getResultTypes(), convertedResults)))
+ return failure();
+
+ if (callOp->getNumResults() != convertedResults.size())
+ return failure();
+
+ SmallVector newInputs;
+ auto operands = callOp.getOperands();
+ for (Value operand : operands) {
+ FailureOr opBuffer = getBuffer(rewriter, operand, options);
+ if (failed(opBuffer))
+ return failure();
+ newInputs.push_back(*opBuffer);
+ }
+
+ auto results = callOp.getResults();
+ auto loc = callOp->getLoc();
+ SmallVector outmemrefs;
+ for (auto result : results) {
+ FailureOr tensorAlloc =
+ bufferization::allocateTensorForShapedValue(rewriter, loc, result, options, false);
+ if (failed(tensorAlloc))
+ return failure();
+
+ auto tensor = *tensorAlloc;
+ RankedTensorType tensorTy = cast(tensor.getType());
+ auto shape = tensorTy.getShape();
+ auto elementTy = tensorTy.getElementType();
+ auto memrefType = MemRefType::get(shape, elementTy);
+ auto toMemrefOp = rewriter.create(loc, memrefType, tensor);
+ auto memref = toMemrefOp.getResult();
+ outmemrefs.push_back(memref);
+ newInputs.push_back(memref);
+ }
+
+ SmallVector emptyRets;
+ rewriter.create(loc, emptyRets, callOp.getCallee(), newInputs);
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, outmemrefs);
+ return success();
+ }
+};
+
+} // namespace
+
+void catalyst::registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
+{
+ registry.addExtension(+[](MLIRContext *ctx, CatalystDialect *dialect) {
+ CustomCallOp::attachInterface(*ctx);
+ PrintOp::attachInterface(*ctx);
+ CallbackOp::attachInterface(*ctx);
+ CallbackCallOp::attachInterface(*ctx);
+ });
+}
diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt
index 77fb4d64b5..07f221469f 100644
--- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt
+++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt
@@ -4,6 +4,12 @@ file(GLOB SRC
ApplyTransformSequencePass.cpp
ArrayListToMemRefPass.cpp
+ scatter_lowering.cpp
+ ScatterPatterns.cpp
+ qnode_to_async_lowering.cpp
+ QnodeToAsyncPatterns.cpp
+ RegisterAllPasses.cpp
+ BufferizableOpInterfaceImpl.cpp
AsyncUtils.cpp
BufferizationPatterns.cpp
catalyst_bufferize.cpp
diff --git a/mlir/lib/Driver/CompilerDriver.cpp b/mlir/lib/Driver/CompilerDriver.cpp
index 9fc3e6f9bf..a5f275adbc 100644
--- a/mlir/lib/Driver/CompilerDriver.cpp
+++ b/mlir/lib/Driver/CompilerDriver.cpp
@@ -56,6 +56,7 @@
#include "llvm/Transforms/IPO/GlobalDCE.h"
#include "Catalyst/IR/CatalystDialect.h"
+#include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h"
#include "Catalyst/Transforms/Passes.h"
#include "Driver/CatalystLLVMTarget.h"
#include "Driver/CompilerDriver.h"
@@ -63,10 +64,12 @@
#include "Driver/Support.h"
#include "Gradient/IR/GradientDialect.h"
#include "Gradient/IR/GradientInterfaces.h"
+#include "Gradient/Transforms/BufferizableOpInterfaceImpl.h"
#include "Gradient/Transforms/Passes.h"
#include "Mitigation/IR/MitigationDialect.h"
#include "Mitigation/Transforms/Passes.h"
#include "Quantum/IR/QuantumDialect.h"
+#include "Quantum/Transforms/BufferizableOpInterfaceImpl.h"
#include "Quantum/Transforms/Passes.h"
#include "Enzyme.h"
@@ -75,6 +78,7 @@
using namespace mlir;
using namespace catalyst;
using namespace catalyst::driver;
+using namespace catalyst::quantum;
namespace cl = llvm::cl;
namespace catalyst::utils {
@@ -296,6 +300,11 @@ void registerAllCatalystDialects(DialectRegistry ®istry)
registry.insert();
registry.insert();
registry.insert();
+
+ // Extend one-shot bufferization pass.
+ catalyst::registerBufferizableOpInterfaceExternalModels(registry);
+ catalyst::quantum::registerBufferizableOpInterfaceExternalModels(registry);
+ catalyst::gradient::registerBufferizableOpInterfaceExternalModels(registry);
}
} // namespace
diff --git a/mlir/lib/Gradient/IR/GradientDialect.cpp b/mlir/lib/Gradient/IR/GradientDialect.cpp
index 4d9cfddb00..068079b99f 100644
--- a/mlir/lib/Gradient/IR/GradientDialect.cpp
+++ b/mlir/lib/Gradient/IR/GradientDialect.cpp
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "Gradient/IR/GradientDialect.h"
@@ -50,6 +51,8 @@ void GradientDialect::initialize()
#include "Gradient/IR/GradientOps.cpp.inc"
>();
addInterface();
+ declarePromisedInterfaces();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp
new file mode 100644
index 0000000000..99b1b4396b
--- /dev/null
+++ b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -0,0 +1,614 @@
+#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "Gradient/IR/GradientOps.h"
+#include "Gradient/Transforms/BufferizableOpInterfaceImpl.h"
+#include "Gradient/Utils/GradientShape.h"
+#include "Quantum/IR/QuantumOps.h"
+#include "llvm/ADT/STLExtras.h"
+
+using namespace mlir;
+using namespace catalyst::gradient;
+
+namespace {
+/**
+ * The new bufferization interface requires `bufferizesToMemoryRead`, `bufferizesToMemoryWrite`,
+ * and `getAliasingValues`.
+ *
+ * `bufferizesToMemoryRead`: Return `true` if the buffer of the given tensor OpOperand is read.
+ *
+ * `bufferizesToMemoryWrite`: Return `true` if the buffer of the given tensor OpOperand is written
+ * (if bufferizing in-place).
+ *
+ * `getAliasingOpOperands`: Return the OpResults that may share the same buffer as the given
+ * OpOperand. Note that MLIR documentation does not mention `getAliasingValues` but it seems to
+ * serve the same purpose.
+ *
+ * Bufferizing FunctionOpInterface is also not documented by MLIR. It requires
+ * `OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel`, which requires the
+ * implementation of `supportsUnstructuredControlFlow`, `hasTensorSemantics`, and
+ * `getAliasingOpOperands`.
+ *
+ * Link: https://mlir.llvm.org/docs/Bufferization/#extending-one-shot-bufferize
+ */
+
+static BaseMemRefType
+getBufferizedFunctionArgType(FunctionOpInterface funcOp, int64_t index,
+ const bufferization::BufferizationOptions &options)
+{
+ auto tensorType = dyn_cast(funcOp.getArgument(index).getType());
+ assert(tensorType && "expected TensorType");
+
+ BaseMemRefType memrefType = options.functionArgTypeConverterFn(
+ tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
+
+ auto layoutAttr = funcOp.getArgAttrOfType(
+ index, bufferization::BufferizationDialect::kBufferLayoutAttrName);
+ if (!layoutAttr)
+ return memrefType;
+
+ auto rankedMemrefType = dyn_cast(memrefType);
+ assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
+ return MemRefType::get(rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
+ layoutAttr.getValue(), rankedMemrefType.getMemorySpace());
+}
+
+static ReturnOp getAssumedUniqueReturnOp(FunctionOpInterface funcOp)
+{
+ ReturnOp returnOp;
+ for (Block &b : funcOp.getFunctionBody()) {
+ if (auto candidateOp = dyn_cast(b.getTerminator())) {
+ if (returnOp)
+ return nullptr;
+ returnOp = candidateOp;
+ }
+ }
+ return returnOp;
+}
+
+Value generateAllocation(OpBuilder &builder, Location loc, Value reference)
+{
+ auto origMemrefType = cast(reference.getType());
+ // TODO: Investigate how to get rid of identity-layout-map
+ //
+ // Hi all. For one-shot-bufferization, is there any automatic way to pass all memref symbols
+ // to AllocOp? we have an example below that triggers error: 'memref.alloc' op symbol
+ // operand count does not equal memref symbol count: expected 1, got 0 . We think we have
+ // to pass the offset symbol to AllocOp.
+ //
+ // %0 = "bufferization.to_memref"(%arg0) : (tensor) -> memref> %1 = "memref.alloc"() <{operandSegmentSizes = array}> : () ->
+ // memref>
+ //
+ // We know we can set function-signature-type-conversion=identity-layout-map to get rid of
+ // it. But according to the document, identity-layout-map could be less efficient, we still
+ // want to stick with the default setting.
+ //
+ // https://discord.com/channels/636084430946959380/642426447167881246/1281620504859512914
+ //
+ // Something looks odd here.
+ // The result of a `memref.alloc` should be a memref without identity layout.
+ // I know that the op supports operands for dims/symbols in the memref type,
+ // but I never understood why.
+ // Imo, a `memref.alloc() : memref` should have been generated.
+ // The result value can then be casted to `memref>`.
+ //
+ // https://discord.com/channels/636084430946959380/642426447167881246/1281710682160627785
+ //
+ // What I find interesting is that the comment says that
+ //
+ // "we know we can set function-signature-type-conversion=identity-layout-map to get rid of
+ // it"
+ //
+ // and that is what we are using, however we still have this rebuilding a memref without the
+ // layout. If that were true, then we could uncomment the following line and it should work.
+ // auto memrefType = origMemrefType;
+ // I can confirm that having
+ // function-signature-type-conversion=identity-layout-map makes the line above succed while the
+ // line below fail:
+ //
+ // Get dynamic dimension sizes from the provided reference value if necessary.
+ auto memrefType = MemRefType::get(origMemrefType.getShape(), origMemrefType.getElementType());
+ //
+ // Looking at this a little bit deeper, I can say that the variable reference
+ // appears to come from a function parameter.
+ // and since it is not the identity layout, then we see the following generic MLIR when not
+ // using identity layout
+ //
+ // "func.func"() <{function_type = (memref>) -> memref>
+ //
+ // and we see this when using the identity layout:
+ //
+ // func.func public @jit_fn(%arg0: memref) -> memref
+ //
+ // When not using identity layout but also not removing the layout in the alloca, there are
+ // errors in some cases but not in others. I believe we have to do some casts in other places as
+ // well, whenever we use allocas and the types come from the arguments.
+ //
+ // My recommendation: at some point it would be good to remove the identity-layout-map from the
+ // frontend but until we have some more resources, let's keep it along with the origMemrefType.
+
+ SmallVector dynamicDims;
+ if (!memrefType.hasStaticShape()) {
+ for (int64_t dim = 0; dim < memrefType.getRank(); dim++) {
+ if (memrefType.isDynamicDim(dim)) {
+ Value dimIndex = builder.create(loc, dim);
+ dynamicDims.push_back(builder.create(loc, reference, dimIndex));
+ }
+ }
+ }
+
+ return builder.create(loc, memrefType, dynamicDims);
+ // Uncomment below to follow Matthias suggestion of placing a CastOp after AllocOp
+ // some more tests will pass.
+ // return builder.create(loc, origMemrefType, alloc_uncasted);
+}
+
+/// Helper function to generate a set of memref allocations.
+///
+/// The allocation size and shape is deduced from a list of existing memref values.
+///
+void generateAllocations(RewriterBase &rewriter, Location loc, SmallVectorImpl &allocations,
+ ValueRange referenceValues)
+{
+ for (Value memref : referenceValues) {
+ allocations.push_back(
+ generateAllocation(rewriter, loc, cast>(memref)));
+ }
+}
+
+struct AdjointOpInterface
+ // This operation is not in DPS style and I believe that operands will only be read.
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto adjointOp = cast(op);
+
+ bufferization::BufferizeTypeConverter typeConverter;
+
+ SmallVector resTypes;
+ if (failed(typeConverter.convertTypes(adjointOp.getResultTypes(), resTypes)))
+ return failure();
+
+ Location loc = adjointOp.getLoc();
+ Value gradSize = adjointOp.getGradSize();
+ SmallVector memrefValues;
+ for (Type resType : resTypes) {
+ MemRefType memrefType = cast(resType);
+ Value memrefValue = rewriter.create(loc, memrefType, gradSize);
+ memrefValues.push_back(memrefValue);
+ }
+
+ SmallVector bufferArgs;
+ ValueRange operands = adjointOp.getArgs();
+ for (Value operand : operands) {
+ FailureOr opBuffer = getBuffer(rewriter, operand, options);
+ if (failed(opBuffer))
+ return failure();
+ bufferArgs.push_back(*opBuffer);
+ }
+
+ rewriter.create(loc, TypeRange{}, adjointOp.getCalleeAttr(),
+ adjointOp.getGradSize(), bufferArgs, memrefValues);
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, memrefValues);
+ return success();
+ }
+};
+
+struct BackpropOpInterface
+ // This operation is not in DPS style
+ // but it has a lot of parameters, notably:
+ // Variadic: $args
+ // Variadic<...RankedTensorOf<[AnyFloat]>>: $cotangents
+ // I think we don't write to the cotangents. And also not to the arguments
+ // so we can set bufferizesToMemoryWrite as false.
+ // The safe assumption is that it should be true.
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto backpropOp = cast(op);
+
+ Location loc = backpropOp.getLoc();
+ SmallVector gradients;
+ SmallVector argShadows;
+ // Conceptually a map from scalar result indices (w.r.t. other scalars) to the position in
+ // the overall list of returned gradients.
+ // For instance, a backprop op that returns (tensor, f64, tensor, f64, f64) will have
+ // scalarIndices = {1, 3, 4}.
+ SmallVector scalarIndices;
+ SmallVector scalarReturnTypes;
+
+ SmallVector bufferArgs;
+ ValueRange operands = backpropOp.getArgs();
+ for (Value operand : operands) {
+ if (isa(operand.getType())) {
+ FailureOr opBuffer = getBuffer(rewriter, operand, options);
+ if (failed(opBuffer))
+ return failure();
+ bufferArgs.push_back(*opBuffer);
+ }
+ else {
+ bufferArgs.push_back(operand);
+ }
+ }
+
+ std::vector diffArgs =
+ computeDiffArgs(bufferArgs, backpropOp.getDiffArgIndicesAttr());
+
+ for (const auto &[idx, diffArg] : llvm::enumerate(diffArgs)) {
+ // Allocate buffers to place the differentiation results (gradients) into. Enzyme refers
+ // to these as shadow arguments. There is one result for each differentiable MemRef
+ // argument, with a matching shape and type.
+ if (isa(diffArg.getType())) {
+ Value shadow = generateAllocation(rewriter, loc, diffArg);
+ gradients.push_back(shadow);
+ argShadows.push_back(shadow);
+ }
+ else if (isa(diffArg.getType())) {
+ scalarReturnTypes.push_back(diffArg.getType());
+ scalarIndices.push_back(idx);
+ // Put a null placeholder value that will be filled in with the result of the
+ // bufferized BackpropOp.
+ gradients.push_back(Value());
+ }
+ }
+
+ // Enzyme requires buffers for the primal outputs as well, even though we don't need their
+ // values. We'll mark them dupNoNeed later on to allow Enzyme to optimize away their
+ // computation.
+ SmallVector calleeResults, resShadows;
+ ValueRange cotangents = backpropOp.getCotangents();
+ SmallVector bufferCotangentsList;
+ for (Value operand : cotangents) {
+ FailureOr opBuffer = getBuffer(rewriter, operand, options);
+ if (failed(opBuffer))
+ return failure();
+ bufferCotangentsList.push_back(*opBuffer);
+ }
+ mlir::ValueRange bufferCotangents(bufferCotangentsList);
+
+ generateAllocations(rewriter, loc, calleeResults, bufferCotangents);
+ // Enzyme mutates the result shadows but the cotangent tensors must be immutable, so we
+ // create copies to pass into Enzyme. Concretely, this issue pops up with multiple
+ // BackpropOps that have the same cotangent tensor due to a CSE effect from one-shot
+ // bufferization.
+ generateAllocations(rewriter, loc, resShadows, bufferCotangents);
+ for (const auto &[cotangent, resShadow] : llvm::zip(bufferCotangents, resShadows)) {
+ rewriter.create(loc, cotangent, resShadow);
+ }
+
+ DenseIntElementsAttr diffArgIndicesAttr = backpropOp.getDiffArgIndices().value_or(nullptr);
+ auto bufferizedBackpropOp = rewriter.create(
+ loc, TypeRange{}, scalarReturnTypes, backpropOp.getCalleeAttr(), bufferArgs, argShadows,
+ calleeResults, resShadows, diffArgIndicesAttr, backpropOp.getKeepValueResultsAttr());
+ // Fill in the null placeholders.
+ for (const auto &[idx, scalarResult] :
+ llvm::enumerate(bufferizedBackpropOp.getGradients())) {
+ gradients[scalarIndices[idx]] = scalarResult;
+ }
+
+ // BackpropOp can return two results for value_and_grad: values and gradients
+ // or only one for grad: gradients
+ SmallVector results;
+ {
+ // If we are lowering a value_and_grad operation, then take values from the
+ // calleeResults
+ if (!backpropOp.getVals().empty()) {
+ results.insert(results.end(), calleeResults.begin(), calleeResults.end());
+ }
+ results.insert(results.end(), gradients.begin(), gradients.end());
+ }
+
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, results);
+ return success();
+ }
+};
+
+struct ForwardOpInterface
+ : public bufferization::OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
+ ForwardOpInterface, ForwardOp> {
+ static bool supportsUnstructuredControlFlow() { return false; }
+
+ bool hasTensorSemantics(Operation *op) const
+ {
+ auto isaTensor = llvm::IsaPred;
+
+ // A function has tensor semantics if it has tensor arguments/results.
+ auto forwardOp = cast(op);
+ bool hasTensorArg = any_of(forwardOp.getArgumentTypes(), isaTensor);
+ bool hasTensorResult = any_of(forwardOp.getResultTypes(), isaTensor);
+ bool hasTensorFuncInType = any_of(forwardOp.getFunctionType().getInputs(), isaTensor);
+ bool hasTensorFuncOutType = any_of(forwardOp.getFunctionType().getResults(), isaTensor);
+ if (hasTensorArg || hasTensorResult || hasTensorFuncInType || hasTensorFuncOutType)
+ return true;
+
+ return false;
+ }
+
+ bufferization::AliasingOpOperandList
+ getAliasingOpOperands(Operation *op, Value value,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ FailureOr getBufferType(Operation *op, Value value,
+ const bufferization::BufferizationOptions &options,
+ SmallVector &invocationStack) const
+ {
+ auto forwardOp = cast(op);
+ auto bbArg = cast(value);
+
+ // Function arguments are special.
+ if (bbArg.getOwner() == &forwardOp.getBody().front())
+ return getBufferizedFunctionArgType(forwardOp, bbArg.getArgNumber(), options);
+
+ return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::getBufferType(
+ op, value, options, invocationStack);
+ }
+
+ LogicalResult verifyAnalysis(Operation *op, const bufferization::AnalysisState &state) const
+ {
+ auto forwardOp = cast(op);
+ // TODO: func.func with multiple returns are not supported.
+ if (!getAssumedUniqueReturnOp(forwardOp))
+ return op->emitOpError("op without unique func.return is not supported");
+ return success();
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto forwardOp = cast(op);
+ FunctionType funcType = forwardOp.getFunctionType();
+
+ // Construct the bufferized function type.
+ SmallVector argTypes;
+ for (const auto &it : llvm::enumerate(funcType.getInputs())) {
+ Type argType = it.value();
+ if (dyn_cast(argType)) {
+ argTypes.push_back(getBufferizedFunctionArgType(forwardOp, it.index(), options));
+ continue;
+ }
+ argTypes.push_back(argType);
+ }
+
+ ReturnOp returnOp = getAssumedUniqueReturnOp(forwardOp);
+ assert(returnOp && "expected func with single return op");
+ Location loc = returnOp.getLoc();
+
+ // 1. Bufferize every block.
+ for (Block &block : forwardOp.getBody())
+ if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options)))
+ return failure();
+
+ // 2. For each result, keep track of which inplace argument it reuses.
+ SmallVector returnValues;
+ for (OpOperand &returnOperand : returnOp->getOpOperands()) {
+ Value returnVal = returnOperand.get();
+ auto tensorType = dyn_cast(returnVal.getType());
+ rewriter.setInsertionPoint(returnOp);
+
+ // If not a tensor type just forward it.
+ if (!tensorType) {
+ returnValues.push_back(returnVal);
+ continue;
+ }
+
+ // Note: If `inferFunctionResultLayout = true`, cast are later folded
+ // away.
+ BaseMemRefType resultType = options.functionArgTypeConverterFn(
+ tensorType, *options.defaultMemorySpaceFn(tensorType), forwardOp, options);
+ Value toMemrefOp =
+ rewriter.create(loc, resultType, returnVal);
+ returnValues.push_back(toMemrefOp);
+ }
+
+ // 3. Rewrite the terminator.
+ forwardOp.walk([&](ReturnOp returnOp) {
+ PatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(returnOp);
+ rewriter.replaceOpWithNewOp(returnOp, returnValues, returnOp.getEmpty());
+ });
+
+ // 4. Rewrite the FuncOp type to buffer form. Also preserve unused return types.
+ SmallVector returnTypes;
+ for (auto retTy : forwardOp.getResultTypes()) {
+ auto tensorType = dyn_cast(retTy);
+ BaseMemRefType resultType = options.functionArgTypeConverterFn(
+ tensorType, *options.defaultMemorySpaceFn(tensorType), forwardOp, options);
+ returnTypes.push_back(resultType);
+ }
+ forwardOp.setType(FunctionType::get(op->getContext(), argTypes, returnTypes));
+
+ return success();
+ }
+};
+
+struct ReverseOpInterface
+ : public bufferization::OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
+ ReverseOpInterface, ReverseOp> {
+ static bool supportsUnstructuredControlFlow() { return false; }
+
+ bool hasTensorSemantics(Operation *op) const
+ {
+ auto isaTensor = llvm::IsaPred;
+
+ // A function has tensor semantics if it has tensor arguments/results.
+ auto reverseOp = cast(op);
+ bool hasTensorArg = any_of(reverseOp.getArgumentTypes(), isaTensor);
+ bool hasTensorResult = any_of(reverseOp.getResultTypes(), isaTensor);
+ bool hasTensorFuncInType = any_of(reverseOp.getFunctionType().getInputs(), isaTensor);
+ bool hasTensorFuncOutType = any_of(reverseOp.getFunctionType().getResults(), isaTensor);
+ if (hasTensorArg || hasTensorResult || hasTensorFuncInType || hasTensorFuncOutType)
+ return true;
+
+ return false;
+ }
+
+ bufferization::AliasingOpOperandList
+ getAliasingOpOperands(Operation *op, Value value,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ FailureOr getBufferType(Operation *op, Value value,
+ const bufferization::BufferizationOptions &options,
+ SmallVector &invocationStack) const
+ {
+ auto reverseOp = cast(op);
+ auto bbArg = cast(value);
+
+ // Function arguments are special.
+ if (bbArg.getOwner() == &reverseOp.getBody().front())
+ return getBufferizedFunctionArgType(reverseOp, bbArg.getArgNumber(), options);
+
+ return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::getBufferType(
+ op, value, options, invocationStack);
+ }
+
+ LogicalResult verifyAnalysis(Operation *op, const bufferization::AnalysisState &state) const
+ {
+ auto reverseOp = cast(op);
+ // TODO: func.func with multiple returns are not supported.
+ if (!getAssumedUniqueReturnOp(reverseOp))
+ return op->emitOpError("op without unique func.return is not supported");
+ return success();
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto reverseOp = cast(op);
+ FunctionType funcType = reverseOp.getFunctionType();
+
+ // Construct the bufferized function type.
+ SmallVector argTypes;
+ for (const auto &it : llvm::enumerate(funcType.getInputs())) {
+ Type argType = it.value();
+ if (dyn_cast(argType)) {
+ argTypes.push_back(getBufferizedFunctionArgType(reverseOp, it.index(), options));
+ continue;
+ }
+ argTypes.push_back(argType);
+ }
+
+ ReturnOp returnOp = getAssumedUniqueReturnOp(reverseOp);
+ assert(returnOp && "expected func with single return op");
+ Location loc = returnOp.getLoc();
+
+ // 1. Bufferize every block.
+ for (Block &block : reverseOp.getBody())
+ if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options)))
+ return failure();
+
+ // 2. For each result, keep track of which inplace argument it reuses.
+ SmallVector returnValues;
+ for (OpOperand &returnOperand : returnOp->getOpOperands()) {
+ Value returnVal = returnOperand.get();
+ auto tensorType = dyn_cast(returnVal.getType());
+ rewriter.setInsertionPoint(returnOp);
+
+ // If not a tensor type just forward it.
+ if (!tensorType) {
+ returnValues.push_back(returnVal);
+ continue;
+ }
+
+ // Note: If `inferFunctionResultLayout = true`, cast are later folded
+ // away.
+ BaseMemRefType resultType = options.functionArgTypeConverterFn(
+ tensorType, *options.defaultMemorySpaceFn(tensorType), reverseOp, options);
+ Value toMemrefOp =
+ rewriter.create(loc, resultType, returnVal);
+ returnValues.push_back(toMemrefOp);
+ }
+
+ // 3. Rewrite the terminator.
+ reverseOp.walk([&](ReturnOp returnOp) {
+ PatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(returnOp);
+ rewriter.replaceOpWithNewOp(returnOp, returnValues, returnOp.getEmpty());
+ });
+
+ // 4. Rewrite the FuncOp type to buffer form. Also preserve unused return types.
+ SmallVector returnTypes;
+ for (auto retTy : reverseOp.getResultTypes()) {
+ auto tensorType = dyn_cast(retTy);
+ BaseMemRefType resultType = options.functionArgTypeConverterFn(
+ tensorType, *options.defaultMemorySpaceFn(tensorType), reverseOp, options);
+ returnTypes.push_back(resultType);
+ }
+ reverseOp.setType(FunctionType::get(op->getContext(), argTypes, returnTypes));
+
+ return success();
+ }
+};
+
+} // namespace
+
+void catalyst::gradient::registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
+{
+ registry.addExtension(+[](MLIRContext *ctx, GradientDialect *dialect) {
+ AdjointOp::attachInterface(*ctx);
+ BackpropOp::attachInterface(*ctx);
+ ForwardOp::attachInterface(*ctx);
+ ReverseOp::attachInterface(*ctx);
+ });
+}
diff --git a/mlir/lib/Gradient/Transforms/CMakeLists.txt b/mlir/lib/Gradient/Transforms/CMakeLists.txt
index 6495dc6cba..6442cfd073 100644
--- a/mlir/lib/Gradient/Transforms/CMakeLists.txt
+++ b/mlir/lib/Gradient/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ set(LIBRARY_NAME gradient-transforms)
file(GLOB SRC
GradMethods/*.cpp
+ BufferizableOpInterfaceImpl.cpp
BufferizationPatterns.cpp
gradient_bufferize.cpp
PreprocessingPatterns.cpp
diff --git a/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp b/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp
index 46abee67ce..30baf98952 100644
--- a/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp
+++ b/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp
@@ -147,7 +147,8 @@ func::FuncOp genSplitPreprocessed(PatternRewriter &rewriter, Location loc, func:
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(&splitFn.getBody().front());
Value paramsBuffer = rewriter.create(loc, paramsBufferType, paramCount);
- Value paramsTensor = rewriter.create(loc, paramsBuffer);
+ Value paramsTensor =
+ rewriter.create(loc, paramsBuffer, /*restrict=*/true);
qnodeQuantumArgs.push_back(paramsTensor);
MemRefType paramsProcessedType = MemRefType::get({}, rewriter.getIndexType());
@@ -289,8 +290,8 @@ func::FuncOp genArgMapFunction(PatternRewriter &rewriter, Location loc, func::Fu
else if (auto returnOp = dyn_cast(op)) {
PatternRewriter::InsertionGuard insertionGuard(rewriter);
rewriter.setInsertionPoint(returnOp);
- Value paramsVector =
- rewriter.create(loc, paramsVectorType, paramsBuffer);
+ Value paramsVector = rewriter.create(
+ loc, paramsVectorType, paramsBuffer, /*restrict=*/true);
returnOp.getOperandsMutable().assign(paramsVector);
}
});
diff --git a/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp b/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp
index b0d48aaf01..19100ea8eb 100644
--- a/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp
+++ b/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp
@@ -58,7 +58,8 @@ static std::vector computePartialDerivative(PatternRewriter &rewriter, Lo
{
constexpr double shift = llvm::numbers::pi / 2;
ShapedType shiftVectorType = RankedTensorType::get({numShifts}, rewriter.getF64Type());
- Value selectorVector = rewriter.create(loc, selectorBuffer);
+ Value selectorVector =
+ rewriter.create(loc, selectorBuffer, /*restrict=*/true);
// Define the shift vectors (pos/neg) as sparse tensor constants.
DenseElementsAttr nonZeroIndices = rewriter.getI64TensorAttr(currentShift);
@@ -284,8 +285,8 @@ func::FuncOp ParameterShiftLowering::genQGradFunction(PatternRewriter &rewriter,
std::vector gradientTensors;
gradientTensors.reserve(gradResTypes.size());
for (Value gradientBuffer : gradientBuffers) {
- gradientTensors.push_back(
- rewriter.create(loc, gradientBuffer));
+ gradientTensors.push_back(rewriter.create(
+ loc, gradientBuffer, /*restrict=*/true));
}
op->setOperands(gradientTensors);
}
diff --git a/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp b/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp
index 5c79663f0d..16c0243352 100644
--- a/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp
+++ b/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp
@@ -14,12 +14,14 @@
#include "iostream"
#include "llvm/Support/raw_ostream.h"
+#include
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -41,10 +43,11 @@ struct PostprocessForwardOp : public OpRewritePattern {
// Check if the numbers of args and returns match Enzyme's format.
auto argc = op.getArgc();
auto resc = op.getResc();
- auto tapeCount = op.getTape();
+ auto tape = op.getTape();
- if (op.getFunctionType().getNumInputs() == (argc + resc) * 2 &&
- op.getFunctionType().getNumResults() == tapeCount)
+ // If function signature is modified, this pass cannot be processed.
+ if (op.getFunctionType().getNumInputs() != argc ||
+ op.getFunctionType().getNumResults() != (resc + tape))
return failure();
auto argTys = op.getArgumentTypes();
@@ -127,7 +130,9 @@ struct PostprocessReverseOp : public OpRewritePattern {
auto forwardResc = op.getResc();
auto tape = op.getTape();
- if (op.getFunctionType().getNumInputs() == (forwardArgc + forwardResc) * 2 + tape)
+ // If function signature is modified, this pass cannot be processed.
+ if (op.getFunctionType().getNumInputs() != (forwardResc + tape) ||
+ op.getFunctionType().getNumResults() != forwardArgc)
return failure();
auto argTys = op.getArgumentTypes();
@@ -212,4 +217,4 @@ void populatePostprocessingPatterns(RewritePatternSet &patterns)
}
} // namespace gradient
-} // namespace catalyst
\ No newline at end of file
+} // namespace catalyst
diff --git a/mlir/lib/Gradient/Transforms/PreprocessingPatterns.cpp b/mlir/lib/Gradient/Transforms/PreprocessingPatterns.cpp
index c41fb2d46b..ec6820b001 100644
--- a/mlir/lib/Gradient/Transforms/PreprocessingPatterns.cpp
+++ b/mlir/lib/Gradient/Transforms/PreprocessingPatterns.cpp
@@ -70,7 +70,6 @@ struct PreprocessReverseOp : public OpRewritePattern {
{
if (!op.getBody().empty())
return failure();
-
Block *block;
rewriter.modifyOpInPlace(op, [&] { block = op.addEntryBlock(); });
@@ -117,4 +116,4 @@ void populatePreprocessingPatterns(RewritePatternSet &patterns)
}
} // namespace gradient
-} // namespace catalyst
\ No newline at end of file
+} // namespace catalyst
diff --git a/mlir/lib/Gradient/Transforms/gradient_preprocess.cpp b/mlir/lib/Gradient/Transforms/gradient_preprocess.cpp
index a0a003991c..f470d5f8d9 100644
--- a/mlir/lib/Gradient/Transforms/gradient_preprocess.cpp
+++ b/mlir/lib/Gradient/Transforms/gradient_preprocess.cpp
@@ -55,4 +55,4 @@ std::unique_ptr createGradientPreprocessingPass()
return std::make_unique();
}
-} // namespace catalyst
\ No newline at end of file
+} // namespace catalyst
diff --git a/mlir/lib/Quantum/IR/QuantumDialect.cpp b/mlir/lib/Quantum/IR/QuantumDialect.cpp
index 385f4e0ae5..d4d820326f 100644
--- a/mlir/lib/Quantum/IR/QuantumDialect.cpp
+++ b/mlir/lib/Quantum/IR/QuantumDialect.cpp
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/IR/DialectImplementation.h" // needed for generated type parser
#include "llvm/ADT/TypeSwitch.h" // needed for generated type parser
@@ -43,6 +44,9 @@ void QuantumDialect::initialize()
#define GET_OP_LIST
#include "Quantum/IR/QuantumOps.cpp.inc"
>();
+ declarePromisedInterfaces();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp
new file mode 100644
index 0000000000..2e61dfe579
--- /dev/null
+++ b/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -0,0 +1,414 @@
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "Quantum/IR/QuantumOps.h"
+#include "Quantum/Transforms/BufferizableOpInterfaceImpl.h"
+
+using namespace mlir;
+using namespace catalyst::quantum;
+
+namespace {
+/**
+ * The new bufferization interface requires `bufferizesToMemoryRead`, `bufferizesToMemoryWrite`,
+ * and `getAliasingValues`.
+ *
+ * `bufferizesToMemoryRead`: Return `true` if the buffer of the given tensor OpOperand is read.
+ *
+ * `bufferizesToMemoryWrite`: Return `true` if the buffer of the given tensor OpOperand is written
+ * (if bufferizing in-place).
+ *
+ * `getAliasingOpOperands`: Return the OpResults that may share the same buffer as the given
+ * OpOperand. Note that MLIR documentation does not mention `getAliasingValues` but it seems to
+ * serve the same purpose.
+ *
+ * Link: https://mlir.llvm.org/docs/Bufferization/#extending-one-shot-bufferize
+ */
+
+/// Bufferization of catalyst.quantum.unitary. Convert Matrix into memref.
+struct QubitUnitaryOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto qubitUnitaryOp = cast(op);
+ Location loc = op->getLoc();
+ auto tensorType = cast(qubitUnitaryOp.getMatrix().getType());
+ MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ auto toMemrefOp =
+ rewriter.create(loc, memrefType, qubitUnitaryOp.getMatrix());
+ auto memref = toMemrefOp.getResult();
+ bufferization::replaceOpWithNewBufferizedOp(
+ rewriter, op, qubitUnitaryOp.getOutQubits().getTypes(),
+ qubitUnitaryOp.getOutCtrlQubits().getTypes(), memref, qubitUnitaryOp.getInQubits(),
+ qubitUnitaryOp.getAdjointAttr(), qubitUnitaryOp.getInCtrlQubits(),
+ qubitUnitaryOp.getInCtrlValues());
+ return success();
+ }
+};
+
+/// Bufferization of catalyst.quantum.hermitian. Convert Matrix into memref.
+struct HermitianOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto hermitianOp = cast(op);
+ Location loc = op->getLoc();
+ auto tensorType = cast(hermitianOp.getMatrix().getType());
+ MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ auto toMemrefOp =
+ rewriter.create(loc, memrefType, hermitianOp.getMatrix());
+ auto memref = toMemrefOp.getResult();
+ auto newHermitianOp = rewriter.create(loc, hermitianOp.getType(), memref,
+ hermitianOp.getQubits());
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, newHermitianOp.getObs());
+
+ return success();
+ }
+};
+
+/// Bufferization of catalyst.quantum.hamiltonian. Convert Matrix into memref.
+struct HamiltonianOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto hamiltonianOp = cast(op);
+ Location loc = op->getLoc();
+ auto tensorType = cast(hamiltonianOp.getCoeffs().getType());
+ MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ auto toMemrefOp =
+ rewriter.create(loc, memrefType, hamiltonianOp.getCoeffs());
+ auto memref = toMemrefOp.getResult();
+ auto newHamiltonianOp = rewriter.create(loc, hamiltonianOp.getType(), memref,
+ hamiltonianOp.getTerms());
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, newHamiltonianOp.getObs());
+
+ return success();
+ }
+};
+
+/// Bufferization of catalyst.quantum.sample. Replace with memref.alloc and a new
+/// catalyst.quantum.sample that uses the memory allocated by memref.alloc.
+struct SampleOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto sampleOp = cast(op);
+ Location loc = op->getLoc();
+ auto tensorType = cast(sampleOp.getSamples().getType());
+ MemRefType resultType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+ Value allocVal = rewriter.create(loc, resultType);
+ rewriter.create(loc, TypeRange{}, ValueRange{sampleOp.getObs(), allocVal},
+ sampleOp->getAttrs());
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, allocVal);
+
+ return success();
+ }
+};
+
+/// Bufferization of catalyst.quantum.state. Replace with memref.alloc and a new
+/// catalyst.quantum.state that uses the memory allocated by memref.alloc.
+struct StateOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto stateOp = cast(op);
+ Location loc = op->getLoc();
+ auto tensorType = cast(stateOp.getState().getType());
+ MemRefType resultType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+ Value allocVal = rewriter.create(loc, resultType);
+ rewriter.create(loc, TypeRange{}, ValueRange{stateOp.getObs(), allocVal});
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, allocVal);
+
+ return success();
+ }
+};
+
+/// Bufferization of catalyst.quantum.probs. Replace with memref.alloc and a new
+/// catalyst.quantum.probs that uses the memory allocated by memref.alloc.
+struct ProbsOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto probsOp = cast(op);
+ Location loc = op->getLoc();
+ auto tensorType = cast(probsOp.getProbabilities().getType());
+ MemRefType resultType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+ Value allocVal = rewriter.create(loc, resultType);
+ rewriter.create(loc, TypeRange{}, ValueRange{probsOp.getObs(), allocVal});
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, allocVal);
+
+ return success();
+ }
+};
+
+/// Bufferization of catalyst.quantum.counts. Replace with memref.allocs and a new
+/// catalyst.quantum.counts that uses the memory allocated by memref.allocs.
+struct CountsOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto countsOp = cast(op);
+ Location loc = op->getLoc();
+ auto tensorType0 = cast(countsOp.getEigvals().getType());
+ auto tensorType1 = cast(countsOp.getCounts().getType());
+ MemRefType resultType0 =
+ MemRefType::get(tensorType0.getShape(), tensorType0.getElementType());
+ MemRefType resultType1 =
+ MemRefType::get(tensorType1.getShape(), tensorType1.getElementType());
+
+ Value allocVal0 = rewriter.create(loc, resultType0);
+ Value allocVal1 = rewriter.create(loc, resultType1);
+ rewriter.create(loc, nullptr, nullptr, countsOp.getObs(), allocVal0, allocVal1,
+ countsOp.getShotsAttr());
+ bufferization::replaceOpWithBufferizedValues(rewriter, op,
+ ValueRange{allocVal0, allocVal1});
+
+ return success();
+ }
+};
+
+/// Bufferization of catalyst.quantum.set_state. Convert InState into memref.
+struct SetStateOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto setStateOp = cast(op);
+ Location loc = op->getLoc();
+ auto tensorType = cast(setStateOp.getInState().getType());
+ MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+ auto toMemrefOp =
+ rewriter.create(loc, memrefType, setStateOp.getInState());
+ auto memref = toMemrefOp.getResult();
+ auto newSetStateOp = rewriter.create(loc, setStateOp.getOutQubits().getTypes(),
+ memref, setStateOp.getInQubits());
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, newSetStateOp.getOutQubits());
+ return success();
+ }
+};
+
+/// Bufferization of catalyst.quantum.set_basic_state. Convert BasisState into memref.
+struct SetBasisStateOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto setBasisStateOp = cast(op);
+ Location loc = op->getLoc();
+ auto tensorType = cast(setBasisStateOp.getBasisState().getType());
+ MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+ auto toMemrefOp = rewriter.create(
+ loc, memrefType, setBasisStateOp.getBasisState());
+ auto memref = toMemrefOp.getResult();
+ auto newSetStateOp = rewriter.create(
+ loc, setBasisStateOp.getOutQubits().getTypes(), memref, setBasisStateOp.getInQubits());
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, newSetStateOp.getOutQubits());
+ return success();
+ }
+};
+
+} // namespace
+
+void catalyst::quantum::registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
+{
+ registry.addExtension(+[](MLIRContext *ctx, QuantumDialect *dialect) {
+ QubitUnitaryOp::attachInterface(*ctx);
+ HermitianOp::attachInterface(*ctx);
+ HamiltonianOp::attachInterface(*ctx);
+ SampleOp::attachInterface(*ctx);
+ StateOp::attachInterface(*ctx);
+ ProbsOp::attachInterface(*ctx);
+ CountsOp::attachInterface(*ctx);
+ SetStateOp::attachInterface(*ctx);
+ SetBasisStateOp::attachInterface(*ctx);
+ });
+}
\ No newline at end of file
diff --git a/mlir/lib/Quantum/Transforms/CMakeLists.txt b/mlir/lib/Quantum/Transforms/CMakeLists.txt
index 96ba30d23e..2504275410 100644
--- a/mlir/lib/Quantum/Transforms/CMakeLists.txt
+++ b/mlir/lib/Quantum/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
set(LIBRARY_NAME quantum-transforms)
file(GLOB SRC
+ BufferizableOpInterfaceImpl.cpp
BufferizationPatterns.cpp
quantum_bufferize.cpp
ConversionPatterns.cpp
diff --git a/mlir/lib/Quantum/Transforms/cp_global_buffers.cpp b/mlir/lib/Quantum/Transforms/cp_global_buffers.cpp
index b461dc8a60..160adb70d6 100644
--- a/mlir/lib/Quantum/Transforms/cp_global_buffers.cpp
+++ b/mlir/lib/Quantum/Transforms/cp_global_buffers.cpp
@@ -87,12 +87,16 @@ llvm::SmallVector getReturnMemRefs(func::ReturnOp op)
*/
Value allocCopyMemrefDyn(Location loc, Value memref, PatternRewriter &rewriter)
{
- auto memrefType = cast(memref.getType());
+ auto origMemrefType = cast(memref.getType());
+ // Rebuild MemRefType without memory layout.
+ auto newMemrefType =
+ MemRefType::get(origMemrefType.getShape(), origMemrefType.getElementType());
+
llvm::SmallVector dynDims;
{
llvm::SmallVector dynIndices;
int64_t ndim = 0;
- for (auto dim : memrefType.getShape()) {
+ for (auto dim : newMemrefType.getShape()) {
if (dim < 0) {
Value dynValue = rewriter.create(loc, memref, ndim);
dynDims.push_back(dynValue);
@@ -101,9 +105,11 @@ Value allocCopyMemrefDyn(Location loc, Value memref, PatternRewriter &rewriter)
}
}
- Value newMemRef = rewriter.create(loc, memrefType, dynDims);
+ Value newMemRef = rewriter.create(loc, newMemrefType, dynDims);
+ // Cast memrefType back to maintain memory layout.
+ Value castMemRef = rewriter.create(loc, origMemrefType, newMemRef);
rewriter.create(loc, memref, newMemRef);
- return newMemRef;
+ return castMemRef;
}
/**
diff --git a/mlir/patches/FunctionOpInterface-bufferization.patch b/mlir/patches/FunctionOpInterface-bufferization.patch
new file mode 100644
index 0000000000..60a2e9b93f
--- /dev/null
+++ b/mlir/patches/FunctionOpInterface-bufferization.patch
@@ -0,0 +1,993 @@
+diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+index 2fda091e412a..eb0df1d92d6a 100644
+--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
++++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+@@ -11,6 +11,7 @@
+
+ #include "mlir/IR/Operation.h"
+ #include "mlir/IR/PatternMatch.h"
++#include "mlir/Interfaces/FunctionInterfaces.h"
+ #include "mlir/Support/LLVM.h"
+ #include "llvm/ADT/DenseMapInfoVariant.h"
+ #include "llvm/ADT/SetVector.h"
+@@ -260,9 +261,9 @@ struct BufferizationOptions {
+ using AnalysisStateInitFn = std::function;
+ /// Tensor -> MemRef type converter.
+ /// Parameters: Value, memory space, func op, bufferization options
+- using FunctionArgTypeConverterFn =
+- std::function;
++ using FunctionArgTypeConverterFn = std::function;
+ /// Tensor -> MemRef type converter.
+ /// Parameters: Value, memory space, bufferization options
+ using UnknownTypeConverterFn = std::function equivalentFuncArgs;
++ DenseMap equivalentFuncArgs;
+
+ /// A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
+- DenseMap aliasingReturnVals;
++ DenseMap aliasingReturnVals;
+
+ /// A set of all read BlockArguments of FuncOps.
+- DenseMap readBbArgs;
++ DenseMap readBbArgs;
+
+ /// A set of all written-to BlockArguments of FuncOps.
+- DenseMap writtenBbArgs;
++ DenseMap writtenBbArgs;
+
+ /// Keep track of which FuncOps are fully analyzed or currently being
+ /// analyzed.
+- DenseMap analyzedFuncOps;
++ DenseMap analyzedFuncOps;
+
+ /// This function is called right before analyzing the given FuncOp. It
+ /// initializes the data structures for the FuncOp in this state object.
+- void startFunctionAnalysis(FuncOp funcOp);
++ void startFunctionAnalysis(FunctionOpInterface funcOp);
+ };
+
+ void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
+diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+index d51d63f243ea..c4201698468c 100644
+--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
++++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+@@ -18,6 +18,7 @@
+ #include "mlir/IR/TypeUtilities.h"
+ #include "mlir/IR/Value.h"
+ #include "mlir/Interfaces/ControlFlowInterfaces.h"
++#include "mlir/Interfaces/FunctionInterfaces.h"
+ #include "llvm/ADT/ScopeExit.h"
+ #include "llvm/Support/Debug.h"
+
+@@ -314,7 +315,7 @@ namespace {
+ /// Default function arg type converter: Use a fully dynamic layout map.
+ BaseMemRefType
+ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
+- func::FuncOp funcOp,
++ FunctionOpInterface funcOp,
+ const BufferizationOptions &options) {
+ return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
+ }
+@@ -361,7 +362,7 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
+ void BufferizationOptions::setFunctionBoundaryTypeConversion(
+ LayoutMapOption layoutMapOption) {
+ functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
+- func::FuncOp funcOp,
++ FunctionOpInterface funcOp,
+ const BufferizationOptions &options) {
+ if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
+ return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
+diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+index 9fbe574ec392..9749a71f3514 100644
+--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
++++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+@@ -22,7 +22,7 @@ namespace mlir {
+ namespace bufferization {
+ namespace func_ext {
+
+-void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
++void FuncAnalysisState::startFunctionAnalysis(FunctionOpInterface funcOp) {
+ analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
+ auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
+ auto createdAliasingResults =
+diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+index 0a4072605c26..a0e5c7fff769 100644
+--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
++++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+@@ -75,7 +75,7 @@ using namespace mlir::bufferization;
+ using namespace mlir::bufferization::func_ext;
+
+ /// A mapping of FuncOps to their callers.
+-using FuncCallerMap = DenseMap>;
++using FuncCallerMap = DenseMap>;
+
+ /// Get or create FuncAnalysisState.
+ static FuncAnalysisState &
+@@ -88,10 +88,11 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
+
+ /// Return the unique ReturnOp that terminates `funcOp`.
+ /// Return nullptr if there is no such unique ReturnOp.
+-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
+- func::ReturnOp returnOp;
+- for (Block &b : funcOp.getBody()) {
+- if (auto candidateOp = dyn_cast(b.getTerminator())) {
++static Operation *getAssumedUniqueReturnOp(FunctionOpInterface funcOp) {
++ Operation *returnOp = nullptr;
++ for (Block &b : funcOp.getFunctionBody()) {
++ auto candidateOp = b.getTerminator();
++ if (candidateOp && candidateOp->hasTrait()) {
+ if (returnOp)
+ return nullptr;
+ returnOp = candidateOp;
+@@ -126,16 +127,16 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
+ /// Store function BlockArguments that are equivalent to/aliasing a returned
+ /// value in FuncAnalysisState.
+ static LogicalResult
+-aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
++aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp,
++ OneShotAnalysisState &state,
+ FuncAnalysisState &funcState) {
+- if (funcOp.getBody().empty()) {
++ if (funcOp.getFunctionBody().empty()) {
+ // No function body available. Conservatively assume that every tensor
+ // return value may alias with any tensor bbArg.
+- FunctionType type = funcOp.getFunctionType();
+- for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
++ for (const auto &inputIt : llvm::enumerate(funcOp.getArgumentTypes())) {
+ if (!isa(inputIt.value()))
+ continue;
+- for (const auto &resultIt : llvm::enumerate(type.getResults())) {
++ for (const auto &resultIt : llvm::enumerate(funcOp.getResultTypes())) {
+ if (!isa(resultIt.value()))
+ continue;
+ int64_t returnIdx = resultIt.index();
+@@ -147,7 +148,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
+ }
+
+ // Support only single return-terminated block in the function.
+- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
++ Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
+ assert(returnOp && "expected func with single return op");
+
+ for (OpOperand &returnVal : returnOp->getOpOperands())
+@@ -168,8 +169,8 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
+ return success();
+ }
+
+-static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
+- bool isWritten) {
++static void annotateFuncArgAccess(FunctionOpInterface funcOp, int64_t idx,
++ bool isRead, bool isWritten) {
+ OpBuilder b(funcOp.getContext());
+ Attribute accessType;
+ if (isRead && isWritten) {
+@@ -189,12 +190,12 @@ static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
+ /// function with unknown ops, we conservatively assume that such ops bufferize
+ /// to a read + write.
+ static LogicalResult
+-funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
++funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp,
++ OneShotAnalysisState &state,
+ FuncAnalysisState &funcState) {
+- for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
+- ++idx) {
++ for (int64_t idx = 0, e = funcOp.getNumArguments(); idx < e; ++idx) {
+ // Skip non-tensor arguments.
+- if (!isa(funcOp.getFunctionType().getInput(idx)))
++ if (!isa(funcOp.getArgumentTypes()[idx]))
+ continue;
+ bool isRead;
+ bool isWritten;
+@@ -204,7 +205,7 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
+ StringRef str = accessAttr.getValue();
+ isRead = str == "read" || str == "read-write";
+ isWritten = str == "write" || str == "read-write";
+- } else if (funcOp.getBody().empty()) {
++ } else if (funcOp.getFunctionBody().empty()) {
+ // If the function has no body, conservatively assume that all args are
+ // read + written.
+ isRead = true;
+@@ -230,20 +231,19 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
+
+ /// Remove bufferization attributes on FuncOp arguments.
+ static void removeBufferizationAttributes(BlockArgument bbArg) {
+- auto funcOp = cast(bbArg.getOwner()->getParentOp());
++ auto funcOp = cast(bbArg.getOwner()->getParentOp());
+ funcOp.removeArgAttr(bbArg.getArgNumber(),
+ BufferizationDialect::kBufferLayoutAttrName);
+ funcOp.removeArgAttr(bbArg.getArgNumber(),
+ BufferizationDialect::kWritableAttrName);
+ }
+
+-/// Return the func::FuncOp called by `callOp`.
+-static func::FuncOp getCalledFunction(func::CallOp callOp) {
++static FunctionOpInterface getCalledFunction(CallOpInterface callOp) {
+ SymbolRefAttr sym =
+ llvm::dyn_cast_if_present(callOp.getCallableForCallee());
+ if (!sym)
+ return nullptr;
+- return dyn_cast_or_null(
++ return dyn_cast_or_null(
+ SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+ }
+
+@@ -251,12 +251,13 @@ static func::FuncOp getCalledFunction(func::CallOp callOp) {
+ /// Note: This only adds new equivalence info if the called function was already
+ /// analyzed.
+ // TODO: This does not handle cyclic function call graphs etc.
+-static void equivalenceAnalysis(func::FuncOp funcOp,
++static void equivalenceAnalysis(FunctionOpInterface funcOp,
+ OneShotAnalysisState &state,
+ FuncAnalysisState &funcState) {
+- funcOp->walk([&](func::CallOp callOp) {
+- func::FuncOp calledFunction = getCalledFunction(callOp);
+- assert(calledFunction && "could not retrieved called func::FuncOp");
++ funcOp->walk([&](CallOpInterface callOp) {
++ FunctionOpInterface calledFunction = getCalledFunction(callOp);
++ if (!calledFunction)
++ return WalkResult::skip();
+
+ // No equivalence info available for the called function.
+ if (!funcState.equivalentFuncArgs.count(calledFunction))
+@@ -267,7 +268,7 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
+ int64_t bbargIdx = it.second;
+ if (!state.isInPlace(callOp->getOpOperand(bbargIdx)))
+ continue;
+- Value returnVal = callOp.getResult(returnIdx);
++ Value returnVal = callOp->getResult(returnIdx);
+ Value argVal = callOp->getOperand(bbargIdx);
+ state.unionEquivalenceClasses(returnVal, argVal);
+ }
+@@ -277,11 +278,9 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
+ }
+
+ /// Return "true" if the given function signature has tensor semantics.
+-static bool hasTensorSignature(func::FuncOp funcOp) {
+- return llvm::any_of(funcOp.getFunctionType().getInputs(),
+- llvm::IsaPred) ||
+- llvm::any_of(funcOp.getFunctionType().getResults(),
+- llvm::IsaPred);
++static bool hasTensorSignature(FunctionOpInterface funcOp) {
++ return llvm::any_of(funcOp.getArgumentTypes(), llvm::IsaPred) ||
++ llvm::any_of(funcOp.getResultTypes(), llvm::IsaPred);
+ }
+
+ /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
+@@ -291,16 +290,16 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
+ /// retrieve the called FuncOp from any func::CallOp.
+ static LogicalResult
+ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
+- SmallVectorImpl &orderedFuncOps,
++ SmallVectorImpl &orderedFuncOps,
+ FuncCallerMap &callerMap) {
+ // For each FuncOp, the set of functions called by it (i.e. the union of
+ // symbols of all nested func::CallOp).
+- DenseMap> calledBy;
++ DenseMap> calledBy;
+ // For each FuncOp, the number of func::CallOp it contains.
+- DenseMap numberCallOpsContainedInFuncOp;
+- WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
+- if (!funcOp.getBody().empty()) {
+- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
++ DenseMap numberCallOpsContainedInFuncOp;
++ WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult {
++ if (!funcOp.getFunctionBody().empty()) {
++ Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
+ if (!returnOp)
+ return funcOp->emitError()
+ << "cannot bufferize a FuncOp with tensors and "
+@@ -309,9 +308,10 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
+
+ // Collect function calls and populate the caller map.
+ numberCallOpsContainedInFuncOp[funcOp] = 0;
+- return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
+- func::FuncOp calledFunction = getCalledFunction(callOp);
+- assert(calledFunction && "could not retrieved called func::FuncOp");
++ return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
++ FunctionOpInterface calledFunction = getCalledFunction(callOp);
++ if (!calledFunction)
++ return WalkResult::skip();
+ // If the called function does not have any tensors in its signature, then
+ // it is not necessary to bufferize the callee before the caller.
+ if (!hasTensorSignature(calledFunction))
+@@ -349,11 +349,11 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
+ /// most generic layout map as function return types. After bufferizing the
+ /// entire function body, a more concise memref type can potentially be used for
+ /// the return type of the function.
+-static void foldMemRefCasts(func::FuncOp funcOp) {
+- if (funcOp.getBody().empty())
++static void foldMemRefCasts(FunctionOpInterface funcOp) {
++ if (funcOp.getFunctionBody().empty())
+ return;
+
+- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
++ Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
+ SmallVector resultTypes;
+
+ for (OpOperand &operand : returnOp->getOpOperands()) {
+@@ -365,8 +365,8 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
+ }
+ }
+
+- auto newFuncType = FunctionType::get(
+- funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
++ auto newFuncType = FunctionType::get(funcOp.getContext(),
++ funcOp.getArgumentTypes(), resultTypes);
+ funcOp.setType(newFuncType);
+ }
+
+@@ -379,7 +379,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
+ FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
+
+ // A list of functions in the order in which they are analyzed + bufferized.
+- SmallVector orderedFuncOps;
++ SmallVector