Skip to content

Commit 38f8167

Browse files
authored
[LAYOUTS] Fix mixed precision swizzling (#6565) (#7032)
Reland of triton-lang/triton#6565
1 parent 6af4919 commit 38f8167

File tree

5 files changed

+101
-57
lines changed

5 files changed

+101
-57
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,11 @@ LinearLayout chooseScaledMfmaScaleLayout(
287287
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
288288
ArrayRef<int64_t> dotOperandShape, unsigned mfmaMDim);
289289

290+
// Create LinearLayout for nvidia mma tile.
291+
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
292+
unsigned kWidth, ArrayRef<unsigned> order,
293+
ArrayRef<unsigned> repOrder);
294+
290295
// Create a LinearLayout similar to mfmaLayout, but changing each thread to hold
291296
// 8 elements. This layout is useful for emitting the widest 128-bit global
292297
// store instructions. Since it closely resembles mfmaLayout, conversion between

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -312,54 +312,46 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
312312
if(!mmaEnc)
313313
return get(context, 1, 1, 1, order, CTALayout);
314314

315-
int opIdx = dotOpEnc.getOpIdx();
316-
auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
317-
318-
// number of rows per phase
319-
320-
// index of the inner dimension in `order`
321-
unsigned inner = (opIdx == 0) ? 0 : 1;
322-
323315
// ---- begin Ampere & Hopper ----
324316
if (mmaEnc.isAmpere() || mmaEnc.isHopper()) {
325-
int perPhase = 128 / (std::max<int>(1, shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth()));
326-
perPhase = std::max<int>(perPhase, 1);
327-
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
328-
int vecWidth = 32 / typeWidthInBit;
329-
if (vecWidth != dotOpEnc.getKWidth() && order[0] == inner) {
330-
perPhase = std::max<int>(perPhase, 2 * vecWidth);
331-
}
332-
int rank = order.size();
333-
// --- handle A operand ---
334-
if (opIdx == 0) { // compute swizzling for A operand
335-
int m = (needTrans) ? matShape[2] : matShape[0];
336-
int k = (needTrans) ? matShape[0] : matShape[2];
337-
int vec = (order[0] == rank-1) ? k : m;
338-
int mmaStride = (order[0] == rank-1) ? m : k;
339-
int maxPhase = std::max(mmaStride / perPhase, 1);
340-
return get(context, vec, perPhase, maxPhase, order, CTALayout);
341-
}
342-
343-
// --- handle B operand ---
344-
if (opIdx == 1) {
345-
// we compute vec and maxPhase m, n and k size of the mma
346-
// instruction. when matmul operands is transposed, we should
347-
// consider that to get m, n and k.
348-
int n = needTrans ? matShape[2] : matShape[1];
349-
int k = needTrans ? matShape[1] : matShape[2];
350-
int vec = (order[0] == rank-1) ? n : k;
351-
int mmaStride = (order[0] == rank-1) ? k : n;
352-
int maxPhase = std::max(mmaStride / perPhase, 1);
353-
return get(context, vec, perPhase, maxPhase, order, CTALayout);
354-
}
355-
356-
llvm_unreachable("invalid operand index");
317+
return get(context, dotOpEnc.getOpIdx(), dotOpEnc.getKWidth(), shape, order, CTALayout, typeWidthInBit, needTrans);
357318
}
358319

359320
// ---- not implemented ----
360321
llvm_unreachable("unsupported swizzling for provided MMA version");
361322
}]>,
362323

