Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .dep-versions
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# To update JAX version alongside compatible dependency tags, run the following script:
# python3 .github/workflows/set_dep_versions.py {JAX_version}
jax=0.6.2
stablehlo=69d6dae46e1c7de36e6e6973654754f05353cba5
llvm=f8cb7987c64dcffb72414a40560055cb717dbf74
enzyme=v0.0.186
stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d
llvm=113f01aa82d055410f22a9d03b3468fa68600589
enzyme=v0.0.203

# Always remove custom PL/LQ versions before release.

Expand Down
2 changes: 1 addition & 1 deletion mlir/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ enzyme:
-DCMAKE_CXX_VISIBILITY_PRESET=$(SYMBOL_VISIBILITY) \
-DCMAKE_POLICY_DEFAULT_CMP0116=NEW

cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-21
cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-22

.PHONY: plugin
plugin:
Expand Down
13 changes: 11 additions & 2 deletions mlir/include/Catalyst/IR/CatalystOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,17 @@ def CallbackOp : Catalyst_Op<"callback",

let builders = [OpBuilder<(ins
"mlir::StringRef":$name, "mlir::FunctionType":$type,
CArg<"mlir::ArrayRef<mlir::NamedAttribute>", "{}">:$attrs)
>];
CArg<"mlir::ArrayRef<mlir::NamedAttribute>", "{}">:$attrs), [{
$_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
$_builder.getStringAttr(name));
$_state.addAttribute("function_type", mlir::TypeAttr::get(type));
$_state.addAttribute("id", $_builder.getI64IntegerAttr(0));
$_state.addAttribute("argc", $_builder.getI64IntegerAttr(type.getNumInputs()));
$_state.addAttribute("resc", $_builder.getI64IntegerAttr(type.getNumResults()));
$_state.attributes.append(attrs.begin(), attrs.end());
$_state.addRegion();
}]>
];

let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
Expand Down
35 changes: 27 additions & 8 deletions mlir/include/Gradient/IR/GradientOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ include "Gradient/IR/GradientInterfaces.td"

def GradOp : Gradient_Op<"grad", [
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
GradientOpInterface
]> {
let summary = "Compute the gradient of a function.";
Expand Down Expand Up @@ -287,7 +287,7 @@ def ForwardOp : Gradient_Op<"forward",
Then:

followed by the original return type, if any.

since there is none, then:

%returnTy = { %tape }
Expand All @@ -302,7 +302,7 @@ def ForwardOp : Gradient_Op<"forward",
One thing that was found experimentally and through tests in Enzyme is that the tape can also be a pointer.
We use this in the case when there is no tape to return. Instead of returning an empty struct, we return a null
pointer that is just never dereferenced.

}];

let arguments = (ins
Expand All @@ -320,8 +320,18 @@ def ForwardOp : Gradient_Op<"forward",

let builders = [OpBuilder<(ins
"mlir::StringRef":$name, "mlir::FunctionType":$type,
CArg<"mlir::ArrayRef<mlir::NamedAttribute>", "{}">:$attrs)
>];
CArg<"mlir::ArrayRef<mlir::NamedAttribute>", "{}">:$attrs), [{
$_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
$_builder.getStringAttr(name));
$_state.addAttribute("function_type", mlir::TypeAttr::get(type));
$_state.addAttribute("implementation", mlir::FlatSymbolRefAttr::get($_builder.getStringAttr("")));
$_state.addAttribute("argc", $_builder.getI64IntegerAttr(0));
$_state.addAttribute("resc", $_builder.getI64IntegerAttr(0));
$_state.addAttribute("tape", $_builder.getI64IntegerAttr(0));
$_state.attributes.append(attrs.begin(), attrs.end());
$_state.addRegion();
}]>
];

let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
Expand Down Expand Up @@ -358,7 +368,6 @@ def ReverseOp : Gradient_Op<"reverse",

%returnTy = { %tape }


}];

let arguments = (ins
Expand All @@ -376,8 +385,18 @@ def ReverseOp : Gradient_Op<"reverse",

let builders = [OpBuilder<(ins
"mlir::StringRef":$name, "mlir::FunctionType":$type,
CArg<"mlir::ArrayRef<mlir::NamedAttribute>", "{}">:$attrs)
>];
CArg<"mlir::ArrayRef<mlir::NamedAttribute>", "{}">:$attrs), [{
$_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
$_builder.getStringAttr(name));
$_state.addAttribute("function_type", mlir::TypeAttr::get(type));
$_state.addAttribute("implementation", mlir::FlatSymbolRefAttr::get($_builder.getStringAttr("")));
$_state.addAttribute("argc", $_builder.getI64IntegerAttr(0));
$_state.addAttribute("resc", $_builder.getI64IntegerAttr(0));
$_state.addAttribute("tape", $_builder.getI64IntegerAttr(0));
$_state.attributes.append(attrs.begin(), attrs.end());
$_state.addRegion();
}]>
];

let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ set(LIBS
${translation_libs}
ExternalStablehloLib
MLIROptLib
MLIRRegisterAllDialects
MLIRRegisterAllPasses
MLIRRegisterAllExtensions
MLIRCatalyst
catalyst-transforms
MLIRQuantum
Expand Down
19 changes: 16 additions & 3 deletions mlir/lib/Driver/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,22 @@

#include <memory>

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "stablehlo/conversions/linalg/transforms/Passes.h"
Expand Down
20 changes: 10 additions & 10 deletions mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ void TensorType2MemrefType(const TypeRange &inTypes, SmallVector<Type> &converte
}
}

