Skip to content

Commit ca70f08

Browse files
authored
Duplicating loads will significantly increase the shared memory usage which is likely to cause out of memory problem. We should find an alternative to not have to duplicate shared memory allocations.
1 parent 53166ef commit ca70f08

File tree

2 files changed

+17
-94
lines changed

2 files changed

+17
-94
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 14 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,9 @@ static void createTMAAsyncCopy(
219219
// encodings, raise assertion, since incompatible shared encoding has been
220220
// handled in splitLoadsForIncompatible.
221221
static std::optional<ttg::SharedEncodingAttr>
222-
getSharedEncIfAllUsersAreDotEnc(Value val) {
222+
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
223223
ttg::SharedEncodingAttr attr;
224+
incompatible = false;
224225
for (Operation *user : val.getUsers()) {
225226
ttg::SharedEncodingAttr tempAttr;
226227
if (user->getNumResults() != 1)
@@ -230,7 +231,8 @@ getSharedEncIfAllUsersAreDotEnc(Value val) {
230231
// First time we find a shared encoding in the chain, save it and try to
231232
// use it if it is compatible with the other users.
232233
tempAttr = cast<ttg::SharedEncodingAttr>(memDesc.getEncoding());
233-
if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value())
234+
if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0), incompatible)
235+
.has_value())
234236
return std::nullopt;
235237
} else {
236238
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
@@ -248,8 +250,10 @@ getSharedEncIfAllUsersAreDotEnc(Value val) {
248250
bitWidth, /*needTrans=*/false);
249251
}
250252
// Check that the shared encodings needed by the users are compatible.
251-
if (attr != nullptr)
252-
assert(attr == tempAttr && "incompatible shared encoding");
253+
if (attr != nullptr && attr != tempAttr) {
254+
incompatible = true;
255+
return std::nullopt;
256+
}
253257
attr = tempAttr;
254258
}
255259
return attr;
@@ -439,8 +443,13 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
439443
loadInfo.sharedEncoding =
440444
getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr);
441445
} else if (auto dot = dyn_cast<tt::DotOp>(use)) {
446+
bool incompatible = false;
442447
loadInfo.sharedEncoding =
443-
getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr);
448+
getSharedEncIfAllUsersAreDotEnc(op->getResult(0), incompatible)
449+
.value_or(nullptr);
450+
// If we can't agree on a shared encoding skip pipelinig the load.
451+
if (incompatible)
452+
continue;
444453
}
445454
} else if (auto loadOp = dyn_cast<tt::LoadOp>(use)) {
446455
// The use of this loadOp is another loadOp. If the use is not in the
@@ -476,83 +485,6 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
476485
return loadToInfo;
477486
}
478487

