@@ -28,6 +28,24 @@ using namespace enzyme;
2828
2929namespace {
3030
31+ struct BatchCacheKey {
32+ FunctionOpInterface function;
33+ SmallVector<int64_t > batchSizes;
34+
35+ // for use in std::map:
36+ bool operator <(const BatchCacheKey &other) const {
37+ if (const_cast <FunctionOpInterface &>(function).getName () !=
38+ const_cast <FunctionOpInterface &>(other.function ).getName ())
39+ return const_cast <FunctionOpInterface &>(function).getName () <
40+ const_cast <FunctionOpInterface &>(other.function ).getName ();
41+ return batchSizes < other.batchSizes ;
42+ }
43+ };
44+
45+ static FunctionOpInterface batchCloneFunction (
46+ FunctionOpInterface F, Twine name, llvm::ArrayRef<int64_t > batchSizes,
47+ std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache);
48+
3149static mlir::TensorType applyBatchSizes (mlir::Type Ty,
3250 llvm::ArrayRef<int64_t > batchSizes) {
3351 auto T = dyn_cast<TensorType>(Ty);
@@ -41,8 +59,56 @@ static mlir::TensorType applyBatchSizes(mlir::Type Ty,
4159 return T2;
4260}
4361
44- static void batchCloneRegion (Region *src, Region *dest, IRMapping &mapper,
45- llvm::ArrayRef<int64_t > batchSizes) {
62+ static LogicalResult handleCallOp (
63+ func::CallOp callOp, OpBuilder &builder, IRMapping &mapper,
64+ llvm::ArrayRef<int64_t > batchSizes,
65+ std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
66+ // Get the called function
67+ auto moduleOp = callOp->getParentOfType <ModuleOp>();
68+ auto calledFunc =
69+ dyn_cast<FunctionOpInterface>(moduleOp.lookupSymbol (callOp.getCallee ()));
70+ if (!calledFunc)
71+ return failure ();
72+
73+ // Create cache key for this function and batch size combination
74+ BatchCacheKey key{calledFunc,
75+ SmallVector<int64_t >(batchSizes.begin (), batchSizes.end ())};
76+
77+ // Look up or create batched version of the called function
78+ FunctionOpInterface batchedFunc;
79+ auto it = batchedFunctionCache.find (key);
80+ if (it != batchedFunctionCache.end ()) {
81+ batchedFunc = it->second ;
82+ } else {
83+ batchedFunc =
84+ batchCloneFunction (calledFunc, " batched_" + calledFunc.getName (),
85+ batchSizes, batchedFunctionCache);
86+ if (!batchedFunc)
87+ return failure ();
88+ batchedFunctionCache[key] = batchedFunc;
89+ }
90+
91+ // Create new call operation to the batched function
92+ SmallVector<Value> newOperands;
93+ for (auto operand : callOp->getOperands ())
94+ newOperands.push_back (mapper.lookup (operand));
95+
96+ auto newCall =
97+ builder.create <func::CallOp>(callOp.getLoc (), batchedFunc.getName (),
98+ batchedFunc.getResultTypes (), newOperands);
99+
100+ // Map the results
101+ for (auto [oldResult, newResult] :
102+ llvm::zip (callOp.getResults (), newCall.getResults ()))
103+ mapper.map (oldResult, newResult);
104+
105+ return success ();
106+ }
107+
108+ static void batchCloneRegion (
109+ Region *src, Region *dest, IRMapping &mapper,
110+ llvm::ArrayRef<int64_t > batchSizes,
111+ std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
46112 // For each block in src, generate a corresponding block in the dest region.
47113 for (auto &blk : *src) {
48114 auto newBlk = new Block ();
@@ -61,6 +127,12 @@ static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper,
61127 OpBuilder builder (&newBlk, newBlk.end ());
62128 for (auto &src : blk) {
63129
130+ if (auto callOp = dyn_cast<func::CallOp>(&src)) {
131+ if (succeeded (handleCallOp (callOp, builder, mapper, batchSizes,
132+ batchedFunctionCache)))
133+ continue ;
134+ }
135+
64136 if (auto ifaceOp = dyn_cast<BatchOpInterface>(&src)) {
65137 auto res = ifaceOp.createBatch (builder, mapper, batchSizes);
66138 if (res.succeeded ())
@@ -93,7 +165,8 @@ static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper,
93165 // Clone the regions.
94166 for (auto &&[oldReg, newReg] :
95167 llvm::zip (src.getRegions (), newOp->getRegions ())) {
96- batchCloneRegion (&oldReg, &newReg, mapper, batchSizes);
168+ batchCloneRegion (&oldReg, &newReg, mapper, batchSizes,
169+ batchedFunctionCache);
97170 }
98171
99172 // Remember the mapping of any results.
@@ -105,9 +178,9 @@ static void batchCloneRegion(Region *src, Region *dest, IRMapping &mapper,
105178 }
106179}
107180
108- static FunctionOpInterface
109- batchCloneFunction ( FunctionOpInterface F, Twine name,
110- llvm::ArrayRef< int64_t > batchSizes ) {
181+ static FunctionOpInterface batchCloneFunction (
182+ FunctionOpInterface F, Twine name, llvm::ArrayRef< int64_t > batchSizes ,
183+ std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache ) {
111184 assert (!F.getFunctionBody ().empty ());
112185
113186 auto FTy = F.getFunctionType ().cast <FunctionType>();
@@ -138,30 +211,51 @@ batchCloneFunction(FunctionOpInterface F, Twine name,
138211 table.insert (NewF);
139212 SymbolTable::setSymbolVisibility (NewF, SymbolTable::Visibility::Private);
140213
214+ // Add the function to the cache BEFORE processing its body to support
215+ // recursion.
216+ BatchCacheKey key{F,
217+ SmallVector<int64_t >(batchSizes.begin (), batchSizes.end ())};
218+ batchedFunctionCache[key] = NewF;
219+
141220 auto &origReg = F.getFunctionBody ();
142221 auto &newReg = NewF.getFunctionBody ();
143222
144223 IRMapping mapper;
145- batchCloneRegion (&origReg, &newReg, mapper, batchSizes);
224+ batchCloneRegion (&origReg, &newReg, mapper, batchSizes, batchedFunctionCache );
146225
147226 return NewF;
148227}
149228
150229struct BatchPass : public BatchPassBase <BatchPass> {
151230 void runOnOperation () override ;
152231
232+ // Cache mapping original function and batch sizes to batched function
233+ std::map<BatchCacheKey, FunctionOpInterface> batchedFunctionCache;
234+
153235 template <typename T>
154236 LogicalResult HandleBatch (SymbolTableCollection &symbolTable, T CI) {
155237 SmallVector<mlir::Value, 2 > args;
156238
157239 auto *symbolOp = symbolTable.lookupNearestSymbolFrom (CI, CI.getFnAttr ());
158240 auto fn = cast<FunctionOpInterface>(symbolOp);
159241
160- FunctionOpInterface newFunc =
161- batchCloneFunction (fn, " batched_" + fn.getName (), CI.getBatchShape ());
162-
163- if (!newFunc)
164- return failure ();
242+ BatchCacheKey key{fn, SmallVector<int64_t >(CI.getBatchShape ().begin (),
243+ CI.getBatchShape ().end ())};
244+
245+ // Check if we already have a batched version
246+ auto it = batchedFunctionCache.find (key);
247+ FunctionOpInterface newFunc;
248+
249+ if (it != batchedFunctionCache.end ()) {
250+ newFunc = it->second ;
251+ } else {
252+ // Create new batched function and store in cache
253+ newFunc = batchCloneFunction (fn, " batched_" + fn.getName (),
254+ CI.getBatchShape (), batchedFunctionCache);
255+ if (!newFunc) {
256+ return failure ();
257+ }
258+ }
165259
166260 OpBuilder builder (CI);
167261 auto dCI =
0 commit comments