Skip to content

Commit d48cd25

Browse files
author
hyun gyu kim
committed
[TIR][Schedule] FuseReductionEpilogue: Add Clipping pattern support
Currently, the FuseReductionEpilogue primitive only supports Bias (addition) and BiasReLU (addition + ReLU) epilogue patterns. However, clipping operations (min(max(x, lower), upper)) are commonly used in deep learning models and would benefit from the same fusion optimization. This commit extends FuseReductionEpilogue to support Clipping patterns by: 1. Adding EpilogueType::Clipping to the enum to distinguish clipping patterns from other epilogue types. 2. Adding clipping_lower_ and clipping_upper_ members to ReductionEpilogueFuser to store clipping bounds extracted from the epilogue pattern. 3. Extending AnalyzeEpiloguePattern to detect clipping patterns: - min(max(temp, lower), upper) - max(min(temp, upper), lower) - All commutative variants of min/max at each level 4. Updating BiasReLU pattern matching to handle max(0, x) form in addition to max(x, 0) for better commutativity support. 5. Modifying CreateFusedReductionBlock to apply clipping to the init value: init = min(max(0, lower), upper) 6. Updating BufferReplacer to apply clipping per-iteration: value = min(max(value, lower), upper) 7. Adding validation in BodyPatternAllowFusion to ensure temp appears exactly once in clipping patterns. 8. Creating comprehensive test coverage with 8 test cases: - Basic fusion test - Numerical correctness verification - Multiple epilogue blocks test - 5 commutative variant tests This implementation follows the same per-iteration semantics as BiasReLU, where clipping is applied at each reduction step rather than post-reduction. This semantic change is documented in the docstring with a warning about potential numerical differences. The test suite verifies that all commutative forms of clipping patterns are correctly recognized and that the fused implementation produces numerically identical results to the per-iteration reference implementation.
1 parent dc296e6 commit d48cd25

File tree

6 files changed

+431
-19
lines changed

6 files changed

+431
-19
lines changed

3rdparty/tvm-ffi

Submodule tvm-ffi updated 61 files

ffi/3rdparty/dlpack

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 3ea601bb413074c49a77c4ce3218bc08f8c4703c

ffi/3rdparty/libbacktrace

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 793921876c981ce49759114d7bb89bb89b2d3a2d

