Skip to content

Commit 4f30282

Browse files
authored
Propagate DotOp thru Join & improve shmem load into LinearEnc (#5924)
There are two parts to this PR: - Propagate dotOp thru join, when dotOp is in the form of linearLayout (mostly reused @lezcano's logic for fp4ToFp) - Add rough optimization for shmem -> LL load Motivation for the second part: currently, shmem load into LL falls back to unswizzled shmem layout in the pipeliner, which results in poor performance. Not only does the `inline_asm` > `join` > `reshape` path suffer from this, so does `fp4_to_fp`. I've added some basic swizzling logic for the shmem layout when loading into dotOp-like LL's. As an example, for bf16xfp4 `dot_scaled` on a small M, large N/K shape, with fixed config (8, 128, 256), and `DISABLE_MMA_V3=1`: * before this shmem optimization: **~160us** * after this shmem optimization: **~124us** Similar improvements can be observed for bf16xint4 (with inline_asm). There's also a small change to increase kWidth in case of `join` by halving `origBitWidth`. This should also be important for perf, since otherwise shmem load width would be too small.   I believe there's still significant room for improvement for small-M shapes, because shmem -> LL does not yet support `ldmatrix`. I can look into this next. PTAL @lezcano @ThomasRaoux, thank you.
1 parent d71421d commit 4f30282

File tree

13 files changed

+397
-153
lines changed

13 files changed

+397
-153
lines changed

include/triton/Dialect/Triton/IR/Dialect.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,12 @@ class DialectInferLayoutInterface
7272

7373
virtual LogicalResult
7474
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
75+
ArrayRef<int64_t> shape,
7576
std::optional<Location> loc) const = 0;
7677

7778
virtual LogicalResult
7879
inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc,
80+
ArrayRef<int64_t> shape,
7981
std::optional<Location> loc) const = 0;
8082

8183
// Verify that the encoding are compatible to be used together in a dot

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,9 @@ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"
620620
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape) const;
621621
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape) const;
622622

623+
SmallVector<unsigned int> getContig(const char *, SmallVector<unsigned int>) const;
623624
SmallVector<unsigned> getContigPerThread() const;
625+
SmallVector<unsigned> getContigPerWarp() const;
624626
SmallVector<unsigned> getOrder() const;
625627

626628
// Generalizes get{Warp,Thread,CTA}Order to linear layouts.

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ unsigned
4949
getNumElementsPerThread(Operation *op, SmallVector<unsigned> order,
5050
triton::ModuleAxisInfoAnalysis &axisInfoAnalysis);
5151

