Skip to content

Commit a3f33e8

Browse files
Tessera struct returns (#2334)
* Handled struct returns in LLVM to Tessera conversion * Handled struct returns in lowering from Tessera back to LLVM
1 parent 886abc5 commit a3f33e8

File tree

5 files changed

+254
-51
lines changed

5 files changed

+254
-51
lines changed

src/enzyme_ad/jax/Dialect/Tessera/Ops.cpp

Lines changed: 59 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -123,30 +123,6 @@ DefineOp DefineOp::clone() {
123123
return clone(mapper);
124124
}
125125

126-
//===----------------------------------------------------------------------===//
127-
// ReturnOp
128-
//===----------------------------------------------------------------------===//
129-
130-
LogicalResult ReturnOp::verify() {
131-
auto function = cast<DefineOp>((*this)->getParentOp());
132-
133-
// The operand number and types must match the function signature.
134-
const auto &results = function.getFunctionType().getResults();
135-
if (getNumOperands() != results.size())
136-
return emitOpError("has ")
137-
<< getNumOperands() << " operands, but enclosing function (@"
138-
<< function.getName() << ") returns " << results.size();
139-
140-
for (unsigned i = 0, e = results.size(); i != e; ++i)
141-
if (getOperand(i).getType() != results[i])
142-
return emitError() << "type of return operand " << i << " ("
143-
<< getOperand(i).getType()
144-
<< ") doesn't match function result type ("
145-
<< results[i] << ")"
146-
<< " in function @" << function.getName();
147-
return success();
148-
}
149-
150126
//===----------------------------------------------------------------------===//
151127
// CallOp
152128
//===----------------------------------------------------------------------===//
@@ -161,31 +137,77 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
161137
return emitOpError() << "'" << fnAttr.getValue()
162138
<< "' does not reference a valid function";
163139

164-
// Verify that the operand and result types match the callee.
140+
// Verify that the operand and result types match the callee,
141+
// unless callee has attribute to indicate struct return.
142+
bool has_sret = (fn->hasAttr("tessera.sret_attrs"));
165143
auto fnType = fn.getFunctionType();
166-
if (fnType.getNumInputs() != getNumOperands())
144+
145+
// If tessera.define has sret attribute,
146+
// tessera.call operand count = tessera.define input count - 1
147+
if (has_sret && (fnType.getNumInputs() - 1) != getNumOperands())
148+
return emitOpError("incorrect number of operands for callee");
149+
if (!has_sret && fnType.getNumInputs() != getNumOperands())
167150
return emitOpError("incorrect number of operands for callee");
168151

169-
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
170-
if (getOperand(i).getType() != fnType.getInput(i))
152+
int startIdx = has_sret ? 1 : 0;
153+
for (unsigned i = startIdx, e = fnType.getNumInputs(); i != e; ++i)
154+
if (getOperand(i - startIdx).getType() != fnType.getInput(i))
171155
return emitOpError("operand type mismatch: expected operand type ")
172156
<< fnType.getInput(i) << ", but provided "
173-
<< getOperand(i).getType() << " for operand number " << i;
157+
<< getOperand(i - startIdx).getType() << " for operand number "
158+
<< i - startIdx;
174159

175-
if (fnType.getNumResults() != getNumResults())
160+
// If tessera.define has sret attribute,
161+
// tessera.call result count = tessera.define result count + 1
162+
if (has_sret && getNumResults() != 1)
163+
return emitOpError("incorrect number of results for callee");
164+
if (!has_sret && fnType.getNumResults() != getNumResults())
176165
return emitOpError("incorrect number of results for callee");
177166

178-
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
179-
if (getResult(i).getType() != fnType.getResult(i)) {
180-
auto diag = emitOpError("result type mismatch at index ") << i;
181-
diag.attachNote() << " op result types: " << getResultTypes();
182-
diag.attachNote() << "function result types: " << fnType.getResults();
183-
return diag;
184-
}
167+
if (has_sret) {
168+
auto argAttrs = fn.getArgAttrsAttr();
169+
auto firstArgAttr = cast<DictionaryAttr>(argAttrs[0]);
170+
auto sretType = cast<TypeAttr>(firstArgAttr.get("llvm.sret")).getValue();
171+
if (getResult(0).getType() != sretType)
172+
return emitOpError("result type mismatch: expected ")
173+
<< sretType << " but got " << getResult(0).getType();
174+
} else {
175+
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
176+
if (getResult(i).getType() != fnType.getResult(i)) {
177+
auto diag = emitOpError("result type mismatch at index ") << i;
178+
diag.attachNote() << " op result types: " << getResultTypes();
179+
diag.attachNote() << "function result types: " << fnType.getResults();
180+
return diag;
181+
}
182+
}
185183

186184
return success();
187185
}
188186

