Skip to content

Commit 2e9d31b

Browse files
authored
[BACKEND] Cleanup redundant broadcast combine pattern (#5167)
Summary of changes: - Remove `broadcast(cst) -> cst` from the triton-combine pass since it's redundant with the existing folder. - Reorder the triton-combine pass to come after the canonicalize pass, to simplify pattern matching - Cleanup patterns in triton-reorder-broadcast that called `Op::canonicalize` in favor of `Op::getCanonicalizationPatterns`.
1 parent fce3e6d commit 2e9d31b

File tree

7 files changed

+17
-64
lines changed

7 files changed

+17
-64
lines changed

lib/Dialect/Triton/Transforms/Combine.cpp

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include "mlir/Support/LLVM.h"
88
#include "mlir/Support/LogicalResult.h"
99
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
10-
#include "triton/Analysis/Utility.h"
1110
#include "triton/Dialect/Triton/IR/Dialect.h"
1211
#include "triton/Dialect/Triton/Transforms/Passes.h"
1312

@@ -18,35 +17,7 @@ namespace mlir::triton {
1817
namespace {
1918

2019
bool isZero(Value val) {
21-
if (matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()))
22-
return true;
23-
// broadcast(constant_0)
24-
if (auto bc = val.getDefiningOp<BroadcastOp>()) {
25-
if (matchPattern(bc.getSrc(), m_Zero()) ||
26-
matchPattern(bc.getSrc(), m_AnyZeroFloat()))
27-
return true;
28-
}
29-
return false;
30-
}
31-
32-
bool isBroadcastConstantCombinable(Attribute value) {
33-
if (auto denseValue = dyn_cast<DenseElementsAttr>(value)) {
34-
return denseValue.isSplat();
35-
}
36-
return isa<FloatAttr, IntegerAttr>(value);
37-
}
38-
39-
DenseElementsAttr getConstantValue(Builder &builder, Attribute value,
40-
Value bcast_res) {
41-
auto resType = cast<ShapedType>(bcast_res.getType());
42-
DenseElementsAttr res;
43-
if (auto denseValue = dyn_cast<DenseElementsAttr>(value)) {
44-
res =
45-
DenseElementsAttr::get(resType, denseValue.getSplatValue<Attribute>());
46-
} else {
47-
res = DenseElementsAttr::get(resType, value);
48-
}
49-
return res;
20+
return (matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()));
5021
}
5122

5223
bool isAddPtrOffsetCombinable(Value first, Value second) {
@@ -231,7 +202,6 @@ class CombineOpsPass : public TritonCombineOpsBase<CombineOpsPass> {
231202
// %}
232203
patterns.add<CombineSelectMaskedLoadPattern>(context);
233204
patterns.add<CombineAddPtrPattern>(context);
234-
patterns.add<CombineBroadcastConstantPattern>(context);
235205
patterns.add<CombineBroadcastMulReducePattern>(context);
236206

237207
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())

lib/Dialect/Triton/Transforms/Combine.td

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,4 @@ def CombineAddPtrPattern : Pat<
4444
(TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1, DefOverflow)),
4545
[(Constraint<CPred<"isAddPtrOffsetCombinable($0, $1)">> $idx0, $idx1)]>;
4646

47-
// broadcast(cst) => cst
48-
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
49-
def CombineBroadcastConstantPattern : Pat<
50-
(TT_BroadcastOp:$bcast_res (Arith_ConstantOp $value)),
51-
(Arith_ConstantOp (getConstantValue $value, $bcast_res), (location $bcast_res)),
52-
[(Constraint<CPred<"isBroadcastConstantCombinable($0)">> $value)]>;
53-
5447
#endif

lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -206,18 +206,6 @@ struct MoveBroadcastAfterElementwisePattern
206206
}
207207
};
208208

209-
template <typename OpType>
210-
class CanonicalizePattern : public OpRewritePattern<OpType> {
211-
public:
212-
explicit CanonicalizePattern(MLIRContext *context)
213-
: OpRewritePattern<OpType>(context) {}
214-
215-
LogicalResult matchAndRewrite(OpType op,
216-
PatternRewriter &rewriter) const override {
217-
return OpType::canonicalize(op, rewriter);
218-
}
219-
};
220-
221209
class ReorderBroadcastPass
222210
: public ::impl::TritonReorderBroadcastBase<ReorderBroadcastPass> {
223211
public:
@@ -226,8 +214,8 @@ class ReorderBroadcastPass
226214
RewritePatternSet patterns(context);
227215
ModuleOp m = getOperation();
228216

229-
patterns.add<CanonicalizePattern<BroadcastOp>>(context);
230-
patterns.add<CanonicalizePattern<ExpandDimsOp>>(context);
217+
BroadcastOp::getCanonicalizationPatterns(patterns, context);
218+
ExpandDimsOp::getCanonicalizationPatterns(patterns, context);
231219
// elementwise(broadcast(a)) => broadcast(elementwise(a))
232220
patterns.add<MoveBroadcastAfterElementwisePattern>(context);
233221
// elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...))

test/Triton/canonicalize.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
161161
tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked>
162162
}
163163
} // end module
164+
165+
// -----
166+
167+
// CHECK-LABEL: @fold_broadcast_constant_pattern
168+
tt.func @fold_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
169+
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
170+
%const = arith.constant dense<1.0> : tensor<8x1xf32>
171+
%bst_out = tt.broadcast %const : tensor<8x1xf32> -> tensor<8x2xf32>
172+
173+
// CHECK-NEXT: tt.return %[[cst]] : tensor<8x2xf32>
174+
tt.return %bst_out : tensor<8x2xf32>
175+
}

test/Triton/combine.mlir

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -208,16 +208,6 @@ tt.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32
208208
tt.return %0, %1, %2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
209209
}
210210

211-
// CHECK-LABEL: @test_combine_broadcast_constant_pattern
212-
tt.func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
213-
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
214-
%const = arith.constant dense<1.0> : tensor<8x1xf32>
215-
%bst_out = tt.broadcast %const : tensor<8x1xf32> -> tensor<8x2xf32>
216-
217-
// CHECK-NEXT: tt.return %[[cst]] : tensor<8x2xf32>
218-
tt.return %bst_out : tensor<8x2xf32>
219-
}
220-
221211
// CHECK-LABEL: @test_canonicalize_masked_load_pattern
222212
tt.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
223213
%true_mask = arith.constant dense<true> : tensor<8xi1>

third_party/amd/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,8 @@ def make_ttir(mod, metadata, options):
189189
pm.enable_debug()
190190
passes.common.add_inliner(pm)
191191
passes.ttir.add_rewrite_tensor_pointer(pm)
192-
passes.ttir.add_combine(pm)
193192
passes.common.add_canonicalizer(pm)
193+
passes.ttir.add_combine(pm)
194194
passes.ttir.add_reorder_broadcast(pm)
195195
passes.common.add_cse(pm)
196196
passes.common.add_licm(pm)

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,8 @@ def make_ttir(mod, metadata, opt):
189189
pm.enable_debug()
190190
passes.common.add_inliner(pm)
191191
passes.ttir.add_rewrite_tensor_pointer(pm)
192-
passes.ttir.add_combine(pm)
193192
passes.common.add_canonicalizer(pm)
193+
passes.ttir.add_combine(pm)
194194
passes.ttir.add_reorder_broadcast(pm)
195195
passes.common.add_cse(pm)
196196
passes.common.add_licm(pm)

0 commit comments

Comments
 (0)