Skip to content

Commit 71ee6b1

Browse files
author
hyun gyu kim
committed
[TIR][Schedule] Support multiple epilogue blocks in FuseReductionEpilogue
- Add CheckBufferStillUsed helper function to check if reduction buffer is still referenced by other blocks after fusion - Only remove intermediate temp buffer if no other blocks reference it - Add test case for multiple epilogue blocks scenario where one epilogue is fused while another still uses the intermediate buffer - This addresses the case where multiple epilogue blocks use the same reduction output, ensuring the temp buffer is preserved when needed Related issue: https://discuss.tvm.apache.org/t/...
1 parent 0fc40e7 commit 71ee6b1

File tree

2 files changed

+302
-167
lines changed

2 files changed

+302
-167
lines changed

src/tir/schedule/primitive/compute_inline.cc

Lines changed: 88 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,6 +1236,82 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
12361236
return Block(new_block);
12371237
}
12381238

1239+
/*!
1240+
* \brief Check if a buffer is still referenced by other blocks in the scope
1241+
*/
1242+
static bool CheckBufferStillUsed(const Block& scope_root, const Buffer& buffer) {
1243+
class BufferUsageChecker : public StmtVisitor {
1244+
public:
1245+
explicit BufferUsageChecker(const Buffer& buffer) : buffer_(buffer) {}
1246+
1247+
bool CheckStmt(const Stmt& stmt) {
1248+
found_usage_ = false;
1249+
VisitStmt(stmt);
1250+
return found_usage_;
1251+
}
1252+
1253+
private:
1254+
void VisitStmt_(const BlockRealizeNode* op) final {
1255+
if (found_usage_) return;
1256+
1257+
if (!op || !op->block.defined()) {
1258+
StmtVisitor::VisitStmt_(op);
1259+
return;
1260+
}
1261+
1262+
const BlockNode* block = op->block.get();
1263+
if (!block) {
1264+
StmtVisitor::VisitStmt_(op);
1265+
return;
1266+
}
1267+
1268+
// Check reads
1269+
for (const BufferRegion& read : block->reads) {
1270+
if (read->buffer.same_as(buffer_)) {
1271+
found_usage_ = true;
1272+
return;
1273+
}
1274+
}
1275+
1276+
// Check writes
1277+
for (const BufferRegion& write : block->writes) {
1278+
if (write->buffer.same_as(buffer_)) {
1279+
found_usage_ = true;
1280+
return;
1281+
}
1282+
}
1283+
1284+
// Continue visiting nested blocks
1285+
StmtVisitor::VisitStmt_(op);
1286+
}
1287+
1288+
void VisitStmt_(const BlockNode* op) final {
1289+
if (found_usage_) return;
1290+
if (!op) return;
1291+
1292+
// Check alloc_buffers
1293+
for (const Buffer& buf : op->alloc_buffers) {
1294+
if (buf.same_as(buffer_)) {
1295+
found_usage_ = true;
1296+
return;
1297+
}
1298+
}
1299+
1300+
StmtVisitor::VisitStmt_(op);
1301+
}
1302+
1303+
const Buffer& buffer_;
1304+
bool found_usage_{false};
1305+
};
1306+
1307+
if (!scope_root->body.defined()) {
1308+
return false;
1309+
}
1310+
1311+
BufferUsageChecker checker(buffer);
1312+
return checker.CheckStmt(scope_root->body);
1313+
}
1314+
12391315
/*!
12401316
* \brief Helper class to replace reduction and epilogue blocks with a single fused block
12411317
*/
@@ -1247,15 +1323,20 @@ class SingleBlockFusionReplacer : public StmtMutator {
12471323
std::move(old_epilogue_block), std::move(reduction_buffer));
12481324
Block result = Downcast<Block>(replacer(std::move(old_scope_root)));
12491325

1250-
// Remove intermediate temp buffer
1251-
BlockNode* p = result.CopyOnWrite();
1252-
ffi::Array<Buffer> new_alloc_buffers;
1253-
for (const Buffer& buf : p->alloc_buffers) {
1254-
if (!buf.same_as(replacer.reduction_buffer_)) {
1255-
new_alloc_buffers.push_back(buf);
1326+
// Check if reduction_buffer is still referenced by other blocks
1327+
bool buffer_still_used = CheckBufferStillUsed(result, reduction_buffer);
1328+
1329+
// Remove intermediate temp buffer only if it's not used by other blocks
1330+
if (!buffer_still_used) {
1331+
BlockNode* p = result.CopyOnWrite();
1332+
ffi::Array<Buffer> new_alloc_buffers;
1333+
for (const Buffer& buf : p->alloc_buffers) {
1334+
if (!buf.same_as(reduction_buffer)) {
1335+
new_alloc_buffers.push_back(buf);
1336+
}
12561337
}
1338+
p->alloc_buffers = new_alloc_buffers;
12571339
}
1258-
p->alloc_buffers = new_alloc_buffers;
12591340

12601341
return result;
12611342
}

0 commit comments

Comments
 (0)