From a1284963ab91d1befc3a69feecc5d423ac29ddf9 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Wed, 30 Apr 2025 08:04:14 -0400 Subject: [PATCH 01/29] init; some boilerplate --- .../Transforms/BufferizableOpInterfaceImpl.h | 23 +++++++++++++++++++ mlir/lib/Driver/CompilerDriver.cpp | 4 ++++ 2 files changed, 27 insertions(+) create mode 100644 mlir/include/Catalyst/Transforms/BufferizableOpInterfaceImpl.h diff --git a/mlir/include/Catalyst/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/Catalyst/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 index 0000000000..89eb431f8f --- /dev/null +++ b/mlir/include/Catalyst/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,23 @@ +// Copyright 2024-2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +using namespace mlir; + +namespace catalyst { + +void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry ®istry); + +} // namespace catalyst diff --git a/mlir/lib/Driver/CompilerDriver.cpp b/mlir/lib/Driver/CompilerDriver.cpp index c2621738b5..69d19caac7 100644 --- a/mlir/lib/Driver/CompilerDriver.cpp +++ b/mlir/lib/Driver/CompilerDriver.cpp @@ -56,6 +56,7 @@ #include "llvm/Transforms/IPO/GlobalDCE.h" #include "Catalyst/IR/CatalystDialect.h" +#include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h" #include "Catalyst/Transforms/Passes.h" #include "Driver/CatalystLLVMTarget.h" #include "Driver/CompilerDriver.h" @@ -962,6 +963,9 @@ int QuantumDriverMainFromCL(int argc, char **argv) registerAllCatalystDialects(registry); registerLLVMTranslations(registry); + // Register bufferization interfaces + catalyst::registerBufferizableOpInterfaceExternalModels(registry); + // Register and parse command line options. std::string inputFilename, outputFilename; std::string helpStr = "Catalyst Command Line Interface options. \n" From db5dcccdd62543a6576dc6e7e9fbdd7b9b2c225b Mon Sep 17 00:00:00 2001 From: paul0403 Date: Wed, 30 Apr 2025 08:44:20 -0400 Subject: [PATCH 02/29] more boilerplate --- mlir/lib/Catalyst/IR/CatalystDialect.cpp | 3 +++ .../BufferizableOpInterfaceImpl.cpp | 24 +++++++++++++++++++ mlir/lib/Catalyst/Transforms/CMakeLists.txt | 3 ++- 3 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp diff --git a/mlir/lib/Catalyst/IR/CatalystDialect.cpp b/mlir/lib/Catalyst/IR/CatalystDialect.cpp index 1dce30c4cc..41190fe130 100644 --- a/mlir/lib/Catalyst/IR/CatalystDialect.cpp +++ b/mlir/lib/Catalyst/IR/CatalystDialect.cpp @@ -14,6 +14,7 @@ #include "Catalyst/IR/CatalystDialect.h" #include "Catalyst/IR/CatalystOps.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" // needed for generated type parser #include "mlir/Interfaces/FunctionImplementation.h" @@ -40,6 +41,8 @@ void CatalystDialect::initialize() #define GET_OP_LIST #include "Catalyst/IR/CatalystOps.cpp.inc" >(); + + declarePromisedInterfaces(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp new file mode 100644 index 0000000000..41c1375190 --- /dev/null +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,24 @@ +// Copyright 2024-2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace { + // stuff here +} // namespace + +void catalyst::registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) +{ + registry.addExtension(+[](MLIRContext *ctx, catalyst::quantum::QuantumDialect *dialect) { + //QubitUnitaryOp::attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index 59e5619eca..64c558db66 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -4,7 +4,8 @@ file(GLOB SRC ApplyTransformSequencePass.cpp ArrayListToMemRefPass.cpp AsyncUtils.cpp - BufferizationPatterns.cpp + BufferizableOpInterfaceImpl.cpp + BufferizationPatterns.cpp // remove catalyst_bufferize.cpp catalyst_to_llvm.cpp DetectQNodes.cpp From 99b24870a3cf958e5342110ac32c359fbf764dcd Mon Sep 17 00:00:00 2001 From: paul0403 Date: Wed, 30 Apr 2025 08:46:05 -0400 Subject: [PATCH 03/29] boilerplate for quantum-opt --- mlir/tools/quantum-opt/quantum-opt.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlir/tools/quantum-opt/quantum-opt.cpp b/mlir/tools/quantum-opt/quantum-opt.cpp index c764c04627..c33dfd00da 100644 --- a/mlir/tools/quantum-opt/quantum-opt.cpp +++ b/mlir/tools/quantum-opt/quantum-opt.cpp @@ -25,6 +25,7 @@ #include "mhlo/IR/hlo_ops.h" #include "Catalyst/IR/CatalystDialect.h" +#include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h" #include "Catalyst/Transforms/Passes.h" #include "Gradient/IR/GradientDialect.h" #include "Gradient/Transforms/Passes.h" @@ -61,6 +62,8 @@ int main(int argc, char **argv) registry.insert(); registry.insert(); + catalyst::registerBufferizableOpInterfaceExternalModels(registry); + return mlir::asMainReturnCode( mlir::MlirOptMain(argc, argv, "Quantum optimizer driver\n", registry)); } From a7ff16d09d42fbfe71dca2f72561f64ab761123f Mon Sep 17 00:00:00 2001 From: paul0403 Date: Wed, 30 Apr 2025 08:51:06 -0400 Subject: [PATCH 04/29] boilerplate... --- .../Transforms/BufferizableOpInterfaceImpl.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp index 41c1375190..14bc910103 100644 --- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -12,13 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +#include "Catalyst/IR/CatalystOps.h" +#include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h" + +using namespace mlir; +using namespace mlir::bufferization; +using namespace catalyst; + namespace { // stuff here } // namespace void catalyst::registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { - registry.addExtension(+[](MLIRContext *ctx, catalyst::quantum::QuantumDialect *dialect) { + registry.addExtension(+[](MLIRContext *ctx, catalyst::CatalystDialect *dialect) { //QubitUnitaryOp::attachInterface(*ctx); }); } From a30ab8ac864cebf8cb97948256469757719a221e Mon Sep 17 00:00:00 2001 From: Tzung-Han Juang Date: Thu, 29 Aug 2024 13:30:15 -0400 Subject: [PATCH 05/29] (cherry pick) Add CustomCall bufferization --- .../BufferizableOpInterfaceImpl.cpp | 111 +++++++++++++++++- 1 file changed, 106 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp index 14bc910103..ee8d688fe8 100644 --- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -25,12 +25,113 @@ using namespace mlir::bufferization; using namespace catalyst; namespace { - // stuff here + +/// Bufferization of catalyst.print. Get memref of printOp.val. +struct PrintOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const { + return false; + } + + bufferization::AliasingValueList getAliasingValues(Operation *op, + OpOperand &opOperand, + const bufferization::AnalysisState &state) const { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const { + auto printOp = cast(op); + if (printOp.getVal()) { + FailureOr source = getBuffer(rewriter, printOp.getVal(), options); + if (failed(source)) + return failure(); + bufferization::replaceOpWithNewBufferizedOp(rewriter, op, *source, + printOp.getConstValAttr(), printOp.getPrintDescriptorAttr()); + } + return success(); + } +}; + +struct CustomCallOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const { + return true; + } + + bufferization::AliasingValueList getAliasingValues(Operation *op, + OpOperand &opOperand, + const bufferization::AnalysisState &state) const { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const { + auto customCallOp = cast(op); + + // Add bufferized arguments + SmallVector bufferArgs; + ValueRange operands = customCallOp.getOperands(); + for (Value operand : operands) { + FailureOr opBuffer = getBuffer(rewriter, operand, options); + if (failed(opBuffer)) + return failure(); + bufferArgs.push_back(*opBuffer); + } + + // Add bufferized return values to the arguments + ValueRange results = customCallOp.getResults(); + for (Value result : results) { + Type resultType = result.getType(); + RankedTensorType tensorType = dyn_cast(resultType); + if (!tensorType) { + return failure(); + } + auto options = bufferization::BufferizationOptions(); + FailureOr tensorAlloc = bufferization::allocateTensorForShapedValue( + rewriter, op->getLoc(), result, options, false); + MemRefType memrefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + auto newBuffer = + rewriter.create(op->getLoc(), memrefType, *tensorAlloc); + bufferArgs.push_back(newBuffer); + } + + // Add the initial number of arguments + int32_t numArguments = static_cast(customCallOp.getNumOperands()); + DenseI32ArrayAttr numArgumentsDenseAttr = rewriter.getDenseI32ArrayAttr({numArguments}); + + // Create an updated custom call operation + rewriter.create(op->getLoc(), TypeRange{}, bufferArgs, + customCallOp.getCallTargetName(), numArgumentsDenseAttr); + size_t startIndex = bufferArgs.size() - customCallOp.getNumResults(); + SmallVector bufferResults(bufferArgs.begin() + startIndex, bufferArgs.end()); + bufferization::replaceOpWithBufferizedValues(rewriter, op, bufferResults); + + return success(); + } +}; + } // namespace -void catalyst::registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) -{ - registry.addExtension(+[](MLIRContext *ctx, catalyst::CatalystDialect *dialect) { - //QubitUnitaryOp::attachInterface(*ctx); +void catalyst::registerBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, CatalystDialect *dialect) { + CustomCallOp::attachInterface(*ctx); + PrintOp::attachInterface(*ctx); }); } From 4aeb8df7658c707719e9082f76a9ed2ee076fce4 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 2 May 2025 09:58:56 -0400 Subject: [PATCH 06/29] changelog --- doc/releases/changelog-dev.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 1e07e62cca..e9a481ad11 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -115,6 +115,7 @@ The new mlir bufferization interface is required by jax 0.4.29 or higher. [(#1027)](https://github.com/PennyLaneAI/catalyst/pull/1027) [(#1686)](https://github.com/PennyLaneAI/catalyst/pull/1686) + [(#1708)](https://github.com/PennyLaneAI/catalyst/pull/1708)

