Skip to content

Commit e06eba4

Browse files
erick-xanadutzunghanjuangpaul0403dime10
authored
Migrate quantum dialect to new one-shot bufferization (#1686)
**Context:** This work is based on #1027 As part of the mlir update, the bufferization of the custom catalyst dialects need to migrate to the new one-shot bufferization interface, as opposed to the old pattern-rewrite style bufferization passes. See more context in #1027. As an example, here is how the new bufferization interface is used for mlir's core `arith` dialect: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h https://github.com/llvm/llvm-project/blob/7ee0097b486b31be8b9a1750b2cd47580efd9587/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp#L54 **Description of the Change:** On the current mlir commit we track, both the old and new bufferization styles exist. The old pattern rewrite style is deprecated. To ease the workflow organization, we migrate one dialect at a time. This PR migrates the `Quantum` dialect's bufferization to the new one-shot interface. Note that the new one-shot interface is supposed to be called only once in the pipeline. However, because we haven't migrated all the dialects yet, we simply swap out the old `--quantum--bufferize` pass in-place, with the new one-shot bufferization pass running on the quantum dialect only. **Benefits:** Align with mlir practices; one step closer to updating mlir. [sc-71487] --------- Co-authored-by: Tzung-Han Juang <[email protected]> Co-authored-by: paul0403 <[email protected]> Co-authored-by: Paul <[email protected]> Co-authored-by: David Ittah <[email protected]>
1 parent 428cc7e commit e06eba4

File tree

17 files changed

+567
-344
lines changed

17 files changed

+567
-344
lines changed

doc/releases/changelog-dev.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,12 @@
110110
* Improved the definition of `YieldOp` in the quantum dialect by removing `AnyTypeOf`
111111
[(#1696)](https://github.com/PennyLaneAI/catalyst/pull/1696)
112112

113+
* The bufferization of custom catalyst dialects has been migrated to the new one-shot
114+
bufferization interface in mlir.
115+
The new mlir bufferization interface is required by jax 0.4.29 or higher.
116+
[(#1027)](https://github.com/PennyLaneAI/catalyst/pull/1027)
117+
[(#1686)](https://github.com/PennyLaneAI/catalyst/pull/1686)
118+
113119
<h3>Documentation 📝</h3>
114120

115121
<h3>Contributors ✍️</h3>
@@ -119,6 +125,7 @@ This release contains contributions from (in alphabetical order):
119125
Joey Carter,
120126
Sengthai Heng,
121127
David Ittah,
128+
Tzung-Han Juang,
122129
Christina Lee,
123130
Erick Ochoa Lopez,
124131
Paul Haochen Wang.

frontend/catalyst/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def get_bufferization_stage(_options: CompileOptions) -> List[str]:
230230
"catalyst-bufferize", # Must be run before -- func.func(linalg-bufferize)
231231
"func.func(linalg-bufferize)",
232232
"func.func(tensor-bufferize)",
233-
"quantum-bufferize",
233+
"one-shot-bufferize{dialect-filter=quantum}",
234234
"func-bufferize",
235235
"func.func(finalizing-bufferize)",
236236
"canonicalize", # Remove dead memrefToTensorOp's
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright 2024-2025 Xanadu Quantum Technologies Inc.
2+
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
using namespace mlir;
18+
19+
namespace catalyst {
20+
21+
namespace quantum {
22+
23+
void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry &registry);
24+
25+
}
26+
27+
} // namespace catalyst

mlir/include/Quantum/Transforms/Passes.td

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,6 @@
1717

1818
include "mlir/Pass/PassBase.td"
1919

20-
def QuantumBufferizationPass : Pass<"quantum-bufferize"> {
21-
let summary = "Bufferize tensors in quantum operations.";
22-
23-
let dependentDialects = [
24-
"bufferization::BufferizationDialect",
25-
"memref::MemRefDialect"
26-
];
27-
28-
let constructor = "catalyst::createQuantumBufferizationPass()";
29-
}
30-
3120
def QuantumConversionPass : Pass<"convert-quantum-to-llvm"> {
3221
let summary = "Perform a dialect conversion from Quantum to LLVM (QIR).";
3322

mlir/lib/Catalyst/Transforms/AsyncUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ std::optional<LLVM::LLVMFuncOp> AsyncUtils::getCalleeSafe(LLVM::CallOp callOp)
215215
bool AsyncUtils::isFunctionNamed(LLVM::LLVMFuncOp funcOp, llvm::StringRef expectedName)
216216
{
217217
llvm::StringRef observedName = funcOp.getSymName();
218-
return observedName.equals(expectedName);
218+
return observedName == expectedName;
219219
}
220220

221221
bool AsyncUtils::isMlirAsyncRuntimeCreateValue(LLVM::LLVMFuncOp funcOp)

mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ void catalyst::registerAllCatalystPasses()
5050
mlir::registerPass(catalyst::createMemrefToLLVMWithTBAAPass);
5151
mlir::registerPass(catalyst::createMitigationLoweringPass);
5252
mlir::registerPass(catalyst::createQnodeToAsyncLoweringPass);
53-
mlir::registerPass(catalyst::createQuantumBufferizationPass);
5453
mlir::registerPass(catalyst::createQuantumConversionPass);
5554
mlir::registerPass(catalyst::createRegisterInactiveCallbackPass);
5655
mlir::registerPass(catalyst::createRemoveChainedSelfInversePass);

mlir/lib/Driver/CompilerDriver.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
#include "Mitigation/Transforms/Passes.h"
7171
#include "QEC/IR/QECDialect.h"
7272
#include "Quantum/IR/QuantumDialect.h"
73+
#include "Quantum/Transforms/BufferizableOpInterfaceImpl.h"
7374
#include "Quantum/Transforms/Passes.h"
7475

7576
#include "Enzyme.h"
@@ -962,6 +963,9 @@ int QuantumDriverMainFromCL(int argc, char **argv)
962963
registerAllCatalystDialects(registry);
963964
registerLLVMTranslations(registry);
964965

966+
// Register bufferization interfaces
967+
catalyst::quantum::registerBufferizableOpInterfaceExternalModels(registry);
968+
965969
// Register and parse command line options.
966970
std::string inputFilename, outputFilename;
967971
std::string helpStr = "Catalyst Command Line Interface options. \n"

mlir/lib/Driver/Pipelines.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "Catalyst/Transforms/Passes.h"
1717
#include "Gradient/Transforms/Passes.h"
1818
#include "Mitigation/Transforms/Passes.h"
19+
#include "Quantum/IR/QuantumDialect.h"
1920
#include "Quantum/Transforms/Passes.h"
2021
#include "mhlo/transforms/passes.h"
2122
#include "mlir/InitAllDialects.h"
@@ -79,7 +80,9 @@ void createBufferizationPipeline(OpPassManager &pm)
7980
pm.addPass(catalyst::createCatalystBufferizationPass());
8081
pm.addNestedPass<mlir::func::FuncOp>(mlir::createLinalgBufferizePass());
8182
pm.addNestedPass<mlir::func::FuncOp>(mlir::tensor::createTensorBufferizePass());
82-
pm.addPass(catalyst::createQuantumBufferizationPass());
83+
mlir::bufferization::OneShotBufferizationOptions quantum_buffer_options;
84+
quantum_buffer_options.opFilter.allowDialect<catalyst::quantum::QuantumDialect>();
85+
pm.addPass(mlir::bufferization::createOneShotBufferizePass(quantum_buffer_options));
8386
pm.addPass(mlir::func::createFuncBufferizePass());
8487
pm.addNestedPass<mlir::func::FuncOp>(mlir::bufferization::createFinalizingBufferizePass());
8588
pm.addPass(mlir::createCanonicalizerPass());

mlir/lib/QEC/Transforms/CommuteCliffordPastPPM.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
// limitations under the License.
1414

1515
#define DEBUG_TYPE "merge_ppr_ppm"
16-
#include "llvm/Support/Casting.h"
17-
#include "llvm/Support/Debug.h"
1816

1917
#include "mlir/Analysis/SliceAnalysis.h"
18+
#include "llvm/Support/Casting.h"
19+
#include "llvm/Support/Debug.h"
20+
//#include "mlir/Analysis/TopologicalSortUtils.h" // enable when updating llvm
2021
#include "mlir/Transforms/TopologicalSortUtils.h"
2122

2223
#include "QEC/IR/QECDialect.h"

mlir/lib/QEC/Transforms/CommuteCliffordTPPR.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#define DEBUG_TYPE "commute_ppr"
1616

1717
#include "llvm/Support/Debug.h"
18-
18+
//#include "mlir/Analysis/TopologicalSortUtils.h" // enable when updating llvm
1919
#include "mlir/Transforms/TopologicalSortUtils.h"
2020

2121
#include "QEC/IR/QECDialect.h"

0 commit comments

Comments
 (0)