479-
// Split users to groups, each group has the same shared encoding.
480-
// If not all users are Dot encoding, return empty vector.
481-
static DenseMap<ttg::SharedEncodingAttr, SmallVector<Operation *>>
482-
handleIncompatibleSharedEncoding(Operation *loadOp) {
483-
DenseMap<ttg::SharedEncodingAttr, SmallVector<Operation *>> loadGroups;
484-
// Go through transitive uses of the loadOp in the same block.
485-
for (Operation *user : loadOp->getUsers()) {
486-
if (user->getBlock() != loadOp->getBlock())
487-
continue;
488-
if (user->getNumResults() != 1)
489-
return loadGroups;
490-
491-
ttg::SharedEncodingAttr tempAttr;
492-
if (auto memDesc =
493-
dyn_cast<triton::MemDescType>(user->getResult(0).getType())) {
494-
tempAttr = cast<ttg::SharedEncodingAttr>(memDesc.getEncoding());
495-
loadGroups[tempAttr].push_back(user);
496-
} else {
497-
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
498-
return loadGroups;
499-
auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(
500-
cast<TensorOrMemDesc>(user->getResult(0).getType()).getEncoding());
501-
if (!dotOpEnc)
502-
return loadGroups;
503-
auto srcTy = cast<TensorOrMemDesc>(loadOp->getResult(0).getType());
504-
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
505-
auto order = ttg::getOrder(srcTy.getEncoding());
506-
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
507-
tempAttr = ttg::SharedEncodingAttr::get(
508-
loadOp->getContext(), dotOpEnc, srcTy.getShape(),
509-
ttg::getOrder(srcTy.getEncoding()),
510-
ttg::getCTALayout(srcTy.getEncoding()),
511-
srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false);
512-
loadGroups[tempAttr].push_back(user);
513-
}
514-
}
515-
return loadGroups;
516-
}
517-
518-
// Clone loads so each group of uses with same shared encoding will have a
519-
// corresponding Load.
520-
static void splitLoadsForIncompatible(
521-
OpBuilder &builder, Operation *loadOp,
522-
DenseMap<ttg::SharedEncodingAttr, SmallVector<Operation *>> &lGroups) {
523-
// The first group will use the original load, create new loads for other
524-
// groups.
525-
unsigned idx = 0;
526-
builder.setInsertionPointAfter(loadOp);
527-
for (auto pair : lGroups) {
528-
SmallVector<Operation *> &group = pair.second;
529-
if (idx++ == 0)
530-
continue;
531-
Operation *newLoad = builder.clone(*loadOp);
532-
for (auto *user : group) {
533-
user->replaceUsesOfWith(loadOp->getResult(0), newLoad->getResult(0));
534-
}
535-
}
536-
}
537-
538-
static void splitLoadsWithIncompatibleEncoding(scf::ForOp forOp) {
539-
// Get the list of all loads.
540-
SmallVector<Operation *> loads;
541-
for (Operation &op : forOp.getBody()->without_terminator()) {
542-
if (isa<tt::LoadOp, tt::ExperimentalDescriptorLoadOp>(op)) {
543-
loads.push_back(&op);
544-
}
545-
}
546-
OpBuilder builder(forOp);
547-
for (auto *loadOp : loads) {
548-
auto lGroups = handleIncompatibleSharedEncoding(loadOp);
549-
LDBG("groups with different encoding: " << lGroups.size() << " "
550-
<< *loadOp);
551-
if (lGroups.size() > 1)
552-
splitLoadsForIncompatible(builder, loadOp, lGroups);
553-
}
554-
}
555-
556488
static llvm::MapVector<Operation *, LoadInfo>
557489
scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule,
558490
DenseSet<Operation *> &rootUsers, int numStages) {
@@ -1106,8 +1038,6 @@ static void invalidateBarriers(OpBuilder &builder,
11061038

11071039
bool mlir::triton::preProcessLoopAndGetSchedule(
11081040
scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) {
1109-
splitLoadsWithIncompatibleEncoding(forOp);
1110-
11111041
// Schedule the loads and root ops (dot ops) in the loop. This will give us
11121042
// a scaffold for the final schedule.
11131043
DenseSet<Operation *> rootUsers;

test/TritonGPU/loop-pipeline.mlir

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -844,16 +844,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
844844
%14 = tt.broadcast %11 : tensor<1x16x!tt.ptr<f16>, #blocked> -> tensor<64x16x!tt.ptr<f16>, #blocked>
845845
%15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked>
846846
%16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr<f16>, #blocked>, tensor<64x16xi32, #blocked>
847-
// check that the load with incompatiable shared encoding gets cloned and feeds into uses with same encoding
848-
// AMD-NOT: alloc
849-
// AMD: scf.for
850-
// CHECK: local_alloc
851-
// CHECK: local_alloc
852-
// CHECK: scf.for
853-
// CHECK: local_load {{.*}} tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1
854-
// CHECK: convert_layout {{.*}} tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0
855-
// CHECK: tt.dot
856-
// CHECK: tt.trans %arg
847+
// check that the load didn't get pipelined.
848+
// COMMON-NOT: alloc
849+
// COMMON: scf.for
857850
%17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 {
858851
%18 = tt.load %16 : tensor<64x16x!tt.ptr<f16>, #blocked>
859852
%19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>

0 commit comments

Comments
 (0)