Skip to content

Commit 47bb288

Browse files
committed
refactor: reuse batching interface for LU factorization
1 parent a7c38bc commit 47bb288

File tree

6 files changed

+182
-242
lines changed

6 files changed

+182
-242
lines changed

src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
#include "stablehlo/dialect/ChloOps.h"
3232
#include "stablehlo/dialect/StablehloOps.h"
3333

34+
#include "src/enzyme_ad/jax/Dialect/Dialect.h"
35+
#include "src/enzyme_ad/jax/Dialect/Ops.h"
3436
#include "src/enzyme_ad/jax/Implementations/WhileLoopInfo.h"
3537
#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"
3638
#include "src/enzyme_ad/jax/Utils.h"
@@ -1962,6 +1964,31 @@ struct SHLOConstantOpBatchInterface
19621964
}
19631965
};
19641966

1967+
struct SHLOGetDimensionSizeOpBatchInterface
1968+
: public BatchOpInterface::ExternalModel<
1969+
SHLOGetDimensionSizeOpBatchInterface, GetDimensionSizeOp> {
1970+
1971+
mlir::LogicalResult createBatch(Operation *src, OpBuilder &builder,
1972+
IRMapping &mapper,
1973+
ArrayRef<int64_t> batchSizes) const {
1974+
auto getDimSizeOp = cast<GetDimensionSizeOp>(src);
1975+
1976+
auto newOp = builder.create<GetDimensionSizeOp>(
1977+
src->getLoc(), mapper.lookup(getDimSizeOp.getOperand()),
1978+
cast<IntegerAttr>(getDimSizeOp.getDimensionAttr()).getInt() +
1979+
batchSizes.size());
1980+
auto bcastOp = builder.create<BroadcastInDimOp>(
1981+
src->getLoc(),
1982+
RankedTensorType::get(
1983+
batchSizes,
1984+
cast<RankedTensorType>(newOp->getResult(0).getType())
1985+
.getElementType()),
1986+
newOp->getResult(0), builder.getDenseI64ArrayAttr({}));
1987+
mapper.map(src->getResult(0), bcastOp->getResult(0));
1988+
return success();
1989+
}
1990+
};
1991+
19651992
struct SHLOTransposeOpBatchInterface
19661993
: public BatchOpInterface::ExternalModel<SHLOTransposeOpBatchInterface,
19671994
TransposeOp> {
@@ -3918,6 +3945,8 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface(
39183945
*context);
39193946

39203947
ConstantOp::attachInterface<SHLOConstantOpBatchInterface>(*context);
3948+
GetDimensionSizeOp::attachInterface<SHLOGetDimensionSizeOpBatchInterface>(
3949+
*context);
39213950
TransposeOp::attachInterface<SHLOTransposeOpBatchInterface>(*context);
39223951
IfOp::attachInterface<SHLOGenericBatchOpInterface<IfOp>>(*context);
39233952
WhileOp::attachInterface<SHLOGenericBatchOpInterface<WhileOp>>(*context);
@@ -3947,6 +3976,10 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface(
39473976

39483977
AddOp::attachInterface<StablehloAddSimplifyMathInterface>(*context);
39493978
SubtractOp::attachInterface<StablehloSubSimplifyMathInterface>(*context);
3979+
3980+
// TODO: move into its own file
3981+
enzymexla::JITCallOp::attachInterface<
3982+
SHLOGenericBatchOpInterface<enzymexla::JITCallOp>>(*context);
39503983
});
39513984
}
39523985

src/enzyme_ad/jax/Passes/LinalgUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ mlir::ArrayAttr getSHLOLayout(PatternRewriter &rewriter,
3838
return rewriter.getArrayAttr(attrs);
3939
}
4040

41-
std::optional<std::string> lapack_precision_prefix(Type elementType) {
41+
std::optional<std::string> lapackPrecisionPrefix(Type elementType) {
4242

4343
// single-precision float
4444
if (elementType.isF32()) {

src/enzyme_ad/jax/Passes/LinalgUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ mlir::ArrayAttr getSHLOLayout(mlir::PatternRewriter &rewriter,
1717
llvm::SmallVector<bool> isColMajorArr,
1818
int64_t maxNumDims);
1919

20-
std::optional<std::string> lapack_precision_prefix(mlir::Type elementType);
20+
std::optional<std::string> lapackPrecisionPrefix(mlir::Type elementType);
2121

2222
#endif // ENZYMEXLA_LINALGUTILS_H

src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ struct GeqrfOpLowering : public OpRewritePattern<enzymexla::GeqrfOp> {
8080
auto type_llvm_void = LLVM::LLVMVoidType::get(ctx);
8181

8282
std::string fn = "geqrf_";
83-
if (auto prefix = lapack_precision_prefix(inputElementType)) {
83+
if (auto prefix = lapackPrecisionPrefix(inputElementType)) {
8484
fn = *prefix + fn;
8585
} else {
8686
op->emitOpError() << "Unsupported element type: " << inputElementType;
@@ -351,7 +351,7 @@ struct GeqrtOpLowering : public OpRewritePattern<enzymexla::GeqrtOp> {
351351
auto type_llvm_void = LLVM::LLVMVoidType::get(ctx);
352352

353353
std::string fn = "geqrt_";
354-
if (auto prefix = lapack_precision_prefix(inputElementType)) {
354+
if (auto prefix = lapackPrecisionPrefix(inputElementType)) {
355355
fn = *prefix + fn;
356356
} else {
357357
op->emitOpError() << "Unsupported element type: " << inputElementType;
@@ -562,7 +562,7 @@ struct OrgqrOpLowering : public OpRewritePattern<enzymexla::OrgqrOp> {
562562
auto type_llvm_void = LLVM::LLVMVoidType::get(ctx);
563563

564564
std::string fn = "gqr_";
565-
if (auto prefix = lapack_precision_prefix(inputElementType)) {
565+
if (auto prefix = lapackPrecisionPrefix(inputElementType)) {
566566
if (prefix == "s" || prefix == "d")
567567
fn = *prefix + "or" + fn;
568568
else
@@ -868,7 +868,7 @@ struct OrmqrOpLowering : public OpRewritePattern<enzymexla::OrmqrOp> {
868868
auto type_llvm_char = rewriter.getIntegerType(8);
869869

870870
std::string fn = "mqr_";
871-
if (auto prefix = lapack_precision_prefix(A_eltype)) {
871+
if (auto prefix = lapackPrecisionPrefix(A_eltype)) {
872872
if (prefix == "s" || prefix == "d")
873873
fn = *prefix + "or" + fn;
874874
else
@@ -1136,7 +1136,7 @@ struct GemqrtOpLowering : public OpRewritePattern<enzymexla::GemqrtOp> {
11361136
auto type_llvm_char = rewriter.getIntegerType(8);
11371137

11381138
std::string fn = "gemqrt_";
1139-
if (auto prefix = lapack_precision_prefix(C_eltype)) {
1139+
if (auto prefix = lapackPrecisionPrefix(C_eltype)) {
11401140
fn = *prefix + fn;
11411141
} else {
11421142
op->emitOpError() << "Unsupported element type: " << C_eltype;

0 commit comments

Comments
 (0)