Skip to content

Commit c210764

Browse files
authored
[SWP] split loads to handle incompatible shared encoding (#4784)
Summary: split loads so each group of uses with the same shared encoding will have a corresponding load. This enables pipelining loads with incompatible shared encoding. AMD has its own version of assignMemoryLayouts, so the test case load_two_users_incompatible_layouts will have different results for AMD.
1 parent e65dd81 commit c210764

File tree

2 files changed

+98
-18
lines changed

2 files changed

+98
-18
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 88 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -216,11 +216,11 @@ static void createTMAAsyncCopy(
216216
// If all the transitive uses of the given value have are used by a convert to
217217
// the same dot operand encoding, return the shared encoding that needs to be
218218
// used to be compatible with users' layouts. If there are imcompatible shared
219-
// encodings set `incompatible` to true.
219+
// encodings, raise assertion, since incompatible shared encoding has been
220+
// handled in splitLoadsForIncompatible.
220221
static std::optional<ttg::SharedEncodingAttr>
221-
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
222+
getSharedEncIfAllUsersAreDotEnc(Value val) {
222223
ttg::SharedEncodingAttr attr;
223-
incompatible = false;
224224
for (Operation *user : val.getUsers()) {
225225
ttg::SharedEncodingAttr tempAttr;
226226
if (user->getNumResults() != 1)
@@ -230,8 +230,7 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
230230
// First time we find a shared encoding in the chain, save it and try to
231231
// use it if it is compatible with the other users.
232232
tempAttr = cast<ttg::SharedEncodingAttr>(memDesc.getEncoding());
233-
if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0), incompatible)
234-
.has_value())
233+
if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value())
235234
return std::nullopt;
236235
} else {
237236
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
@@ -249,10 +248,8 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
249248
bitWidth, /*needTrans=*/false);
250249
}
251250
// Check that the shared encodings needed by the users are compatible.
252-
if (attr != nullptr && attr != tempAttr) {
253-
incompatible = true;
254-
return std::nullopt;
255-
}
251+
if (attr != nullptr)
252+
assert(attr == tempAttr && "incompatible shared encoding");
256253
attr = tempAttr;
257254
}
258255
return attr;
@@ -442,13 +439,9 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
442439
loadInfo.sharedEncoding =
443440
getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr);
444441
} else if (auto dot = dyn_cast<tt::DotOp>(use)) {
445-
bool incompatible = false;
446442
loadInfo.sharedEncoding =
447-
getSharedEncIfAllUsersAreDotEnc(op->getResult(0), incompatible)
448-
.value_or(nullptr);
449-
// If we can't agree on a shared encoding skip pipelinig the load.
450-
if (incompatible)
451-
continue;
443+
getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr);
444+
452445
// HACK: Triton LLVM codegen has a bug where local_loads from #shared to
453446
// #mma layout can lead to invalid code if the loaded shape is smaller
454447
// than the mma tile (e.g. loading a 128x1 tensor for an MMAv2 dot with
@@ -514,9 +507,87 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
514507
return loadToInfo;
515508
}
516509

