Skip to content

Commit bca378d

Browse files
authored
[BACKEND] Revert smem layout heuristic added in PR#5924 (#5983)
Partial revert of triton-lang/triton#5924 The changed heuristic to pick the swizzling causes performance regression in some cases. Reverting it for now. cc: @ggengnv
1 parent 2b2a872 commit bca378d

File tree

1 file changed

+8
-58
lines changed

1 file changed

+8
-58
lines changed

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 8 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,46 +1030,6 @@ StringRef getAMDArch(Operation *module) {
10301030
return ref.drop_front(4); // drop the "hip:"
10311031
}
10321032

1033-
// Rough utility for obtaining a SharedEnc for a LinearEncoding,
1034-
// as we've replaced DotOpEnc with Linear in some cases
1035-
// (specifically, fp4ToFp and similar unpack-upcast thru join)
1036-
std::optional<ttg::SwizzledSharedEncodingAttr>
1037-
getSharedForLinear(ttg::LinearEncodingAttr enc,
1038-
ArrayRef<unsigned int> globalOrder, ArrayRef<int64_t> shape,
1039-
unsigned elemBitWidth, ttg::CTALayoutAttr ctaLayout) {
1040-
auto ctx = enc.getContext();
1041-
auto ll = enc.getLinearLayout();
1042-
auto rank = shape.size();
1043-
1044-
if (rank != 2)
1045-
return std::nullopt;
1046-
1047-
auto order = enc.getOrder();
1048-
assert(globalOrder.size() == rank);
1049-
// TODO add memdesc_trans support for dot(trans(cvt(src) #linear) #dot_op)
1050-
if (order != globalOrder)
1051-
return std::nullopt;
1052-
1053-
auto innerDim = order[0];
1054-
auto outerDim = order[1];
1055-
auto contigPerWarp = enc.getContigPerWarp();
1056-
1057-
constexpr unsigned BANK_SIZE{128};
1058-
auto elemBytes = elemBitWidth / 8;
1059-
1060-
auto vec = contigPerWarp[innerDim];
1061-
auto rowSize = elemBytes * (unsigned)shape[innerDim];
1062-
auto perPhase = std::max(BANK_SIZE / rowSize, 1u);
1063-
auto maxPhase = std::max(contigPerWarp[outerDim] / perPhase, 1u);
1064-
1065-
// cp.async does not support transfer size < 4B
1066-
if (vec * elemBytes < 4 && perPhase < maxPhase)
1067-
return std::nullopt;
1068-
1069-
return ttg::SwizzledSharedEncodingAttr::get(ctx, vec, perPhase, maxPhase,
1070-
order, ctaLayout);
1071-
}
1072-
10731033
// If all the transitive uses of the given value have are used by a convert to
10741034
// the same dot operand encoding, return the shared encoding that needs to be
10751035
// used to be compatible with users' layouts. If there are incompatible shared
@@ -1096,28 +1056,18 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
10961056
} else {
10971057
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
10981058
return std::nullopt;
1099-
auto enc =
1059+
auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(
11001060
cast<triton::gpu::TensorOrMemDesc>(user->getResult(0).getType())
1101-
.getEncoding();
1061+
.getEncoding());
1062+
if (!dotOpEnc)
1063+
return std::nullopt;
11021064
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
1103-
auto ctaLayout = ttg::getCTALayout(srcTy.getEncoding());
1065+
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
11041066
auto order = ttg::getOrder(srcTy.getEncoding());
11051067
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
1106-
1107-
if (auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(enc)) {
1108-
tempAttr = ttg::SwizzledSharedEncodingAttr::get(
1109-
val.getContext(), dotOpEnc, srcTy.getShape(), order, ctaLayout,
1110-
bitWidth, /*needTrans=*/false);
1111-
} else if (auto linearEnc = dyn_cast<ttg::LinearEncodingAttr>(enc)) {
1112-
1113-
auto attrOpt = getSharedForLinear(linearEnc, order, srcTy.getShape(),
1114-
bitWidth, ctaLayout);
1115-
if (!attrOpt)
1116-
return std::nullopt;
1117-
tempAttr = *attrOpt;
1118-
} else {
1119-
return std::nullopt;
1120-
}
1068+
tempAttr = ttg::SwizzledSharedEncodingAttr::get(
1069+
val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout,
1070+
bitWidth, /*needTrans=*/false);
11211071
}
11221072
// Check that the shared encodings needed by the users are compatible.
11231073
if (attr != nullptr && attr != tempAttr) {

0 commit comments

Comments
 (0)