@@ -988,14 +988,32 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre
988988 * \brief Helper to fuse epilogue block into reduction block
989989 * Analyzes epilogue pattern and transforms reduction init/update
990990 */
991+ // Epilogue type enumeration
992+ enum class EpilogueType {
993+ Bias, // temp + C
994+ BiasReLU, // max(temp + C, 0)
995+ };
996+
991997class ReductionEpilogueFuser : public BaseInliner {
992998 public:
993999 explicit ReductionEpilogueFuser (const Buffer& reduction_buffer, const BlockNode* reduction_block,
9941000 const BlockRealize& epilogue_block_realize,
9951001 const StmtSRef& scope_root_sref)
9961002 : BaseInliner(reduction_buffer, epilogue_block_realize->block, scope_root_sref),
9971003 reduction_block_(reduction_block),
998- epilogue_block_(epilogue_block_realize->block.get()) {}
1004+ epilogue_block_(epilogue_block_realize->block.get()),
1005+ epilogue_type_(EpilogueType::Bias) {
1006+ // Disable opaque access check for epilogue fusion
1007+ // Epilogue blocks can read multiple buffers (temp + bias), which is allowed
1008+ has_opaque_access = false ;
1009+ }
1010+
1011+ // Override CheckOpaqueAccess to allow multiple buffer reads
1012+ void CheckOpaqueAccess (const VarNode* buffer_var) {
1013+ // For epilogue fusion, we allow multiple buffer reads (temp + bias)
1014+ // So we don't check for opaque access
1015+ // BaseInliner::CheckOpaqueAccess(buffer_var); // Don't call base class
1016+ }
9991017
10001018 bool BodyPatternAllowFusion (const BlockRealize& epilogue_block_realize);
10011019
@@ -1012,18 +1030,21 @@ class ReductionEpilogueFuser : public BaseInliner {
10121030 const BufferStoreNode* from) {
10131031 struct Extractor : public ExprVisitor {
10141032 void VisitExpr_ (const BufferLoadNode* load) final {
1015- if (load->buffer .get () == buffer) {
1033+ if (load->buffer .same_as ( buffer) ) {
10161034 result.push_back (load);
10171035 }
1036+ // Continue visiting child nodes (indices)
10181037 ExprVisitor::VisitExpr_ (load);
10191038 }
1020- const BufferNode* buffer;
1039+ Buffer buffer;
10211040 std::vector<const BufferLoadNode*> result;
10221041 } extractor;
1023- extractor.buffer = buffer.get ();
1042+ extractor.buffer = buffer;
1043+ // Visit indices first (though they typically don't contain BufferLoad)
10241044 for (const PrimExpr& expr : from->indices ) {
10251045 extractor (expr);
10261046 }
1047+ // Visit the value expression (e.g., max(temp + C, 0) for ReLU)
10271048 extractor (from->value );
10281049 return std::move (extractor.result );
10291050 }
@@ -1036,6 +1057,7 @@ class ReductionEpilogueFuser : public BaseInliner {
10361057 BufferRegion epilogue_output_region_{nullptr }; // Write region of D
10371058 Buffer epilogue_addend_buffer_{nullptr }; // Addend buffer C
10381059 BufferRegion epilogue_addend_region_{nullptr }; // Read region of C
1060+ EpilogueType epilogue_type_; // Type of epilogue operation
10391061};
10401062
10411063bool ReductionEpilogueFuser::BodyPatternAllowFusion (const BlockRealize& epilogue_block_realize) {
@@ -1077,7 +1099,7 @@ bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue
10771099}
10781100
10791101bool ReductionEpilogueFuser::AnalyzeEpiloguePattern (const PrimExpr& value) {
1080- // Pattern: temp[i,j] + C[i,j] or C[i,j] + temp[i,j]
1102+ // Pattern 1 : temp[i,j] + C[i,j] or C[i,j] + temp[i,j] (Bias)
10811103 if (const auto * add = value.as <AddNode>()) {
10821104 const auto * load_a = add->a .as <BufferLoadNode>();
10831105 const auto * load_b = add->b .as <BufferLoadNode>();
@@ -1088,10 +1110,40 @@ bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
10881110 // Ensure exactly one operand is from the reduction buffer
10891111 if (a_is_target != b_is_target) {
10901112 epilogue_addend_ = a_is_target ? add->b : add->a ;
1113+ epilogue_type_ = EpilogueType::Bias;
10911114 return true ;
10921115 }
10931116 }
10941117
1118+ // Pattern 2: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0) (BiasReLU)
1119+ if (const auto * max_node = value.as <MaxNode>()) {
1120+ // Check if second operand is zero (ReLU: max(x, 0))
1121+ // Support both integer and float zero constants
1122+ bool is_zero_const = false ;
1123+ if (tir::is_zero (max_node->b )) {
1124+ is_zero_const = true ;
1125+ } else if (const auto * float_imm = max_node->b .as <FloatImmNode>()) {
1126+ is_zero_const = (float_imm->value == 0.0 );
1127+ }
1128+ if (is_zero_const) {
1129+ // Check if first operand is AddNode
1130+ if (const auto * add = max_node->a .as <AddNode>()) {
1131+ const auto * load_a = add->a .as <BufferLoadNode>();
1132+ const auto * load_b = add->b .as <BufferLoadNode>();
1133+
1134+ bool a_is_target = load_a && load_a->buffer .same_as (inlined_buffer_);
1135+ bool b_is_target = load_b && load_b->buffer .same_as (inlined_buffer_);
1136+
1137+ // Ensure exactly one operand is from the reduction buffer
1138+ if (a_is_target != b_is_target) {
1139+ epilogue_addend_ = a_is_target ? add->b : add->a ;
1140+ epilogue_type_ = EpilogueType::BiasReLU;
1141+ return true ;
1142+ }
1143+ }
1144+ }
1145+ }
1146+
10951147 return false ;
10961148}
10971149
@@ -1158,20 +1210,40 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
11581210 var_map[epilogue_data_vars[i]] = reduction_data_vars[i];
11591211 }
11601212
1161- // 2. Change init to epilogue value: D[vi, vj] = C[vi, vj]
1162- BufferStore new_init_store (epilogue_output_buffer_, Substitute (epilogue_addend_, var_map),
1163- Substitute (epilogue_output_indices_, var_map));
1213+ // 2. Change init to epilogue value based on epilogue type
1214+ BufferStore new_init_store;
1215+ if (epilogue_type_ == EpilogueType::BiasReLU) {
1216+ // For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU semantics
1217+ PrimExpr init_value = Substitute (epilogue_addend_, var_map);
1218+ PrimExpr zero = tir::make_zero (init_value.dtype ());
1219+ new_init_store = BufferStore (epilogue_output_buffer_, Max (init_value, zero),
1220+ Substitute (epilogue_output_indices_, var_map));
1221+ } else {
1222+ // Bias: D[vi, vj] = C[vi, vj]
1223+ new_init_store = BufferStore (epilogue_output_buffer_, Substitute (epilogue_addend_, var_map),
1224+ Substitute (epilogue_output_indices_, var_map));
1225+ }
11641226 new_block->init = new_init_store;
11651227
11661228 // 3. Replace output buffer from temp to D in body
11671229 class BufferReplacer : public StmtExprMutator {
11681230 public:
1169- BufferReplacer (Buffer old_buf, Buffer new_buf) : old_buffer_(old_buf), new_buffer_(new_buf) {}
1231+ BufferReplacer (Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type, DataType dtype)
1232+ : old_buffer_(old_buf),
1233+ new_buffer_ (new_buf),
1234+ epilogue_type_(epilogue_type),
1235+ dtype_(dtype) {}
11701236
11711237 Stmt VisitStmt_ (const BufferStoreNode* op) final {
11721238 BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_ (op));
11731239 if (store->buffer .same_as (old_buffer_)) {
1174- return BufferStore (new_buffer_, store->value , store->indices );
1240+ PrimExpr new_value = store->value ;
1241+ // For ReLU, apply max per iteration to match per-iteration ReLU semantics
1242+ if (epilogue_type_ == EpilogueType::BiasReLU) {
1243+ PrimExpr zero = tir::make_zero (dtype_);
1244+ new_value = Max (new_value, zero);
1245+ }
1246+ return BufferStore (new_buffer_, new_value, store->indices );
11751247 }
11761248 return store;
11771249 }
@@ -1187,9 +1259,12 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
11871259 private:
11881260 Buffer old_buffer_;
11891261 Buffer new_buffer_;
1262+ EpilogueType epilogue_type_;
1263+ DataType dtype_;
11901264 };
11911265
1192- BufferReplacer replacer (inlined_buffer_, epilogue_output_buffer_);
1266+ DataType dtype = epilogue_output_buffer_->dtype;
1267+ BufferReplacer replacer (inlined_buffer_, epilogue_output_buffer_, epilogue_type_, dtype);
11931268 new_block->body = replacer(reduction_block->body);
11941269
11951270 // 4. Update write regions
0 commit comments