510+
// Split users to groups, each group has the same shared encoding.
511+
// If not all users are Dot encoding, return empty vector.
512+
static DenseMap<ttg::SharedEncodingAttr, SmallVector<Operation *>>
513+
handleIncompatibleSharedEncoding(Operation *loadOp) {
514+
DenseMap<ttg::SharedEncodingAttr, SmallVector<Operation *>> loadGroups;
515+
// Go through transitive uses of the loadOp in the same block.
516+
for (Operation *user : loadOp->getUsers()) {
517+
if (user->getBlock() != loadOp->getBlock())
518+
continue;
519+
if (user->getNumResults() != 1)
520+
return loadGroups;
521+
522+
ttg::SharedEncodingAttr tempAttr;
523+
if (auto memDesc =
524+
dyn_cast<triton::MemDescType>(user->getResult(0).getType())) {
525+
tempAttr = cast<ttg::SharedEncodingAttr>(memDesc.getEncoding());
526+
loadGroups[tempAttr].push_back(user);
527+
} else {
528+
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
529+
return loadGroups;
530+
auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(
531+
cast<TensorOrMemDesc>(user->getResult(0).getType()).getEncoding());
532+
if (!dotOpEnc)
533+
return loadGroups;
534+
auto srcTy = cast<TensorOrMemDesc>(loadOp->getResult(0).getType());
535+
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
536+
auto order = ttg::getOrder(srcTy.getEncoding());
537+
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
538+
tempAttr = ttg::SharedEncodingAttr::get(
539+
loadOp->getContext(), dotOpEnc, srcTy.getShape(),
540+
ttg::getOrder(srcTy.getEncoding()),
541+
ttg::getCTALayout(srcTy.getEncoding()),
542+
srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false);
543+
loadGroups[tempAttr].push_back(user);
544+
}
545+
}
546+
return loadGroups;
547+
}
548+
549+
// Clone loads so each group of uses with same shared encoding will have a
550+
// corresponding Load.
551+
static void splitLoadsForIncompatible(
552+
OpBuilder &builder, Operation *loadOp,
553+
DenseMap<ttg::SharedEncodingAttr, SmallVector<Operation *>> &lGroups) {
554+
// The first group will use the original load, create new loads for other
555+
// groups.
556+
unsigned idx = 0;
557+
builder.setInsertionPointAfter(loadOp);
558+
for (auto pair : lGroups) {
559+
SmallVector<Operation *> &group = pair.second;
560+
if (idx++ == 0)
561+
continue;
562+
Operation *newLoad = builder.clone(*loadOp);
563+
for (auto *user : group) {
564+
user->replaceUsesOfWith(loadOp->getResult(0), newLoad->getResult(0));
565+
}
566+
}
567+
}
568+
569+
static void splitLoadsWithIncompatibleEncoding(scf::ForOp forOp) {
570+
// Get the list of all loads.
571+
SmallVector<Operation *> loads;
572+
for (Operation &op : forOp.getBody()->without_terminator()) {
573+
if (isa<tt::LoadOp, tt::ExperimentalDescriptorLoadOp>(op)) {
574+
loads.push_back(&op);
575+
}
576+
}
577+
OpBuilder builder(forOp);
578+
for (auto *loadOp : loads) {
579+
auto lGroups = handleIncompatibleSharedEncoding(loadOp);
580+
LDBG("groups with different encoding: " << lGroups.size() << " "
581+
<< *loadOp);
582+
if (lGroups.size() > 1)
583+
splitLoadsForIncompatible(builder, loadOp, lGroups);
584+
}
585+
}
586+
517587
static llvm::MapVector<Operation *, LoadInfo>
518588
scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule,
519589
DenseSet<Operation *> &rootUsers, int numStages) {
590+
520591
ModuleOp moduleOp = forOp->getParentOfType<ModuleOp>();
521592
tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp);
522593

@@ -1054,6 +1125,8 @@ static void invalidateBarriers(OpBuilder &builder,
10541125

10551126
bool mlir::triton::preProcessLoopAndGetSchedule(
10561127
scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) {
1128+
splitLoadsWithIncompatibleEncoding(forOp);
1129+
10571130
// Schedule the loads and root ops (dot ops) in the loop. This will give us
10581131
// a scaffold for the final schedule.
10591132
DenseSet<Operation *> rootUsers;

test/TritonGPU/loop-pipeline.mlir

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -844,9 +844,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
844844
%14 = tt.broadcast %11 : tensor<1x16x!tt.ptr<f16>, #blocked> -> tensor<64x16x!tt.ptr<f16>, #blocked>
845845
%15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked>
846846
%16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr<f16>, #blocked>, tensor<64x16xi32, #blocked>
847-
// check that the load didn't get pipelined.
848-
// COMMON-NOT: alloc
849-
// COMMON: scf.for
847+
// check that the load with incompatiable shared encoding gets cloned and feeds into uses with same encoding
848+
// AMD-NOT: alloc
849+
// AMD: scf.for
850+
// CHECK: local_alloc
851+
// CHECK: local_alloc
852+
// CHECK: scf.for
853+
// CHECK: local_load {{.*}} tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1
854+
// CHECK: convert_layout {{.*}} tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0
855+
// CHECK: tt.dot
856+
// CHECK: tt.trans %arg
850857
%17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 {
851858
%18 = tt.load %16 : tensor<64x16x!tt.ptr<f16>, #blocked>
852859
%19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>

0 commit comments

Comments
 (0)