diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 2e16ed6e2b..1e07e62cca 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -110,6 +110,12 @@
* Improved the definition of `YieldOp` in the quantum dialect by removing `AnyTypeOf`
[(#1696)](https://github.com/PennyLaneAI/catalyst/pull/1696)
+* The bufferization of custom catalyst dialects has been migrated to the new one-shot
+ bufferization interface in mlir.
+ The new mlir bufferization interface is required by jax 0.4.29 or higher.
+ [(#1027)](https://github.com/PennyLaneAI/catalyst/pull/1027)
+ [(#1686)](https://github.com/PennyLaneAI/catalyst/pull/1686)
+
Documentation 📝
Contributors ✍️
@@ -119,6 +125,7 @@ This release contains contributions from (in alphabetical order):
Joey Carter,
Sengthai Heng,
David Ittah,
+Tzung-Han Juang,
Christina Lee,
Erick Ochoa Lopez,
Paul Haochen Wang.
diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py
index eb3a949f1b..49ece83c56 100644
--- a/frontend/catalyst/pipelines.py
+++ b/frontend/catalyst/pipelines.py
@@ -230,7 +230,7 @@ def get_bufferization_stage(_options: CompileOptions) -> List[str]:
"catalyst-bufferize", # Must be run before -- func.func(linalg-bufferize)
"func.func(linalg-bufferize)",
"func.func(tensor-bufferize)",
- "quantum-bufferize",
+ "one-shot-bufferize{dialect-filter=quantum}",
"func-bufferize",
"func.func(finalizing-bufferize)",
"canonicalize", # Remove dead memrefToTensorOp's
diff --git a/mlir/include/Quantum/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/Quantum/Transforms/BufferizableOpInterfaceImpl.h
new file mode 100644
index 0000000000..ba6e894dac
--- /dev/null
+++ b/mlir/include/Quantum/Transforms/BufferizableOpInterfaceImpl.h
@@ -0,0 +1,27 @@
+// Copyright 2024-2025 Xanadu Quantum Technologies Inc.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+// http://www.apache.org/licenses/LICENSE-2.0
+
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+using namespace mlir;
+
+namespace catalyst {
+
+namespace quantum {
+
+void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry ®istry);
+
+}
+
+} // namespace catalyst
diff --git a/mlir/include/Quantum/Transforms/Passes.td b/mlir/include/Quantum/Transforms/Passes.td
index 496c36906f..c120f5a501 100644
--- a/mlir/include/Quantum/Transforms/Passes.td
+++ b/mlir/include/Quantum/Transforms/Passes.td
@@ -17,17 +17,6 @@
include "mlir/Pass/PassBase.td"
-def QuantumBufferizationPass : Pass<"quantum-bufferize"> {
- let summary = "Bufferize tensors in quantum operations.";
-
- let dependentDialects = [
- "bufferization::BufferizationDialect",
- "memref::MemRefDialect"
- ];
-
- let constructor = "catalyst::createQuantumBufferizationPass()";
-}
-
def QuantumConversionPass : Pass<"convert-quantum-to-llvm"> {
let summary = "Perform a dialect conversion from Quantum to LLVM (QIR).";
diff --git a/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp b/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp
index a0c27129cb..a295156330 100644
--- a/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp
+++ b/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp
@@ -215,7 +215,7 @@ std::optional AsyncUtils::getCalleeSafe(LLVM::CallOp callOp)
bool AsyncUtils::isFunctionNamed(LLVM::LLVMFuncOp funcOp, llvm::StringRef expectedName)
{
llvm::StringRef observedName = funcOp.getSymName();
- return observedName.equals(expectedName);
+ return observedName == expectedName;
}
bool AsyncUtils::isMlirAsyncRuntimeCreateValue(LLVM::LLVMFuncOp funcOp)
diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
index 459e887020..cc1379a6d1 100644
--- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
+++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
@@ -50,7 +50,6 @@ void catalyst::registerAllCatalystPasses()
mlir::registerPass(catalyst::createMemrefToLLVMWithTBAAPass);
mlir::registerPass(catalyst::createMitigationLoweringPass);
mlir::registerPass(catalyst::createQnodeToAsyncLoweringPass);
- mlir::registerPass(catalyst::createQuantumBufferizationPass);
mlir::registerPass(catalyst::createQuantumConversionPass);
mlir::registerPass(catalyst::createRegisterInactiveCallbackPass);
mlir::registerPass(catalyst::createRemoveChainedSelfInversePass);
diff --git a/mlir/lib/Driver/CompilerDriver.cpp b/mlir/lib/Driver/CompilerDriver.cpp
index c2621738b5..2875ad79a7 100644
--- a/mlir/lib/Driver/CompilerDriver.cpp
+++ b/mlir/lib/Driver/CompilerDriver.cpp
@@ -70,6 +70,7 @@
#include "Mitigation/Transforms/Passes.h"
#include "QEC/IR/QECDialect.h"
#include "Quantum/IR/QuantumDialect.h"
+#include "Quantum/Transforms/BufferizableOpInterfaceImpl.h"
#include "Quantum/Transforms/Passes.h"
#include "Enzyme.h"
@@ -962,6 +963,9 @@ int QuantumDriverMainFromCL(int argc, char **argv)
registerAllCatalystDialects(registry);
registerLLVMTranslations(registry);
+ // Register bufferization interfaces
+ catalyst::quantum::registerBufferizableOpInterfaceExternalModels(registry);
+
// Register and parse command line options.
std::string inputFilename, outputFilename;
std::string helpStr = "Catalyst Command Line Interface options. \n"
diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp
index d1e2dc3eb9..46a3079c69 100644
--- a/mlir/lib/Driver/Pipelines.cpp
+++ b/mlir/lib/Driver/Pipelines.cpp
@@ -16,6 +16,7 @@
#include "Catalyst/Transforms/Passes.h"
#include "Gradient/Transforms/Passes.h"
#include "Mitigation/Transforms/Passes.h"
+#include "Quantum/IR/QuantumDialect.h"
#include "Quantum/Transforms/Passes.h"
#include "mhlo/transforms/passes.h"
#include "mlir/InitAllDialects.h"
@@ -79,7 +80,9 @@ void createBufferizationPipeline(OpPassManager &pm)
pm.addPass(catalyst::createCatalystBufferizationPass());
pm.addNestedPass(mlir::createLinalgBufferizePass());
pm.addNestedPass(mlir::tensor::createTensorBufferizePass());
- pm.addPass(catalyst::createQuantumBufferizationPass());
+ mlir::bufferization::OneShotBufferizationOptions quantum_buffer_options;
+ quantum_buffer_options.opFilter.allowDialect();
+ pm.addPass(mlir::bufferization::createOneShotBufferizePass(quantum_buffer_options));
pm.addPass(mlir::func::createFuncBufferizePass());
pm.addNestedPass(mlir::bufferization::createFinalizingBufferizePass());
pm.addPass(mlir::createCanonicalizerPass());
diff --git a/mlir/lib/QEC/Transforms/CommuteCliffordPastPPM.cpp b/mlir/lib/QEC/Transforms/CommuteCliffordPastPPM.cpp
index 096e485388..03bd8bff81 100644
--- a/mlir/lib/QEC/Transforms/CommuteCliffordPastPPM.cpp
+++ b/mlir/lib/QEC/Transforms/CommuteCliffordPastPPM.cpp
@@ -13,10 +13,11 @@
// limitations under the License.
#define DEBUG_TYPE "merge_ppr_ppm"
-#include "llvm/Support/Casting.h"
-#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+//#include "mlir/Analysis/TopologicalSortUtils.h" // enable when updating llvm
#include "mlir/Transforms/TopologicalSortUtils.h"
#include "QEC/IR/QECDialect.h"
diff --git a/mlir/lib/QEC/Transforms/CommuteCliffordTPPR.cpp b/mlir/lib/QEC/Transforms/CommuteCliffordTPPR.cpp
index af49ffe81e..ebfa5aac84 100644
--- a/mlir/lib/QEC/Transforms/CommuteCliffordTPPR.cpp
+++ b/mlir/lib/QEC/Transforms/CommuteCliffordTPPR.cpp
@@ -15,7 +15,7 @@
#define DEBUG_TYPE "commute_ppr"
#include "llvm/Support/Debug.h"
-
+//#include "mlir/Analysis/TopologicalSortUtils.h" // enable when updating llvm
#include "mlir/Transforms/TopologicalSortUtils.h"
#include "QEC/IR/QECDialect.h"
diff --git a/mlir/lib/Quantum/IR/QuantumDialect.cpp b/mlir/lib/Quantum/IR/QuantumDialect.cpp
index 385f4e0ae5..7049f58e63 100644
--- a/mlir/lib/Quantum/IR/QuantumDialect.cpp
+++ b/mlir/lib/Quantum/IR/QuantumDialect.cpp
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/IR/DialectImplementation.h" // needed for generated type parser
#include "llvm/ADT/TypeSwitch.h" // needed for generated type parser
@@ -43,6 +44,10 @@ void QuantumDialect::initialize()
#define GET_OP_LIST
#include "Quantum/IR/QuantumOps.cpp.inc"
>();
+
+ declarePromisedInterfaces();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp
new file mode 100644
index 0000000000..1460236cca
--- /dev/null
+++ b/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -0,0 +1,479 @@
+// Copyright 2024-2025 Xanadu Quantum Technologies Inc.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+// http://www.apache.org/licenses/LICENSE-2.0
+
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+
+#include "Quantum/IR/QuantumOps.h"
+#include "Quantum/Transforms/BufferizableOpInterfaceImpl.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+using namespace catalyst::quantum;
+
+/**
+ * Implementation of the BufferizableOpInterface for use with one-shot bufferization.
+ * For more information on the interface, refer to the documentation below:
+ * https://mlir.llvm.org/docs/Bufferization/#extending-one-shot-bufferize
+ * https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td#L14
+ */
+
+namespace {
+
+// Bufferization of quantum.unitary.
+// Convert Matrix into memref.
+struct QubitUnitaryOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto qubitUnitaryOp = cast(op);
+ Location loc = op->getLoc();
+ auto tensorType = cast(qubitUnitaryOp.getMatrix().getType());
+ MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ auto toMemrefOp =
+ rewriter.create(loc, memrefType, qubitUnitaryOp.getMatrix());
+ auto memref = toMemrefOp.getResult();
+ bufferization::replaceOpWithNewBufferizedOp(
+ rewriter, op, qubitUnitaryOp.getOutQubits().getTypes(),
+ qubitUnitaryOp.getOutCtrlQubits().getTypes(), memref, qubitUnitaryOp.getInQubits(),
+ qubitUnitaryOp.getAdjointAttr(), qubitUnitaryOp.getInCtrlQubits(),
+ qubitUnitaryOp.getInCtrlValues());
+ return success();
+ }
+};
+
+// Bufferization of quantum.hermitian.
+// Convert Matrix into memref.
+struct HermitianOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto hermitianOp = cast(op);
+ Location loc = op->getLoc();
+ auto tensorType = cast(hermitianOp.getMatrix().getType());
+ MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ auto toMemrefOp =
+ rewriter.create(loc, memrefType, hermitianOp.getMatrix());
+ auto memref = toMemrefOp.getResult();
+ auto newHermitianOp = rewriter.create(loc, hermitianOp.getType(), memref,
+ hermitianOp.getQubits());
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, newHermitianOp.getObs());
+
+ return success();
+ }
+};
+
+// Bufferization of quantum.hamiltonian.
+// Convert coefficient tensor into memref.
+struct HamiltonianOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto hamiltonianOp = cast(op);
+ Location loc = op->getLoc();
+ auto tensorType = cast(hamiltonianOp.getCoeffs().getType());
+ MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ auto toMemrefOp =
+ rewriter.create(loc, memrefType, hamiltonianOp.getCoeffs());
+ auto memref = toMemrefOp.getResult();
+ auto newHamiltonianOp = rewriter.create(loc, hamiltonianOp.getType(), memref,
+ hamiltonianOp.getTerms());
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, newHamiltonianOp.getObs());
+
+ return success();
+ }
+};
+
+// Bufferization of quantum.sample.
+// Result tensor of quantum.sample is bufferized with a corresponding memref.alloc.
+// Users of the result tensor are updated to use the new memref.
+struct SampleOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto sampleOp = cast(op);
+ Location loc = op->getLoc();
+ auto tensorType = cast(sampleOp.getSamples().getType());
+ MemRefType resultType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+ SmallVector allocSizes;
+ for (Value dynShapeDimension : sampleOp.getDynamicShape()) {
+ auto indexCastOp =
+ rewriter.create(loc, rewriter.getIndexType(), dynShapeDimension);
+ allocSizes.push_back(indexCastOp);
+ }
+
+ Value allocVal = rewriter.create(loc, resultType, allocSizes);
+ auto allocedSampleOp = rewriter.create(
+ loc, TypeRange{}, ValueRange{sampleOp.getObs(), allocVal}, op->getAttrs());
+ allocedSampleOp->setAttr("operandSegmentSizes", rewriter.getDenseI32ArrayAttr({1, 0, 1}));
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, allocVal);
+ return success();
+ }
+};
+
+// Bufferization of quantum.counts.
+// Result tensors of quantum.counts are bufferized with corresponding memref.alloc ops.
+// Users of the result tensors are updated to use the new memrefs.
+struct CountsOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto countsOp = cast(op);
+ Location loc = op->getLoc();
+
+ SmallVector buffers;
+ for (size_t i : {0, 1}) {
+ auto tensorType = cast(countsOp.getType(i));
+ MemRefType resultType =
+ MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ auto shape = cast(tensorType).getShape();
+
+ Value allocVal;
+ if (shape[0] == ShapedType::kDynamic) {
+ auto indexCastOp = rewriter.create(loc, rewriter.getIndexType(),
+ countsOp.getDynamicShape());
+ allocVal =
+ rewriter.create(loc, resultType, ValueRange{indexCastOp});
+ }
+ else {
+ allocVal = rewriter.create(loc, resultType);
+ }
+ buffers.push_back(allocVal);
+ }
+
+ rewriter.create(loc, nullptr, nullptr, countsOp.getObs(), nullptr, buffers[0],
+ buffers[1]);
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, buffers);
+
+ return success();
+ }
+};
+
+// Bufferization of quantum.probs.
+// Result tensor of quantum.probs is bufferized with a corresponding memref.alloc.
+// Users of the result tensor are updated to use the new memref.
+struct ProbsOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto probsOp = cast(op);
+ Location loc = op->getLoc();
+ auto tensorType = cast(probsOp.getProbabilities().getType());
+ MemRefType resultType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+ Value buffer;
+ auto shape = cast(tensorType).getShape();
+ if (shape[0] == ShapedType::kDynamic) {
+ auto indexCastOp = rewriter.create(loc, rewriter.getIndexType(),
+ probsOp.getDynamicShape());
+ buffer = rewriter.create(loc, resultType, ValueRange{indexCastOp});
+ }
+ else {
+ buffer = rewriter.create(loc, resultType);
+ }
+
+ auto allocedProbsOp =
+ rewriter.create(loc, TypeRange{}, ValueRange{probsOp.getObs(), buffer});
+ allocedProbsOp->setAttr("operandSegmentSizes", rewriter.getDenseI32ArrayAttr({1, 0, 1}));
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, buffer);
+ return success();
+ }
+};
+
+// Bufferization of quantum.state.
+// Result tensor of quantum.state is bufferized with a corresponding memref.alloc.
+// Users of the result tensor are updated to use the new memref.
+struct StateOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto stateOp = cast(op);
+ Location loc = op->getLoc();
+ auto tensorType = cast(stateOp.getState().getType());
+ MemRefType resultType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+ Value buffer;
+ auto shape = cast(tensorType).getShape();
+ if (shape[0] == ShapedType::kDynamic) {
+ auto indexCastOp = rewriter.create(loc, rewriter.getIndexType(),
+ stateOp.getDynamicShape());
+ buffer = rewriter.create(loc, resultType, ValueRange{indexCastOp});
+ }
+ else {
+ buffer = rewriter.create(loc, resultType);
+ }
+
+ auto allocedStateOp =
+ rewriter.create(loc, TypeRange{}, ValueRange{stateOp.getObs(), buffer});
+ allocedStateOp->setAttr("operandSegmentSizes", rewriter.getDenseI32ArrayAttr({1, 0, 1}));
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, buffer);
+ return success();
+ }
+};
+
+// Bufferization of quantum.set_state.
+// Convert InState into memref.
+struct SetStateOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto setStateOp = cast(op);
+ Location loc = op->getLoc();
+ auto tensorType = cast(setStateOp.getInState().getType());
+ MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+ auto toMemrefOp =
+ rewriter.create(loc, memrefType, setStateOp.getInState());
+ auto memref = toMemrefOp.getResult();
+ auto newSetStateOp = rewriter.create(loc, setStateOp.getOutQubits().getTypes(),
+ memref, setStateOp.getInQubits());
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, newSetStateOp.getOutQubits());
+ return success();
+ }
+};
+
+// Bufferization of quantum.set_basis_state.
+// Convert BasisState into memref.
+struct SetBasisStateOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto setBasisStateOp = cast(op);
+ Location loc = op->getLoc();
+ auto tensorType = cast(setBasisStateOp.getBasisState().getType());
+ MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+ auto toMemrefOp = rewriter.create(
+ loc, memrefType, setBasisStateOp.getBasisState());
+ auto memref = toMemrefOp.getResult();
+ auto newSetStateOp = rewriter.create(
+ loc, setBasisStateOp.getOutQubits().getTypes(), memref, setBasisStateOp.getInQubits());
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, newSetStateOp.getOutQubits());
+ return success();
+ }
+};
+
+} // namespace
+
+void catalyst::quantum::registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
+{
+ registry.addExtension(+[](MLIRContext *ctx, catalyst::quantum::QuantumDialect *dialect) {
+ QubitUnitaryOp::attachInterface(*ctx);
+ HermitianOp::attachInterface(*ctx);
+ HamiltonianOp::attachInterface(*ctx);
+ SampleOp::attachInterface(*ctx);
+ CountsOp::attachInterface(*ctx);
+ ProbsOp::attachInterface(*ctx);
+ StateOp::attachInterface(*ctx);
+ SetStateOp::attachInterface(*ctx);
+ SetBasisStateOp::attachInterface(*ctx);
+ });
+}
diff --git a/mlir/lib/Quantum/Transforms/BufferizationPatterns.cpp b/mlir/lib/Quantum/Transforms/BufferizationPatterns.cpp
deleted file mode 100644
index b48493ef1e..0000000000
--- a/mlir/lib/Quantum/Transforms/BufferizationPatterns.cpp
+++ /dev/null
@@ -1,259 +0,0 @@
-// Copyright 2022-2023 Xanadu Quantum Technologies Inc.
-
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-
-// http://www.apache.org/licenses/LICENSE-2.0
-
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Index/IR/IndexDialect.h"
-#include "mlir/Dialect/Index/IR/IndexOps.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-#include "Quantum/IR/QuantumOps.h"
-#include "Quantum/Transforms/Patterns.h"
-
-using namespace mlir;
-using namespace catalyst::quantum;
-
-namespace {
-
-struct BufferizeQubitUnitaryOp : public OpConversionPattern {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(QubitUnitaryOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override
- {
- rewriter.replaceOpWithNewOp(
- op, op.getOutQubits().getTypes(), op.getOutCtrlQubits().getTypes(), adaptor.getMatrix(),
- adaptor.getInQubits(), adaptor.getAdjointAttr(), adaptor.getInCtrlQubits(),
- adaptor.getInCtrlValues());
- return success();
- }
-};
-
-struct BufferizeHermitianOp : public OpConversionPattern {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(HermitianOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override
- {
- rewriter.replaceOpWithNewOp(op, op.getType(), adaptor.getMatrix(),
- adaptor.getQubits());
- return success();
- }
-};
-
-struct BufferizeHamiltonianOp : public OpConversionPattern {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(HamiltonianOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override
- {
- rewriter.replaceOpWithNewOp(op, op.getType(), adaptor.getCoeffs(),
- adaptor.getTerms());
- return success();
- }
-};
-
-struct BufferizeSampleOp : public OpConversionPattern {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(SampleOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override
- {
- Type tensorType = op.getType(0);
- MemRefType resultType = cast(getTypeConverter()->convertType(tensorType));
- Location loc = op.getLoc();
-
- SmallVector allocSizes;
- for (Value dynShapeDimension : op.getDynamicShape()) {
- auto indexCastOp =
- rewriter.create(loc, rewriter.getIndexType(), dynShapeDimension);
- allocSizes.push_back(indexCastOp);
- }
-
- Value allocVal = rewriter.replaceOpWithNewOp(op, resultType, allocSizes);
- auto allocedSampleOp = rewriter.create(
- loc, TypeRange{}, ValueRange{adaptor.getObs(), allocVal}, op->getAttrs());
- allocedSampleOp->setAttr("operandSegmentSizes", rewriter.getDenseI32ArrayAttr({1, 0, 1}));
- return success();
- }
-};
-
-struct BufferizeStateOp : public OpConversionPattern {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(StateOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override
- {
- Type tensorType = op.getType(0);
- MemRefType resultType = cast(getTypeConverter()->convertType(tensorType));
- Location loc = op.getLoc();
-
- Value buffer;
- auto shape = cast(tensorType).getShape();
- if (shape[0] == ShapedType::kDynamic) {
- auto indexCastOp =
- rewriter.create(loc, rewriter.getIndexType(), op.getDynamicShape());
- buffer = rewriter.replaceOpWithNewOp(op, resultType,
- ValueRange{indexCastOp});
- }
- else {
- buffer = rewriter.replaceOpWithNewOp(op, resultType);
- }
-
- auto allocedStateOp =
- rewriter.create(loc, TypeRange{}, ValueRange{adaptor.getObs(), buffer});
- allocedStateOp->setAttr("operandSegmentSizes", rewriter.getDenseI32ArrayAttr({1, 0, 1}));
- return success();
- }
-};
-
-struct BufferizeProbsOp : public OpConversionPattern {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(ProbsOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override
- {
- Type tensorType = op.getType(0);
- MemRefType resultType = cast(getTypeConverter()->convertType(tensorType));
- Location loc = op.getLoc();
-
- Value buffer;
- auto shape = cast(tensorType).getShape();
- if (shape[0] == ShapedType::kDynamic) {
- auto indexCastOp =
- rewriter.create(loc, rewriter.getIndexType(), op.getDynamicShape());
- buffer = rewriter.replaceOpWithNewOp(op, resultType,
- ValueRange{indexCastOp});
- }
- else {
- buffer = rewriter.replaceOpWithNewOp(op, resultType);
- }
-
- auto allocedProbsOp =
- rewriter.create(loc, TypeRange{}, ValueRange{adaptor.getObs(), buffer});
- allocedProbsOp->setAttr("operandSegmentSizes", rewriter.getDenseI32ArrayAttr({1, 0, 1}));
- return success();
- }
-};
-
-struct BufferizeCountsOp : public OpConversionPattern {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(CountsOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override
- {
- Location loc = op.getLoc();
- SmallVector buffers;
- for (size_t i : {0, 1}) {
- Type tensorType = op.getType(i);
- MemRefType resultType = cast(getTypeConverter()->convertType(tensorType));
- auto shape = cast(tensorType).getShape();
-
- Value allocVal;
- if (shape[0] == ShapedType::kDynamic) {
- auto indexCastOp = rewriter.create(loc, rewriter.getIndexType(),
- op.getDynamicShape());
- allocVal =
- rewriter.create(loc, resultType, ValueRange{indexCastOp});
- }
- else {
- allocVal = rewriter.create(loc, resultType);
- }
- buffers.push_back(allocVal);
- }
- rewriter.replaceOp(op, buffers);
-
- rewriter.create(loc, nullptr, nullptr, adaptor.getObs(), nullptr, buffers[0],
- buffers[1]);
-
- return success();
- }
-};
-
-struct BufferizeSetStateOp : public OpConversionPattern {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(SetStateOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override
- {
- Type tensorType = op.getInState().getType();
- MemRefType memrefType = cast(getTypeConverter()->convertType(tensorType));
- auto toMemrefOp =
- rewriter.create(op->getLoc(), memrefType, op.getInState());
- auto memref = toMemrefOp.getResult();
- rewriter.replaceOpWithNewOp(op, op.getOutQubits().getTypes(), memref,
- adaptor.getInQubits());
- return success();
- }
-};
-
-struct BufferizeSetBasisStateOp : public OpConversionPattern {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(SetBasisStateOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override
- {
- Type tensorType = op.getBasisState().getType();
- MemRefType memrefType = cast(getTypeConverter()->convertType(tensorType));
- auto toMemrefOp = rewriter.create(op->getLoc(), memrefType,
- op.getBasisState());
- auto memref = toMemrefOp.getResult();
- rewriter.replaceOpWithNewOp(op, op.getOutQubits().getTypes(), memref,
- adaptor.getInQubits());
- return success();
- }
-};
-
-} // namespace
-
-namespace catalyst {
-namespace quantum {
-
-void populateBufferizationLegality(TypeConverter &typeConverter, ConversionTarget &target)
-{
- // Default to operations being legal with the exception of the ones below.
- target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
- // Quantum ops which return arrays need to be marked illegal when the type is a tensor.
- target.addDynamicallyLegalOp(
- [&](QubitUnitaryOp op) { return typeConverter.isLegal(op.getMatrix().getType()); });
- target.addDynamicallyLegalOp(
- [&](HermitianOp op) { return typeConverter.isLegal(op.getMatrix().getType()); });
- target.addDynamicallyLegalOp(
- [&](HamiltonianOp op) { return typeConverter.isLegal(op.getCoeffs().getType()); });
- target.addDynamicallyLegalOp([&](SampleOp op) { return op.isBufferized(); });
- target.addDynamicallyLegalOp([&](StateOp op) { return op.isBufferized(); });
- target.addDynamicallyLegalOp([&](ProbsOp op) { return op.isBufferized(); });
- target.addDynamicallyLegalOp([&](CountsOp op) { return op.isBufferized(); });
- target.addDynamicallyLegalOp([&](SetStateOp op) { return op.isBufferized(); });
- target.addDynamicallyLegalOp(
- [&](SetBasisStateOp op) { return op.isBufferized(); });
-}
-
-void populateBufferizationPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)
-{
- patterns.add(typeConverter, patterns.getContext());
- patterns.add(typeConverter, patterns.getContext());
- patterns.add(typeConverter, patterns.getContext());
- patterns.add(typeConverter, patterns.getContext());
- patterns.add(typeConverter, patterns.getContext());
- patterns.add(typeConverter, patterns.getContext());
- patterns.add(typeConverter, patterns.getContext());
- patterns.add(typeConverter, patterns.getContext());
- patterns.add(typeConverter, patterns.getContext());
-}
-
-} // namespace quantum
-} // namespace catalyst
diff --git a/mlir/lib/Quantum/Transforms/CMakeLists.txt b/mlir/lib/Quantum/Transforms/CMakeLists.txt
index d983a23eb9..3a244ac4d6 100644
--- a/mlir/lib/Quantum/Transforms/CMakeLists.txt
+++ b/mlir/lib/Quantum/Transforms/CMakeLists.txt
@@ -1,8 +1,7 @@
set(LIBRARY_NAME quantum-transforms)
file(GLOB SRC
- BufferizationPatterns.cpp
- quantum_bufferize.cpp
+ BufferizableOpInterfaceImpl.cpp
ConversionPatterns.cpp
quantum_to_llvm.cpp
emit_catalyst_pyface.cpp
diff --git a/mlir/lib/Quantum/Transforms/quantum_bufferize.cpp b/mlir/lib/Quantum/Transforms/quantum_bufferize.cpp
deleted file mode 100644
index 901f79da6f..0000000000
--- a/mlir/lib/Quantum/Transforms/quantum_bufferize.cpp
+++ /dev/null
@@ -1,63 +0,0 @@
-// Copyright 2022-2023 Xanadu Quantum Technologies Inc.
-
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-
-// http://www.apache.org/licenses/LICENSE-2.0
-
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-#include "Quantum/IR/QuantumOps.h"
-#include "Quantum/Transforms/Passes.h"
-#include "Quantum/Transforms/Patterns.h"
-
-using namespace mlir;
-using namespace catalyst::quantum;
-
-namespace catalyst {
-namespace quantum {
-
-#define GEN_PASS_DEF_QUANTUMBUFFERIZATIONPASS
-#include "Quantum/Transforms/Passes.h.inc"
-
-struct QuantumBufferizationPass : impl::QuantumBufferizationPassBase {
- using QuantumBufferizationPassBase::QuantumBufferizationPassBase;
-
- void runOnOperation() final
- {
- MLIRContext *context = &getContext();
- bufferization::BufferizeTypeConverter typeConverter;
-
- RewritePatternSet patterns(context);
- populateBufferizationPatterns(typeConverter, patterns);
-
- ConversionTarget target(*context);
- bufferization::populateBufferizeMaterializationLegality(target);
- populateBufferizationLegality(typeConverter, target);
-
- if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) {
- signalPassFailure();
- }
- }
-};
-
-} // namespace quantum
-
-std::unique_ptr createQuantumBufferizationPass()
-{
- return std::make_unique();
-}
-
-} // namespace catalyst
diff --git a/mlir/test/Quantum/BufferizationTest.mlir b/mlir/test/Quantum/BufferizationTest.mlir
index 995b92c86c..fb53b96d32 100644
--- a/mlir/test/Quantum/BufferizationTest.mlir
+++ b/mlir/test/Quantum/BufferizationTest.mlir
@@ -12,7 +12,37 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// RUN: quantum-opt --quantum-bufferize --split-input-file %s | FileCheck %s
+// RUN: quantum-opt --one-shot-bufferize --split-input-file %s | FileCheck %s
+
+func.func @qubit_unitary(%q0: !quantum.bit, %matrix: tensor<2x2xcomplex>) {
+ // CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : memref<2x2xcomplex>
+ // CHECK: {{%.+}} = quantum.unitary([[memref]] : memref<2x2xcomplex>) %arg0 : !quantum.bit
+ %out_qubits = quantum.unitary(%matrix : tensor<2x2xcomplex>) %q0 : !quantum.bit
+
+ func.return
+}
+
+// -----
+
+func.func @hermitian(%q0: !quantum.bit, %matrix: tensor<2x2xcomplex>) {
+ // CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : memref<2x2xcomplex>
+ // CHECK: {{%.+}} = quantum.hermitian([[memref]] : memref<2x2xcomplex>) %arg0 : !quantum.obs
+ %obs = quantum.hermitian(%matrix : tensor<2x2xcomplex>) %q0 : !quantum.obs
+
+ func.return
+}
+
+// -----
+
+func.func @hamiltonian(%obs: !quantum.obs, %coeffs: tensor<1xf64>){
+ // CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : memref<1xf64>
+ // CHECK: {{%.+}} = quantum.hamiltonian([[memref]] : memref<1xf64>) %arg0 : !quantum.obs
+ %hamil = quantum.hamiltonian(%coeffs: tensor<1xf64>) %obs : !quantum.obs
+
+ func.return
+}
+
+// -----
//////////////////
// Measurements //
@@ -132,4 +162,3 @@ module @set_basis_state {
return
}
}
-
diff --git a/mlir/tools/quantum-opt/quantum-opt.cpp b/mlir/tools/quantum-opt/quantum-opt.cpp
index c764c04627..eda34657df 100644
--- a/mlir/tools/quantum-opt/quantum-opt.cpp
+++ b/mlir/tools/quantum-opt/quantum-opt.cpp
@@ -34,6 +34,7 @@
#include "Mitigation/Transforms/Passes.h"
#include "QEC/IR/QECDialect.h"
#include "Quantum/IR/QuantumDialect.h"
+#include "Quantum/Transforms/BufferizableOpInterfaceImpl.h"
#include "Quantum/Transforms/Passes.h"
namespace test {
@@ -61,6 +62,8 @@ int main(int argc, char **argv)
registry.insert();
registry.insert();
+ catalyst::quantum::registerBufferizableOpInterfaceExternalModels(registry);
+
return mlir::asMainReturnCode(
mlir::MlirOptMain(argc, argv, "Quantum optimizer driver\n", registry));
}