189187
FunctionType CallOp::getCalleeType() {
190188
return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
191189
}
190+
191+
//===----------------------------------------------------------------------===//
192+
// ReturnOp
193+
//===----------------------------------------------------------------------===//
194+
195+
LogicalResult ReturnOp::verify() {
196+
auto function = cast<DefineOp>((*this)->getParentOp());
197+
198+
// The operand number and types must match the function signature.
199+
const auto &results = function.getFunctionType().getResults();
200+
if (getNumOperands() != results.size())
201+
return emitOpError("has ")
202+
<< getNumOperands() << " operands, but enclosing function (@"
203+
<< function.getName() << ") returns " << results.size();
204+
205+
for (unsigned i = 0, e = results.size(); i != e; ++i)
206+
if (getOperand(i).getType() != results[i])
207+
return emitError() << "type of return operand " << i << " ("
208+
<< getOperand(i).getType()
209+
<< ") doesn't match function result type ("
210+
<< results[i] << ")"
211+
<< " in function @" << function.getName();
212+
return success();
213+
}

src/enzyme_ad/jax/Passes/Tessera/LLVMToTessera.cpp

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,22 @@ class FuncOpRewrite final : public OpRewritePattern<LLVM::LLVMFuncOp> {
5757
auto *ctx = funcOp->getContext();
5858
auto funcName = funcOp.getName();
5959
auto llvmFuncType = funcOp.getFunctionType();
60-
auto fnType =
61-
FunctionType::get(ctx, llvmFuncType.getParams(),
62-
isa<LLVM::LLVMVoidType>(llvmFuncType.getReturnType())
63-
? TypeRange{}
64-
: TypeRange{llvmFuncType.getReturnType()});
60+
auto params = llvmFuncType.getParams();
61+
auto retType = llvmFuncType.getReturnType();
62+
63+
// Check if first argument has sret attribute
64+
bool hasSret = false;
65+
auto argAttrs = funcOp.getArgAttrsAttr();
66+
if (!params.empty() && argAttrs) {
67+
auto firstArgAttrs = cast<DictionaryAttr>(argAttrs[0]);
68+
if (auto sretAttr =
69+
firstArgAttrs.get(LLVM::LLVMDialect::getStructRetAttrName()))
70+
hasSret = true;
71+
}
72+
73+
auto fnType = FunctionType::get(
74+
ctx, params,
75+
isa<LLVM::LLVMVoidType>(retType) ? TypeRange{} : TypeRange{retType});
6576

6677
// Replace current function name with tessera name defined in
6778
// tessera.convert attribute
@@ -84,6 +95,11 @@ class FuncOpRewrite final : public OpRewritePattern<LLVM::LLVMFuncOp> {
8495
tesseraDefineOp->setAttr("tessera.original_name",
8596
rewriter.getStringAttr(funcName));
8697

98+
// Add attribute if function uses struct return and store the first arg's
99+
// attributes for exact reconstruction later
100+
if (hasSret)
101+
tesseraDefineOp->setAttr("tessera.sret_attrs", argAttrs[0]);
102+
87103
// Clone body of function
88104
if (!funcOp.isExternal()) {
89105
rewriter.inlineRegionBefore(funcOp.getBody(), tesseraDefineOp.getBody(),
@@ -114,9 +130,50 @@ class CallOpRewrite final : public OpRewritePattern<LLVM::CallOp> {
114130
if (!isa_and_nonnull<tessera::DefineOp>(callee))
115131
return failure();
116132

117-
rewriter.replaceOpWithNewOp<tessera::CallOp>(
118-
callOp, callOp.getResultTypes(), callOp.getOperands(),
119-
callOp->getAttrs());
133+
// Check if first operand has sret attribute. If so, remove it from
134+
// the operand list and use its pointed-to type as the SSA return type,
135+
// since tessera.call returns values directly rather than writing through
136+
// a pointer.
137+
Value sretPtr;
138+
Type sretType;
139+
auto operands = callOp.getOperands();
140+
auto argAttrs = callOp.getArgAttrsAttr();
141+
SmallVector<Value> newOperands;
142+
SmallVector<Attribute> newArgAttrs;
143+
SmallVector<NamedAttribute> newAttrs;
144+
145+
if (!operands.empty() && argAttrs) {
146+
auto firstArgAttrs = cast<DictionaryAttr>(argAttrs[0]);
147+
if (auto sretAttr =
148+
firstArgAttrs.get(LLVM::LLVMDialect::getStructRetAttrName())) {
149+
sretPtr = callOp.getOperand(0);
150+
sretType = cast<TypeAttr>(sretAttr).getValue();
151+
// Build operands and arg attributes without first element
152+
for (int i = 1; i < operands.size(); i++)
153+
newOperands.push_back(callOp.getOperand(i));
154+
for (int j = 1; j < argAttrs.size(); j++)
155+
newArgAttrs.push_back(argAttrs[j]);
156+
// Filter out arg_attrs from attributes
157+
for (auto attr : callOp->getAttrs()) {
158+
if (attr.getName() != callOp.getArgAttrsAttrName())
159+
newAttrs.push_back(attr);
160+
}
161+
}
162+
}
163+
164+
// Create tessera.call op with SSA return type
165+
if (sretPtr) {
166+
auto newCall = rewriter.create<tessera::CallOp>(
167+
callOp.getLoc(), TypeRange{sretType}, newOperands, newAttrs);
168+
rewriter.create<LLVM::StoreOp>(callOp.getLoc(), newCall.getResult(0),
169+
sretPtr);
170+
newCall->setAttr(newCall.getArgAttrsAttrName(),
171+
rewriter.getArrayAttr(newArgAttrs));
172+
rewriter.eraseOp(callOp);
173+
} else {
174+
rewriter.replaceOpWithNewOp<tessera::CallOp>(
175+
callOp, callOp.getResultTypes(), operands, callOp->getAttrs());
176+
}
120177

121178
return success();
122179
}
@@ -154,7 +211,10 @@ struct LLVMToTesseraPass
154211

155212
patterns.add<FuncOpRewrite, CallOpRewrite, ReturnOpRewrite>(ctx);
156213

157-
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
214+
GreedyRewriteConfig config;
215+
config.setUseTopDownTraversal(true);
216+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns),
217+
config))) {
158218
llvm::errs() << "Failed to convert LLVM dialect operations to tessera "
159219
"dialect operations\n";
160220
signalPassFailure();

src/enzyme_ad/jax/Passes/Tessera/TesseraToLLVM.cpp

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,13 @@ class DefineOpRewrite final : public OpRewritePattern<tessera::DefineOp> {
7878
auto funcOp = LLVM::LLVMFuncOp::create(rewriter, defineOp.getLoc(),
7979
funcName, llvmFuncType);
8080

81-
// Copy over all attributes other than the function name and type.
81+
// Copy over all attributes other than the function name and type and
82+
// attributes used only for tessera conversion
8283
for (const auto &namedAttr : defineOp->getAttrs()) {
8384
if (namedAttr.getName() != defineOp.getFunctionTypeAttrName() &&
8485
namedAttr.getName() != SymbolTable::getSymbolAttrName() &&
85-
namedAttr.getName() != "tessera.original_name")
86+
namedAttr.getName() != "tessera.original_name" &&
87+
namedAttr.getName() != "tessera.sret_attrs")
8688
funcOp->setAttr(namedAttr.getName(), namedAttr.getValue());
8789
}
8890

@@ -109,9 +111,73 @@ class CallOpRewrite final : public OpRewritePattern<tessera::CallOp> {
109111
LogicalResult matchAndRewrite(tessera::CallOp callOp,
110112
PatternRewriter &rewriter) const override {
111113

112-
rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, callOp.getResultTypes(),
113-
callOp.getOperands(),
114-
callOp->getAttrs());
114+
auto calleeAttr = callOp.getCalleeAttr();
115+
if (!calleeAttr)
116+
return failure();
117+
118+
auto callee = SymbolTable::lookupSymbolIn(
119+
callOp->getParentOfType<ModuleOp>(), calleeAttr);
120+
121+
// Check if callee has sret attribute. If so, allocate new pointer to
122+
// contain result of tessera.call and insert as first argument in llvm.call.
123+
auto defineOp = dyn_cast_or_null<tessera::DefineOp>(callee);
124+
if (!defineOp)
125+
return failure();
126+
127+
auto sretAttrs =
128+
defineOp->getAttrOfType<DictionaryAttr>("tessera.sret_attrs");
129+
if (sretAttrs) {
130+
if (callOp.getNumResults() == 0)
131+
return callOp.emitOpError(
132+
"tessera.call to sret function must have a result");
133+
auto sretType = callOp.getResult(0).getType();
134+
int64_t alignment = 0;
135+
if (auto alignAttr = sretAttrs.get(LLVM::LLVMDialect::getAlignAttrName()))
136+
alignment = cast<IntegerAttr>(alignAttr).getInt();
137+
Value one = rewriter.create<LLVM::ConstantOp>(
138+
callOp.getLoc(), rewriter.getI32Type(),
139+
rewriter.getI32IntegerAttr(1));
140+
141+
// Allocate stack storage for the sret return value
142+
Value sretPtr = rewriter.create<LLVM::AllocaOp>(
143+
callOp.getLoc(), LLVM::LLVMPointerType::get(callOp->getContext()),
144+
sretType, one, alignment);
145+
146+
// Build new operands with sretPtr as first arg
147+
SmallVector<Value> newOperands;
148+
newOperands.push_back(sretPtr);
149+
for (auto operand : callOp.getOperands())
150+
newOperands.push_back(operand);
151+
152+
// Reconstruct arg attributes with sret attr first
153+
SmallVector<Attribute> newArgAttrs;
154+
newArgAttrs.push_back(sretAttrs);
155+
if (auto argAttrs = callOp.getArgAttrsAttr()) {
156+
for (auto argAttr : argAttrs)
157+
newArgAttrs.push_back(argAttr);
158+
}
159+
160+
// Filter out arg_attrs from attributes
161+
SmallVector<NamedAttribute> newAttrs;
162+
for (auto attr : callOp->getAttrs()) {
163+
if (attr.getName() != callOp.getArgAttrsAttrName())
164+
newAttrs.push_back(attr);
165+
}
166+
167+
auto newCall = rewriter.create<LLVM::CallOp>(callOp.getLoc(), TypeRange{},
168+
newOperands, newAttrs);
169+
newCall->setAttr(newCall.getArgAttrsAttrName(),
170+
rewriter.getArrayAttr(newArgAttrs));
171+
172+
// Load result from sret pointer and replace uses
173+
auto loadedResult =
174+
rewriter.create<LLVM::LoadOp>(callOp.getLoc(), sretType, sretPtr);
175+
rewriter.replaceOp(callOp, loadedResult.getResult());
176+
} else {
177+
rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, callOp.getResultTypes(),
178+
callOp.getOperands(),
179+
callOp->getAttrs());
180+
}
115181

116182
return success();
117183
}

