Skip to content

Commit 5c632cc

Browse files
authored
Cache batched functions and recursively batch. (#2222)
* don't batch the same function twice. * batch functions that are called within batched function * initial tests * support recursive functions * recursion test * formatting
1 parent ec3a788 commit 5c632cc

File tree

4 files changed

+174
-12
lines changed

4 files changed

+174
-12
lines changed

enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp

Lines changed: 106 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,24 @@ using namespace enzyme;
2828

2929
namespace {
3030

31+
struct BatchCacheKey {
32+
FunctionOpInterface function;
33+
SmallVector<int64_t> batchSizes;
34+
35+
// for use in std::map:
36+
bool operator<(const BatchCacheKey &other) const {
37+
if (const_cast<FunctionOpInterface &>(function).getName() !=
38+
const_cast<FunctionOpInterface &>(other.function).getName())
39+
return const_cast<FunctionOpInterface &>(function).getName() <
40+
const_cast<FunctionOpInterface &>(other.function).getName();
41+
return batchSizes < other.batchSizes;
42+
}
43+
};
44+
45+
static FunctionOpInterface batchCloneFunction(
46+
FunctionOpInterface F, Twine name, llvm::ArrayRef<int64_t> batchSizes,
47+
std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache);
48+
3149
static mlir::TensorType applyBatchSizes(mlir::Type Ty,
3250
llvm::ArrayRef<int64_t> batchSizes) {
3351
auto T = dyn_cast<TensorType>(Ty);
@@ -41,8 +59,56 @@ static mlir::TensorType applyBatchSizes(mlir::Type Ty,
4159
return T2;
4260
}
4361

44-
static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper,
45-
llvm::ArrayRef<int64_t> batchSizes) {
62+
static LogicalResult handleCallOp(
63+
func::CallOp callOp, OpBuilder &builder, IRMapping &mapper,
64+
llvm::ArrayRef<int64_t> batchSizes,
65+
std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
66+
// Get the called function
67+
auto moduleOp = callOp->getParentOfType<ModuleOp>();
68+
auto calledFunc =
69+
dyn_cast<FunctionOpInterface>(moduleOp.lookupSymbol(callOp.getCallee()));
70+
if (!calledFunc)
71+
return failure();
72+
73+
// Create cache key for this function and batch size combination
74+
BatchCacheKey key{calledFunc,
75+
SmallVector<int64_t>(batchSizes.begin(), batchSizes.end())};
76+
77+
// Look up or create batched version of the called function
78+
FunctionOpInterface batchedFunc;
79+
auto it = batchedFunctionCache.find(key);
80+
if (it != batchedFunctionCache.end()) {
81+
batchedFunc = it->second;
82+
} else {
83+
batchedFunc =
84+
batchCloneFunction(calledFunc, "batched_" + calledFunc.getName(),
85+
batchSizes, batchedFunctionCache);
86+
if (!batchedFunc)
87+
return failure();
88+
batchedFunctionCache[key] = batchedFunc;
89+
}
90+
91+
// Create new call operation to the batched function
92+
SmallVector<Value> newOperands;
93+
for (auto operand : callOp->getOperands())
94+
newOperands.push_back(mapper.lookup(operand));
95+
96+
auto newCall =
97+
builder.create<func::CallOp>(callOp.getLoc(), batchedFunc.getName(),
98+
batchedFunc.getResultTypes(), newOperands);
99+
100+
// Map the results
101+
for (auto [oldResult, newResult] :
102+
llvm::zip(callOp.getResults(), newCall.getResults()))
103+
mapper.map(oldResult, newResult);
104+
105+
return success();
106+
}
107+
108+
static void batchCloneRegion(
109+
Region *src, Region *dest, IRMapping &mapper,
110+
llvm::ArrayRef<int64_t> batchSizes,
111+
std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
46112
// For each block in src, generate a corresponding block in the dest region.
47113
for (auto &blk : *src) {
48114
auto newBlk = new Block();
@@ -61,6 +127,12 @@ static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper,
61127
OpBuilder builder(&newBlk, newBlk.end());
62128
for (auto &src : blk) {
63129

130+
if (auto callOp = dyn_cast<func::CallOp>(&src)) {
131+
if (succeeded(handleCallOp(callOp, builder, mapper, batchSizes,
132+
batchedFunctionCache)))
133+
continue;
134+
}
135+
64136
if (auto ifaceOp = dyn_cast<BatchOpInterface>(&src)) {
65137
auto res = ifaceOp.createBatch(builder, mapper, batchSizes);
66138
if (res.succeeded())
@@ -93,7 +165,8 @@ static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper,
93165
// Clone the regions.
94166
for (auto &&[oldReg, newReg] :
95167
llvm::zip(src.getRegions(), newOp->getRegions())) {
96-
batchCloneRegion(&oldReg, &newReg, mapper, batchSizes);
168+
batchCloneRegion(&oldReg, &newReg, mapper, batchSizes,
169+
batchedFunctionCache);
97170
}
98171

99172
// Remember the mapping of any results.
@@ -105,9 +178,9 @@ static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper,
105178
}
106179
}
107180

108-
static FunctionOpInterface
109-
batchCloneFunction(FunctionOpInterface F, Twine name,
110-
llvm::ArrayRef<int64_t> batchSizes) {
181+
static FunctionOpInterface batchCloneFunction(
182+
FunctionOpInterface F, Twine name, llvm::ArrayRef<int64_t> batchSizes,
183+
std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
111184
assert(!F.getFunctionBody().empty());
112185

113186
auto FTy = F.getFunctionType().cast<FunctionType>();
@@ -138,30 +211,51 @@ batchCloneFunction(FunctionOpInterface F, Twine name,
138211
table.insert(NewF);
139212
SymbolTable::setSymbolVisibility(NewF, SymbolTable::Visibility::Private);
140213

214+
// Add the function to the cache BEFORE processing its body to support
215+
// recursion.
216+
BatchCacheKey key{F,
217+
SmallVector<int64_t>(batchSizes.begin(), batchSizes.end())};
218+
batchedFunctionCache[key] = NewF;
219+
141220
auto &origReg = F.getFunctionBody();
142221
auto &newReg = NewF.getFunctionBody();
143222

144223
IRMapping mapper;
145-
batchCloneRegion(&origReg, &newReg, mapper, batchSizes);
224+
batchCloneRegion(&origReg, &newReg, mapper, batchSizes, batchedFunctionCache);
146225

147226
return NewF;
148227
}
149228

150229
struct BatchPass : public BatchPassBase<BatchPass> {
151230
void runOnOperation() override;
152231

232+
// Cache mapping original function and batch sizes to batched function
233+
std::map<BatchCacheKey, FunctionOpInterface> batchedFunctionCache;
234+
153235
template <typename T>
154236
LogicalResult HandleBatch(SymbolTableCollection &symbolTable, T CI) {
155237
SmallVector<mlir::Value, 2> args;
156238

157239
auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr());
158240
auto fn = cast<FunctionOpInterface>(symbolOp);
159241

160-
FunctionOpInterface newFunc =
161-
batchCloneFunction(fn, "batched_" + fn.getName(), CI.getBatchShape());
162-
163-
if (!newFunc)
164-
return failure();
242+
BatchCacheKey key{fn, SmallVector<int64_t>(CI.getBatchShape().begin(),
243+
CI.getBatchShape().end())};
244+
245+
// Check if we already have a batched version
246+
auto it = batchedFunctionCache.find(key);
247+
FunctionOpInterface newFunc;
248+
249+
if (it != batchedFunctionCache.end()) {
250+
newFunc = it->second;
251+
} else {
252+
// Create new batched function and store in cache
253+
newFunc = batchCloneFunction(fn, "batched_" + fn.getName(),
254+
CI.getBatchShape(), batchedFunctionCache);
255+
if (!newFunc) {
256+
return failure();
257+
}
258+
}
165259

166260
OpBuilder builder(CI);
167261
auto dCI =
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
module {
2+
func.func private @f(%arg0: tensor<16xf32>, %arg1: tensor<16xf32>) -> tensor<16xf32> {
3+
return %arg0 : tensor<16xf32>
4+
}
5+
func.func @main(%arg0: tensor<4x16xf32>, %arg1: tensor<4x16xf32>) {
6+
%2 = enzyme.batch @f(%arg0, %arg1) {batch_shape = array<i64: 4>} : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
7+
%3 = enzyme.batch @f(%arg0, %arg1) {batch_shape = array<i64: 4>} : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
8+
return
9+
}
10+
}
11+
12+
// CHECK: func.func @main(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) {
13+
// CHECK-NEXT: %[[v0:.+]] = call @batched_f(%[[arg0]], %[[arg1]]) : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
14+
// CHECK-NEXT: %[[v1:.+]] = call @batched_f(%[[arg0]], %[[arg1]]) : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
15+
// CHECK-NEXT: return
16+
// CHECK-NEXT: }
17+
// CHECK: func.func private @batched_f(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) -> tensor<4x16xf32> {
18+
// CHECK-NEXT: return %[[arg0]] : tensor<4x16xf32>
19+
// CHECK-NEXT: }

enzyme/test/MLIR/Batch/call.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: %eopt -enzyme-batch %s | FileCheck %s
2+
3+
module {
4+
func.func private @g(%arg0: tensor<16xf32>) -> tensor<16xf32> {
5+
return %arg0 : tensor<16xf32>
6+
}
7+
func.func private @f(%arg0: tensor<16xf32>, %arg1: tensor<16xf32>) -> tensor<16xf32> {
8+
%1 = func.call @g(%arg0) : (tensor<16xf32>) -> tensor<16xf32>
9+
return %1 : tensor<16xf32>
10+
}
11+
func.func @main(%arg0: tensor<4x16xf32>, %arg1: tensor<4x16xf32>) {
12+
%2 = enzyme.batch @f(%arg0, %arg1) {batch_shape = array<i64: 4>} : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
13+
return
14+
}
15+
}
16+
17+
// CHECK: func.func @main(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) {
18+
// CHECK-NEXT: %[[v0:.+]] = call @batched_f(%[[arg0]], %[[arg1]]) : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
19+
// CHECK-NEXT: return
20+
// CHECK-NEXT: }
21+
// CHECK: func.func private @batched_f(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) -> tensor<4x16xf32> {
22+
// CHECK-NEXT: %[[v0:.+]] = call @batched_g(%[[arg0]]) : (tensor<4x16xf32>) -> tensor<4x16xf32>
23+
// CHECK-NEXT: return %[[v0]] : tensor<4x16xf32>
24+
// CHECK-NEXT: }
25+
// CHECK: func.func private @batched_g(%[[arg0:.+]]: tensor<4x16xf32>) -> tensor<4x16xf32> {
26+
// CHECK-NEXT: return %[[arg0]] : tensor<4x16xf32>
27+
// CHECK-NEXT: }
28+
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: %eopt -enzyme-batch %s | FileCheck %s
2+
3+
module {
4+
func.func private @f(%arg0: tensor<16xf32>, %arg1: tensor<16xf32>) -> tensor<16xf32> {
5+
%0 = func.call @f(%arg0, %arg1) : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
6+
return %0 : tensor<16xf32>
7+
}
8+
func.func @main(%arg0: tensor<4x16xf32>, %arg1: tensor<4x16xf32>) {
9+
%0 = enzyme.batch @f(%arg0, %arg1) {batch_shape = array<i64: 4>} : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
10+
return
11+
}
12+
}
13+
14+
// CHECK: func.func @main(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) {
15+
// CHECK-NEXT: %[[v0:.+]] = call @batched_f(%[[arg0]], %[[arg1]]) : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
16+
// CHECK-NEXT: return
17+
// CHECK-NEXT: }
18+
// CHECK: func.func private @batched_f(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) -> tensor<4x16xf32> {
19+
// CHECK-NEXT: %[[v0:.+]] = call @batched_f(%[[arg0]], %[[arg1]]) : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
20+
// CHECK-NEXT: return %[[v0]] : tensor<4x16xf32>
21+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)