Documentation 📝

From 750891622ecb2f00ab45f03718bdf7fb2275b74d Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 2 May 2025 10:16:10 -0400 Subject: [PATCH 07/29] add remove hint in cmakelists --- mlir/lib/Catalyst/Transforms/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index 64c558db66..1cbe4b64c5 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -6,7 +6,7 @@ file(GLOB SRC AsyncUtils.cpp BufferizableOpInterfaceImpl.cpp BufferizationPatterns.cpp // remove - catalyst_bufferize.cpp + catalyst_bufferize.cpp // remove catalyst_to_llvm.cpp DetectQNodes.cpp DetensorizeSCFPass.cpp From dbf3ff7f9def3ab58b549377c0c456a996b0a946 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Tue, 6 May 2025 14:22:20 -0400 Subject: [PATCH 08/29] remove pattern header from quantum dialect --- mlir/include/Quantum/Transforms/Patterns.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlir/include/Quantum/Transforms/Patterns.h b/mlir/include/Quantum/Transforms/Patterns.h index d278f809b7..8edaf0ffe5 100644 --- a/mlir/include/Quantum/Transforms/Patterns.h +++ b/mlir/include/Quantum/Transforms/Patterns.h @@ -21,8 +21,6 @@ namespace catalyst { namespace quantum { -void populateBufferizationLegality(mlir::TypeConverter &, mlir::ConversionTarget &); -void populateBufferizationPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &); void populateQIRConversionPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &); void populateAdjointPatterns(mlir::RewritePatternSet &); void populateSelfInversePatterns(mlir::RewritePatternSet &); From 3f7c608903190622ec57ac59ef3f0c325a7a01fa Mon Sep 17 00:00:00 2001 From: paul0403 Date: Tue, 6 May 2025 16:12:15 -0400 Subject: [PATCH 09/29] add callback and callbackcall op --- frontend/catalyst/pipelines.py | 5 +- mlir/include/Catalyst/Transforms/Patterns.h | 1 + mlir/lib/Catalyst/IR/CatalystDialect.cpp | 2 +- .../BufferizableOpInterfaceImpl.cpp | 152 +++++++++++++++++- mlir/test/Catalyst/BufferizationTest.mlir | 28 ++-- 5 files changed, 170 insertions(+), 18 deletions(-) diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index 49ece83c56..242c4a36a2 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -227,12 +227,13 @@ def get_bufferization_stage(_options: CompileOptions) -> List[str]: "empty-tensor-to-alloc-tensor", "func.func(bufferization-bufferize)", "func.func(tensor-bufferize)", - "catalyst-bufferize", # Must be run before -- func.func(linalg-bufferize) + #"catalyst-bufferize", # Must be run before -- func.func(linalg-bufferize) + "one-shot-bufferize{dialect-filter=catalyst}", # Must be run before -- func.func(linalg-bufferize) "func.func(linalg-bufferize)", "func.func(tensor-bufferize)", "one-shot-bufferize{dialect-filter=quantum}", "func-bufferize", - "func.func(finalizing-bufferize)", + #"func.func(finalizing-bufferize)", "canonicalize", # Remove dead memrefToTensorOp's "gradient-postprocess", # introduced during gradient-bufferize of callbacks diff --git a/mlir/include/Catalyst/Transforms/Patterns.h b/mlir/include/Catalyst/Transforms/Patterns.h index cdc5157806..7531929ba7 100644 --- a/mlir/include/Catalyst/Transforms/Patterns.h +++ b/mlir/include/Catalyst/Transforms/Patterns.h @@ -21,6 +21,7 @@ namespace catalyst { +// TODO: remove buffer pattern void populateBufferizationPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &); void populateScatterPatterns(mlir::RewritePatternSet &); diff --git a/mlir/lib/Catalyst/IR/CatalystDialect.cpp b/mlir/lib/Catalyst/IR/CatalystDialect.cpp index 41190fe130..a086096cb6 100644 --- a/mlir/lib/Catalyst/IR/CatalystDialect.cpp +++ b/mlir/lib/Catalyst/IR/CatalystDialect.cpp @@ -42,7 +42,7 @@ void CatalystDialect::initialize() #include "Catalyst/IR/CatalystOps.cpp.inc" >(); - declarePromisedInterfaces(); + declarePromisedInterfaces(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp index ee8d688fe8..44871ecb97 100644 --- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Transforms/DialectConversion.h" #include "Catalyst/IR/CatalystOps.h" #include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h" @@ -51,8 +52,9 @@ struct PrintOpInterface auto printOp = cast(op); if (printOp.getVal()) { FailureOr source = getBuffer(rewriter, printOp.getVal(), options); - if (failed(source)) + if (failed(source)) { return failure(); + } bufferization::replaceOpWithNewBufferizedOp(rewriter, op, *source, printOp.getConstValAttr(), printOp.getPrintDescriptorAttr()); } @@ -60,6 +62,7 @@ struct PrintOpInterface } }; +/// Bufferization of catalyst.custom_call. Mainly get buffers for arguments. struct CustomCallOpInterface : public bufferization::BufferizableOpInterface::ExternalModel { @@ -88,8 +91,9 @@ struct CustomCallOpInterface ValueRange operands = customCallOp.getOperands(); for (Value operand : operands) { FailureOr opBuffer = getBuffer(rewriter, operand, options); - if (failed(opBuffer)) + if (failed(opBuffer)) { return failure(); + } bufferArgs.push_back(*opBuffer); } @@ -126,6 +130,148 @@ struct CustomCallOpInterface } }; +struct CallbackOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool hasTensorSemantics(Operation *op) const + { + auto isaTensor = llvm::IsaPred; + + // A function has tensor semantics if it has tensor arguments/results. + auto callbackOp = cast(op); + bool hasTensorArg = any_of(callbackOp.getArgumentTypes(), isaTensor); + bool hasTensorResult = any_of(callbackOp.getResultTypes(), isaTensor); + if (hasTensorArg || hasTensorResult) + return true; + + return false; + } + + bufferization::AliasingOpOperandList + getAliasingOpOperands(Operation *op, Value value, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto callbackOp = cast(op); + + auto argTys = callbackOp.getArgumentTypes(); + auto retTys = callbackOp.getResultTypes(); + SmallVector emptyRets; + SmallVector args(argTys.begin(), argTys.end()); + args.insert(args.end(), retTys.begin(), retTys.end()); + SmallVector bufferArgs; + for (Type ty : args) { + auto tensorType = dyn_cast(ty); + if (!tensorType) { + bufferArgs.push_back(ty); + } + else { + bufferArgs.push_back( + MemRefType::get(tensorType.getShape(), tensorType.getElementType())); + } + } + auto callbackTy = rewriter.getFunctionType(bufferArgs, emptyRets); + rewriter.modifyOpInPlace(op, [&] { callbackOp.setFunctionType(callbackTy); }); + + return success(); + } +}; + +void convertTypes(SmallVector inTypes, SmallVector &convertedResults){ + // See https://github.com/llvm/llvm-project/pull/114155/files + for (Type inType : inTypes) { + if (isa(inType)){ + convertedResults.push_back(bufferization::getMemRefTypeWithStaticIdentityLayout(cast(inType))); + } else { + convertedResults.push_back(inType); + } + } +} + +struct CallbackCallOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToAllocation(Operation *op, Value value) const { return true; } + + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + // We can safely say false because CallbackCallOp's memrefs + // will be put in a JAX array and JAX arrays are immutable. + // + // Unlike NumPy arrays, JAX arrays are always immutable. + // + // https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto callOp = cast(op); + + SmallVector convertedResults; + convertTypes(SmallVector(callOp.getResultTypes()), convertedResults); + if (callOp->getNumResults() != convertedResults.size()) { + return failure(); + } + + SmallVector newInputs; + auto operands = callOp.getOperands(); + for (Value operand : operands) { + FailureOr opBuffer = getBuffer(rewriter, operand, options); + if (failed(opBuffer)) { + return failure(); + } + newInputs.push_back(*opBuffer); + } + + auto results = callOp.getResults(); + auto loc = callOp->getLoc(); + SmallVector outmemrefs; + for (auto result : results) { + FailureOr tensorAlloc = + bufferization::allocateTensorForShapedValue(rewriter, loc, result, options, false); + if (failed(tensorAlloc)) { + return failure(); + } + + auto tensor = *tensorAlloc; + RankedTensorType tensorTy = cast(tensor.getType()); + auto shape = tensorTy.getShape(); + auto elementTy = tensorTy.getElementType(); + auto memrefType = MemRefType::get(shape, elementTy); + auto toMemrefOp = rewriter.create(loc, memrefType, tensor); + auto memref = toMemrefOp.getResult(); + outmemrefs.push_back(memref); + newInputs.push_back(memref); + } + + SmallVector emptyRets; + rewriter.create(loc, emptyRets, callOp.getCallee(), newInputs); + bufferization::replaceOpWithBufferizedValues(rewriter, op, outmemrefs); + return success(); + } +}; + } // namespace void catalyst::registerBufferizableOpInterfaceExternalModels( @@ -133,5 +279,7 @@ void catalyst::registerBufferizableOpInterfaceExternalModels( registry.addExtension(+[](MLIRContext *ctx, CatalystDialect *dialect) { CustomCallOp::attachInterface(*ctx); PrintOp::attachInterface(*ctx); + CallbackOp::attachInterface(*ctx); + CallbackCallOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Catalyst/BufferizationTest.mlir b/mlir/test/Catalyst/BufferizationTest.mlir index bcff9bff98..e3f8751532 100644 --- a/mlir/test/Catalyst/BufferizationTest.mlir +++ b/mlir/test/Catalyst/BufferizationTest.mlir @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// RUN: quantum-opt --catalyst-bufferize --split-input-file %s | FileCheck %s +// RUN: quantum-opt --one-shot-bufferize --split-input-file %s | FileCheck %s ////////////////////// // Catalyst PrintOp // @@ -20,7 +20,8 @@ func.func @dbprint_val(%arg0: tensor) { - // CHECK: "catalyst.print"(%0) : (memref) -> () + // CHECK: %0 = bufferization.to_memref %arg0 + // CHECK: "catalyst.print"(%0) : (memref>) -> () "catalyst.print"(%arg0) : (tensor) -> () return @@ -30,7 +31,8 @@ func.func @dbprint_val(%arg0: tensor) { func.func @dbprint_memref(%arg0: tensor) { - // CHECK: "catalyst.print"(%0) <{print_descriptor}> : (memref) -> () + // CHECK: %0 = bufferization.to_memref %arg0 + // CHECK: "catalyst.print"(%0) <{print_descriptor}> : (memref>) -> () "catalyst.print"(%arg0) {print_descriptor} : (tensor) -> () return @@ -49,11 +51,12 @@ func.func @dbprint_str() { // ----- func.func @custom_call(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { - // CHECK: [[memrefArg:%.+]] = bufferization.to_memref %arg0 : memref<3x3xf64> - // CHECK: [[alloc:%.+]] = bufferization.alloc_tensor() {{.*}}: tensor<3x3xf64> - // CHECK: [[allocmemref:%.+]] = bufferization.to_memref [[alloc]] : memref<3x3xf64> - // CHECK: catalyst.custom_call fn("lapack_dgesdd") ([[memrefArg]], [[allocmemref]]) {number_original_arg = array} : (memref<3x3xf64>, memref<3x3xf64>) -> () - // CHECK: [[res:%.+]] = bufferization.to_tensor [[allocmemref]] : memref<3x3xf64> + // CHECK: [[memrefArg:%.+]] = bufferization.to_memref %arg0 : memref<3x3xf64, strided<[?, ?], offset: ?>> + // CHECK: [[sourceAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64> + // CHECK: memref.copy [[memrefArg]], [[sourceAlloc]] + // CHECK: [[destAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64> + // CHECK: catalyst.custom_call fn("lapack_dgesdd") ([[sourceAlloc]], [[destAlloc]]) {number_original_arg = array} : (memref<3x3xf64>, memref<3x3xf64>) -> () + // CHECK: [[res:%.+]] = bufferization.to_tensor [[destAlloc]] : memref<3x3xf64> // CHECK: return [[res]] : tensor<3x3xf64> %0 = catalyst.custom_call fn("lapack_dgesdd") (%arg0) : (tensor<3x3xf64>) -> (tensor<3x3xf64>) @@ -77,12 +80,11 @@ module @test1 { // CHECK-LABEL: @foo( // CHECK-SAME: [[arg0:%.+]]: tensor) func.func private @foo(%arg0: tensor) -> tensor { - // CHECK-DAG: [[memref0:%.+]] = bufferization.to_memref [[arg0]] - // CHECK-DAG: [[tensor1:%.+]] = bufferization.alloc_tensor - // CHECK: [[memref1:%.+]] = bufferization.to_memref [[tensor1]] - // CHECK: catalyst.callback_call @callback_1([[memref0]], [[memref1]]) + // CHECK-DAG: [[memref0:%.+]] = bufferization.to_memref [[arg0]] : memref> + // CHECK-DAG: [[resAlloc:%.+]] = memref.alloc() {{.*}}: memref + // CHECK: catalyst.callback_call @callback_1([[memref0]], [[resAlloc]]) : (memref>, memref) -> () %1 = catalyst.callback_call @callback_1(%arg0) : (tensor) -> (tensor) - // CHECK: [[retval:%.+]] = bufferization.to_tensor [[memref1]] + // CHECK: [[retval:%.+]] = bufferization.to_tensor [[resAlloc]] // CHECK: return [[retval]] return %1 : tensor } From 8f593461f3c1e9e4ef49a1e2691d621ffff87a6a Mon Sep 17 00:00:00 2001 From: paul0403 Date: Tue, 6 May 2025 16:20:47 -0400 Subject: [PATCH 10/29] update cpp pipeline --- mlir/lib/Driver/Pipelines.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index 46a3079c69..1c85030b76 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -77,14 +77,17 @@ void createBufferizationPipeline(OpPassManager &pm) pm.addPass(mlir::bufferization::createEmptyTensorToAllocTensorPass()); pm.addNestedPass(mlir::bufferization::createBufferizationBufferizePass()); pm.addNestedPass(mlir::tensor::createTensorBufferizePass()); - pm.addPass(catalyst::createCatalystBufferizationPass()); + //pm.addPass(catalyst::createCatalystBufferizationPass()); + mlir::bufferization::OneShotBufferizationOptions catalyst_buffer_options; + catalyst_buffer_options.opFilter.allowDialect(); + pm.addPass(mlir::bufferization::createOneShotBufferizePass(catalyst_buffer_options)); pm.addNestedPass(mlir::createLinalgBufferizePass()); pm.addNestedPass(mlir::tensor::createTensorBufferizePass()); mlir::bufferization::OneShotBufferizationOptions quantum_buffer_options; quantum_buffer_options.opFilter.allowDialect(); pm.addPass(mlir::bufferization::createOneShotBufferizePass(quantum_buffer_options)); pm.addPass(mlir::func::createFuncBufferizePass()); - pm.addNestedPass(mlir::bufferization::createFinalizingBufferizePass()); + //pm.addNestedPass(mlir::bufferization::createFinalizingBufferizePass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(catalyst::createGradientPostprocessingPass()); pm.addNestedPass(mlir::bufferization::createBufferHoistingPass()); From e60360591a31b6515eee70daa084b8cc40df750c Mon Sep 17 00:00:00 2001 From: paul0403 Date: Tue, 6 May 2025 16:22:09 -0400 Subject: [PATCH 11/29] missed include --- mlir/lib/Driver/Pipelines.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index 1c85030b76..14a415abc7 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include "Driver/Pipelines.h" +#include "Catalyst/IR/CatalystDialect.h" #include "Catalyst/Transforms/Passes.h" #include "Gradient/Transforms/Passes.h" #include "Mitigation/Transforms/Passes.h" From 69fcac26489290b840e505e37218880f954c6233 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Wed, 7 May 2025 11:00:26 -0400 Subject: [PATCH 12/29] remove old catalyst dialect bufferization pass --- frontend/catalyst/pipelines.py | 4 +- mlir/include/Catalyst/Transforms/Passes.td | 12 -- mlir/include/Catalyst/Transforms/Patterns.h | 3 - .../Transforms/BufferizationPatterns.cpp | 173 ------------------ mlir/lib/Catalyst/Transforms/CMakeLists.txt | 2 - .../Catalyst/Transforms/RegisterAllPasses.cpp | 1 - .../Transforms/catalyst_bufferize.cpp | 72 -------- mlir/lib/Driver/Pipelines.cpp | 2 - 8 files changed, 1 insertion(+), 268 deletions(-) delete mode 100644 mlir/lib/Catalyst/Transforms/BufferizationPatterns.cpp delete mode 100644 mlir/lib/Catalyst/Transforms/catalyst_bufferize.cpp diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index 242c4a36a2..e7bc82b758 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -227,13 +227,11 @@ def get_bufferization_stage(_options: CompileOptions) -> List[str]: "empty-tensor-to-alloc-tensor", "func.func(bufferization-bufferize)", "func.func(tensor-bufferize)", - #"catalyst-bufferize", # Must be run before -- func.func(linalg-bufferize) - "one-shot-bufferize{dialect-filter=catalyst}", # Must be run before -- func.func(linalg-bufferize) + "one-shot-bufferize{dialect-filter=catalyst}", # Must be run before --func.func(linalg-bufferize) "func.func(linalg-bufferize)", "func.func(tensor-bufferize)", "one-shot-bufferize{dialect-filter=quantum}", "func-bufferize", - #"func.func(finalizing-bufferize)", "canonicalize", # Remove dead memrefToTensorOp's "gradient-postprocess", # introduced during gradient-bufferize of callbacks diff --git a/mlir/include/Catalyst/Transforms/Passes.td b/mlir/include/Catalyst/Transforms/Passes.td index 0f50286dac..d22246e3c8 100644 --- a/mlir/include/Catalyst/Transforms/Passes.td +++ b/mlir/include/Catalyst/Transforms/Passes.td @@ -27,18 +27,6 @@ def DetensorizeSCFPass : Pass<"detensorize-scf"> { let constructor = "catalyst::createDetensorizeSCFPass()"; } -def CatalystBufferizationPass : Pass<"catalyst-bufferize"> { - let summary = "Bufferize tensors in catalyst utility ops."; - - let dependentDialects = [ - "bufferization::BufferizationDialect", - "memref::MemRefDialect", - "index::IndexDialect" - ]; - - let constructor = "catalyst::createCatalystBufferizationPass()"; -} - def ArrayListToMemRefPass : Pass<"convert-arraylist-to-memref"> { let summary = "Lower array list operations to memref operations."; let description = [{ diff --git a/mlir/include/Catalyst/Transforms/Patterns.h b/mlir/include/Catalyst/Transforms/Patterns.h index 7531929ba7..6bbf3150ff 100644 --- a/mlir/include/Catalyst/Transforms/Patterns.h +++ b/mlir/include/Catalyst/Transforms/Patterns.h @@ -21,9 +21,6 @@ namespace catalyst { -// TODO: remove buffer pattern -void populateBufferizationPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &); - void populateScatterPatterns(mlir::RewritePatternSet &); void populateHloCustomCallPatterns(mlir::RewritePatternSet &); diff --git a/mlir/lib/Catalyst/Transforms/BufferizationPatterns.cpp b/mlir/lib/Catalyst/Transforms/BufferizationPatterns.cpp deleted file mode 100644 index 1fd2a436ca..0000000000 --- a/mlir/lib/Catalyst/Transforms/BufferizationPatterns.cpp +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright 2023 Xanadu Quantum Technologies Inc. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "Catalyst/IR/CatalystDialect.h" -#include "Catalyst/IR/CatalystOps.h" - -using namespace mlir; -using namespace catalyst; - -namespace { - -struct BufferizePrintOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(PrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - if (op.getVal()) { - rewriter.replaceOpWithNewOp(op, adaptor.getVal(), adaptor.getConstValAttr(), - adaptor.getPrintDescriptorAttr()); - } - return success(); - } -}; - -struct BufferizeCustomCallOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(CustomCallOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - // Add bufferized arguments - SmallVector bufferArgs; - ValueRange operands = adaptor.getOperands(); - for (Value operand : operands) { - bufferArgs.push_back(operand); - } - - // Add bufferized return values to the arguments - ValueRange results = op.getResults(); - for (Value result : results) { - Type resultType = result.getType(); - RankedTensorType tensorType = dyn_cast(resultType); - if (!tensorType) { - return failure(); - } - auto options = bufferization::BufferizationOptions(); - FailureOr tensorAlloc = bufferization::allocateTensorForShapedValue( - rewriter, op->getLoc(), result, options, false); - MemRefType memrefType = - MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - auto newBuffer = - rewriter.create(op->getLoc(), memrefType, *tensorAlloc); - bufferArgs.push_back(newBuffer); - } - // Add the initial number of arguments - int32_t numArguments = static_cast(op.getNumOperands()); - DenseI32ArrayAttr numArgumentsDenseAttr = rewriter.getDenseI32ArrayAttr({numArguments}); - - // Create an updated custom call operation - rewriter.create(op->getLoc(), TypeRange{}, bufferArgs, op.getCallTargetName(), - numArgumentsDenseAttr); - size_t startIndex = bufferArgs.size() - op.getNumResults(); - SmallVector bufferResults(bufferArgs.begin() + startIndex, bufferArgs.end()); - rewriter.replaceOp(op, bufferResults); - return success(); - } -}; - -struct BufferizeCallbackOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult match(CallbackOp op) const override - { - // Only match here if we have all memref arguments and return values. - if (llvm::any_of(op.getArgumentTypes(), - [](Type argType) { return !isa(argType); })) { - return failure(); - } - if (llvm::any_of(op.getResultTypes(), - [](Type argType) { return !isa(argType); })) { - return failure(); - } - - // Only match if we have result types. - return op.getResultTypes().empty() ? failure() : success(); - } - - void rewrite(CallbackOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - auto argTys = op.getArgumentTypes(); - auto retTys = op.getResultTypes(); - SmallVector emptyRets; - SmallVector args(argTys.begin(), argTys.end()); - args.insert(args.end(), retTys.begin(), retTys.end()); - auto callbackTy = rewriter.getFunctionType(args, emptyRets); - rewriter.modifyOpInPlace(op, [&] { op.setFunctionType(callbackTy); }); - } -}; - -struct BufferizeCallbackCallOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(CallbackCallOp callOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - SmallVector convertedResults; - if (failed(typeConverter->convertTypes(callOp.getResultTypes(), convertedResults))) - return failure(); - - if (callOp->getNumResults() != convertedResults.size()) - return failure(); - - auto operands = adaptor.getOperands(); - SmallVector newInputs(operands.begin(), operands.end()); - auto results = callOp.getResults(); - - auto loc = callOp->getLoc(); - auto options = bufferization::BufferizationOptions(); - SmallVector outmemrefs; - for (auto result : results) { - FailureOr tensorAlloc = - bufferization::allocateTensorForShapedValue(rewriter, loc, result, options, false); - if (failed(tensorAlloc)) - return failure(); - - auto tensor = *tensorAlloc; - RankedTensorType tensorTy = cast(tensor.getType()); - auto shape = tensorTy.getShape(); - auto elementTy = tensorTy.getElementType(); - auto memrefType = MemRefType::get(shape, elementTy); - auto toMemrefOp = rewriter.create(loc, memrefType, tensor); - auto memref = toMemrefOp.getResult(); - outmemrefs.push_back(memref); - newInputs.push_back(memref); - } - - SmallVector emptyRets; - rewriter.create(loc, emptyRets, callOp.getCallee(), newInputs); - rewriter.replaceOp(callOp, outmemrefs); - return success(); - } -}; - -} // namespace - -namespace catalyst { - -void populateBufferizationPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) -{ - patterns.add(typeConverter, patterns.getContext()); - patterns.add(typeConverter, patterns.getContext()); - patterns.add(typeConverter, patterns.getContext()); - patterns.add(typeConverter, patterns.getContext()); -} - -} // namespace catalyst diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index 1cbe4b64c5..48d4aa362b 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -5,8 +5,6 @@ file(GLOB SRC ArrayListToMemRefPass.cpp AsyncUtils.cpp BufferizableOpInterfaceImpl.cpp - BufferizationPatterns.cpp // remove - catalyst_bufferize.cpp // remove catalyst_to_llvm.cpp DetectQNodes.cpp DetensorizeSCFPass.cpp diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index cc1379a6d1..3ca100ceee 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -29,7 +29,6 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createAnnotateFunctionPass); mlir::registerPass(catalyst::createApplyTransformSequencePass); mlir::registerPass(catalyst::createArrayListToMemRefPass); - mlir::registerPass(catalyst::createCatalystBufferizationPass); mlir::registerPass(catalyst::createCatalystConversionPass); mlir::registerPass(catalyst::createCopyGlobalMemRefPass); mlir::registerPass(catalyst::createDetensorizeSCFPass); diff --git a/mlir/lib/Catalyst/Transforms/catalyst_bufferize.cpp b/mlir/lib/Catalyst/Transforms/catalyst_bufferize.cpp deleted file mode 100644 index 8363f4d39b..0000000000 --- a/mlir/lib/Catalyst/Transforms/catalyst_bufferize.cpp +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2023 Xanadu Quantum Technologies Inc. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "Catalyst/IR/CatalystOps.h" -#include "Catalyst/Transforms/Passes.h" -#include "Catalyst/Transforms/Patterns.h" - -using namespace mlir; -using namespace catalyst; - -namespace catalyst { - -#define GEN_PASS_DEF_CATALYSTBUFFERIZATIONPASS -#include "Catalyst/Transforms/Passes.h.inc" - -struct CatalystBufferizationPass : impl::CatalystBufferizationPassBase { - using CatalystBufferizationPassBase::CatalystBufferizationPassBase; - - void runOnOperation() final - { - MLIRContext *context = &getContext(); - bufferization::BufferizeTypeConverter typeConverter; - - RewritePatternSet patterns(context); - populateBufferizationPatterns(typeConverter, patterns); - populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); - - ConversionTarget target(*context); - bufferization::populateBufferizeMaterializationLegality(target); - target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - target.addDynamicallyLegalOp( - [&](PrintOp op) { return typeConverter.isLegal(op); }); - target.addDynamicallyLegalOp( - [&](CustomCallOp op) { return typeConverter.isLegal(op); }); - target.addDynamicallyLegalOp([&](CallbackOp op) { - return typeConverter.isSignatureLegal(op.getFunctionType()) && - typeConverter.isLegal(&op.getBody()) && op.getResultTypes().empty(); - }); - target.addDynamicallyLegalOp( - [&](CallbackCallOp op) { return typeConverter.isLegal(op); }); - - if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { - signalPassFailure(); - } - } -}; - -std::unique_ptr createCatalystBufferizationPass() -{ - return std::make_unique(); -} - -} // namespace catalyst diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index 14a415abc7..3b5b5cc6da 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -78,7 +78,6 @@ void createBufferizationPipeline(OpPassManager &pm) pm.addPass(mlir::bufferization::createEmptyTensorToAllocTensorPass()); pm.addNestedPass(mlir::bufferization::createBufferizationBufferizePass()); pm.addNestedPass(mlir::tensor::createTensorBufferizePass()); - //pm.addPass(catalyst::createCatalystBufferizationPass()); mlir::bufferization::OneShotBufferizationOptions catalyst_buffer_options; catalyst_buffer_options.opFilter.allowDialect(); pm.addPass(mlir::bufferization::createOneShotBufferizePass(catalyst_buffer_options)); @@ -88,7 +87,6 @@ void createBufferizationPipeline(OpPassManager &pm) quantum_buffer_options.opFilter.allowDialect(); pm.addPass(mlir::bufferization::createOneShotBufferizePass(quantum_buffer_options)); pm.addPass(mlir::func::createFuncBufferizePass()); - //pm.addNestedPass(mlir::bufferization::createFinalizingBufferizePass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(catalyst::createGradientPostprocessingPass()); pm.addNestedPass(mlir::bufferization::createBufferHoistingPass()); From 13037be4a1d0b5ef5591bfa99731ae5f2972909b Mon Sep 17 00:00:00 2001 From: paul0403 Date: Wed, 7 May 2025 11:15:35 -0400 Subject: [PATCH 13/29] format --- mlir/lib/Catalyst/IR/CatalystDialect.cpp | 3 +- .../BufferizableOpInterfaceImpl.cpp | 58 +++++++++++-------- 2 files changed, 36 insertions(+), 25 deletions(-) diff --git a/mlir/lib/Catalyst/IR/CatalystDialect.cpp b/mlir/lib/Catalyst/IR/CatalystDialect.cpp index a086096cb6..aae67c64ca 100644 --- a/mlir/lib/Catalyst/IR/CatalystDialect.cpp +++ b/mlir/lib/Catalyst/IR/CatalystDialect.cpp @@ -42,7 +42,8 @@ void CatalystDialect::initialize() #include "Catalyst/IR/CatalystOps.cpp.inc" >(); - declarePromisedInterfaces(); + declarePromisedInterfaces(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp index 44871ecb97..21af8bdac1 100644 --- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -29,34 +29,37 @@ namespace { /// Bufferization of catalyst.print. Get memref of printOp.val. struct PrintOpInterface - : public bufferization::BufferizableOpInterface::ExternalModel { + : public bufferization::BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const bufferization::AnalysisState &state) const { + const bufferization::AnalysisState &state) const + { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const bufferization::AnalysisState &state) const { + const bufferization::AnalysisState &state) const + { return false; } - bufferization::AliasingValueList getAliasingValues(Operation *op, - OpOperand &opOperand, - const bufferization::AnalysisState &state) const { + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { return {}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const { + const bufferization::BufferizationOptions &options) const + { auto printOp = cast(op); if (printOp.getVal()) { FailureOr source = getBuffer(rewriter, printOp.getVal(), options); if (failed(source)) { return failure(); } - bufferization::replaceOpWithNewBufferizedOp(rewriter, op, *source, - printOp.getConstValAttr(), printOp.getPrintDescriptorAttr()); + bufferization::replaceOpWithNewBufferizedOp( + rewriter, op, *source, printOp.getConstValAttr(), printOp.getPrintDescriptorAttr()); } return success(); } @@ -65,25 +68,29 @@ struct PrintOpInterface /// Bufferization of catalyst.custom_call. Mainly get buffers for arguments. struct CustomCallOpInterface : public bufferization::BufferizableOpInterface::ExternalModel { + CustomCallOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const bufferization::AnalysisState &state) const { + const bufferization::AnalysisState &state) const + { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const bufferization::AnalysisState &state) const { + const bufferization::AnalysisState &state) const + { return true; } - bufferization::AliasingValueList getAliasingValues(Operation *op, - OpOperand &opOperand, - const bufferization::AnalysisState &state) const { + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { return {}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const { + const bufferization::BufferizationOptions &options) const + { auto customCallOp = cast(op); // Add bufferized arguments @@ -115,7 +122,7 @@ struct CustomCallOpInterface bufferArgs.push_back(newBuffer); } - // Add the initial number of arguments + // Add the initial number of arguments int32_t numArguments = static_cast(customCallOp.getNumOperands()); DenseI32ArrayAttr numArgumentsDenseAttr = rewriter.getDenseI32ArrayAttr({numArguments}); @@ -182,12 +189,15 @@ struct CallbackOpInterface } }; -void convertTypes(SmallVector inTypes, SmallVector &convertedResults){ +void convertTypes(SmallVector inTypes, SmallVector &convertedResults) +{ // See https://github.com/llvm/llvm-project/pull/114155/files for (Type inType : inTypes) { - if (isa(inType)){ - convertedResults.push_back(bufferization::getMemRefTypeWithStaticIdentityLayout(cast(inType))); - } else { + if (isa(inType)) { + convertedResults.push_back( + bufferization::getMemRefTypeWithStaticIdentityLayout(cast(inType))); + } + else { convertedResults.push_back(inType); } } @@ -274,8 +284,8 @@ struct CallbackCallOpInterface } // namespace -void catalyst::registerBufferizableOpInterfaceExternalModels( - DialectRegistry ®istry) { +void catalyst::registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) +{ registry.addExtension(+[](MLIRContext *ctx, CatalystDialect *dialect) { CustomCallOp::attachInterface(*ctx); PrintOp::attachInterface(*ctx); From 66f4ab047469e492c8ec2b7f357afe7d33b4a033 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Wed, 7 May 2025 13:08:04 -0400 Subject: [PATCH 14/29] codefactor --- frontend/catalyst/pipelines.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index e7bc82b758..9b4e5900ea 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -97,7 +97,7 @@ class CompileOptions: def __post_init__(self): # Check that async runs must not be seeded - if self.async_qnodes and self.seed != None: + if self.async_qnodes and self.seed is not None: raise CompileError( """ Seeding has no effect on asynchronous QNodes, @@ -107,7 +107,7 @@ def __post_init__(self): ) # Check that seed is 32-bit unsigned int - if (self.seed != None) and (self.seed < 0 or self.seed > 2**32 - 1): + if (self.seed is not None) and (self.seed < 0 or self.seed > 2**32 - 1): raise ValueError( """ Seed must be an unsigned 32-bit integer! @@ -227,7 +227,8 @@ def get_bufferization_stage(_options: CompileOptions) -> List[str]: "empty-tensor-to-alloc-tensor", "func.func(bufferization-bufferize)", "func.func(tensor-bufferize)", - "one-shot-bufferize{dialect-filter=catalyst}", # Must be run before --func.func(linalg-bufferize) + # Catalyst dialect's bufferization must be run before --func.func(linalg-bufferize) + "one-shot-bufferize{dialect-filter=catalyst}", "func.func(linalg-bufferize)", "func.func(tensor-bufferize)", "one-shot-bufferize{dialect-filter=quantum}", From 1a3ff7dbd6bb83fe159956548d30917a4e69cbbc Mon Sep 17 00:00:00 2001 From: paul0403 Date: Wed, 7 May 2025 13:47:10 -0400 Subject: [PATCH 15/29] custom call also allocates --- mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp index 21af8bdac1..a5e8ce0002 100644 --- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -69,6 +69,8 @@ struct PrintOpInterface struct CustomCallOpInterface : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToAllocation(Operation *op, Value value) const { return true; } + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const bufferization::AnalysisState &state) const { From 045b15b3c4295ba2ac50bbf6516adf809de1e0a9 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Wed, 7 May 2025 15:50:07 -0400 Subject: [PATCH 16/29] do not hint memory write for custom op when not in memref land --- .../Transforms/BufferizableOpInterfaceImpl.cpp | 17 ++++++++++++++++- mlir/test/Catalyst/BufferizationTest.mlir | 5 ++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp index a5e8ce0002..a1bba29843 100644 --- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -74,13 +74,28 @@ struct CustomCallOpInterface bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const bufferization::AnalysisState &state) const { + // Custom Call Op always reads the operand memory no matter what. return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const bufferization::AnalysisState &state) const { - return true; + // Custom Call Op will only write to memory after bufferization is complete, + // and the op is already talking to memrefs. + // i.e. there will be no memory write when the op is still in tensor land. + // This assumption is ok, since bufferization is run after all the tensor + // abstract transformations are complete. + // See https://mlir.llvm.org/docs/Bufferization/#overview: + // * These [bufferization] passes typically run as one of the last steps in a pass + // * pipeline, right before lowering to memref ops to LLVM. That is because many + // * transformations are easier or only supported in tensor land; e.g., tile/fuse/… on + // * tensors first, then bufferize the remaining IR. + + if (isa(opOperand.get().getType())) { + return true; + } + return false; } bufferization::AliasingValueList diff --git a/mlir/test/Catalyst/BufferizationTest.mlir b/mlir/test/Catalyst/BufferizationTest.mlir index e3f8751532..efc8bdfdae 100644 --- a/mlir/test/Catalyst/BufferizationTest.mlir +++ b/mlir/test/Catalyst/BufferizationTest.mlir @@ -52,10 +52,9 @@ func.func @dbprint_str() { func.func @custom_call(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { // CHECK: [[memrefArg:%.+]] = bufferization.to_memref %arg0 : memref<3x3xf64, strided<[?, ?], offset: ?>> - // CHECK: [[sourceAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64> - // CHECK: memref.copy [[memrefArg]], [[sourceAlloc]] // CHECK: [[destAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64> - // CHECK: catalyst.custom_call fn("lapack_dgesdd") ([[sourceAlloc]], [[destAlloc]]) {number_original_arg = array} : (memref<3x3xf64>, memref<3x3xf64>) -> () + // CHECK: catalyst.custom_call fn("lapack_dgesdd") ([[memrefArg]], [[destAlloc]]) {number_original_arg = array} : + // CHECK-SAME: (memref<3x3xf64, strided<[?, ?], offset: ?>>, memref<3x3xf64>) -> () // CHECK: [[res:%.+]] = bufferization.to_tensor [[destAlloc]] : memref<3x3xf64> // CHECK: return [[res]] : tensor<3x3xf64> %0 = catalyst.custom_call fn("lapack_dgesdd") (%arg0) : (tensor<3x3xf64>) -> (tensor<3x3xf64>) From fdb16e04427a9a26f948e324225f1eb61fe14e45 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 8 May 2025 14:21:56 -0400 Subject: [PATCH 17/29] lapack kernels might write into source array thus memory write must be true for custom call --- .../BufferizableOpInterfaceImpl.cpp | 21 ++++++------------- mlir/test/Catalyst/BufferizationTest.mlir | 6 ++++-- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp index a1bba29843..94baa8b911 100644 --- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -81,21 +81,12 @@ struct CustomCallOpInterface bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const bufferization::AnalysisState &state) const { - // Custom Call Op will only write to memory after bufferization is complete, - // and the op is already talking to memrefs. - // i.e. there will be no memory write when the op is still in tensor land. - // This assumption is ok, since bufferization is run after all the tensor - // abstract transformations are complete. - // See https://mlir.llvm.org/docs/Bufferization/#overview: - // * These [bufferization] passes typically run as one of the last steps in a pass - // * pipeline, right before lowering to memref ops to LLVM. That is because many - // * transformations are easier or only supported in tensor land; e.g., tile/fuse/… on - // * tensors first, then bufferize the remaining IR. - - if (isa(opOperand.get().getType())) { - return true; - } - return false; + // We use custom call op to call the lapack kernels. + // These kernels might write to the source array. + // https://www.netlib.org/lapack/lug/node112.html + // * array or scalar arguments defining the input data; + // * some of them may be overwritten by results; + return true; } bufferization::AliasingValueList diff --git a/mlir/test/Catalyst/BufferizationTest.mlir b/mlir/test/Catalyst/BufferizationTest.mlir index efc8bdfdae..616cc202bd 100644 --- a/mlir/test/Catalyst/BufferizationTest.mlir +++ b/mlir/test/Catalyst/BufferizationTest.mlir @@ -52,9 +52,11 @@ func.func @dbprint_str() { func.func @custom_call(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { // CHECK: [[memrefArg:%.+]] = bufferization.to_memref %arg0 : memref<3x3xf64, strided<[?, ?], offset: ?>> + // CHECK: [[sourceAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64> + // CHECK: memref.copy [[memrefArg]], [[sourceAlloc]] // CHECK: [[destAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64> - // CHECK: catalyst.custom_call fn("lapack_dgesdd") ([[memrefArg]], [[destAlloc]]) {number_original_arg = array} : - // CHECK-SAME: (memref<3x3xf64, strided<[?, ?], offset: ?>>, memref<3x3xf64>) -> () + // CHECK: catalyst.custom_call fn("lapack_dgesdd") ([[sourceAlloc]], [[destAlloc]]) {number_original_arg = array} : + // CHECK-SAME: (memref<3x3xf64>, memref<3x3xf64>) -> () // CHECK: [[res:%.+]] = bufferization.to_tensor [[destAlloc]] : memref<3x3xf64> // CHECK: return [[res]] : tensor<3x3xf64> %0 = catalyst.custom_call fn("lapack_dgesdd") (%arg0) : (tensor<3x3xf64>) -> (tensor<3x3xf64>) From 38bf1c425a1a6f13689f671c3738fb9c14f7aa2e Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 8 May 2025 15:12:54 -0400 Subject: [PATCH 18/29] add bufferization interface doc banner --- .../Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp index 94baa8b911..706b91b796 100644 --- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -25,6 +25,13 @@ using namespace mlir; using namespace mlir::bufferization; using namespace catalyst; +/** + * Implementation of the BufferizableOpInterface for use with one-shot bufferization. + * For more information on the interface, refer to the documentation below: + * https://mlir.llvm.org/docs/Bufferization/#extending-one-shot-bufferize + * https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td#L14 + */ + namespace { /// Bufferization of catalyst.print. Get memref of printOp.val. From fe4e944ff42c4b551ed25a289b80c53d067a9815 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 8 May 2025 15:17:28 -0400 Subject: [PATCH 19/29] add {} to a one-line if block --- mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp index 706b91b796..9c23f03ddd 100644 --- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -163,8 +163,9 @@ struct CallbackOpInterface auto callbackOp = cast(op); bool hasTensorArg = any_of(callbackOp.getArgumentTypes(), isaTensor); bool hasTensorResult = any_of(callbackOp.getResultTypes(), isaTensor); - if (hasTensorArg || hasTensorResult) + if (hasTensorArg || hasTensorResult) { return true; + } return false; } From 3c6dfcdba80d173d8482eb0ab3a41d7075df3cdc Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 8 May 2025 15:21:22 -0400 Subject: [PATCH 20/29] remove aliasing operand method from callback op: it does not have tensor operands --- .../Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp index 9c23f03ddd..50518a58aa 100644 --- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -170,13 +170,6 @@ struct CallbackOpInterface return false; } - bufferization::AliasingOpOperandList - getAliasingOpOperands(Operation *op, Value value, - const bufferization::AnalysisState &state) const - { - return {}; - } - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const bufferization::BufferizationOptions &options) const { From b39f28afce3bbc225c081d3156383674db3586df Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 9 May 2025 07:55:41 -0400 Subject: [PATCH 21/29] (prototype) make a white list of custom calls that won't copy --- .../BufferizableOpInterfaceImpl.cpp | 4 ++++ mlir/test/Catalyst/BufferizationTest.mlir | 20 ++++++++++++++++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp index 50518a58aa..cf5c8dccfc 100644 --- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -93,6 +93,10 @@ struct CustomCallOpInterface // https://www.netlib.org/lapack/lug/node112.html // * array or scalar arguments defining the input data; // * some of them may be overwritten by results; + if (cast(op).getCallTargetName().str() == "lapack_dgesdd") { + // white list lapack kernels that are known to not write into input tensors + return false; + } return true; } diff --git a/mlir/test/Catalyst/BufferizationTest.mlir b/mlir/test/Catalyst/BufferizationTest.mlir index 616cc202bd..dafe251e8f 100644 --- a/mlir/test/Catalyst/BufferizationTest.mlir +++ b/mlir/test/Catalyst/BufferizationTest.mlir @@ -50,16 +50,30 @@ func.func @dbprint_str() { // ----- -func.func @custom_call(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { +func.func @custom_call_no_write(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { + // CHECK: [[sourceAlloc:%.+]] = bufferization.to_memref %arg0 : memref<3x3xf64, strided<[?, ?], offset: ?>> + // CHECK: [[destAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64> + // CHECK: catalyst.custom_call fn("lapack_dgesdd") ([[sourceAlloc]], [[destAlloc]]) {number_original_arg = array} : + // CHECK-SAME: (memref<3x3xf64, strided<[?, ?], offset: ?>>, memref<3x3xf64>) -> () + // CHECK: [[res:%.+]] = bufferization.to_tensor [[destAlloc]] : memref<3x3xf64> + // CHECK: return [[res]] : tensor<3x3xf64> + %0 = catalyst.custom_call fn("lapack_dgesdd") (%arg0) : (tensor<3x3xf64>) -> (tensor<3x3xf64>) + + return %0 : tensor<3x3xf64> +} + +// ----- + +func.func @custom_call_with_write(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { // CHECK: [[memrefArg:%.+]] = bufferization.to_memref %arg0 : memref<3x3xf64, strided<[?, ?], offset: ?>> // CHECK: [[sourceAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64> // CHECK: memref.copy [[memrefArg]], [[sourceAlloc]] // CHECK: [[destAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64> - // CHECK: catalyst.custom_call fn("lapack_dgesdd") ([[sourceAlloc]], [[destAlloc]]) {number_original_arg = array} : + // CHECK: catalyst.custom_call fn("write_to_arg") ([[sourceAlloc]], [[destAlloc]]) {number_original_arg = array} : // CHECK-SAME: (memref<3x3xf64>, memref<3x3xf64>) -> () // CHECK: [[res:%.+]] = bufferization.to_tensor [[destAlloc]] : memref<3x3xf64> // CHECK: return [[res]] : tensor<3x3xf64> - %0 = catalyst.custom_call fn("lapack_dgesdd") (%arg0) : (tensor<3x3xf64>) -> (tensor<3x3xf64>) + %0 = catalyst.custom_call fn("write_to_arg") (%arg0) : (tensor<3x3xf64>) -> (tensor<3x3xf64>) return %0 : tensor<3x3xf64> } From 3289d1966ae02d943f51f4137d273a3b02a75d76 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 9 May 2025 14:16:18 -0400 Subject: [PATCH 22/29] Set identity layout map option. This avoids the strides. --- frontend/catalyst/pipelines.py | 2 +- mlir/lib/Driver/Pipelines.cpp | 7 +++++++ mlir/test/Catalyst/BufferizationTest.mlir | 19 +++++++++++-------- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index 9b4e5900ea..35a77cf28c 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -228,7 +228,7 @@ def get_bufferization_stage(_options: CompileOptions) -> List[str]: "func.func(bufferization-bufferize)", "func.func(tensor-bufferize)", # Catalyst dialect's bufferization must be run before --func.func(linalg-bufferize) - "one-shot-bufferize{dialect-filter=catalyst}", + "one-shot-bufferize{dialect-filter=catalyst unknown-type-conversion=identity-layout-map}", "func.func(linalg-bufferize)", "func.func(tensor-bufferize)", "one-shot-bufferize{dialect-filter=quantum}", diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index 3b5b5cc6da..e9c1df73e5 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -20,6 +20,7 @@ #include "Quantum/IR/QuantumDialect.h" #include "Quantum/Transforms/Passes.h" #include "mhlo/transforms/passes.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Pass/PassManager.h" @@ -80,6 +81,12 @@ void createBufferizationPipeline(OpPassManager &pm) pm.addNestedPass(mlir::tensor::createTensorBufferizePass()); mlir::bufferization::OneShotBufferizationOptions catalyst_buffer_options; catalyst_buffer_options.opFilter.allowDialect(); + catalyst_buffer_options.unknownTypeConverterFn = + [=](Value value, Attribute memorySpace, + const mlir::bufferization::BufferizationOptions &options) { + auto tensorType = cast(value.getType()); + return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace); + }; pm.addPass(mlir::bufferization::createOneShotBufferizePass(catalyst_buffer_options)); pm.addNestedPass(mlir::createLinalgBufferizePass()); pm.addNestedPass(mlir::tensor::createTensorBufferizePass()); diff --git a/mlir/test/Catalyst/BufferizationTest.mlir b/mlir/test/Catalyst/BufferizationTest.mlir index dafe251e8f..a5ef009f06 100644 --- a/mlir/test/Catalyst/BufferizationTest.mlir +++ b/mlir/test/Catalyst/BufferizationTest.mlir @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -// RUN: quantum-opt --one-shot-bufferize --split-input-file %s | FileCheck %s +// RUN: quantum-opt --split-input-file \ +// RUN: --pass-pipeline="builtin.module( \ +// RUN: one-shot-bufferize{unknown-type-conversion=identity-layout-map} \ +// RUN: )" %s | FileCheck %s ////////////////////// // Catalyst PrintOp // @@ -21,7 +24,7 @@ func.func @dbprint_val(%arg0: tensor) { // CHECK: %0 = bufferization.to_memref %arg0 - // CHECK: "catalyst.print"(%0) : (memref>) -> () + // CHECK: "catalyst.print"(%0) : (memref) -> () "catalyst.print"(%arg0) : (tensor) -> () return @@ -32,7 +35,7 @@ func.func @dbprint_val(%arg0: tensor) { func.func @dbprint_memref(%arg0: tensor) { // CHECK: %0 = bufferization.to_memref %arg0 - // CHECK: "catalyst.print"(%0) <{print_descriptor}> : (memref>) -> () + // CHECK: "catalyst.print"(%0) <{print_descriptor}> : (memref) -> () "catalyst.print"(%arg0) {print_descriptor} : (tensor) -> () return @@ -51,10 +54,10 @@ func.func @dbprint_str() { // ----- func.func @custom_call_no_write(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { - // CHECK: [[sourceAlloc:%.+]] = bufferization.to_memref %arg0 : memref<3x3xf64, strided<[?, ?], offset: ?>> + // CHECK: [[sourceAlloc:%.+]] = bufferization.to_memref %arg0 : memref<3x3xf64> // CHECK: [[destAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64> // CHECK: catalyst.custom_call fn("lapack_dgesdd") ([[sourceAlloc]], [[destAlloc]]) {number_original_arg = array} : - // CHECK-SAME: (memref<3x3xf64, strided<[?, ?], offset: ?>>, memref<3x3xf64>) -> () + // CHECK-SAME: (memref<3x3xf64>, memref<3x3xf64>) -> () // CHECK: [[res:%.+]] = bufferization.to_tensor [[destAlloc]] : memref<3x3xf64> // CHECK: return [[res]] : tensor<3x3xf64> %0 = catalyst.custom_call fn("lapack_dgesdd") (%arg0) : (tensor<3x3xf64>) -> (tensor<3x3xf64>) @@ -65,7 +68,7 @@ func.func @custom_call_no_write(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { // ----- func.func @custom_call_with_write(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { - // CHECK: [[memrefArg:%.+]] = bufferization.to_memref %arg0 : memref<3x3xf64, strided<[?, ?], offset: ?>> + // CHECK: [[memrefArg:%.+]] = bufferization.to_memref %arg0 : memref<3x3xf64> // CHECK: [[sourceAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64> // CHECK: memref.copy [[memrefArg]], [[sourceAlloc]] // CHECK: [[destAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64> @@ -95,9 +98,9 @@ module @test1 { // CHECK-LABEL: @foo( // CHECK-SAME: [[arg0:%.+]]: tensor) func.func private @foo(%arg0: tensor) -> tensor { - // CHECK-DAG: [[memref0:%.+]] = bufferization.to_memref [[arg0]] : memref> + // CHECK-DAG: [[memref0:%.+]] = bufferization.to_memref [[arg0]] : memref // CHECK-DAG: [[resAlloc:%.+]] = memref.alloc() {{.*}}: memref - // CHECK: catalyst.callback_call @callback_1([[memref0]], [[resAlloc]]) : (memref>, memref) -> () + // CHECK: catalyst.callback_call @callback_1([[memref0]], [[resAlloc]]) : (memref, memref) -> () %1 = catalyst.callback_call @callback_1(%arg0) : (tensor) -> (tensor) // CHECK: [[retval:%.+]] = bufferization.to_tensor [[resAlloc]] // CHECK: return [[retval]] From 4f48b1bd863859a6ee47ab0bd81e25cc517ac13b Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 9 May 2025 18:00:23 -0400 Subject: [PATCH 23/29] no copy: jax already does the copy around the lapack kernels --- .../BufferizableOpInterfaceImpl.cpp | 12 ++--------- mlir/test/Catalyst/BufferizationTest.mlir | 20 ++----------------- 2 files changed, 4 insertions(+), 28 deletions(-) diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp index cf5c8dccfc..72a6856a35 100644 --- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -88,16 +88,8 @@ struct CustomCallOpInterface bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const bufferization::AnalysisState &state) const { - // We use custom call op to call the lapack kernels. - // These kernels might write to the source array. - // https://www.netlib.org/lapack/lug/node112.html - // * array or scalar arguments defining the input data; - // * some of them may be overwritten by results; - if (cast(op).getCallTargetName().str() == "lapack_dgesdd") { - // white list lapack kernels that are known to not write into input tensors - return false; - } - return true; + // TODO: update explanation about our lapack already doing copy + return false; } bufferization::AliasingValueList diff --git a/mlir/test/Catalyst/BufferizationTest.mlir b/mlir/test/Catalyst/BufferizationTest.mlir index a5ef009f06..8b938f8f89 100644 --- a/mlir/test/Catalyst/BufferizationTest.mlir +++ b/mlir/test/Catalyst/BufferizationTest.mlir @@ -53,8 +53,8 @@ func.func @dbprint_str() { // ----- -func.func @custom_call_no_write(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { - // CHECK: [[sourceAlloc:%.+]] = bufferization.to_memref %arg0 : memref<3x3xf64> +func.func @custom_call_no_user(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { + // CHECK: [[sourceAlloc:%.+]] = bufferization.to_memref %arg0 // CHECK: [[destAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64> // CHECK: catalyst.custom_call fn("lapack_dgesdd") ([[sourceAlloc]], [[destAlloc]]) {number_original_arg = array} : // CHECK-SAME: (memref<3x3xf64>, memref<3x3xf64>) -> () @@ -67,22 +67,6 @@ func.func @custom_call_no_write(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { // ----- -func.func @custom_call_with_write(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { - // CHECK: [[memrefArg:%.+]] = bufferization.to_memref %arg0 : memref<3x3xf64> - // CHECK: [[sourceAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64> - // CHECK: memref.copy [[memrefArg]], [[sourceAlloc]] - // CHECK: [[destAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64> - // CHECK: catalyst.custom_call fn("write_to_arg") ([[sourceAlloc]], [[destAlloc]]) {number_original_arg = array} : - // CHECK-SAME: (memref<3x3xf64>, memref<3x3xf64>) -> () - // CHECK: [[res:%.+]] = bufferization.to_tensor [[destAlloc]] : memref<3x3xf64> - // CHECK: return [[res]] : tensor<3x3xf64> - %0 = catalyst.custom_call fn("write_to_arg") (%arg0) : (tensor<3x3xf64>) -> (tensor<3x3xf64>) - - return %0 : tensor<3x3xf64> -} - -// ----- - // CHECK-LABEL: @test0 module @test0 { // CHECK: catalyst.callback @callback_1(memref, memref) From 7c6759f7877fad025160f456ecc0f15f6f1e796e Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 9 May 2025 18:01:03 -0400 Subject: [PATCH 24/29] name --- mlir/test/Catalyst/BufferizationTest.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Catalyst/BufferizationTest.mlir b/mlir/test/Catalyst/BufferizationTest.mlir index 8b938f8f89..50c691fef9 100644 --- a/mlir/test/Catalyst/BufferizationTest.mlir +++ b/mlir/test/Catalyst/BufferizationTest.mlir @@ -53,7 +53,7 @@ func.func @dbprint_str() { // ----- -func.func @custom_call_no_user(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { +func.func @custom_call(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { // CHECK: [[sourceAlloc:%.+]] = bufferization.to_memref %arg0 // CHECK: [[destAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64> // CHECK: catalyst.custom_call fn("lapack_dgesdd") ([[sourceAlloc]], [[destAlloc]]) {number_original_arg = array} : From c949ce7a251361d307ff86e28efbb38dd27d043e Mon Sep 17 00:00:00 2001 From: paul0403 Date: Mon, 12 May 2025 09:59:18 -0400 Subject: [PATCH 25/29] add comment about jax shim layer's copy --- .../Transforms/BufferizableOpInterfaceImpl.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp index 72a6856a35..d02fb7c787 100644 --- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -88,7 +88,21 @@ struct CustomCallOpInterface bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const bufferization::AnalysisState &state) const { - // TODO: update explanation about our lapack already doing copy + // We only use custom call for the jax lapack kernels. + // This is actually hard-guarded: in the lowering pattern for custom call + // we check that the name of the callee is a jax symbol for a lapack kernel. + // + // The lapack kernels themselves might overwrite some of the input arrays. + // However, in jax's shim wrapper layer, a memcpy is already performed. + // See + // https://github.com/PennyLaneAI/catalyst/blob/main/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels.cpp + // + // The arguments to the underlying lapack kernel are denoted by the jax wrapper + // function as `data`. The `data` args already contain the output array that + // the lapack kernel is supposed to write into. The other input arrays are all marked const. + // Jax then purifies the function by adding a new argument `out` to hold the + // output array. + return false; } From 0b701127bb5ddea5645d6a29d49a19e4841a60e6 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Mon, 12 May 2025 10:26:24 -0400 Subject: [PATCH 26/29] a bit more comment --- mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp index d02fb7c787..e5fb6b9bca 100644 --- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -102,6 +102,9 @@ struct CustomCallOpInterface // the lapack kernel is supposed to write into. The other input arrays are all marked const. // Jax then purifies the function by adding a new argument `out` to hold the // output array. + // + // In other words, the jax wrappers we call here with custom call op + // are already pure, and we won't have side effects on the input tensors. return false; } From b06ec693569f0096fde0f5116d831b0aa2d1178e Mon Sep 17 00:00:00 2001 From: paul0403 Date: Mon, 12 May 2025 10:41:15 -0400 Subject: [PATCH 27/29] add back finalizing bufferize pass this PR should only be swapping out catalyst dialect for one shot bufferization --- frontend/catalyst/pipelines.py | 1 + 1 file changed, 1 insertion(+) diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index 35a77cf28c..996f6a2c28 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -233,6 +233,7 @@ def get_bufferization_stage(_options: CompileOptions) -> List[str]: "func.func(tensor-bufferize)", "one-shot-bufferize{dialect-filter=quantum}", "func-bufferize", + "func.func(finalizing-bufferize)", "canonicalize", # Remove dead memrefToTensorOp's "gradient-postprocess", # introduced during gradient-bufferize of callbacks From d5f4d3d253f7dd8b7f324d4a66347c1d76c61b97 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Mon, 12 May 2025 10:44:39 -0400 Subject: [PATCH 28/29] add back finalizing pass in cpp pipeline --- mlir/lib/Driver/Pipelines.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index e9c1df73e5..e3d7f3ad55 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -94,6 +94,7 @@ void createBufferizationPipeline(OpPassManager &pm) quantum_buffer_options.opFilter.allowDialect(); pm.addPass(mlir::bufferization::createOneShotBufferizePass(quantum_buffer_options)); pm.addPass(mlir::func::createFuncBufferizePass()); + pm.addNestedPass(mlir::bufferization::createFinalizingBufferizePass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(catalyst::createGradientPostprocessingPass()); pm.addNestedPass(mlir::bufferization::createBufferHoistingPass()); From a752ead303f2b121e25bfa5ad3a4be7df158b65f Mon Sep 17 00:00:00 2001 From: paul0403 Date: Mon, 12 May 2025 11:13:14 -0400 Subject: [PATCH 29/29] CI