test/lit_tests/tessera/llvm_to_tessera.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ llvm.func @simple_func() attributes {tessera.convert = #tessera<convert "tessera
55
}
66

77
// CHECK-LABEL: tessera.define @tessera_simple_func
8+
// CHECK-SAME: tessera.original_name = "simple_func"
89
// CHECK: tessera.return
910

1011
// -----
@@ -53,3 +54,28 @@ llvm.func @func_with_indirect_call(%arg0: !llvm.ptr) {
5354

5455
// CHECK-LABEL: llvm.func @func_with_indirect_call
5556
// CHECK: llvm.call %arg0() : !llvm.ptr, () -> ()
57+
58+
// -----
59+
60+
llvm.func @sret_func(%arg0: !llvm.ptr {llvm.sret = !llvm.struct<(f32, f32)>, llvm.align = 8 : i64, llvm.nonnull}, %arg1: !llvm.ptr {llvm.noundef, llvm.readonly}) attributes {tessera.convert = #tessera<convert "tessera_sret_func">} {
61+
%0 = llvm.load %arg1 {alignment = 8 : i64} : !llvm.ptr -> f32
62+
llvm.store %0, %arg0 {alignment = 8 : i64} : f32, !llvm.ptr
63+
llvm.return
64+
}
65+
66+
llvm.func @caller() {
67+
%0 = llvm.mlir.constant(1 : i32) : i32
68+
%1 = llvm.alloca %0 x !llvm.struct<(f32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
69+
%2 = llvm.alloca %0 x !llvm.struct<(f32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
70+
llvm.call @sret_func(%1, %2) : (!llvm.ptr {llvm.align = 8 : i64, llvm.nonnull, llvm.sret = !llvm.struct<(f32, f32)>}, !llvm.ptr {llvm.nonnull, llvm.noundef}) -> ()
71+
llvm.return
72+
}
73+
74+
// CHECK-LABEL: tessera.define @tessera_sret_func
75+
// CHECK-SAME: tessera.sret_attrs = {llvm.align = 8 : i64, llvm.nonnull, llvm.sret = !llvm.struct<(f32, f32)>}
76+
// CHECK: tessera.return
77+
78+
// CHECK-LABEL: llvm.func @caller
79+
// CHECK: %[[RES:.*]] = tessera.call @tessera_sret_func
80+
// CHECK-SAME: -> !llvm.struct<(f32, f32)>
81+
// CHECK: llvm.store %[[RES]], %{{.*}} : !llvm.struct<(f32, f32)>, !llvm.ptr

test/lit_tests/tessera/tessera_to_llvm.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,32 @@ tessera.define @tessera_func_with_call() attributes {tessera.original_name = "fu
3333
// CHECK-LABEL: llvm.func @func_with_call
3434
// CHECK: llvm.call @helper() : () -> ()
3535
// CHECK: llvm.return
36+
37+
// -----
38+
39+
tessera.define @tessera_sret_func(%arg0: !llvm.ptr {llvm.align = 8 : i64, llvm.nonnull, llvm.sret = !llvm.struct<(f32, f32)>}, %arg1: !llvm.ptr {llvm.noundef, llvm.readonly}) attributes {CConv = #llvm.cconv<ccc>, linkage = #llvm.linkage<external>, tessera.original_name = "sret_func", tessera.sret_attrs = {llvm.align = 8 : i64, llvm.nonnull, llvm.sret = !llvm.struct<(f32, f32)>}, unnamed_addr = 0 : i64, visibility_ = 0 : i64} {
40+
%0 = llvm.load %arg1 {alignment = 8 : i64} : !llvm.ptr -> f32
41+
llvm.store %0, %arg0 {alignment = 8 : i64} : f32, !llvm.ptr
42+
tessera.return
43+
}
44+
45+
llvm.func @caller() {
46+
%0 = llvm.mlir.constant(1 : i32) : i32
47+
%1 = llvm.alloca %0 x !llvm.struct<(f32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
48+
%2 = llvm.alloca %0 x !llvm.struct<(f32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
49+
%3 = tessera.call @tessera_sret_func(%2) {CConv = #llvm.cconv<ccc>, TailCallKind = #llvm.tailcallkind<none>, arg_attrs = [{llvm.nonnull, llvm.noundef}], fastmathFlags = #llvm.fastmath<none>, op_bundle_sizes = array<i32>, operandSegmentSizes = array<i32: 2, 0>} : (!llvm.ptr) -> !llvm.struct<(f32, f32)>
50+
llvm.store %3, %1 : !llvm.struct<(f32, f32)>, !llvm.ptr
51+
llvm.return
52+
}
53+
54+
// CHECK-LABEL: llvm.func @sret_func
55+
// CHECK-SAME: !llvm.ptr {llvm.align = 8 : i64, llvm.nonnull, llvm.sret = !llvm.struct<(f32, f32)>}
56+
// CHECK-LABEL: llvm.func @caller
57+
// CHECK: %[[A1:.*]] = llvm.alloca
58+
// CHECK: %[[A2:.*]] = llvm.alloca
59+
// CHECK: %[[SRET:.*]] = llvm.alloca
60+
// CHECK: llvm.call @sret_func(%[[SRET]], %[[A2]])
61+
// CHECK-SAME: !llvm.ptr {llvm.align = 8 : i64, llvm.nonnull, llvm.sret = !llvm.struct<(f32, f32)>}
62+
// CHECK: %[[LOADED:.*]] = llvm.load %[[SRET]]
63+
// CHECK: llvm.store %[[LOADED]], {{.*}}
64+
// CHECK-NOT: tessera.call

0 commit comments

Comments
 (0)