52+
// Returns whether the op is a "view op", i.e. doesn't move any data
53+
bool isView(Operation *op);
54+
5255
/* Dump Triton IR in graphviz dot format.
5356
*
5457
* You can override `onValue` and `onOperation` in a subclass to mark

include/triton/Tools/LayoutUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ LinearLayout ensureLayoutNotSmallerThan(
8787
// are "dim0", "dim1", etc.
8888
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank);
8989

90+
// Return a vector of the standard out dimension name/value pairs, i.e.
91+
// ("dim0", dstShape[0]), ("dim1", dstShape[1]), etc.
92+
SmallVector<std::pair<StringAttr, int32_t>>
93+
standardOutDimPairs(MLIRContext *ctx, ArrayRef<int64_t> dstShape);
94+
9095
// Return an identity mapping from `inDimName` to the standard out dimensions,
9196
// with the dimensions sized according to the shape. The bases are sorted
9297
// according to `order`, with the most minor dimension first.

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,7 +1046,7 @@ JoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
10461046
Attribute retEnc;
10471047
if (srcEnc) {
10481048
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
1049-
->inferJoinOpEncoding(srcEnc, retEnc, location)
1049+
->inferJoinOpEncoding(srcEnc, retEnc, srcTy.getShape(), location)
10501050
.failed()) {
10511051
return failure();
10521052
}
@@ -1079,7 +1079,7 @@ LogicalResult SplitOp::inferReturnTypes(
10791079
Attribute retEnc;
10801080
if (srcEnc) {
10811081
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
1082-
->inferSplitOpEncoding(srcEnc, retEnc, location)
1082+
->inferSplitOpEncoding(srcEnc, retEnc, srcTy.getShape(), location)
10831083
.failed()) {
10841084
return failure();
10851085
}

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 161 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,42 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
488488
return encoding;
489489
}
490490

491+
LogicalResult tryJoinOnAxis(MLIRContext *ctx, const LinearLayout &inLl,
492+
LinearLayout &outLl, bool fwdInference, int axis,
493+
std::optional<Location> loc) {
494+
auto kRegister = StringAttr::get(ctx, "register");
495+
auto outDims = llvm::to_vector(inLl.getOutDimNames());
496+
if (fwdInference) {
497+
auto split = LinearLayout::identity1D(2, kRegister, outDims[axis]);
498+
outLl = split * inLl;
499+
} else {
500+
// TODO This requires a division algorithm!
501+
// Implement manually ll.divideLeft(split)
502+
auto contiguousElems =
503+
LinearEncodingAttr::get(ctx, inLl).getContigPerThread();
504+
if (contiguousElems[axis] > 1) {
505+
LinearLayout::BasesT newBases;
506+
for (const auto &basesDim : inLl.getBases()) {
507+
std::vector<std::vector<int32_t>> newBasesDim;
508+
for (auto base : basesDim.second) {
509+
if (base[axis] == 1) {
510+
continue;
511+
}
512+
base[axis] /= 2;
513+
newBasesDim.push_back(std::move(base));
514+
}
515+
newBases.insert({basesDim.first, std::move(newBasesDim)});
516+
}
517+
outLl = LinearLayout(std::move(newBases), std::move(outDims));
518+
} else {
519+
return emitOptionalError(loc,
520+
"Fp4ToFpOp/SplitOp requires at least 2 elements "
521+
"per thread in the axis/last dimension");
522+
}
523+
}
524+
return success();
525+
}
526+
491527
} // namespace gpu
492528
} // namespace triton
493529
} // namespace mlir
@@ -1239,28 +1275,39 @@ LinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
12391275
return scaledLayout.basesPerDim(kRegister, /*skipBroadcast=*/false);
12401276
}
12411277