324+
// NVIDIA constructor!
325+
// TODO(lezcano): We should totally get rid of all these constructors...
326+
AttrBuilder<(ins "int":$opIdx,
327+
"unsigned":$kWidth,
328+
"ArrayRef<int64_t>":$shape,
329+
"ArrayRef<unsigned>":$order,
330+
"CTALayoutAttr":$CTALayout,
331+
"unsigned":$bitwidth,
332+
"bool":$needTrans), [{
333+
int K = getShapePerCTA(CTALayout.getCTASplitNum(), shape)[order[0]];
334+
// Elems necessary to cover all the banks divided by the inner dimension
335+
// This packs a few rows together for small K
336+
int perPhase = std::max<int>(1024 / (bitwidth * K), 1);
337+
338+
int mmaStride = 8;
339+
int vec = 4 * kWidth;
340+
// needsTrans is equiv. to flipping the opIdx
341+
if (needTrans)
342+
std::swap(vec, mmaStride);
343+
assert(opIdx == 0 || opIdx == 1);
344+
int rank = order.size();
345+
int kDim = opIdx == 0 ? rank-1 : rank-2;
346+
if (order[0] != kDim)
347+
std::swap(vec, mmaStride);
348+
// Count how many vec elements are needed to cover all the banks
349+
int maxPhase = std::max(std::min<int>(mmaStride, 1024 / (vec * bitwidth)), 1);
350+
// Account for the row packing from perPhase: mmaStride / perPhase
351+
maxPhase = std::max(maxPhase / perPhase, 1);
352+
return get(context, vec, perPhase, maxPhase, order, CTALayout);
353+
}]>,
354+
363355
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
364356
"ArrayRef<int64_t>":$shape,
365357
"ArrayRef<unsigned>":$order,

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,8 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
153153
bool kContig) {
154154
// kContig: if true, the matrix is fastest-running on k,
155155
// otherwise it is on m (resp. n)
156-
// opIdx=0: [batch, m, k] if rank == 3 else [m, k]
157-
// opIdx=1: [batch, k, n] if rank == 3 else [k, n]
158-
// batch (if rank == 3) is always the slowest running dimension
159-
assert(rank == 2 || rank == 3);
156+
// opIdx=0: [*batch, m, k]
157+
// opIdx=1: [*batch, k, n]
160158
assert(opIdx == 0 || opIdx == 1);
161159
auto rowMajor = bool(opIdx) != kContig;
162160
return getMatrixOrder(rank, rowMajor);

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,6 +1058,45 @@ StringRef getAMDArch(Operation *module) {
10581058
return ref.drop_front(4); // drop the "hip:"
10591059
}
10601060

1061+
inline ttg::SwizzledSharedEncodingAttr
1062+
swizzleDotOperandLike(RankedTensorType type, ttg::CTALayoutAttr ctaLayout) {
1063+
// We want to see if the linear layout has the same order as an mma microtile
1064+
// of shape (8, 4*kWidth) or (4*kWidth, 8). If so, we return a
1065+
// DotOperandEncodingAttr with a tile of this shape This works because
1066+
// SwizzledSharedEncodingAttr::get just looks at the microtile to determine
1067+
// the swizzling
1068+
1069+
auto *ctx = type.getContext();
1070+
auto layout = ttg::toLinearEncoding(type);
1071+
auto order = layout.getThreadOrder();
1072+
auto rank = order.size();
1073+
if (rank < 2) {
1074+
return {};
1075+
}
1076+
int opIdx;
1077+
if (ttg::getOrderForDotOperand(0, rank, /*kContig=*/true) == order) {
1078+
opIdx = 0;
1079+
} else if (ttg::getOrderForDotOperand(1, rank, /*kContig=*/true) == order) {
1080+
opIdx = 1;
1081+
} else {
1082+
return {};
1083+
}
1084+
auto kWidth = layout.getContigPerThread()[order[0]];
1085+
SmallVector<unsigned> microtileShape(rank, 1);
1086+
microtileShape[order[0]] = 4 * kWidth;
1087+
microtileShape[order[1]] = 8;
1088+
// All the LinearLayouts contained within LinearEncoidngAttr have order [0, 1,
1089+
// 2, ...]
1090+
auto repOrder = to_vector(llvm::seq<unsigned>(rank));
1091+
auto tile = ttg::nvidiaMmaTile(ctx, microtileShape, kWidth, order, repOrder);
1092+
if (!divideLeft(layout.getLinearLayout(), tile).has_value()) {
1093+
return {};
1094+
}
1095+
return ttg::SwizzledSharedEncodingAttr::get(
1096+
ctx, opIdx, kWidth, type.getShape(), order, ctaLayout,
1097+
type.getElementTypeBitWidth(), false);
1098+
}
1099+
10611100
// If all the transitive uses of the given value have are used by a convert to
10621101
// the same dot operand encoding, return the shared encoding that needs to be
10631102
// used to be compatible with users' layouts. If there are incompatible shared
@@ -1084,18 +1123,28 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
10841123
} else {
10851124
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
10861125
return std::nullopt;
1087-
auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(
1088-
cast<triton::gpu::TensorOrMemDesc>(user->getResult(0).getType())
1089-
.getEncoding());
1090-
if (!dotOpEnc)
1091-
return std::nullopt;
10921126
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
1093-
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
1094-
auto order = getOrderForMemory(srcTy);
1095-
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
1096-
tempAttr = ttg::SwizzledSharedEncodingAttr::get(
1097-
val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout,
1098-
bitWidth, /*needTrans=*/false);
1127+
auto dstTy = cast<RankedTensorType>(user->getResult(0).getType());
1128+
1129+
// FIXME This may not be correct for multiple CTA, but getCTALayout is NYI
1130+
// for LinearEncodingAttr
1131+
auto CTALayout = isa<ttg::LinearEncodingAttr>(dstTy.getEncoding())
1132+
? ttg::getCTALayout(srcTy.getEncoding())
1133+
: ttg::getCTALayout(dstTy.getEncoding());
1134+
1135+
if (auto dot =
1136+
dyn_cast<ttg::DotOperandEncodingAttr>(dstTy.getEncoding())) {
1137+
auto order = getOrderForMemory(srcTy);
1138+
unsigned bitWidth = srcTy.getElementTypeBitWidth();
1139+
tempAttr = ttg::SwizzledSharedEncodingAttr::get(
1140+
val.getContext(), dot, srcTy.getShape(), order, CTALayout, bitWidth,
1141+
/*needTrans=*/false);
1142+
} else {
1143+
// Try to see if the layout is like an mma microtile
1144+
tempAttr = swizzleDotOperandLike(dstTy, CTALayout);
1145+
}
1146+
if (!tempAttr)
1147+
return std::nullopt;
10991148
}
11001149
// Check that the shared encodings needed by the users are compatible.
11011150
if (attr != nullptr && attr != tempAttr) {

test/TritonGPU/reduce-data-duplication.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: triton-opt %s -split-input-file -tritongpu-reduce-data-duplication | FileCheck %s
22

3-
// CHECK: #[[$SHARED:.*]] = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}
3+
// CHECK: #[[$SHARED:.*]] = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1]}
44
// CHECK-LABEL: apply_swizzle
55
// CHECK: %{{.*}} = ttg.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !ttg.memdesc<16x256xf16, #[[$SHARED]], #smem>
66

@@ -29,7 +29,7 @@ module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-
2929

3030
// -----
3131

32-
// CHECK: #[[$SHARED:.*]] = #ttg.swizzled_shared<{vec = 32, perPhase = 128, maxPhase = 1, order = [1, 0]}>
32+
// CHECK: #[[$SHARED:.*]] = #ttg.swizzled_shared<{vec = 32, perPhase = 64, maxPhase = 1, order = [1, 0]}>
3333
// CHECK-LABEL: handles_small_contiguous_dim
3434
// CHECK: %{{.*}} = ttg.local_alloc %{{.*}} : (tensor<32x1xf16, #{{.*}}>) -> !ttg.memdesc<32x1xf16, #[[$SHARED]], #smem>
3535

0 commit comments

Comments
 (0)