static BaseMemRefType
static bufferization::BufferLikeType
getBufferizedFunctionArgType(FunctionOpInterface funcOp, int64_t index,
const bufferization::BufferizationOptions &options)
{
Expand All @@ -134,7 +134,7 @@ getBufferizedFunctionArgType(FunctionOpInterface funcOp, int64_t index,
BaseMemRefType memrefType = options.functionArgTypeConverterFn(
tensorType, *options.defaultMemorySpaceFn(tensorType), nullptr, options);

return memrefType;
return cast<bufferization::BufferLikeType>(memrefType);
}

static ReturnOp getAssumedUniqueReturnOp(FunctionOpInterface funcOp)
Expand Down Expand Up @@ -402,10 +402,10 @@ struct ForwardOpInterface
return {};
}

FailureOr<BaseMemRefType> getBufferType(Operation *op, Value value,
const bufferization::BufferizationOptions &options,
const bufferization::BufferizationState &state,
SmallVector<Value> &invocationStack) const
FailureOr<bufferization::BufferLikeType>
getBufferType(Operation *op, Value value, const bufferization::BufferizationOptions &options,
const bufferization::BufferizationState &state,
SmallVector<Value> &invocationStack) const
{
// The getBufferType() method is called on either BlockArguments or OpResults.
// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td#L506
Expand Down Expand Up @@ -526,10 +526,10 @@ struct ReverseOpInterface
return {};
}

FailureOr<BaseMemRefType> getBufferType(Operation *op, Value value,
const bufferization::BufferizationOptions &options,
const bufferization::BufferizationState &state,
SmallVector<Value> &invocationStack) const
FailureOr<bufferization::BufferLikeType>
getBufferType(Operation *op, Value value, const bufferization::BufferizationOptions &options,
const bufferization::BufferizationState &state,
SmallVector<Value> &invocationStack) const
{
// See comment on the getBufferType() method on forward op.
auto reverseOp = cast<ReverseOp>(op);
Expand Down
22 changes: 7 additions & 15 deletions mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/SymbolTable.h"

#include "Gradient/Utils/EinsumLinalgGeneric.h"
Expand Down Expand Up @@ -60,8 +61,6 @@ template <class T> std::vector<int64_t> _tovec(const T &x)

LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rewriter) const
{
MLIRContext *ctx = getContext();

Location loc = op.getLoc();

auto func_diff_operand_indices = computeDiffArgIndices(op.getDiffArgIndices());
Expand Down Expand Up @@ -159,12 +158,9 @@ LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rew
}
else {
assert(acc.value().getType() == res.getType());

auto add_op = rewriter.create<linalg::ElemwiseBinaryOp>(
loc, res.getType(), ValueRange({acc.value(), res}), acc.value(),
linalg::BinaryFnAttr::get(ctx, linalg::BinaryFn::add),
linalg::TypeFnAttr::get(ctx, linalg::TypeFn::cast_signed));
acc = add_op.getResultTensors()[0];
auto addOp = rewriter.create<linalg::AddOp>(
loc, res.getType(), ValueRange{acc.value(), res}, ValueRange{acc.value()});
acc = addOp.getResultTensors()[0];
}
}
assert(acc.has_value());
Expand All @@ -181,8 +177,6 @@ LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rew

LogicalResult VJPLoweringPattern::matchAndRewrite(VJPOp op, PatternRewriter &rewriter) const
{
MLIRContext *ctx = getContext();

Location loc = op.getLoc();

auto func_diff_operand_indices = computeDiffArgIndices(op.getDiffArgIndices());
Expand Down Expand Up @@ -278,11 +272,9 @@ LogicalResult VJPLoweringPattern::matchAndRewrite(VJPOp op, PatternRewriter &rew
else {
assert(acc.value().getType() == res.getType());

auto add_op = rewriter.create<linalg::ElemwiseBinaryOp>(
loc, res.getType(), ValueRange({acc.value(), res}), acc.value(),
linalg::BinaryFnAttr::get(ctx, linalg::BinaryFn::add),
linalg::TypeFnAttr::get(ctx, linalg::TypeFn::cast_signed));
acc = add_op.getResultTensors()[0];
auto addOp = rewriter.create<linalg::AddOp>(
loc, res.getType(), ValueRange{acc.value(), res}, ValueRange{acc.value()});
acc = addOp.getResultTensors()[0];
}
}
assert(acc.has_value());
Expand Down
1 change: 1 addition & 0 deletions mlir/tools/quantum-lsp-server/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ set(LIBS
${conversion_libs}
ExternalStablehloLib
MLIRLspServerLib
MLIRRegisterAllDialects
MLIRCatalyst
MLIRQuantum
MLIRQEC
Expand Down
2 changes: 2 additions & 0 deletions mlir/tools/quantum-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ set(LIBS
${extension_libs}
ExternalStablehloLib
MLIROptLib
MLIRRegisterAllDialects
MLIRRegisterAllPasses
MLIRCatalyst
catalyst-transforms
catalyst-stablehlo-transforms
Expand Down