Skip to content
Merged
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
0d83653
init; changelog
paul0403 May 12, 2025
fb9c6db
boilerplate
paul0403 May 12, 2025
62a3f5a
missing namespace
paul0403 May 12, 2025
81e6e97
missing include
paul0403 May 12, 2025
6242a0b
typo
paul0403 May 12, 2025
7c149d0
(cherry pick) Add ForwardOp Bufferization
tzunghanjuang Sep 6, 2024
5504e29
move over Tzunghan's old work
paul0403 May 13, 2025
fcc0424
adjoint op
paul0403 May 13, 2025
a310d05
backprop op
paul0403 May 13, 2025
2e26140
add adjoint test with multiple results
paul0403 May 14, 2025
b7f62ae
some cleanup on backprop
paul0403 May 14, 2025
d113b4a
clean up backprop
paul0403 May 14, 2025
6282518
test file rename
paul0403 May 14, 2025
277f547
changelog
paul0403 May 14, 2025
5419116
update backprop test
paul0403 May 14, 2025
19e25f8
do not manually copy the cotangents for backprop
paul0403 May 14, 2025
0962b81
use ValueRange instead of vector
paul0403 May 14, 2025
7273e4e
update comment about backprop's mem write
paul0403 May 14, 2025
4714f3c
move over Tzunghan's forward and reverse
paul0403 May 15, 2025
4d995e7
try CI
paul0403 May 15, 2025
d037193
format
paul0403 May 15, 2025
14b8599
remove old gradient bufferization
paul0403 May 15, 2025
b98c9ea
clean prints
paul0403 May 16, 2025
d3389a9
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 16, 2025
89e0bbd
format
paul0403 May 16, 2025
491cfa1
add gradient preprocessing test
paul0403 May 16, 2025
c7b57f2
easier on the eyes
paul0403 May 16, 2025
128978b
reverse op preprocessing test
paul0403 May 16, 2025
bc3caf9
add bufferization test for forward and reverse
paul0403 May 16, 2025
1eb49c1
add post processing test for forward and reverse
paul0403 May 16, 2025
eff06f4
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 16, 2025
a194c06
make most dialects into one-shot-bufferize(dialect)
paul0403 May 19, 2025
66f043f
a stable version before trying one-shot-bufferize pass
paul0403 May 19, 2025
e35b25e
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 20, 2025
dd04fb9
change the functionArgTypeConverterFn in gradient bufferization to ta…
paul0403 May 20, 2025
5e9d75c
use one-shot bufferization in python pipeline
paul0403 May 20, 2025
e306252
update llvm to the commit that has
paul0403 May 20, 2025
83f6fc2
add restrict unitattr to the ToTensorOps in gradient lowering pass
paul0403 May 20, 2025
ffb2c60
GreedyRewriteConfig.enableRegionSimplification is no longer just a pl…
paul0403 May 20, 2025
1848bed
update TopologicalSortUtils.h location
paul0403 May 20, 2025
9b090f3
include <variant>
paul0403 May 20, 2025
93ff924
track llvm and mhlo versions to jax 0.4.32
paul0403 May 20, 2025
ece1c6d
update llvm and mhlo submodules to jax 0.4.32 versions
paul0403 May 20, 2025
7dfd0cc
.dep-versions format :sweat-smile:
paul0403 May 20, 2025
b628f47
enzymestatic-19 -> 20
paul0403 May 20, 2025
d2bc0b6
MhloQuantToIntConversion is removed
paul0403 May 20, 2025
3790a68
`translateModuleToLLVMIR` got a new argument `disableVerification`
paul0403 May 20, 2025
f370449
just comment out old bufferization passes in cpp pipeline for now
paul0403 May 20, 2025
268bc64
update cpp pipeline
paul0403 May 21, 2025
d3a75ab
turn on `copy-before-write` for async
paul0403 May 21, 2025
6e0d8df
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 21, 2025
a55e941
Add Tzunghan as author
paul0403 May 21, 2025
0739553
changelog
paul0403 May 21, 2025
0fc83a5
Use `eliminate-empty-tensors` pass instead of `empty-tensor-to-alloc-…
paul0403 May 21, 2025
b8895ed
add `restrict` attr to to_tensor ops in mlir lit test
paul0403 May 21, 2025
26d8f50
line-too-long on the python bufferization options string
paul0403 May 21, 2025
d5cdb13
skip upstream ml_dtypes lit test
paul0403 May 21, 2025
ef323be
Update mlir lit tests impacted by mlir update.
paul0403 May 21, 2025
692cf99
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 21, 2025
e04b79f
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 22, 2025
1550045
Merge remote-tracking branch 'origin/paul0403/new_bufferize_gradient_…
paul0403 May 22, 2025
88d28ae
Merge remote-tracking branch 'origin/main' into paul0403/one-shot-buf…
paul0403 May 23, 2025
80cec7b
format
paul0403 May 23, 2025
c07b826
just cleaning cache and checking wheels; I will respond to comments o…
paul0403 May 23, 2025
90a2546
Merge remote-tracking branch 'origin/main' into paul0403/one-shot-buf…
paul0403 May 26, 2025
1e1435d
add `promote-buffers-to-stack` pass
paul0403 May 26, 2025
7826069
Merge remote-tracking branch 'origin/main' into paul0403/one-shot-buf…
paul0403 May 26, 2025
d3132f6
promote buffer to stack pass in cpp pipeline
paul0403 May 26, 2025
dc2f7b8
move createBufferizationToMemRefPass to the end of the bufferization …
paul0403 May 26, 2025
2cbdeb8
Merge remote-tracking branch 'origin/main' into paul0403/one-shot-buf…
paul0403 May 26, 2025
546cd2c
move bufferization-to-memref back
paul0403 May 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions .dep-versions
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
# Always update the version check in catalyst.__init__ when changing the JAX version.

#############
# We track mlir submodule versions from jax 0.4.32 for now
# These are the earliest versions with complete upstream bufferization changes
# Versions are retrieved from
# python3 .github/workflows/set_dep_versions.py 0.4.32
#############

jax=0.6.0
mhlo=89a891c986650c33df76885f5620e0a92150d90f
llvm=3a8316216807d64a586b971f51695e23883331f7
mhlo=25b008569f413d76cfa8f481f3a84e82b89c47f4
llvm=5f74671c85877e03622e8d308aee15ed73ccee7c
enzyme=v0.0.149

# Always remove custom PL/LQ versions before release.
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build-wheel-linux-arm64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ jobs:
-DCMAKE_CXX_VISIBILITY_PRESET=default \
-DCMAKE_CXX_FLAGS="-fuse-ld=lld"

cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-19
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-20

- name: Save Enzyme Build
id: save-enzyme-build
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build-wheel-linux-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ jobs:
-DCMAKE_CXX_VISIBILITY_PRESET=default \
-DCMAKE_CXX_FLAGS="-fuse-ld=lld"

cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-19
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-20

- name: Save Enzyme Build
id: save-enzyme-build
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build-wheel-macos-arm64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ jobs:
-DENZYME_STATIC_LIB=ON \
-DCMAKE_CXX_VISIBILITY_PRESET=default

cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-19
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-20

- name: Save Enzyme Build
id: save-enzyme-build
Expand Down
2 changes: 2 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@
[(#1027)](https://github.com/PennyLaneAI/catalyst/pull/1027)
[(#1686)](https://github.com/PennyLaneAI/catalyst/pull/1686)
[(#1708)](https://github.com/PennyLaneAI/catalyst/pull/1708)
[(#1740)](https://github.com/PennyLaneAI/catalyst/pull/1740)
[(#1751)](https://github.com/PennyLaneAI/catalyst/pull/1751)

* Redundant `OptionalAttr` is removed from `adjoint` argument in `QuantumOps.td` TableGen file
[(#1746)](https://github.com/PennyLaneAI/catalyst/pull/1746)
Expand Down
35 changes: 18 additions & 17 deletions frontend/catalyst/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,27 +213,27 @@
return list(filter(partial(is_not, None), quantum_compilation))


def get_bufferization_stage(_options: CompileOptions) -> List[str]:
def get_bufferization_stage(options: CompileOptions) -> List[str]:
"""Returns the list of passes that performs bufferization"""

bufferization_options = """bufferize-function-boundaries
allow-return-allocs-from-loops
function-boundary-type-conversion=identity-layout-map
unknown-type-conversion=identity-layout-map""".replace(
"\n", " "
)
if options.async_qnodes:
bufferization_options += " copy-before-write"

Check warning on line 226 in frontend/catalyst/pipelines.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/pipelines.py#L226

Added line #L226 was not covered by tests

bufferization = [
"one-shot-bufferize{dialect-filter=memref}",
"inline",
"gradient-preprocess",
"gradient-bufferize",
"scf-bufferize",
"convert-tensor-to-linalg", # tensor.pad
"convert-elementwise-to-linalg", # Must be run before --arith-bufferize
"arith-bufferize",
"empty-tensor-to-alloc-tensor",
"func.func(bufferization-bufferize)",
"func.func(tensor-bufferize)",
# Catalyst dialect's bufferization must be run before --func.func(linalg-bufferize)
"one-shot-bufferize{dialect-filter=catalyst unknown-type-conversion=identity-layout-map}",
"func.func(linalg-bufferize)",
"func.func(tensor-bufferize)",
"one-shot-bufferize{dialect-filter=quantum}",
"func-bufferize",
"func.func(finalizing-bufferize)",
"convert-elementwise-to-linalg", # Must be run before --one-shot-bufferize
"gradient-preprocess",
"eliminate-empty-tensors",
####################
"one-shot-bufferize{" + bufferization_options + "}",
####################
"canonicalize", # Remove dead memrefToTensorOp's
"gradient-postprocess",
# introduced during gradient-bufferize of callbacks
Expand All @@ -247,6 +247,7 @@
# "cse",
"cp-global-memref",
]

return bufferization


Expand Down
1 change: 0 additions & 1 deletion mlir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ set(ALL_MHLO_PASSES
HloToLinalgUtils
MhloToLinalg
MhloToStablehlo
MhloQuantToIntConversion
StablehloToMhlo
)

Expand Down
7 changes: 5 additions & 2 deletions mlir/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ llvm:

# TODO: when updating LLVM, test to see if mlir/unittests/Bytecode/BytecodeTest.cpp:55 is passing
# and remove filter. This tests fails on CI/CD not locally.
LIT_FILTER_OUT="Bytecode|tosa-to-tensor" cmake --build $(LLVM_BUILD_DIR) --target $(LLVM_TARGETS)
# Note: the upstream lit test llvm-project/mlir/test/python/execution_engine.py requries
# the python package `ml_dtypes`. We don't actually use the execution engine, so we skip the
# test to reduce unnecessary dependencies.
LIT_FILTER_OUT="Bytecode|tosa-to-tensor|execution_engine" cmake --build $(LLVM_BUILD_DIR) --target $(LLVM_TARGETS)

.PHONY: mhlo
mhlo: TARGET_FILE := $(MK_DIR)/mlir-hlo/mhlo/transforms/CMakeLists.txt
Expand Down Expand Up @@ -130,7 +133,7 @@ enzyme:
-DCMAKE_CXX_VISIBILITY_PRESET=$(SYMBOL_VISIBILITY) \
-DCMAKE_POLICY_DEFAULT_CMP0116=NEW

cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-19
cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-20

.PHONY: plugin
plugin:
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/Gradient/IR/GradientOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"

#include "Gradient/IR/GradientInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"

#include "Gradient/IR/GradientDialect.h"
#include "Gradient/IR/GradientInterfaces.h"

#define GET_OP_CLASSES
#include "Gradient/IR/GradientOps.h.inc"
3 changes: 2 additions & 1 deletion mlir/include/Gradient/IR/GradientOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/OpBase.td"
Expand Down Expand Up @@ -388,7 +389,7 @@ def ReverseOp : Gradient_Op<"reverse",
}

def ReturnOp : Gradient_Op<"return",
[Terminator, ParentOneOf<["ForwardOp", "ReverseOp"]>]> {
[ReturnLike, Terminator, ParentOneOf<["ForwardOp", "ReverseOp"]>]> {

let summary = "Return tapes or nothing";

Expand Down
27 changes: 27 additions & 0 deletions mlir/include/Gradient/Transforms/BufferizableOpInterfaceImpl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright 2024-2025 Xanadu Quantum Technologies Inc.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

using namespace mlir;

namespace catalyst {

namespace gradient {

void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry &registry);

}

} // namespace catalyst
12 changes: 0 additions & 12 deletions mlir/include/Gradient/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,6 @@

include "mlir/Pass/PassBase.td"

def GradientBufferizationPass : Pass<"gradient-bufferize"> {
let summary = "Bufferize tensors in quantum operations.";

let dependentDialects = [
"bufferization::BufferizationDialect",
"memref::MemRefDialect",
"index::IndexDialect"
];

let constructor = "catalyst::createGradientBufferizationPass()";
}

def GradientPreprocessingPass : Pass<"gradient-preprocess"> {
let summary = "Insert Func.CallOp for ForwardOp and ReverseOp.";

Expand Down
1 change: 0 additions & 1 deletion mlir/include/Gradient/Transforms/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
namespace catalyst {
namespace gradient {

void populateBufferizationPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &);
void populatePreprocessingPatterns(mlir::RewritePatternSet &);
void populatePostprocessingPatterns(mlir::RewritePatternSet &);
void populateLoweringPatterns(mlir::RewritePatternSet &);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Catalyst/Transforms/DetectQNodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ struct AddExceptionHandlingPass : impl::AddExceptionHandlingPassBase<AddExceptio

GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
config.enableRegionSimplification = false;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;

if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns1), config))) {
signalPassFailure();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ struct AnnotateWithFullyQualifiedNamePass
// Do not fold to save in compile time.
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
config.enableRegionSimplification = false;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;

RewritePatternSet annotate(context);
auto root = getOperation();
Expand All @@ -409,7 +409,7 @@ struct InlineNestedSymbolTablePass : PassWrapper<InlineNestedSymbolTablePass, Op

GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
config.enableRegionSimplification = false;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;

RewritePatternSet renameFunctions(context);

Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ void catalyst::registerAllCatalystPasses()
mlir::registerPass(catalyst::createDisentangleSWAPPass);
mlir::registerPass(catalyst::createEmitCatalystPyInterfacePass);
mlir::registerPass(catalyst::createGEPInboundsPass);
mlir::registerPass(catalyst::createGradientBufferizationPass);
mlir::registerPass(catalyst::createGradientConversionPass);
mlir::registerPass(catalyst::createGradientPreprocessingPass);
mlir::registerPass(catalyst::createGradientPostprocessingPass);
Expand Down
5 changes: 4 additions & 1 deletion mlir/lib/Driver/CompilerDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
#include "Driver/Support.h"
#include "Gradient/IR/GradientDialect.h"
#include "Gradient/IR/GradientInterfaces.h"
#include "Gradient/Transforms/BufferizableOpInterfaceImpl.h"
#include "Gradient/Transforms/Passes.h"
#include "Ion/IR/IonDialect.h"
#include "MBQC/IR/MBQCDialect.h"
Expand Down Expand Up @@ -737,7 +738,8 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput &
TimingScope translateTiming = timing.nest("Translate");
llvmModule =
timer::timer(translateModuleToLLVMIR, "translateModuleToLLVMIR",
/* add_endl */ false, *mlirModule, llvmContext, "LLVMDialectModule");
/* add_endl */ false, *mlirModule, llvmContext, "LLVMDialectModule",
/* disableVerification */ true);
if (!llvmModule) {
CO_MSG(options, Verbosity::Urgent, "Failed to translate LLVM module\n");
return failure();
Expand Down Expand Up @@ -966,6 +968,7 @@ int QuantumDriverMainFromCL(int argc, char **argv)

// Register bufferization interfaces
catalyst::registerBufferizableOpInterfaceExternalModels(registry);
catalyst::gradient::registerBufferizableOpInterfaceExternalModels(registry);
catalyst::quantum::registerBufferizableOpInterfaceExternalModels(registry);

// Register and parse command line options.
Expand Down
42 changes: 16 additions & 26 deletions mlir/lib/Driver/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "Driver/Pipelines.h"
#include "Catalyst/IR/CatalystDialect.h"
#include "Catalyst/Transforms/Passes.h"
#include "Gradient/IR/GradientDialect.h"
#include "Gradient/Transforms/Passes.h"
#include "Mitigation/Transforms/Passes.h"
#include "Quantum/IR/QuantumDialect.h"
Expand Down Expand Up @@ -66,35 +67,24 @@ void createQuantumCompilationPipeline(OpPassManager &pm)
}
void createBufferizationPipeline(OpPassManager &pm)
{
mlir::bufferization::OneShotBufferizationOptions options;
options.opFilter.allowDialect<mlir::bufferization::BufferizationDialect>();
pm.addPass(mlir::bufferization::createOneShotBufferizePass(options));
pm.addPass(mlir::createInlinerPass());
pm.addPass(catalyst::createGradientPreprocessingPass());
pm.addPass(catalyst::createGradientBufferizationPass());
pm.addPass(mlir::createSCFBufferizePass());
pm.addPass(mlir::createConvertTensorToLinalgPass());
pm.addPass(mlir::createConvertElementwiseToLinalgPass());
pm.addPass(mlir::arith::createArithBufferizePass());
pm.addPass(mlir::bufferization::createEmptyTensorToAllocTensorPass());
pm.addNestedPass<mlir::func::FuncOp>(mlir::bufferization::createBufferizationBufferizePass());
pm.addNestedPass<mlir::func::FuncOp>(mlir::tensor::createTensorBufferizePass());
mlir::bufferization::OneShotBufferizationOptions catalyst_buffer_options;
catalyst_buffer_options.opFilter.allowDialect<catalyst::CatalystDialect>();
catalyst_buffer_options.unknownTypeConverterFn =
[=](Value value, Attribute memorySpace,
const mlir::bufferization::BufferizationOptions &options) {
auto tensorType = cast<TensorType>(value.getType());
return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
};
pm.addPass(mlir::bufferization::createOneShotBufferizePass(catalyst_buffer_options));
pm.addNestedPass<mlir::func::FuncOp>(mlir::createLinalgBufferizePass());
pm.addNestedPass<mlir::func::FuncOp>(mlir::tensor::createTensorBufferizePass());
mlir::bufferization::OneShotBufferizationOptions quantum_buffer_options;
quantum_buffer_options.opFilter.allowDialect<catalyst::quantum::QuantumDialect>();
pm.addPass(mlir::bufferization::createOneShotBufferizePass(quantum_buffer_options));
pm.addPass(mlir::func::createFuncBufferizePass());
pm.addNestedPass<mlir::func::FuncOp>(mlir::bufferization::createFinalizingBufferizePass());
pm.addPass(catalyst::createGradientPreprocessingPass());
pm.addPass(mlir::bufferization::createEmptyTensorEliminationPass());
///////////
mlir::bufferization::OneShotBufferizationOptions options;
options.bufferizeFunctionBoundaries = true;
options.allowReturnAllocsFromLoops = true;
options.setFunctionBoundaryTypeConversion(
mlir::bufferization::LayoutMapOption::IdentityLayoutMap);
options.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
const mlir::bufferization::BufferizationOptions &options) {
auto tensorType = cast<TensorType>(value.getType());
return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
};
pm.addPass(mlir::bufferization::createOneShotBufferizePass(options));
//////////////
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(catalyst::createGradientPostprocessingPass());
pm.addNestedPass<mlir::func::FuncOp>(mlir::bufferization::createBufferHoistingPass());
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Driver/Timer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <string>
#include <string_view>
#include <thread>
#include <utility> // std::forward

#include <ctime>

Expand Down
6 changes: 5 additions & 1 deletion mlir/lib/Gradient/IR/GradientDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Transforms/InliningUtils.h"

#include "Gradient/IR/GradientDialect.h"
#include "Gradient/IR/GradientOps.h"
#include "mlir/Interfaces/FunctionImplementation.h"

using namespace mlir;
using namespace catalyst::gradient;
Expand Down Expand Up @@ -50,6 +51,9 @@ void GradientDialect::initialize()
#include "Gradient/IR/GradientOps.cpp.inc"
>();
addInterface<GradientInlinerInterface>();

declarePromisedInterfaces<bufferization::BufferizableOpInterface, AdjointOp, BackpropOp,
ForwardOp, ReverseOp>();
}

//===----------------------------------------------------------------------===//
Expand Down
Loading