1242-
SmallVector<unsigned> LinearEncodingAttr::getContigPerThread() const {
1278+
SmallVector<unsigned>
1279+
LinearEncodingAttr::getContig(const char *inDim,
1280+
SmallVector<unsigned int> lowerContig) const {
12431281
auto ll = getLinearLayout();
1244-
const auto &regs =
1245-
ll.getBases().find(StringAttr::get(getContext(), "register"))->second;
1282+
const auto &bases =
1283+
ll.getBases().find(StringAttr::get(getContext(), inDim))->second;
12461284
auto order = getOrder();
12471285
auto rank = order.size();
12481286

1249-
SmallVector<unsigned> contig(rank, 1);
1250-
auto regIt = regs.begin();
1287+
SmallVector<unsigned> contig(lowerContig);
1288+
auto basisIt = bases.begin();
12511289
for (unsigned dim : order) {
12521290
std::vector<int32_t> basis(rank, 0);
1253-
basis[dim] = 1;
1291+
basis[dim] = contig[dim];
12541292

1255-
while (regIt != regs.end() && *regIt == basis) {
1293+
while (basisIt != bases.end() && *basisIt == basis) {
12561294
contig[dim] *= 2;
12571295
basis[dim] *= 2;
1258-
++regIt;
1296+
++basisIt;
12591297
}
12601298
}
12611299
return contig;
12621300
}
12631301

1302+
SmallVector<unsigned> LinearEncodingAttr::getContigPerThread() const {
1303+
SmallVector<unsigned> contig(getOrder().size(), 1);
1304+
return getContig("register", contig);
1305+
}
1306+
1307+
SmallVector<unsigned> LinearEncodingAttr::getContigPerWarp() const {
1308+
return getContig("lane", getContigPerThread());
1309+
}
1310+
12641311
unsigned
12651312
LinearEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape) const {
12661313
return product(getElemsPerThread(shape));
@@ -2721,14 +2768,12 @@ struct TritonGPUInferLayoutInterface
27212768
}
27222769

27232770
auto newRank = dstShape.size();
2724-
SmallVector<std::pair<StringAttr, int32_t>> newOutDims;
2725-
for (auto [dim, size] :
2726-
llvm::zip(standardOutDimNames(ctx, newRank), dstShape)) {
2727-
newOutDims.emplace_back(dim, size);
2728-
}
2729-
auto srcOutDims = to_vector(src.getOutDimNames());
2771+
2772+
auto newOutDims = standardOutDimPairs(ctx, dstShape);
2773+
27302774
// reshapeOp assumes minor-to-major, so we need to transpose the out dims
27312775
// before the reshape
2776+
auto srcOutDims = to_vector(src.getOutDimNames());
27322777
std::reverse(srcOutDims.begin(), srcOutDims.end());
27332778
std::reverse(newOutDims.begin(), newOutDims.end());
27342779
auto dst = src.transposeOuts(srcOutDims)
@@ -2740,82 +2785,117 @@ struct TritonGPUInferLayoutInterface
27402785

27412786
LogicalResult
27422787
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
2788+
ArrayRef<int64_t> shape,
27432789
std::optional<Location> loc) const override {
2744-
auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
2745-
if (!enc) {
2746-
return emitOptionalError(loc,
2747-
"JoinOp can only operate on BlockedEncoding");
2790+
if (auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc)) {
2791+
// JoinOp takes two tensors of shape AxBxC and generates a tensor of shape
2792+
// AxBxCx2. The encoding is the same as the input, but with 2 elems per
2793+
// thread in the new dimension. The new dimension is most-minor.
2794+
auto append = [](ArrayRef<unsigned> vals, int val) {
2795+
SmallVector<unsigned> ret(vals);
2796+
ret.push_back(val);
2797+
return ret;
2798+
};
2799+
auto appendMinorDim = [](ArrayRef<unsigned> order) {
2800+
SmallVector<unsigned> ret(order);
2801+
ret.insert(ret.begin(), ret.size());
2802+
return ret;
2803+
};
2804+
dstEnc = BlockedEncodingAttr::get(
2805+
enc.getContext(), //
2806+
append(enc.getSizePerThread(), 2), //
2807+
append(enc.getThreadsPerWarp(), 1), //
2808+
append(enc.getWarpsPerCTA(), 1), //
2809+
appendMinorDim(enc.getOrder()), //
2810+
CTALayoutAttr::get(enc.getContext(), //
2811+
append(enc.getCTAsPerCGA(), 1),
2812+
append(enc.getCTASplitNum(), 1),
2813+
appendMinorDim(enc.getCTAOrder())));
2814+
return success();
27482815
}
27492816

2750-
// JoinOp takes two tensors of shape AxBxC and generates a tensor of shape
2751-
// AxBxCx2. The encoding is the same as the input, but with 2 elems per
2752-
// thread in the new dimension. The new dimension is most-minor.
2753-
auto append = [](ArrayRef<unsigned> vals, int val) {
2754-
SmallVector<unsigned> ret(vals);
2755-
ret.push_back(val);
2756-
return ret;
2757-
};
2758-
auto appendMinorDim = [](ArrayRef<unsigned> order) {
2759-
SmallVector<unsigned> ret(order);
2760-
ret.insert(ret.begin(), ret.size());
2761-
return ret;
2762-
};
2763-
dstEnc = BlockedEncodingAttr::get(
2764-
enc.getContext(), //
2765-
append(enc.getSizePerThread(), 2), //
2766-
append(enc.getThreadsPerWarp(), 1), //
2767-
append(enc.getWarpsPerCTA(), 1), //
2768-
appendMinorDim(enc.getOrder()), //
2769-
CTALayoutAttr::get(enc.getContext(), //
2770-
append(enc.getCTAsPerCGA(), 1),
2771-
append(enc.getCTASplitNum(), 1),
2772-
appendMinorDim(enc.getCTAOrder())));
2817+
auto ctx = getContext();
2818+
2819+
// Append dim to shape
2820+
auto ll = toLinearLayout(shape, srcEnc);
2821+
SmallVector<int64_t> dstShape(shape.begin(), shape.end());
2822+
dstShape.push_back(1);
2823+
ll = ll.reshapeOuts(standardOutDimPairs(ctx, dstShape));
2824+
2825+
// Try join on last dim
2826+
auto axis = dstShape.size() - 1;
2827+
auto newLl = LinearLayout::empty();
2828+
auto result =
2829+
tryJoinOnAxis(ctx, ll, newLl, /*fwdInference=*/true, axis, loc);
2830+
2831+
assert(result.succeeded());
2832+
dstEnc = LinearEncodingAttr::get(ctx, newLl);
27732833
return success();
27742834
}
27752835

