Skip to content

Commit 582e713

Browse files
authored
refactor: cleanup linalg/lapack lowering (#1501)
* feat: get_dimension_size batch interface * feat: implement jitcall batching with shlo_generic_batch_op_interface * refactor: reuse batching interface for LU factorization * fix: remove old changes * refactor: move into separate functions * test: update LU tests * feat: dynamic slice simplify * feat: mark memory effects * fix: update to new API * fix: use correct return * chore: run fmt * test: fix
1 parent 563c1b3 commit 582e713

File tree

8 files changed

+682
-730
lines changed

8 files changed

+682
-730
lines changed

src/enzyme_ad/jax/Dialect/TritonExt/Ops.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@ template <>
1515
triton_ext::TritonCallOp ReadOnlyArg<triton_ext::TritonCallOp>::create(
1616
PatternRewriter &rewriter, triton_ext::TritonCallOp launchOp,
1717
ArrayRef<Type> resTys, ArrayAttr outputAliases) const {
18-
return rewriter.create<triton_ext::TritonCallOp>(
19-
launchOp.getLoc(), resTys, launchOp.getFn(), launchOp.getGridx(),
20-
launchOp.getGridy(), launchOp.getGridz(), launchOp.getClusterx(),
21-
launchOp.getClustery(), launchOp.getClusterz(), launchOp.getInputs(),
22-
launchOp.getBackendConfigAttr(), launchOp.getOperandLayoutsAttr(),
18+
return triton_ext::TritonCallOp::create(
19+
rewriter, launchOp.getLoc(), resTys, launchOp.getFn(),
20+
launchOp.getGridx(), launchOp.getGridy(), launchOp.getGridz(),
21+
launchOp.getClusterx(), launchOp.getClustery(), launchOp.getClusterz(),
22+
launchOp.getInputs(), launchOp.getBackendConfigAttr(),
23+
launchOp.getOperandLayoutsAttr(),
2324
/*resultLayouts*/ nullptr, launchOp.getArgAttrsAttr(),
2425
launchOp.getResAttrsAttr(), outputAliases,
2526
launchOp.getXlaSideEffectFreeAttr());

src/enzyme_ad/jax/Passes/LinalgUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ mlir::ArrayAttr getSHLOLayout(PatternRewriter &rewriter,
3838
return rewriter.getArrayAttr(attrs);
3939
}
4040

41-
std::optional<std::string> lapack_precision_prefix(Type elementType) {
41+
std::optional<std::string> lapackPrecisionPrefix(Type elementType) {
4242

4343
// single-precision float
4444
if (elementType.isF32()) {

src/enzyme_ad/jax/Passes/LinalgUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ mlir::ArrayAttr getSHLOLayout(mlir::PatternRewriter &rewriter,
1717
llvm::SmallVector<bool> isColMajorArr,
1818
int64_t maxNumDims);
1919

20-
std::optional<std::string> lapack_precision_prefix(mlir::Type elementType);
20+
std::optional<std::string> lapackPrecisionPrefix(mlir::Type elementType);
2121

2222
#endif // ENZYMEXLA_LINALGUTILS_H

src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp

Lines changed: 38 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,20 @@ struct GeqrfOpLowering : public OpRewritePattern<enzymexla::GeqrfOp> {
4141
LogicalResult matchAndRewrite(enzymexla::GeqrfOp op,
4242
PatternRewriter &rewriter) const override {
4343
if (backend == "cpu")
44-
return this->matchAndRewrite_cpu(op, rewriter);
45-
44+
return matchAndRewriteCPU(op, rewriter);
4645
else if (backend == "cuda")
47-
return this->matchAndRewrite_cuda(op, rewriter);
48-
46+
return matchAndRewriteCUDA(op, rewriter);
4947
else if (backend == "tpu")
50-
return this->matchAndRewrite_tpu(op, rewriter);
51-
48+
return matchAndRewriteTPU(op, rewriter);
5249
else
5350
return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend +
5451
"\"");
5552
}
5653

5754
// TODO get matrix sizes dynamically so that we don't need to create a
5855
// function wrapper for each op instance
59-
LogicalResult matchAndRewrite_cpu(enzymexla::GeqrfOp op,
60-
PatternRewriter &rewriter) const {
56+
LogicalResult matchAndRewriteCPU(enzymexla::GeqrfOp op,
57+
PatternRewriter &rewriter) const {
6158
auto ctx = op->getContext();
6259
LLVMTypeConverter typeConverter(ctx);
6360

@@ -80,7 +77,7 @@ struct GeqrfOpLowering : public OpRewritePattern<enzymexla::GeqrfOp> {
8077
auto type_llvm_void = LLVM::LLVMVoidType::get(ctx);
8178

8279
std::string fn = "geqrf_";
83-
if (auto prefix = lapack_precision_prefix(inputElementType)) {
80+
if (auto prefix = lapackPrecisionPrefix(inputElementType)) {
8481
fn = *prefix + fn;
8582
} else {
8683
op->emitOpError() << "Unsupported element type: " << inputElementType;
@@ -209,8 +206,8 @@ struct GeqrfOpLowering : public OpRewritePattern<enzymexla::GeqrfOp> {
209206
return success();
210207
}
211208

212-
LogicalResult matchAndRewrite_cuda(enzymexla::GeqrfOp op,
213-
PatternRewriter &rewriter) const {
209+
LogicalResult matchAndRewriteCUDA(enzymexla::GeqrfOp op,
210+
PatternRewriter &rewriter) const {
214211
auto ctx = op->getContext();
215212
LLVMTypeConverter typeConverter(ctx);
216213

@@ -265,8 +262,8 @@ struct GeqrfOpLowering : public OpRewritePattern<enzymexla::GeqrfOp> {
265262
return success();
266263
}
267264

268-
LogicalResult matchAndRewrite_tpu(enzymexla::GeqrfOp op,
269-
PatternRewriter &rewriter) const {
265+
LogicalResult matchAndRewriteTPU(enzymexla::GeqrfOp op,
266+
PatternRewriter &rewriter) const {
270267
auto ctx = op->getContext();
271268
LLVMTypeConverter typeConverter(ctx);
272269

@@ -316,23 +313,20 @@ struct GeqrtOpLowering : public OpRewritePattern<enzymexla::GeqrtOp> {
316313
LogicalResult matchAndRewrite(enzymexla::GeqrtOp op,
317314
PatternRewriter &rewriter) const override {
318315
if (backend == "cpu")
319-
return this->matchAndRewrite_cpu(op, rewriter);
320-
316+
return matchAndRewriteCPU(op, rewriter);
321317
// else if (backend == "cuda")
322-
// return this->matchAndRewrite_cuda(op, rewriter);
323-
318+
// return matchAndRewriteCUDA(op, rewriter);
324319
// else if (backend == "tpu")
325-
// return this->matchAndRewrite_tpu(op, rewriter);
326-
320+
// return matchAndRewriteTPU(op, rewriter);
327321
else
328322
return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend +
329323
"\"");
330324
}
331325

332326
// TODO get matrix sizes dynamically so that we don't need to create a
333327
// function wrapper for each op instance
334-
LogicalResult matchAndRewrite_cpu(enzymexla::GeqrtOp op,
335-
PatternRewriter &rewriter) const {
328+
LogicalResult matchAndRewriteCPU(enzymexla::GeqrtOp op,
329+
PatternRewriter &rewriter) const {
336330
auto ctx = op->getContext();
337331
LLVMTypeConverter typeConverter(ctx);
338332

@@ -355,7 +349,7 @@ struct GeqrtOpLowering : public OpRewritePattern<enzymexla::GeqrtOp> {
355349
auto type_llvm_void = LLVM::LLVMVoidType::get(ctx);
356350

357351
std::string fn = "geqrt_";
358-
if (auto prefix = lapack_precision_prefix(inputElementType)) {
352+
if (auto prefix = lapackPrecisionPrefix(inputElementType)) {
359353
fn = *prefix + fn;
360354
} else {
361355
op->emitOpError() << "Unsupported element type: " << inputElementType;
@@ -523,23 +517,20 @@ struct OrgqrOpLowering : public OpRewritePattern<enzymexla::OrgqrOp> {
523517
LogicalResult matchAndRewrite(enzymexla::OrgqrOp op,
524518
PatternRewriter &rewriter) const override {
525519
if (backend == "cpu")
526-
return this->matchAndRewrite_cpu(op, rewriter);
527-
520+
return matchAndRewriteCPU(op, rewriter);
528521
else if (backend == "cuda")
529-
return this->matchAndRewrite_cuda(op, rewriter);
530-
522+
return matchAndRewriteCUDA(op, rewriter);
531523
else if (backend == "tpu")
532-
return this->matchAndRewrite_tpu(op, rewriter);
533-
524+
return matchAndRewriteTPU(op, rewriter);
534525
else
535526
return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend +
536527
"\"");
537528
}
538529

539530
// TODO get matrix sizes dynamically so that we don't need to create a
540531
// function wrapper for each op instance
541-
LogicalResult matchAndRewrite_cpu(enzymexla::OrgqrOp op,
542-
PatternRewriter &rewriter) const {
532+
LogicalResult matchAndRewriteCPU(enzymexla::OrgqrOp op,
533+
PatternRewriter &rewriter) const {
543534
auto ctx = op->getContext();
544535
LLVMTypeConverter typeConverter(ctx);
545536

@@ -567,7 +558,7 @@ struct OrgqrOpLowering : public OpRewritePattern<enzymexla::OrgqrOp> {
567558
auto type_llvm_void = LLVM::LLVMVoidType::get(ctx);
568559

569560
std::string fn = "gqr_";
570-
if (auto prefix = lapack_precision_prefix(inputElementType)) {
561+
if (auto prefix = lapackPrecisionPrefix(inputElementType)) {
571562
if (prefix == "s" || prefix == "d")
572563
fn = *prefix + "or" + fn;
573564
else
@@ -688,8 +679,8 @@ struct OrgqrOpLowering : public OpRewritePattern<enzymexla::OrgqrOp> {
688679
return success();
689680
}
690681

691-
LogicalResult matchAndRewrite_cuda(enzymexla::OrgqrOp op,
692-
PatternRewriter &rewriter) const {
682+
LogicalResult matchAndRewriteCUDA(enzymexla::OrgqrOp op,
683+
PatternRewriter &rewriter) const {
693684
auto ctx = op->getContext();
694685
LLVMTypeConverter typeConverter(ctx);
695686

@@ -734,8 +725,8 @@ struct OrgqrOpLowering : public OpRewritePattern<enzymexla::OrgqrOp> {
734725
return success();
735726
}
736727

737-
LogicalResult matchAndRewrite_tpu(enzymexla::OrgqrOp op,
738-
PatternRewriter &rewriter) const {
728+
LogicalResult matchAndRewriteTPU(enzymexla::OrgqrOp op,
729+
PatternRewriter &rewriter) const {
739730
auto ctx = op->getContext();
740731
LLVMTypeConverter typeConverter(ctx);
741732

@@ -772,23 +763,20 @@ struct OrmqrOpLowering : public OpRewritePattern<enzymexla::OrmqrOp> {
772763
LogicalResult matchAndRewrite(enzymexla::OrmqrOp op,
773764
PatternRewriter &rewriter) const override {
774765
if (backend == "cpu")
775-
return this->matchAndRewrite_cpu(op, rewriter);
776-
766+
return matchAndRewriteCPU(op, rewriter);
777767
// else if (backend == "cuda")
778-
// return this->matchAndRewrite_cuda(op, rewriter);
779-
768+
// return matchAndRewriteCUDA(op, rewriter);
780769
// else if (backend == "tpu")
781-
// return this->matchAndRewrite_tpu(op, rewriter);
782-
770+
// return matchAndRewriteTPU(op, rewriter);
783771
else
784772
return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend +
785773
"\"");
786774
}
787775

788776
// TODO get matrix sizes dynamically so that we don't need to create a
789777
// function wrapper for each op instance
790-
LogicalResult matchAndRewrite_cpu(enzymexla::OrmqrOp op,
791-
PatternRewriter &rewriter) const {
778+
LogicalResult matchAndRewriteCPU(enzymexla::OrmqrOp op,
779+
PatternRewriter &rewriter) const {
792780
auto ctx = op->getContext();
793781
LLVMTypeConverter typeConverter(ctx);
794782

@@ -873,7 +861,7 @@ struct OrmqrOpLowering : public OpRewritePattern<enzymexla::OrmqrOp> {
873861
auto type_llvm_char = rewriter.getIntegerType(8);
874862

875863
std::string fn = "mqr_";
876-
if (auto prefix = lapack_precision_prefix(A_eltype)) {
864+
if (auto prefix = lapackPrecisionPrefix(A_eltype)) {
877865
if (prefix == "s" || prefix == "d")
878866
fn = *prefix + "or" + fn;
879867
else
@@ -1031,23 +1019,20 @@ struct GemqrtOpLowering : public OpRewritePattern<enzymexla::GemqrtOp> {
10311019
LogicalResult matchAndRewrite(enzymexla::GemqrtOp op,
10321020
PatternRewriter &rewriter) const override {
10331021
if (backend == "cpu")
1034-
return this->matchAndRewrite_cpu(op, rewriter);
1035-
1022+
return matchAndRewriteCPU(op, rewriter);
10361023
// else if (backend == "cuda")
1037-
// return this->matchAndRewrite_cuda(op, rewriter);
1038-
1024+
// return matchAndRewriteCUDA(op, rewriter);
10391025
// else if (backend == "tpu")
1040-
// return this->matchAndRewrite_tpu(op, rewriter);
1041-
1026+
// return matchAndRewriteTPU(op, rewriter);
10421027
else
10431028
return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend +
10441029
"\"");
10451030
}
10461031

10471032
// TODO get matrix sizes dynamically so that we don't need to create a
10481033
// function wrapper for each op instance
1049-
LogicalResult matchAndRewrite_cpu(enzymexla::GemqrtOp op,
1050-
PatternRewriter &rewriter) const {
1034+
LogicalResult matchAndRewriteCPU(enzymexla::GemqrtOp op,
1035+
PatternRewriter &rewriter) const {
10511036
auto ctx = op->getContext();
10521037
LLVMTypeConverter typeConverter(ctx);
10531038

@@ -1141,7 +1126,7 @@ struct GemqrtOpLowering : public OpRewritePattern<enzymexla::GemqrtOp> {
11411126
auto type_llvm_char = rewriter.getIntegerType(8);
11421127

11431128
std::string fn = "gemqrt_";
1144-
if (auto prefix = lapack_precision_prefix(C_eltype)) {
1129+
if (auto prefix = lapackPrecisionPrefix(C_eltype)) {
11451130
fn = *prefix + fn;
11461131
} else {
11471132
op->emitOpError() << "Unsupported element type: " << C_eltype;

0 commit comments

Comments
 (0)