Skip to content

Commit 0ec07e7

Browse files
authored
[Warp Spec] Optimize partitioning by hoisting above broadcasts (#7692)
This PR adds a small optimization step to partitioning that transforms ```mlir %x = producer() partition=0 %y = broadcast %x : <Axf32> -> <AxBxf32> partition=0 use(%y) partition=1 ``` Into ```mlir %x = producer() partition=0 %y = broadcast %x : <Axf32> -> <AxBxf32> partition=1 use(%y) partition=1 ``` To reduce the amount of shared memory needed.
1 parent 460deb6 commit 0ec07e7

File tree

4 files changed

+72
-13
lines changed

4 files changed

+72
-13
lines changed

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
33
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
44
#include "llvm/ADT/SCCIterator.h"
5+
#include "llvm/IR/Use.h"
56

67
using namespace mlir;
78
using namespace triton;
@@ -119,16 +120,14 @@ bool WarpSchedule::trySchedule(Partition *partition, Operation *op) {
119120

120121
FailureOr<WarpSchedule> WarpSchedule::deserialize(scf::ForOp loop) {
121122
auto stages = loop->getAttrOfType<ArrayAttr>(kPartitionStagesAttrName);
122-
if (!stages) {
123-
return mlir::emitWarning(loop.getLoc(), "missing '")
124-
<< kPartitionStagesAttrName << "' attribute";
125-
}
123+
if (!stages)
124+
return failure();
126125

127126
WarpSchedule result;
128127
for (auto [idx, attr] : llvm::enumerate(stages)) {
129128
auto stage = dyn_cast<IntegerAttr>(attr);
130129
if (!stage || stage.getInt() < 0) {
131-
return mlir::emitWarning(loop.getLoc(), "partition stages attribute '")
130+
return mlir::emitError(loop.getLoc(), "partition stages attribute '")
132131
<< kPartitionStagesAttrName << "' has invalid element " << attr;
133132
}
134133

@@ -140,10 +139,8 @@ FailureOr<WarpSchedule> WarpSchedule::deserialize(scf::ForOp loop) {
140139
Partition *partition = result.getRootPartition();
141140
if (auto attr = op.getAttrOfType<IntegerAttr>(kPartitionAttrName)) {
142141
int64_t idx = attr.getInt();
143-
if (idx < 0 || idx >= result.partitions.size()) {
144-
return mlir::emitWarning(op.getLoc(), "invalid partition index ")
145-
<< idx;
146-
}
142+
if (idx < 0 || idx >= result.partitions.size())
143+
return mlir::emitError(op.getLoc(), "invalid partition index ") << idx;
147144
partition = result.partitions[idx].get();
148145
}
149146
result.insert(partition, &op);

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,14 @@ static void scheduleUsers(scf::ForOp loop, WarpSchedule &schedule,
149149
// first-order partition assignment to the operations in the scheme and its
150150
// users and/or dependencies. This sets up the initial partitioning of the ops.
151151
static std::optional<WarpSchedule> getInitialSchedule(scf::ForOp loop) {
152-
WarpSchedule schedule;
152+
// Check for an existing schedule.
153+
if (FailureOr<WarpSchedule> scheduleOr = WarpSchedule::deserialize(loop);
154+
succeeded(scheduleOr))
155+
return {std::move(*scheduleOr)};
153156

154157
// Start by creating the default partition, a partition for for all loads, and
155158
// a partition for all MMAs.
159+
WarpSchedule schedule;
156160
Partition *defaultPartition = schedule.addPartition(0);
157161
Partition *mmaPartition = schedule.addPartition(1);
158162
Partition *loadPartition = schedule.addPartition(0);
@@ -479,6 +483,39 @@ void propagatePartitions(scf::ForOp loop, WarpSchedule &schedule) {
479483
}
480484
}
481485

486+
// Rematerialize chains of broadcasts where the user is in a different partition
487+
// than the broadcast to reduce the amount of data that needs to be transferred.
488+
void rematerializeBroadcasts(WarpSchedule &schedule, OpOperand *use) {
489+
static_assert(
490+
std::is_base_of_v<OpTrait::OneResult<BroadcastOp>, BroadcastOp> &&
491+
std::is_base_of_v<OpTrait::OneResult<ExpandDimsOp>, ExpandDimsOp>);
492+
493+
Operation *defOp = use->get().getDefiningOp();
494+
while (isa_and_nonnull<BroadcastOp, ExpandDimsOp>(defOp)) {
495+
Operation *clone = OpBuilder(defOp).clone(*defOp);
496+
Partition *userPartition = schedule.getPartition(use->getOwner());
497+
assert(userPartition && "user not scheduled");
498+
schedule.insert(userPartition, clone);
499+
use->set(clone->getResult(0));
500+
501+
defOp = clone->getOperand(0).getDefiningOp();
502+
use = &clone->getOpOperand(0);
503+
}
504+
}
505+
506+
void optimizeSchedule(scf::ForOp loop, WarpSchedule &schedule) {
507+
for (Partition &partition : schedule.getPartitions()) {
508+
SmallVector<OpOperand *> uses;
509+
schedule.iterateOutputs(loop, &partition,
510+
[&](Operation *defOp, OpOperand &use) {
511+
if (!isa<scf::YieldOp>(use.getOwner()))
512+
uses.push_back(&use);
513+
});
514+
for (OpOperand *use : uses)
515+
rematerializeBroadcasts(schedule, use);
516+
}
517+
}
518+
482519
//===----------------------------------------------------------------------===//
483520
// Pass Definition
484521
//===----------------------------------------------------------------------===//
@@ -507,6 +544,7 @@ void PartitionScheduling::runOnOperation() {
507544
for (scf::ForOp loop : loops) {
508545
if (std::optional<WarpSchedule> schedule = getInitialSchedule(loop)) {
509546
propagatePartitions(loop, *schedule);
547+
optimizeSchedule(loop, *schedule);
510548
schedule->serialize(loop);
511549
}
512550
}

test/TritonGPU/partition-scheduling.mlir

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ tt.func public @attention_forward(
2828
%zero = arith.constant dense<0.0> : tensor<256x64xf32, #blocked>
2929
%one = arith.constant dense<1.0> : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
3030

31-
%QK_tmem, %QK_tok = ttng.tmem_alloc : () -> (!ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable>, !ttg.async.token)
3231

3332
%loop_outs:4 = scf.for %i = %c0_i32 to %n_tiles step %c64_i32 iter_args(
3433
%l_i = %one,
@@ -46,6 +45,7 @@ tt.func public @attention_forward(
4645
%K_shared = ttg.local_alloc %K : (tensor<64x64xf16, #load_blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem>
4746

4847
%K_trans = ttg.memdesc_trans %K_shared {order = array<i32: 1, 0>} : !ttg.memdesc<64x64xf16, #shared, #smem> -> !ttg.memdesc<64x64xf16, #shared_T, #smem>
48+
%QK_tmem, %QK_tok = ttng.tmem_alloc : () -> (!ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable>, !ttg.async.token)
4949
%QK_mma_tok = ttng.tc_gen5_mma %Q_shared, %K_trans, %QK_tmem[%QK_tok], %false, %true : !ttg.memdesc<256x64xf16, #shared, #smem>, !ttg.memdesc<64x64xf16, #shared_T, #smem>, !ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable>
5050

5151
%QK, %QK_load_tok = ttng.tmem_load %QK_tmem[%QK_mma_tok] : !ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable> -> tensor<256x64xf32, #blocked>
@@ -138,4 +138,28 @@ tt.func public @mma_operand_view(
138138
tt.return
139139
}
140140

141+
// CHECK-LABEL: @optimize_broadcast
142+
tt.func @optimize_broadcast(%arg0: i32) {
143+
%c0_i32 = arith.constant 0 : i32
144+
%c1_i32 = arith.constant 1 : i32
145+
// CHECK: scf.for
146+
scf.for %i = %c0_i32 to %arg0 step %c1_i32 : i32 {
147+
// CHECK: [[X:%.*]] = "producer"{{.*}}partition = 0
148+
%x = "producer"() {ttg.partition = 0 : i32} : () -> tensor<128xf32>
149+
150+
// CHECK-DAG: [[X0_P0:%.*]] = tt.expand_dims [[X]] {{.*}}partition = 0
151+
// CHECK-DAG: [[X0_P1:%.*]] = tt.expand_dims [[X]] {{.*}}partition = 1
152+
%x0 = tt.expand_dims %x {axis = 0 : i32} : tensor<128xf32> -> tensor<1x128xf32>
153+
// CHECK-DAG: [[X1_P0:%.*]] = tt.broadcast [[X0_P0]] {{.*}}partition = 0
154+
// CHECK-DAG: [[X1_P1:%.*]] = tt.broadcast [[X0_P1]] {{.*}}partition = 1
155+
%x1 = tt.broadcast %x0 : tensor<1x128xf32> -> tensor<128x128xf32>
156+
157+
// CHECK: "use"([[X1_P0]]) {{.*}}partition = 0
158+
"use"(%x1) {ttg.partition = 0 : i32} : (tensor<128x128xf32>) -> ()
159+
// CHECK: "use"([[X1_P1]]) {{.*}}partition = 1
160+
"use"(%x1) {ttg.partition = 1 : i32} : (tensor<128x128xf32>) -> ()
161+
} {tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32]}
162+
tt.return
163+
}
164+
141165
}

test/TritonGPU/rewrite-partition-dependencies.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ tt.func @no_def_op(%lb: i32, %ub: i32, %step: i32) {
337337
module attributes {"ttg.num-warps" = 4 : i32} {
338338

339339
tt.func @invalid_attribute(%lb: i32, %ub: i32, %step: i32) {
340-
// expected-warning @below {{partition stages attribute 'ttg.partition.stages' has invalid element "a"}}
340+
// expected-error @below {{partition stages attribute 'ttg.partition.stages' has invalid element "a"}}
341341
scf.for %i = %lb to %ub step %step : i32 {
342342
scf.yield
343343
} {ttg.partition.stages = ["a"]}
@@ -359,7 +359,7 @@ module attributes {"ttg.num-warps" = 4 : i32} {
359359

360360
tt.func @invalid_attribute(%lb: i32, %ub: i32, %step: i32) {
361361
scf.for %k = %lb to %ub step %step : i32 {
362-
// expected-warning @below {{invalid partition index -1}}
362+
// expected-error @below {{invalid partition index -1}}
363363
"op"() {ttg.partition = -1} : () -> ()
364364
scf.yield
365365
} {ttg.partition.stages = [2, 2]}

0 commit comments

Comments
 (0)