27762836
LogicalResult
27772837
inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc,
2838+
ArrayRef<int64_t> shape,
27782839
std::optional<Location> loc) const override {
27792840
auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
2780-
if (!enc) {
2781-
return emitOptionalError(loc,
2782-
"SplitOp can only operate on BlockedEncoding");
2841+
if (enc) {
2842+
// SplitOp takes a tensor of shape AxBxCx2 and generates two tensors of
2843+
// shape AxBxC. The input must have 2 elements per thread in the last
2844+
// dimension, which must be most-minor. The result encoding is the same
2845+
// as the input, but with the last dimension removed.
2846+
if (enc.getSizePerThread().back() != 2) {
2847+
return emitOptionalError(
2848+
loc, "SplitOp requires 2 elements per thread in the "
2849+
"last dimension of the input");
2850+
}
2851+
if (enc.getThreadsPerWarp().back() != 1 ||
2852+
enc.getWarpsPerCTA().back() != 1 || enc.getCTAsPerCGA().back() != 1) {
2853+
return emitOptionalError(
2854+
loc, "SplitOp requires threadsPerWarp, warpsPerCTA, "
2855+
"and CTAsPerCGA = 1 for the last dimension of the input");
2856+
}
2857+
if (enc.getCTALayout().getCTAsPerCGA().back() != 1) {
2858+
return emitOptionalError(
2859+
loc,
2860+
"SplitOp requires the last dimension to be most-minor in CTAOrder");
2861+
}
2862+
SmallVector<unsigned> newOrder(enc.getOrder());
2863+
int splitDim = newOrder.size() - 1;
2864+
// Remove splitDim from order.
2865+
newOrder.erase(std::remove(newOrder.begin(), newOrder.end(), splitDim),
2866+
newOrder.end());
2867+
dstEnc = BlockedEncodingAttr::get(
2868+
enc.getContext(), //
2869+
ArrayRef(enc.getSizePerThread()).drop_back(1),
2870+
ArrayRef(enc.getThreadsPerWarp()).drop_back(1),
2871+
ArrayRef(enc.getWarpsPerCTA()).drop_back(1), ArrayRef(newOrder),
2872+
CTALayoutAttr::get(enc.getContext(), //
2873+
ArrayRef(enc.getCTAsPerCGA()).drop_back(1),
2874+
ArrayRef(enc.getCTASplitNum()).drop_back(1),
2875+
ArrayRef(enc.getCTAOrder()).drop_front(1)));
2876+
return success();
27832877
}
27842878

2785-
// SplitOp takes a tensor of shape AxBxCx2 and generates two tensors of
2786-
// shape AxBxC. The input must have 2 elements per thread in the last
2787-
// dimension, which must be most-minor. The result encoding is the same as
2788-
// the input, but with the last dimension removed.
2789-
if (enc.getSizePerThread().back() != 2) {
2790-
return emitOptionalError(loc,
2791-
"SplitOp requires 2 elements per thread in the "
2792-
"last dimension of the input");
2793-
}
2794-
if (enc.getThreadsPerWarp().back() != 1 ||
2795-
enc.getWarpsPerCTA().back() != 1 || enc.getCTAsPerCGA().back() != 1) {
2796-
return emitOptionalError(
2797-
loc, "SplitOp requires threadsPerWarp, warpsPerCTA, "
2798-
"and CTAsPerCGA = 1 for the last dimension of the input");
2879+
auto axis = shape.size() - 1;
2880+
assert(shape[axis] == 2 &&
2881+
"SplitOp input shape should have 2 in the last dim");
2882+
2883+
auto ctx = getContext();
2884+
2885+
// Split on last dim
2886+
auto ll = toLinearLayout(shape, srcEnc);
2887+
auto newLl = LinearLayout::empty();
2888+
auto result =
2889+
tryJoinOnAxis(ctx, ll, newLl, /*fwdInference=*/false, axis, loc);
2890+
if (!result.succeeded()) {
2891+
return failure();
27992892
}
2800-
if (enc.getCTALayout().getCTAsPerCGA().back() != 1) {
2801-
return emitOptionalError(
2802-
loc,
2803-
"SplitOp requires the last dimension to be most-minor in CTAOrder");
2804-
}
2805-
SmallVector<unsigned> newOrder(enc.getOrder());
2806-
int splitDim = newOrder.size() - 1;
2807-
// Remove splitDim from order.
2808-
newOrder.erase(std::remove(newOrder.begin(), newOrder.end(), splitDim),
2809-
newOrder.end());
2810-
dstEnc = BlockedEncodingAttr::get(
2811-
enc.getContext(), //
2812-
ArrayRef(enc.getSizePerThread()).drop_back(1),
2813-
ArrayRef(enc.getThreadsPerWarp()).drop_back(1),
2814-
ArrayRef(enc.getWarpsPerCTA()).drop_back(1), ArrayRef(newOrder),
2815-
CTALayoutAttr::get(enc.getContext(), //
2816-
ArrayRef(enc.getCTAsPerCGA()).drop_back(1),
2817-
ArrayRef(enc.getCTASplitNum()).drop_back(1),
2818-
ArrayRef(enc.getCTAOrder()).drop_front(1)));
2893+
2894+
// Remove last dim from newLl (which should be 1)
2895+
SmallVector<int64_t> dstShape(shape.begin(), shape.end());
2896+
dstShape.pop_back();
2897+
newLl = newLl.reshapeOuts(standardOutDimPairs(ctx, dstShape));
2898+
dstEnc = LinearEncodingAttr::get(ctx, newLl);
28192899
return success();
28202900
}
28212901

@@ -2873,37 +2953,10 @@ struct TritonGPUInferLayoutInterface
28732953
}
28742954

28752955
auto ll = toLinearLayout(shape, inEnc);
2876-
2877-
auto kRegister = StringAttr::get(ctx, "register");
2878-
auto outDims = llvm::to_vector(ll.getOutDimNames());
2879-
LinearLayout newLl = LinearLayout::empty();
2880-
if (fwdInference) {
2881-
auto split = LinearLayout::identity1D(2, kRegister, outDims[axis]);
2882-
newLl = split * ll;
2883-
} else {
2884-
// TODO This requires a division algorithm!
2885-
// Implement manually ll.divideLeft(split)
2886-
auto contiguousElems =
2887-
LinearEncodingAttr::get(ctx, ll).getContigPerThread();
2888-
if (contiguousElems[axis] > 1) {
2889-
LinearLayout::BasesT newBases;
2890-
for (const auto &basesDim : ll.getBases()) {
2891-
std::vector<std::vector<int32_t>> newBasesDim;
2892-
for (auto base : basesDim.second) {
2893-
if (base[axis] == 1) {
2894-
continue;
2895-
}
2896-
base[axis] /= 2;
2897-
newBasesDim.push_back(std::move(base));
2898-
}
2899-
newBases.insert({basesDim.first, std::move(newBasesDim)});
2900-
}
2901-
newLl = LinearLayout(std::move(newBases), std::move(outDims));
2902-
} else {
2903-
return emitOptionalError(loc, "Fp4ToFpOp requires at least 2 elements "
2904-
"per thread in the axis dimension");
2905-
}
2906-
}
2956+
auto newLl = LinearLayout::empty();
2957+
auto result = tryJoinOnAxis(ctx, ll, newLl, fwdInference, axis, loc);
2958+
if (!result.succeeded())
2959+
return result;
29072960
outEnc = LinearEncodingAttr::get(ctx, newLl);
29082961
return success();
29092962
}

0 commit comments

Comments
 (0)