diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 2e16ed6e2b..1e07e62cca 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -110,6 +110,12 @@ * Improved the definition of `YieldOp` in the quantum dialect by removing `AnyTypeOf` [(#1696)](https://github.com/PennyLaneAI/catalyst/pull/1696) +* The bufferization of custom catalyst dialects has been migrated to the new one-shot + bufferization interface in mlir. + The new mlir bufferization interface is required by jax 0.4.29 or higher. + [(#1027)](https://github.com/PennyLaneAI/catalyst/pull/1027) + [(#1686)](https://github.com/PennyLaneAI/catalyst/pull/1686) +

Documentation 📝

Contributors ✍️

@@ -119,6 +125,7 @@ This release contains contributions from (in alphabetical order): Joey Carter, Sengthai Heng, David Ittah, +Tzung-Han Juang, Christina Lee, Erick Ochoa Lopez, Paul Haochen Wang. diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index eb3a949f1b..49ece83c56 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -230,7 +230,7 @@ def get_bufferization_stage(_options: CompileOptions) -> List[str]: "catalyst-bufferize", # Must be run before -- func.func(linalg-bufferize) "func.func(linalg-bufferize)", "func.func(tensor-bufferize)", - "quantum-bufferize", + "one-shot-bufferize{dialect-filter=quantum}", "func-bufferize", "func.func(finalizing-bufferize)", "canonicalize", # Remove dead memrefToTensorOp's diff --git a/mlir/include/Quantum/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/Quantum/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 index 0000000000..ba6e894dac --- /dev/null +++ b/mlir/include/Quantum/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,27 @@ +// Copyright 2024-2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +using namespace mlir; + +namespace catalyst { + +namespace quantum { + +void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry ®istry); + +} + +} // namespace catalyst diff --git a/mlir/include/Quantum/Transforms/Passes.td b/mlir/include/Quantum/Transforms/Passes.td index 496c36906f..c120f5a501 100644 --- a/mlir/include/Quantum/Transforms/Passes.td +++ b/mlir/include/Quantum/Transforms/Passes.td @@ -17,17 +17,6 @@ include "mlir/Pass/PassBase.td" -def QuantumBufferizationPass : Pass<"quantum-bufferize"> { - let summary = "Bufferize tensors in quantum operations."; - - let dependentDialects = [ - "bufferization::BufferizationDialect", - "memref::MemRefDialect" - ]; - - let constructor = "catalyst::createQuantumBufferizationPass()"; -} - def QuantumConversionPass : Pass<"convert-quantum-to-llvm"> { let summary = "Perform a dialect conversion from Quantum to LLVM (QIR)."; diff --git a/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp b/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp index a0c27129cb..a295156330 100644 --- a/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp +++ b/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp @@ -215,7 +215,7 @@ std::optional AsyncUtils::getCalleeSafe(LLVM::CallOp callOp) bool AsyncUtils::isFunctionNamed(LLVM::LLVMFuncOp funcOp, llvm::StringRef expectedName) { llvm::StringRef observedName = funcOp.getSymName(); - return observedName.equals(expectedName); + return observedName == expectedName; } bool AsyncUtils::isMlirAsyncRuntimeCreateValue(LLVM::LLVMFuncOp funcOp) diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index 459e887020..cc1379a6d1 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -50,7 +50,6 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createMemrefToLLVMWithTBAAPass); mlir::registerPass(catalyst::createMitigationLoweringPass); mlir::registerPass(catalyst::createQnodeToAsyncLoweringPass); - mlir::registerPass(catalyst::createQuantumBufferizationPass); mlir::registerPass(catalyst::createQuantumConversionPass); mlir::registerPass(catalyst::createRegisterInactiveCallbackPass); mlir::registerPass(catalyst::createRemoveChainedSelfInversePass); diff --git a/mlir/lib/Driver/CompilerDriver.cpp b/mlir/lib/Driver/CompilerDriver.cpp index c2621738b5..2875ad79a7 100644 --- a/mlir/lib/Driver/CompilerDriver.cpp +++ b/mlir/lib/Driver/CompilerDriver.cpp @@ -70,6 +70,7 @@ #include "Mitigation/Transforms/Passes.h" #include "QEC/IR/QECDialect.h" #include "Quantum/IR/QuantumDialect.h" +#include "Quantum/Transforms/BufferizableOpInterfaceImpl.h" #include "Quantum/Transforms/Passes.h" #include "Enzyme.h" @@ -962,6 +963,9 @@ int QuantumDriverMainFromCL(int argc, char **argv) registerAllCatalystDialects(registry); registerLLVMTranslations(registry); + // Register bufferization interfaces + catalyst::quantum::registerBufferizableOpInterfaceExternalModels(registry); + // Register and parse command line options. std::string inputFilename, outputFilename; std::string helpStr = "Catalyst Command Line Interface options. \n" diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index d1e2dc3eb9..46a3079c69 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -16,6 +16,7 @@ #include "Catalyst/Transforms/Passes.h" #include "Gradient/Transforms/Passes.h" #include "Mitigation/Transforms/Passes.h" +#include "Quantum/IR/QuantumDialect.h" #include "Quantum/Transforms/Passes.h" #include "mhlo/transforms/passes.h" #include "mlir/InitAllDialects.h" @@ -79,7 +80,9 @@ void createBufferizationPipeline(OpPassManager &pm) pm.addPass(catalyst::createCatalystBufferizationPass()); pm.addNestedPass(mlir::createLinalgBufferizePass()); pm.addNestedPass(mlir::tensor::createTensorBufferizePass()); - pm.addPass(catalyst::createQuantumBufferizationPass()); + mlir::bufferization::OneShotBufferizationOptions quantum_buffer_options; + quantum_buffer_options.opFilter.allowDialect(); + pm.addPass(mlir::bufferization::createOneShotBufferizePass(quantum_buffer_options)); pm.addPass(mlir::func::createFuncBufferizePass()); pm.addNestedPass(mlir::bufferization::createFinalizingBufferizePass()); pm.addPass(mlir::createCanonicalizerPass()); diff --git a/mlir/lib/QEC/Transforms/CommuteCliffordPastPPM.cpp b/mlir/lib/QEC/Transforms/CommuteCliffordPastPPM.cpp index 096e485388..03bd8bff81 100644 --- a/mlir/lib/QEC/Transforms/CommuteCliffordPastPPM.cpp +++ b/mlir/lib/QEC/Transforms/CommuteCliffordPastPPM.cpp @@ -13,10 +13,11 @@ // limitations under the License. #define DEBUG_TYPE "merge_ppr_ppm" -#include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +//#include "mlir/Analysis/TopologicalSortUtils.h" // enable when updating llvm #include "mlir/Transforms/TopologicalSortUtils.h" #include "QEC/IR/QECDialect.h" diff --git a/mlir/lib/QEC/Transforms/CommuteCliffordTPPR.cpp b/mlir/lib/QEC/Transforms/CommuteCliffordTPPR.cpp index af49ffe81e..ebfa5aac84 100644 --- a/mlir/lib/QEC/Transforms/CommuteCliffordTPPR.cpp +++ b/mlir/lib/QEC/Transforms/CommuteCliffordTPPR.cpp @@ -15,7 +15,7 @@ #define DEBUG_TYPE "commute_ppr" #include "llvm/Support/Debug.h" - +//#include "mlir/Analysis/TopologicalSortUtils.h" // enable when updating llvm #include "mlir/Transforms/TopologicalSortUtils.h" #include "QEC/IR/QECDialect.h" diff --git a/mlir/lib/Quantum/IR/QuantumDialect.cpp b/mlir/lib/Quantum/IR/QuantumDialect.cpp index 385f4e0ae5..7049f58e63 100644 --- a/mlir/lib/Quantum/IR/QuantumDialect.cpp +++ b/mlir/lib/Quantum/IR/QuantumDialect.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/IR/DialectImplementation.h" // needed for generated type parser #include "llvm/ADT/TypeSwitch.h" // needed for generated type parser @@ -43,6 +44,10 @@ void QuantumDialect::initialize() #define GET_OP_LIST #include "Quantum/IR/QuantumOps.cpp.inc" >(); + + declarePromisedInterfaces(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp new file mode 100644 index 0000000000..1460236cca --- /dev/null +++ b/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,479 @@ +// Copyright 2024-2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +#include "Quantum/IR/QuantumOps.h" +#include "Quantum/Transforms/BufferizableOpInterfaceImpl.h" + +using namespace mlir; +using namespace mlir::bufferization; +using namespace catalyst::quantum; + +/** + * Implementation of the BufferizableOpInterface for use with one-shot bufferization. + * For more information on the interface, refer to the documentation below: + * https://mlir.llvm.org/docs/Bufferization/#extending-one-shot-bufferize + * https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td#L14 + */ + +namespace { + +// Bufferization of quantum.unitary. +// Convert Matrix into memref. +struct QubitUnitaryOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto qubitUnitaryOp = cast(op); + Location loc = op->getLoc(); + auto tensorType = cast(qubitUnitaryOp.getMatrix().getType()); + MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + auto toMemrefOp = + rewriter.create(loc, memrefType, qubitUnitaryOp.getMatrix()); + auto memref = toMemrefOp.getResult(); + bufferization::replaceOpWithNewBufferizedOp( + rewriter, op, qubitUnitaryOp.getOutQubits().getTypes(), + qubitUnitaryOp.getOutCtrlQubits().getTypes(), memref, qubitUnitaryOp.getInQubits(), + qubitUnitaryOp.getAdjointAttr(), qubitUnitaryOp.getInCtrlQubits(), + qubitUnitaryOp.getInCtrlValues()); + return success(); + } +}; + +// Bufferization of quantum.hermitian. +// Convert Matrix into memref. +struct HermitianOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto hermitianOp = cast(op); + Location loc = op->getLoc(); + auto tensorType = cast(hermitianOp.getMatrix().getType()); + MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + auto toMemrefOp = + rewriter.create(loc, memrefType, hermitianOp.getMatrix()); + auto memref = toMemrefOp.getResult(); + auto newHermitianOp = rewriter.create(loc, hermitianOp.getType(), memref, + hermitianOp.getQubits()); + bufferization::replaceOpWithBufferizedValues(rewriter, op, newHermitianOp.getObs()); + + return success(); + } +}; + +// Bufferization of quantum.hamiltonian. +// Convert coefficient tensor into memref. +struct HamiltonianOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto hamiltonianOp = cast(op); + Location loc = op->getLoc(); + auto tensorType = cast(hamiltonianOp.getCoeffs().getType()); + MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + auto toMemrefOp = + rewriter.create(loc, memrefType, hamiltonianOp.getCoeffs()); + auto memref = toMemrefOp.getResult(); + auto newHamiltonianOp = rewriter.create(loc, hamiltonianOp.getType(), memref, + hamiltonianOp.getTerms()); + bufferization::replaceOpWithBufferizedValues(rewriter, op, newHamiltonianOp.getObs()); + + return success(); + } +}; + +// Bufferization of quantum.sample. +// Result tensor of quantum.sample is bufferized with a corresponding memref.alloc. +// Users of the result tensor are updated to use the new memref. +struct SampleOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bool bufferizesToAllocation(Operation *op, Value value) const { return true; } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto sampleOp = cast(op); + Location loc = op->getLoc(); + auto tensorType = cast(sampleOp.getSamples().getType()); + MemRefType resultType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + + SmallVector allocSizes; + for (Value dynShapeDimension : sampleOp.getDynamicShape()) { + auto indexCastOp = + rewriter.create(loc, rewriter.getIndexType(), dynShapeDimension); + allocSizes.push_back(indexCastOp); + } + + Value allocVal = rewriter.create(loc, resultType, allocSizes); + auto allocedSampleOp = rewriter.create( + loc, TypeRange{}, ValueRange{sampleOp.getObs(), allocVal}, op->getAttrs()); + allocedSampleOp->setAttr("operandSegmentSizes", rewriter.getDenseI32ArrayAttr({1, 0, 1})); + bufferization::replaceOpWithBufferizedValues(rewriter, op, allocVal); + return success(); + } +}; + +// Bufferization of quantum.counts. +// Result tensors of quantum.counts are bufferized with corresponding memref.alloc ops. +// Users of the result tensors are updated to use the new memrefs. +struct CountsOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bool bufferizesToAllocation(Operation *op, Value value) const { return true; } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto countsOp = cast(op); + Location loc = op->getLoc(); + + SmallVector buffers; + for (size_t i : {0, 1}) { + auto tensorType = cast(countsOp.getType(i)); + MemRefType resultType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + auto shape = cast(tensorType).getShape(); + + Value allocVal; + if (shape[0] == ShapedType::kDynamic) { + auto indexCastOp = rewriter.create(loc, rewriter.getIndexType(), + countsOp.getDynamicShape()); + allocVal = + rewriter.create(loc, resultType, ValueRange{indexCastOp}); + } + else { + allocVal = rewriter.create(loc, resultType); + } + buffers.push_back(allocVal); + } + + rewriter.create(loc, nullptr, nullptr, countsOp.getObs(), nullptr, buffers[0], + buffers[1]); + bufferization::replaceOpWithBufferizedValues(rewriter, op, buffers); + + return success(); + } +}; + +// Bufferization of quantum.probs. +// Result tensor of quantum.probs is bufferized with a corresponding memref.alloc. +// Users of the result tensor are updated to use the new memref. +struct ProbsOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bool bufferizesToAllocation(Operation *op, Value value) const { return true; } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto probsOp = cast(op); + Location loc = op->getLoc(); + auto tensorType = cast(probsOp.getProbabilities().getType()); + MemRefType resultType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + + Value buffer; + auto shape = cast(tensorType).getShape(); + if (shape[0] == ShapedType::kDynamic) { + auto indexCastOp = rewriter.create(loc, rewriter.getIndexType(), + probsOp.getDynamicShape()); + buffer = rewriter.create(loc, resultType, ValueRange{indexCastOp}); + } + else { + buffer = rewriter.create(loc, resultType); + } + + auto allocedProbsOp = + rewriter.create(loc, TypeRange{}, ValueRange{probsOp.getObs(), buffer}); + allocedProbsOp->setAttr("operandSegmentSizes", rewriter.getDenseI32ArrayAttr({1, 0, 1})); + bufferization::replaceOpWithBufferizedValues(rewriter, op, buffer); + return success(); + } +}; + +// Bufferization of quantum.state. +// Result tensor of quantum.state is bufferized with a corresponding memref.alloc. +// Users of the result tensor are updated to use the new memref. +struct StateOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bool bufferizesToAllocation(Operation *op, Value value) const { return true; } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto stateOp = cast(op); + Location loc = op->getLoc(); + auto tensorType = cast(stateOp.getState().getType()); + MemRefType resultType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + + Value buffer; + auto shape = cast(tensorType).getShape(); + if (shape[0] == ShapedType::kDynamic) { + auto indexCastOp = rewriter.create(loc, rewriter.getIndexType(), + stateOp.getDynamicShape()); + buffer = rewriter.create(loc, resultType, ValueRange{indexCastOp}); + } + else { + buffer = rewriter.create(loc, resultType); + } + + auto allocedStateOp = + rewriter.create(loc, TypeRange{}, ValueRange{stateOp.getObs(), buffer}); + allocedStateOp->setAttr("operandSegmentSizes", rewriter.getDenseI32ArrayAttr({1, 0, 1})); + bufferization::replaceOpWithBufferizedValues(rewriter, op, buffer); + return success(); + } +}; + +// Bufferization of quantum.set_state. +// Convert InState into memref. +struct SetStateOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto setStateOp = cast(op); + Location loc = op->getLoc(); + auto tensorType = cast(setStateOp.getInState().getType()); + MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + + auto toMemrefOp = + rewriter.create(loc, memrefType, setStateOp.getInState()); + auto memref = toMemrefOp.getResult(); + auto newSetStateOp = rewriter.create(loc, setStateOp.getOutQubits().getTypes(), + memref, setStateOp.getInQubits()); + bufferization::replaceOpWithBufferizedValues(rewriter, op, newSetStateOp.getOutQubits()); + return success(); + } +}; + +// Bufferization of quantum.set_basis_state. +// Convert BasisState into memref. +struct SetBasisStateOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto setBasisStateOp = cast(op); + Location loc = op->getLoc(); + auto tensorType = cast(setBasisStateOp.getBasisState().getType()); + MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + + auto toMemrefOp = rewriter.create( + loc, memrefType, setBasisStateOp.getBasisState()); + auto memref = toMemrefOp.getResult(); + auto newSetStateOp = rewriter.create( + loc, setBasisStateOp.getOutQubits().getTypes(), memref, setBasisStateOp.getInQubits()); + bufferization::replaceOpWithBufferizedValues(rewriter, op, newSetStateOp.getOutQubits()); + return success(); + } +}; + +} // namespace + +void catalyst::quantum::registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) +{ + registry.addExtension(+[](MLIRContext *ctx, catalyst::quantum::QuantumDialect *dialect) { + QubitUnitaryOp::attachInterface(*ctx); + HermitianOp::attachInterface(*ctx); + HamiltonianOp::attachInterface(*ctx); + SampleOp::attachInterface(*ctx); + CountsOp::attachInterface(*ctx); + ProbsOp::attachInterface(*ctx); + StateOp::attachInterface(*ctx); + SetStateOp::attachInterface(*ctx); + SetBasisStateOp::attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Quantum/Transforms/BufferizationPatterns.cpp b/mlir/lib/Quantum/Transforms/BufferizationPatterns.cpp deleted file mode 100644 index b48493ef1e..0000000000 --- a/mlir/lib/Quantum/Transforms/BufferizationPatterns.cpp +++ /dev/null @@ -1,259 +0,0 @@ -// Copyright 2022-2023 Xanadu Quantum Technologies Inc. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "Quantum/IR/QuantumOps.h" -#include "Quantum/Transforms/Patterns.h" - -using namespace mlir; -using namespace catalyst::quantum; - -namespace { - -struct BufferizeQubitUnitaryOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(QubitUnitaryOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - rewriter.replaceOpWithNewOp( - op, op.getOutQubits().getTypes(), op.getOutCtrlQubits().getTypes(), adaptor.getMatrix(), - adaptor.getInQubits(), adaptor.getAdjointAttr(), adaptor.getInCtrlQubits(), - adaptor.getInCtrlValues()); - return success(); - } -}; - -struct BufferizeHermitianOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(HermitianOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - rewriter.replaceOpWithNewOp(op, op.getType(), adaptor.getMatrix(), - adaptor.getQubits()); - return success(); - } -}; - -struct BufferizeHamiltonianOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(HamiltonianOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - rewriter.replaceOpWithNewOp(op, op.getType(), adaptor.getCoeffs(), - adaptor.getTerms()); - return success(); - } -}; - -struct BufferizeSampleOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(SampleOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - Type tensorType = op.getType(0); - MemRefType resultType = cast(getTypeConverter()->convertType(tensorType)); - Location loc = op.getLoc(); - - SmallVector allocSizes; - for (Value dynShapeDimension : op.getDynamicShape()) { - auto indexCastOp = - rewriter.create(loc, rewriter.getIndexType(), dynShapeDimension); - allocSizes.push_back(indexCastOp); - } - - Value allocVal = rewriter.replaceOpWithNewOp(op, resultType, allocSizes); - auto allocedSampleOp = rewriter.create( - loc, TypeRange{}, ValueRange{adaptor.getObs(), allocVal}, op->getAttrs()); - allocedSampleOp->setAttr("operandSegmentSizes", rewriter.getDenseI32ArrayAttr({1, 0, 1})); - return success(); - } -}; - -struct BufferizeStateOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(StateOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - Type tensorType = op.getType(0); - MemRefType resultType = cast(getTypeConverter()->convertType(tensorType)); - Location loc = op.getLoc(); - - Value buffer; - auto shape = cast(tensorType).getShape(); - if (shape[0] == ShapedType::kDynamic) { - auto indexCastOp = - rewriter.create(loc, rewriter.getIndexType(), op.getDynamicShape()); - buffer = rewriter.replaceOpWithNewOp(op, resultType, - ValueRange{indexCastOp}); - } - else { - buffer = rewriter.replaceOpWithNewOp(op, resultType); - } - - auto allocedStateOp = - rewriter.create(loc, TypeRange{}, ValueRange{adaptor.getObs(), buffer}); - allocedStateOp->setAttr("operandSegmentSizes", rewriter.getDenseI32ArrayAttr({1, 0, 1})); - return success(); - } -}; - -struct BufferizeProbsOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(ProbsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - Type tensorType = op.getType(0); - MemRefType resultType = cast(getTypeConverter()->convertType(tensorType)); - Location loc = op.getLoc(); - - Value buffer; - auto shape = cast(tensorType).getShape(); - if (shape[0] == ShapedType::kDynamic) { - auto indexCastOp = - rewriter.create(loc, rewriter.getIndexType(), op.getDynamicShape()); - buffer = rewriter.replaceOpWithNewOp(op, resultType, - ValueRange{indexCastOp}); - } - else { - buffer = rewriter.replaceOpWithNewOp(op, resultType); - } - - auto allocedProbsOp = - rewriter.create(loc, TypeRange{}, ValueRange{adaptor.getObs(), buffer}); - allocedProbsOp->setAttr("operandSegmentSizes", rewriter.getDenseI32ArrayAttr({1, 0, 1})); - return success(); - } -}; - -struct BufferizeCountsOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(CountsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - Location loc = op.getLoc(); - SmallVector buffers; - for (size_t i : {0, 1}) { - Type tensorType = op.getType(i); - MemRefType resultType = cast(getTypeConverter()->convertType(tensorType)); - auto shape = cast(tensorType).getShape(); - - Value allocVal; - if (shape[0] == ShapedType::kDynamic) { - auto indexCastOp = rewriter.create(loc, rewriter.getIndexType(), - op.getDynamicShape()); - allocVal = - rewriter.create(loc, resultType, ValueRange{indexCastOp}); - } - else { - allocVal = rewriter.create(loc, resultType); - } - buffers.push_back(allocVal); - } - rewriter.replaceOp(op, buffers); - - rewriter.create(loc, nullptr, nullptr, adaptor.getObs(), nullptr, buffers[0], - buffers[1]); - - return success(); - } -}; - -struct BufferizeSetStateOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(SetStateOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - Type tensorType = op.getInState().getType(); - MemRefType memrefType = cast(getTypeConverter()->convertType(tensorType)); - auto toMemrefOp = - rewriter.create(op->getLoc(), memrefType, op.getInState()); - auto memref = toMemrefOp.getResult(); - rewriter.replaceOpWithNewOp(op, op.getOutQubits().getTypes(), memref, - adaptor.getInQubits()); - return success(); - } -}; - -struct BufferizeSetBasisStateOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(SetBasisStateOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - Type tensorType = op.getBasisState().getType(); - MemRefType memrefType = cast(getTypeConverter()->convertType(tensorType)); - auto toMemrefOp = rewriter.create(op->getLoc(), memrefType, - op.getBasisState()); - auto memref = toMemrefOp.getResult(); - rewriter.replaceOpWithNewOp(op, op.getOutQubits().getTypes(), memref, - adaptor.getInQubits()); - return success(); - } -}; - -} // namespace - -namespace catalyst { -namespace quantum { - -void populateBufferizationLegality(TypeConverter &typeConverter, ConversionTarget &target) -{ - // Default to operations being legal with the exception of the ones below. - target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - // Quantum ops which return arrays need to be marked illegal when the type is a tensor. - target.addDynamicallyLegalOp( - [&](QubitUnitaryOp op) { return typeConverter.isLegal(op.getMatrix().getType()); }); - target.addDynamicallyLegalOp( - [&](HermitianOp op) { return typeConverter.isLegal(op.getMatrix().getType()); }); - target.addDynamicallyLegalOp( - [&](HamiltonianOp op) { return typeConverter.isLegal(op.getCoeffs().getType()); }); - target.addDynamicallyLegalOp([&](SampleOp op) { return op.isBufferized(); }); - target.addDynamicallyLegalOp([&](StateOp op) { return op.isBufferized(); }); - target.addDynamicallyLegalOp([&](ProbsOp op) { return op.isBufferized(); }); - target.addDynamicallyLegalOp([&](CountsOp op) { return op.isBufferized(); }); - target.addDynamicallyLegalOp([&](SetStateOp op) { return op.isBufferized(); }); - target.addDynamicallyLegalOp( - [&](SetBasisStateOp op) { return op.isBufferized(); }); -} - -void populateBufferizationPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) -{ - patterns.add(typeConverter, patterns.getContext()); - patterns.add(typeConverter, patterns.getContext()); - patterns.add(typeConverter, patterns.getContext()); - patterns.add(typeConverter, patterns.getContext()); - patterns.add(typeConverter, patterns.getContext()); - patterns.add(typeConverter, patterns.getContext()); - patterns.add(typeConverter, patterns.getContext()); - patterns.add(typeConverter, patterns.getContext()); - patterns.add(typeConverter, patterns.getContext()); -} - -} // namespace quantum -} // namespace catalyst diff --git a/mlir/lib/Quantum/Transforms/CMakeLists.txt b/mlir/lib/Quantum/Transforms/CMakeLists.txt index d983a23eb9..3a244ac4d6 100644 --- a/mlir/lib/Quantum/Transforms/CMakeLists.txt +++ b/mlir/lib/Quantum/Transforms/CMakeLists.txt @@ -1,8 +1,7 @@ set(LIBRARY_NAME quantum-transforms) file(GLOB SRC - BufferizationPatterns.cpp - quantum_bufferize.cpp + BufferizableOpInterfaceImpl.cpp ConversionPatterns.cpp quantum_to_llvm.cpp emit_catalyst_pyface.cpp diff --git a/mlir/lib/Quantum/Transforms/quantum_bufferize.cpp b/mlir/lib/Quantum/Transforms/quantum_bufferize.cpp deleted file mode 100644 index 901f79da6f..0000000000 --- a/mlir/lib/Quantum/Transforms/quantum_bufferize.cpp +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2022-2023 Xanadu Quantum Technologies Inc. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "Quantum/IR/QuantumOps.h" -#include "Quantum/Transforms/Passes.h" -#include "Quantum/Transforms/Patterns.h" - -using namespace mlir; -using namespace catalyst::quantum; - -namespace catalyst { -namespace quantum { - -#define GEN_PASS_DEF_QUANTUMBUFFERIZATIONPASS -#include "Quantum/Transforms/Passes.h.inc" - -struct QuantumBufferizationPass : impl::QuantumBufferizationPassBase { - using QuantumBufferizationPassBase::QuantumBufferizationPassBase; - - void runOnOperation() final - { - MLIRContext *context = &getContext(); - bufferization::BufferizeTypeConverter typeConverter; - - RewritePatternSet patterns(context); - populateBufferizationPatterns(typeConverter, patterns); - - ConversionTarget target(*context); - bufferization::populateBufferizeMaterializationLegality(target); - populateBufferizationLegality(typeConverter, target); - - if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { - signalPassFailure(); - } - } -}; - -} // namespace quantum - -std::unique_ptr createQuantumBufferizationPass() -{ - return std::make_unique(); -} - -} // namespace catalyst diff --git a/mlir/test/Quantum/BufferizationTest.mlir b/mlir/test/Quantum/BufferizationTest.mlir index 995b92c86c..fb53b96d32 100644 --- a/mlir/test/Quantum/BufferizationTest.mlir +++ b/mlir/test/Quantum/BufferizationTest.mlir @@ -12,7 +12,37 @@ // See the License for the specific language governing permissions and // limitations under the License. -// RUN: quantum-opt --quantum-bufferize --split-input-file %s | FileCheck %s +// RUN: quantum-opt --one-shot-bufferize --split-input-file %s | FileCheck %s + +func.func @qubit_unitary(%q0: !quantum.bit, %matrix: tensor<2x2xcomplex>) { + // CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : memref<2x2xcomplex> + // CHECK: {{%.+}} = quantum.unitary([[memref]] : memref<2x2xcomplex>) %arg0 : !quantum.bit + %out_qubits = quantum.unitary(%matrix : tensor<2x2xcomplex>) %q0 : !quantum.bit + + func.return +} + +// ----- + +func.func @hermitian(%q0: !quantum.bit, %matrix: tensor<2x2xcomplex>) { + // CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : memref<2x2xcomplex> + // CHECK: {{%.+}} = quantum.hermitian([[memref]] : memref<2x2xcomplex>) %arg0 : !quantum.obs + %obs = quantum.hermitian(%matrix : tensor<2x2xcomplex>) %q0 : !quantum.obs + + func.return +} + +// ----- + +func.func @hamiltonian(%obs: !quantum.obs, %coeffs: tensor<1xf64>){ + // CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : memref<1xf64> + // CHECK: {{%.+}} = quantum.hamiltonian([[memref]] : memref<1xf64>) %arg0 : !quantum.obs + %hamil = quantum.hamiltonian(%coeffs: tensor<1xf64>) %obs : !quantum.obs + + func.return +} + +// ----- ////////////////// // Measurements // @@ -132,4 +162,3 @@ module @set_basis_state { return } } - diff --git a/mlir/tools/quantum-opt/quantum-opt.cpp b/mlir/tools/quantum-opt/quantum-opt.cpp index c764c04627..eda34657df 100644 --- a/mlir/tools/quantum-opt/quantum-opt.cpp +++ b/mlir/tools/quantum-opt/quantum-opt.cpp @@ -34,6 +34,7 @@ #include "Mitigation/Transforms/Passes.h" #include "QEC/IR/QECDialect.h" #include "Quantum/IR/QuantumDialect.h" +#include "Quantum/Transforms/BufferizableOpInterfaceImpl.h" #include "Quantum/Transforms/Passes.h" namespace test { @@ -61,6 +62,8 @@ int main(int argc, char **argv) registry.insert(); registry.insert(); + catalyst::quantum::registerBufferizableOpInterfaceExternalModels(registry); + return mlir::asMainReturnCode( mlir::MlirOptMain(argc, argv, "Quantum optimizer driver\n", registry)); }