python/tvm/tir/schedule/schedule.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2356,14 +2356,41 @@ def fuse_reduction_epilogue(
23562356
It requires:
23572357
1) The reduction block is a complete reduction block
23582358
2) The epilogue block only reads from the reduction block's output
2359-
3) The epilogue performs a simple addition: output = reduction_result + bias
2359+
3) The epilogue matches one of the supported patterns:
2360+
- Bias: ``output = reduction_result + bias``
2361+
- BiasReLU: ``output = max(reduction_result + bias, 0)``
2362+
- Clipping: ``output = min(max(reduction_result, lower), upper)``
2363+
or their commutative variants
2364+
2365+
.. warning::
2366+
2367+
**Semantic Change for Non-Linear Epilogues (BiasReLU, Clipping):**
2368+
2369+
For non-linear epilogues (BiasReLU and Clipping), fusion changes the
2370+
computation semantics from post-reduction application to per-iteration
2371+
application. This can lead to different numerical results.
2372+
2373+
**Example with Clipping to [-5, 5] and inputs [6, -2]:**
2374+
2375+
- **Post-reduction clipping** (original): ``clip(sum([6, -2])) = clip(4) = 4``
2376+
- **Per-iteration clipping** (fused): ``acc=0 → clip(0+6)=5 → clip(5+(-2))=3``
2377+
2378+
The fused version applies clipping at each reduction iteration, which
2379+
may be an intended optimization for some models but can cause unexpected
2380+
correctness issues if users are not aware of this behavior.
2381+
2382+
For linear epilogues (Bias), fusion preserves exact numerical equivalence.
23602383
23612384
Parameters
23622385
----------
23632386
reduction_block : Union[BlockRV, str]
23642387
The reduction block (e.g., matmul)
23652388
epilogue_block : Union[BlockRV, str]
2366-
The epilogue block to be fused (e.g., bias add)
2389+
The epilogue block to be fused (e.g., bias add, ReLU, clipping)
2390+
2391+
Examples
2392+
--------
2393+
See :py:func:`test_tir_schedule_fuse_reduction_epilogue` for examples.
23672394
"""
23682395
reduction_block = self._normalize_block_arg(reduction_block)
23692396
epilogue_block = self._normalize_block_arg(epilogue_block)

src/tir/schedule/primitive/compute_inline.cc

Lines changed: 127 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -992,6 +992,7 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre
992992
enum class EpilogueType {
993993
Bias, // temp + C
994994
BiasReLU, // max(temp + C, 0)
995+
Clipping, // min(max(temp, lower), upper)
995996
};
996997

997998
class ReductionEpilogueFuser : public BaseInliner {
@@ -1058,6 +1059,8 @@ class ReductionEpilogueFuser : public BaseInliner {
10581059
Buffer epilogue_addend_buffer_{nullptr}; // Addend buffer C
10591060
BufferRegion epilogue_addend_region_{nullptr}; // Read region of C
10601061
EpilogueType epilogue_type_; // Type of epilogue operation
1062+
PrimExpr clipping_lower_{nullptr}; // Lower bound for clipping
1063+
PrimExpr clipping_upper_{nullptr}; // Upper bound for clipping
10611064
};
10621065

10631066
bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize) {
@@ -1080,19 +1083,28 @@ bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue
10801083
return false;
10811084
}
10821085

1083-
// 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j]
1086+
// 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j] or
1087+
// D[i,j] = min(max(temp[i,j], lower), upper)
10841088
if (!AnalyzeEpiloguePattern(inlined_store_->value)) {
1085-
// Failure: epilogue is not a simple addition pattern
1089+
// Failure: epilogue is not a supported pattern (Bias, BiasReLU, or Clipping)
10861090
return false;
10871091
}
10881092

1089-
// 5. Check if producer is a reduction block
1093+
// 5. For Clipping pattern, verify temp appears exactly once
1094+
if (epilogue_type_ == EpilogueType::Clipping) {
1095+
if (loads.size() != 1) {
1096+
// Failure: temp must appear exactly once in clipping pattern
1097+
return false;
1098+
}
1099+
}
1100+
1101+
// 6. Check if producer is a reduction block
10901102
if (!IsReductionBlock(reduction_block_)) {
10911103
// Failure: producer is not a reduction block
10921104
return false;
10931105
}
10941106

1095-
// 6. Extract epilogue information (output buffer, indices, regions, etc.)
1107+
// 7. Extract epilogue information (output buffer, indices, regions, etc.)
10961108
ExtractEpilogueInfo();
10971109

10981110
return true;
@@ -1115,19 +1127,97 @@ bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
11151127
}
11161128
}
11171129

1118-
// Pattern 2: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0) (BiasReLU)
1130+
// Pattern 2: min(max(temp[i,j], lower), upper) or max(min(temp[i,j], upper), lower) (Clipping)
1131+
// Handle all commutative variants of min/max at each level.
1132+
1133+
// Helper to check if an expression is a load from the reduction buffer, and
1134+
// return the other operand as `other` if so.
1135+
auto match_buffer_in_commutative_op = [this](const PrimExpr& a, const PrimExpr& b,
1136+
PrimExpr* other) -> bool {
1137+
if (const auto* load_a = a.as<BufferLoadNode>()) {
1138+
if (load_a->buffer.same_as(inlined_buffer_)) {
1139+
*other = b;
1140+
return true;
1141+
}
1142+
}
1143+
if (const auto* load_b = b.as<BufferLoadNode>()) {
1144+
if (load_b->buffer.same_as(inlined_buffer_)) {
1145+
*other = a;
1146+
return true;
1147+
}
1148+
}
1149+
return false;
1150+
};
1151+
1152+
// Check for min(max(temp, lower), upper) and commutative variants
1153+
if (const auto* min_node = value.as<MinNode>()) {
1154+
const MaxNode* max_node = nullptr;
1155+
PrimExpr upper;
1156+
// Try both (a, b) as possible positions of the inner max
1157+
if ((max_node = min_node->a.as<MaxNode>())) {
1158+
upper = min_node->b;
1159+
} else if ((max_node = min_node->b.as<MaxNode>())) {
1160+
upper = min_node->a;
1161+
}
1162+
if (max_node != nullptr) {
1163+
PrimExpr lower;
1164+
if (match_buffer_in_commutative_op(max_node->a, max_node->b, &lower)) {
1165+
clipping_lower_ = lower;
1166+
clipping_upper_ = upper;
1167+
epilogue_type_ = EpilogueType::Clipping;
1168+
return true;
1169+
}
1170+
}
1171+
}
1172+
1173+
// Check for max(min(temp[i,j], upper), lower) and commutative variants
11191174
if (const auto* max_node = value.as<MaxNode>()) {
1120-
// Check if second operand is zero (ReLU: max(x, 0))
1121-
// Support both integer and float zero constants
1175+
const MinNode* min_node = nullptr;
1176+
PrimExpr lower;
1177+
// Try both (a, b) as possible positions of the inner min
1178+
if ((min_node = max_node->a.as<MinNode>())) {
1179+
lower = max_node->b;
1180+
} else if ((min_node = max_node->b.as<MinNode>())) {
1181+
lower = max_node->a;
1182+
}
1183+
if (min_node != nullptr) {
1184+
PrimExpr upper;
1185+
if (match_buffer_in_commutative_op(min_node->a, min_node->b, &upper)) {
1186+
clipping_lower_ = lower;
1187+
clipping_upper_ = upper;
1188+
epilogue_type_ = EpilogueType::Clipping;
1189+
return true;
1190+
}
1191+
}
1192+
}
1193+
1194+
// Pattern 3: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0) (BiasReLU)
1195+
// Also handle max(0, temp[i,j] + C[i,j]) or max(0, C[i,j] + temp[i,j])
1196+
if (const auto* max_node = value.as<MaxNode>()) {
1197+
// Check if either operand is zero (ReLU: max(x, 0) or max(0, x))
1198+
// Support both integer and float zero constants.
1199+
const PrimExpr* add_candidate = nullptr;
11221200
bool is_zero_const = false;
1123-
if (tir::is_zero(max_node->b)) {
1201+
auto is_zero_expr = [](const PrimExpr& expr) -> bool {
1202+
if (tir::is_zero(expr)) {
1203+
return true;
1204+
}
1205+
if (const auto* float_imm = expr.as<FloatImmNode>()) {
1206+
return float_imm->value == 0.0;
1207+
}
1208+
return false;
1209+
};
1210+
1211+
if (is_zero_expr(max_node->a)) {
1212+
is_zero_const = true;
1213+
add_candidate = &max_node->b;
1214+
} else if (is_zero_expr(max_node->b)) {
11241215
is_zero_const = true;
1125-
} else if (const auto* float_imm = max_node->b.as<FloatImmNode>()) {
1126-
is_zero_const = (float_imm->value == 0.0);
1216+
add_candidate = &max_node->a;
11271217
}
1128-
if (is_zero_const) {
1129-
// Check if first operand is AddNode
1130-
if (const auto* add = max_node->a.as<AddNode>()) {
1218+
1219+
if (is_zero_const && add_candidate != nullptr) {
1220+
if (const auto* add = add_candidate->as<AddNode>()) {
11311221
const auto* load_a = add->a.as<BufferLoadNode>();
11321222
const auto* load_b = add->b.as<BufferLoadNode>();
11331223

@@ -1218,6 +1308,14 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
12181308
PrimExpr zero = tir::make_zero(init_value.dtype());
12191309
new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value, zero),
12201310
Substitute(epilogue_output_indices_, var_map));
1311+
} else if (epilogue_type_ == EpilogueType::Clipping) {
1312+
// For Clipping, init should be min(max(init_value, lower), upper)
1313+
// Since init is typically 0, this becomes min(max(0, lower), upper)
1314+
PrimExpr init_value = tir::make_zero(epilogue_output_buffer_->dtype);
1315+
PrimExpr clipped_init = Min(Max(init_value, Substitute(clipping_lower_, var_map)),
1316+
Substitute(clipping_upper_, var_map));
1317+
new_init_store = BufferStore(epilogue_output_buffer_, clipped_init,
1318+
Substitute(epilogue_output_indices_, var_map));
12211319
} else {
12221320
// Bias: D[vi, vj] = C[vi, vj]
12231321
new_init_store = BufferStore(epilogue_output_buffer_, Substitute(epilogue_addend_, var_map),
@@ -1228,11 +1326,14 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
12281326
// 3. Replace output buffer from temp to D in body
12291327
class BufferReplacer : public StmtExprMutator {
12301328
public:
1231-
BufferReplacer(Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type, DataType dtype)
1329+
BufferReplacer(Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type, DataType dtype,
1330+
PrimExpr clipping_lower = PrimExpr(), PrimExpr clipping_upper = PrimExpr())
12321331
: old_buffer_(old_buf),
12331332
new_buffer_(new_buf),
12341333
epilogue_type_(epilogue_type),
1235-
dtype_(dtype) {}
1334+
dtype_(dtype),
1335+
clipping_lower_(clipping_lower),
1336+
clipping_upper_(clipping_upper) {}
12361337

12371338
Stmt VisitStmt_(const BufferStoreNode* op) final {
12381339
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
@@ -1242,6 +1343,9 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
12421343
if (epilogue_type_ == EpilogueType::BiasReLU) {
12431344
PrimExpr zero = tir::make_zero(dtype_);
12441345
new_value = Max(new_value, zero);
1346+
} else if (epilogue_type_ == EpilogueType::Clipping) {
1347+
// For Clipping, apply min(max(value, lower), upper) per iteration
1348+
new_value = Min(Max(new_value, clipping_lower_), clipping_upper_);
12451349
}
12461350
return BufferStore(new_buffer_, new_value, store->indices);
12471351
}
@@ -1261,10 +1365,17 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
12611365
Buffer new_buffer_;
12621366
EpilogueType epilogue_type_;
12631367
DataType dtype_;
1368+
PrimExpr clipping_lower_;
1369+
PrimExpr clipping_upper_;
12641370
};
12651371

12661372
DataType dtype = epilogue_output_buffer_->dtype;
1267-
BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_, epilogue_type_, dtype);
1373+
PrimExpr clipping_lower_subst =
1374+
epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_lower_, var_map) : PrimExpr();
1375+
PrimExpr clipping_upper_subst =
1376+
epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_upper_, var_map) : PrimExpr();
1377+
BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_, epilogue_type_, dtype,
1378+
clipping_lower_subst, clipping_upper_subst);
12681379
new_block->body = replacer(reduction_block->body);
12691380

12701381
// 4. Update write regions

0 commit comments

Comments
 (0)