Skip to content

Commit 838368d

Browse files
Added lit tests
1 parent e4591a1 commit 838368d

File tree

6 files changed

+151
-12
lines changed

6 files changed

+151
-12
lines changed

src/enzyme_ad/jax/Dialect/Tessera/Ops.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ def CallOp : TesseraOp<"call",
119119

120120
let results = (outs Variadic<AnyType>);
121121

122+
let assemblyFormat = [{
123+
$callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
124+
}];
125+
122126
let builders = [
123127
OpBuilder<(ins "DefineOp":$callee, CArg<"ValueRange", "{}">:$operands), [{
124128
$_state.addOperands(operands);

src/enzyme_ad/jax/Passes/Tessera/FuncToTessera.cpp

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,19 @@
66
//===----------------------------------------------------------------------===//
77

88
#include "mlir/Bytecode/BytecodeOpInterface.h"
9-
#include "src/enzyme_ad/jax/Dialect/Tessera/Ops.h"
109
#include "mlir/Dialect/Func/IR/FuncOps.h"
11-
#include "mlir/IR/PatternMatch.h"
10+
#include "mlir/IR/BuiltinDialect.h"
11+
#include "mlir/IR/BuiltinOps.h"
1212
#include "mlir/IR/IRMapping.h"
13-
#include "mlir/Transforms/DialectConversion.h"
14-
#include "mlir/Interfaces/FunctionInterfaces.h"
13+
#include "mlir/IR/PatternMatch.h"
1514
#include "mlir/Interfaces/CallInterfaces.h"
15+
#include "mlir/Interfaces/FunctionInterfaces.h"
1616
#include "mlir/Pass/Pass.h"
17-
#include "mlir/IR/BuiltinOps.h"
17+
#include "mlir/Transforms/DialectConversion.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1819
#include "src/enzyme_ad/jax/Dialect/Tessera/Dialect.h"
20+
#include "src/enzyme_ad/jax/Dialect/Tessera/Ops.h"
1921
#include "src/enzyme_ad/jax/Passes/Tessera/Passes.h"
20-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2122

2223
using namespace mlir;
2324
using namespace mlir::enzyme;
@@ -59,6 +60,22 @@ class FuncOpRewrite final : public OpRewritePattern<func::FuncOp> {
5960
funcOp.getBody().cloneInto(&tesseraDefineOp.getBody(),
6061
tesseraDefineOp.getBody().end(),
6162
mapper);
63+
64+
// Now walk through the cloned operations and convert func.return to
65+
// tessera.return
66+
tesseraDefineOp.walk([&](func::ReturnOp returnOp) {
67+
rewriter.setInsertionPoint(returnOp);
68+
rewriter.replaceOpWithNewOp<tessera::ReturnOp>(returnOp,
69+
returnOp.getOperands());
70+
});
71+
72+
// Convert func.call to tessera.call
73+
tesseraDefineOp.walk([&](func::CallOp callOp) {
74+
rewriter.setInsertionPoint(callOp);
75+
rewriter.replaceOpWithNewOp<tessera::CallOp>(
76+
callOp, callOp.getResultTypes(), callOp.getOperands(),
77+
callOp->getAttrs());
78+
});
6279
}
6380

6481
rewriter.eraseOp(funcOp);
@@ -81,7 +98,7 @@ class CallOpRewrite final : public OpRewritePattern<func::CallOp> {
8198
Operation *calleeOp = SymbolTable::lookupSymbolIn(moduleOp, calleeAttr);
8299

83100
// Only convert if the callee is a Tessera DefineOp
84-
if (isa<tessera::DefineOp>(calleeOp))
101+
if (!isa<tessera::DefineOp>(calleeOp))
85102
return rewriter.notifyMatchFailure(callOp, "Callee is not a Tessera DefineOp");
86103

87104
rewriter.replaceOpWithNewOp<tessera::CallOp>(callOp, callOp.getResultTypes(),
@@ -122,14 +139,28 @@ namespace mlir::enzyme::tessera {
122139
struct FuncToTesseraPass
123140
: public PassWrapper<FuncToTesseraPass, OperationPass<ModuleOp>> {
124141

142+
StringRef getArgument() const final { return "func-to-tessera"; }
143+
StringRef getDescription() const final {
144+
return "Convert func dialect to tessera dialect.";
145+
}
146+
147+
void getDependentDialects(DialectRegistry &registry) const override {
148+
registry.insert<tessera::TesseraDialect>();
149+
}
150+
125151
void runOnOperation() override {
126152
MLIRContext *ctx = &getContext();
153+
154+
ConversionTarget target(*ctx);
155+
target.addLegalDialect<tessera::TesseraDialect>();
156+
target.addLegalDialect<BuiltinDialect>();
157+
target.addIllegalDialect<func::FuncDialect>();
158+
127159
RewritePatternSet patterns(ctx);
128160

129161
patterns.add<FuncOpRewrite, CallOpRewrite, ReturnOpRewrite>(ctx);
130162

131-
if (failed(applyPatternsAndFoldGreedily(getOperation(),
132-
std::move(patterns))))
163+
if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
133164
signalPassFailure();
134165
}
135166
};

src/enzyme_ad/jax/Passes/Tessera/TesseraToFunc.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,8 @@ class DefineOpRewrite final : public OpRewritePattern<tessera::DefineOp> {
4444

4545

4646
// Create the `func.func` op
47-
auto funcOp = rewriter.create<tessera::DefineOp>(
48-
defineOp.getLoc(), defineOp.getName(), fnType);
49-
47+
auto funcOp = rewriter.create<func::FuncOp>(defineOp.getLoc(),
48+
defineOp.getName(), fnType);
5049

5150
// Copy over all attributes other than the function name and type.
5251
for (const auto &namedAttr : defineOp->getAttrs()) {
@@ -111,6 +110,15 @@ namespace mlir::enzyme::tessera {
111110
struct TesseraToFuncPass
112111
: public PassWrapper<TesseraToFuncPass, OperationPass<ModuleOp>> {
113112

113+
StringRef getArgument() const final { return "tessera-to-func"; }
114+
StringRef getDescription() const final {
115+
return "Convert tessera dialect to func dialect.";
116+
}
117+
118+
void getDependentDialects(DialectRegistry &registry) const override {
119+
registry.insert<func::FuncDialect>();
120+
}
121+
114122
void runOnOperation() override {
115123
MLIRContext *ctx = &getContext();
116124
RewritePatternSet patterns(ctx);
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: enzymexlamlir-opt %s -func-to-tessera | FileCheck %s
2+
3+
// CHECK-LABEL: tessera.define @simple_func
4+
func.func @simple_func() {
5+
// CHECK: tessera.return
6+
func.return
7+
}
8+
9+
// -----
10+
11+
// CHECK-LABEL: tessera.define @func_with_args
12+
func.func @func_with_args(%arg0: i32, %arg1: f32) -> i32 {
13+
// CHECK: tessera.return %arg0 : i32
14+
func.return %arg0 : i32
15+
}
16+
17+
// -----
18+
19+
// CHECK-LABEL: tessera.define @helper
20+
func.func @helper() {
21+
func.return
22+
}
23+
24+
// CHECK-LABEL: tessera.define @func_with_call
25+
func.func @func_with_call() {
26+
// CHECK: tessera.call @helper() : () -> ()
27+
func.call @helper() : () -> ()
28+
// CHECK: tessera.return
29+
func.return
30+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: enzymexlamlir-opt %s | FileCheck %s
2+
3+
// CHECK-LABEL: tessera.define @foo
4+
tessera.define @foo() {
5+
// CHECK: tessera.return
6+
tessera.return
7+
}
8+
9+
// -----
10+
11+
// CHECK-LABEL: tessera.define @bar
12+
tessera.define @bar() -> i32 {
13+
%c42_i32 = arith.constant 42 : i32
14+
// CHECK: tessera.return %{{.*}} : i32
15+
tessera.return %c42_i32 : i32
16+
}
17+
18+
// -----
19+
20+
// CHECK-LABEL: tessera.define @caller
21+
tessera.define @caller() {
22+
// CHECK: tessera.call @foo() : () -> ()
23+
tessera.call @foo() : () -> ()
24+
// CHECK: tessera.return
25+
tessera.return
26+
}
27+
28+
// -----
29+
30+
// CHECK-LABEL: tessera.define @with_args
31+
tessera.define @with_args(%arg0: i32, %arg1: f32) -> i32 {
32+
// CHECK: %[[V0:.*]] = tessera.call @bar() : () -> i32
33+
%0 = tessera.call @bar() : () -> i32
34+
// CHECK: tessera.return %[[V0]] : i32
35+
tessera.return %0 : i32
36+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: enzymexlamlir-opt %s -tessera-to-func | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @simple_func
4+
tessera.define @simple_func() {
5+
// CHECK: func.return
6+
tessera.return
7+
}
8+
9+
// -----
10+
11+
// CHECK-LABEL: func.func @func_with_args
12+
tessera.define @func_with_args(%arg0: i32, %arg1: f32) -> i32 {
13+
// CHECK: func.return %arg0 : i32
14+
tessera.return %arg0 : i32
15+
}
16+
17+
// -----
18+
19+
// CHECK-LABEL: func.func @helper
20+
tessera.define @helper() {
21+
tessera.return
22+
}
23+
24+
// CHECK-LABEL: func.func @func_with_call
25+
tessera.define @func_with_call() {
26+
// CHECK: func.call @helper() : () -> ()
27+
tessera.call @helper() : () -> ()
28+
// CHECK: func.return
29+
tessera.return
30+
}

0 commit comments

Comments
 (0)