Skip to content

Commit b789363

Browse files
committed
refactor: move into separate functions
1 parent 8f54d70 commit b789363

File tree

2 files changed

+533
-537
lines changed

2 files changed

+533
-537
lines changed

src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp

Lines changed: 33 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,20 @@ struct GeqrfOpLowering : public OpRewritePattern<enzymexla::GeqrfOp> {
4141
LogicalResult matchAndRewrite(enzymexla::GeqrfOp op,
4242
PatternRewriter &rewriter) const override {
4343
if (backend == "cpu")
44-
return this->matchAndRewrite_cpu(op, rewriter);
45-
44+
return matchAndRewriteCPU(op, rewriter);
4645
else if (backend == "cuda")
47-
return this->matchAndRewrite_cuda(op, rewriter);
48-
46+
return matchAndRewriteCUDA(op, rewriter);
4947
else if (backend == "tpu")
50-
return this->matchAndRewrite_tpu(op, rewriter);
51-
48+
return matchAndRewriteTPU(op, rewriter);
5249
else
5350
return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend +
5451
"\"");
5552
}
5653

5754
// TODO get matrix sizes dynamically so that we don't need to create a
5855
// function wrapper for each op instance
59-
LogicalResult matchAndRewrite_cpu(enzymexla::GeqrfOp op,
60-
PatternRewriter &rewriter) const {
56+
LogicalResult matchAndRewriteCPU(enzymexla::GeqrfOp op,
57+
PatternRewriter &rewriter) const {
6158
auto ctx = op->getContext();
6259
LLVMTypeConverter typeConverter(ctx);
6360

@@ -209,8 +206,8 @@ struct GeqrfOpLowering : public OpRewritePattern<enzymexla::GeqrfOp> {
209206
return success();
210207
}
211208

212-
LogicalResult matchAndRewrite_cuda(enzymexla::GeqrfOp op,
213-
PatternRewriter &rewriter) const {
209+
LogicalResult matchAndRewriteCUDA(enzymexla::GeqrfOp op,
210+
PatternRewriter &rewriter) const {
214211
auto ctx = op->getContext();
215212
LLVMTypeConverter typeConverter(ctx);
216213

@@ -265,8 +262,8 @@ struct GeqrfOpLowering : public OpRewritePattern<enzymexla::GeqrfOp> {
265262
return success();
266263
}
267264

268-
LogicalResult matchAndRewrite_tpu(enzymexla::GeqrfOp op,
269-
PatternRewriter &rewriter) const {
265+
LogicalResult matchAndRewriteTPU(enzymexla::GeqrfOp op,
266+
PatternRewriter &rewriter) const {
270267
auto ctx = op->getContext();
271268
LLVMTypeConverter typeConverter(ctx);
272269

@@ -316,23 +313,20 @@ struct GeqrtOpLowering : public OpRewritePattern<enzymexla::GeqrtOp> {
316313
LogicalResult matchAndRewrite(enzymexla::GeqrtOp op,
317314
PatternRewriter &rewriter) const override {
318315
if (backend == "cpu")
319-
return this->matchAndRewrite_cpu(op, rewriter);
320-
316+
return matchAndRewriteCPU(op, rewriter);
321317
// else if (backend == "cuda")
322-
// return this->matchAndRewrite_cuda(op, rewriter);
323-
318+
// return matchAndRewriteCUDA(op, rewriter);
324319
// else if (backend == "tpu")
325-
// return this->matchAndRewrite_tpu(op, rewriter);
326-
320+
// return matchAndRewriteTPU(op, rewriter);
327321
else
328322
return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend +
329323
"\"");
330324
}
331325

332326
// TODO get matrix sizes dynamically so that we don't need to create a
333327
// function wrapper for each op instance
334-
LogicalResult matchAndRewrite_cpu(enzymexla::GeqrtOp op,
335-
PatternRewriter &rewriter) const {
328+
LogicalResult matchAndRewriteCPU(enzymexla::GeqrtOp op,
329+
PatternRewriter &rewriter) const {
336330
auto ctx = op->getContext();
337331
LLVMTypeConverter typeConverter(ctx);
338332

@@ -523,23 +517,20 @@ struct OrgqrOpLowering : public OpRewritePattern<enzymexla::OrgqrOp> {
523517
LogicalResult matchAndRewrite(enzymexla::OrgqrOp op,
524518
PatternRewriter &rewriter) const override {
525519
if (backend == "cpu")
526-
return this->matchAndRewrite_cpu(op, rewriter);
527-
520+
return matchAndRewriteCPU(op, rewriter);
528521
else if (backend == "cuda")
529-
return this->matchAndRewrite_cuda(op, rewriter);
530-
522+
return matchAndRewriteCUDA(op, rewriter);
531523
else if (backend == "tpu")
532-
return this->matchAndRewrite_tpu(op, rewriter);
533-
524+
return matchAndRewriteTPU(op, rewriter);
534525
else
535526
return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend +
536527
"\"");
537528
}
538529

539530
// TODO get matrix sizes dynamically so that we don't need to create a
540531
// function wrapper for each op instance
541-
LogicalResult matchAndRewrite_cpu(enzymexla::OrgqrOp op,
542-
PatternRewriter &rewriter) const {
532+
LogicalResult matchAndRewriteCPU(enzymexla::OrgqrOp op,
533+
PatternRewriter &rewriter) const {
543534
auto ctx = op->getContext();
544535
LLVMTypeConverter typeConverter(ctx);
545536

@@ -688,8 +679,8 @@ struct OrgqrOpLowering : public OpRewritePattern<enzymexla::OrgqrOp> {
688679
return success();
689680
}
690681

691-
LogicalResult matchAndRewrite_cuda(enzymexla::OrgqrOp op,
692-
PatternRewriter &rewriter) const {
682+
LogicalResult matchAndRewriteCUDA(enzymexla::OrgqrOp op,
683+
PatternRewriter &rewriter) const {
693684
auto ctx = op->getContext();
694685
LLVMTypeConverter typeConverter(ctx);
695686

@@ -734,8 +725,8 @@ struct OrgqrOpLowering : public OpRewritePattern<enzymexla::OrgqrOp> {
734725
return success();
735726
}
736727

737-
LogicalResult matchAndRewrite_tpu(enzymexla::OrgqrOp op,
738-
PatternRewriter &rewriter) const {
728+
LogicalResult matchAndRewriteTPU(enzymexla::OrgqrOp op,
729+
PatternRewriter &rewriter) const {
739730
auto ctx = op->getContext();
740731
LLVMTypeConverter typeConverter(ctx);
741732

@@ -772,23 +763,20 @@ struct OrmqrOpLowering : public OpRewritePattern<enzymexla::OrmqrOp> {
772763
LogicalResult matchAndRewrite(enzymexla::OrmqrOp op,
773764
PatternRewriter &rewriter) const override {
774765
if (backend == "cpu")
775-
return this->matchAndRewrite_cpu(op, rewriter);
776-
766+
return matchAndRewriteCPU(op, rewriter);
777767
// else if (backend == "cuda")
778-
// return this->matchAndRewrite_cuda(op, rewriter);
779-
768+
// return matchAndRewriteCUDA(op, rewriter);
780769
// else if (backend == "tpu")
781-
// return this->matchAndRewrite_tpu(op, rewriter);
782-
770+
// return matchAndRewriteTPU(op, rewriter);
783771
else
784772
return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend +
785773
"\"");
786774
}
787775

788776
// TODO get matrix sizes dynamically so that we don't need to create a
789777
// function wrapper for each op instance
790-
LogicalResult matchAndRewrite_cpu(enzymexla::OrmqrOp op,
791-
PatternRewriter &rewriter) const {
778+
LogicalResult matchAndRewriteCPU(enzymexla::OrmqrOp op,
779+
PatternRewriter &rewriter) const {
792780
auto ctx = op->getContext();
793781
LLVMTypeConverter typeConverter(ctx);
794782

@@ -1031,23 +1019,20 @@ struct GemqrtOpLowering : public OpRewritePattern<enzymexla::GemqrtOp> {
10311019
LogicalResult matchAndRewrite(enzymexla::GemqrtOp op,
10321020
PatternRewriter &rewriter) const override {
10331021
if (backend == "cpu")
1034-
return this->matchAndRewrite_cpu(op, rewriter);
1035-
1022+
return matchAndRewriteCPU(op, rewriter);
10361023
// else if (backend == "cuda")
1037-
// return this->matchAndRewrite_cuda(op, rewriter);
1038-
1024+
// return matchAndRewriteCUDA(op, rewriter);
10391025
// else if (backend == "tpu")
1040-
// return this->matchAndRewrite_tpu(op, rewriter);
1041-
1026+
// return matchAndRewriteTPU(op, rewriter);
10421027
else
10431028
return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend +
10441029
"\"");
10451030
}
10461031

10471032
// TODO get matrix sizes dynamically so that we don't need to create a
10481033
// function wrapper for each op instance
1049-
LogicalResult matchAndRewrite_cpu(enzymexla::GemqrtOp op,
1050-
PatternRewriter &rewriter) const {
1034+
LogicalResult matchAndRewriteCPU(enzymexla::GemqrtOp op,
1035+
PatternRewriter &rewriter) const {
10511036
auto ctx = op->getContext();
10521037
LLVMTypeConverter typeConverter(ctx);
10531038

0 commit comments

Comments
 (0)