Skip to content

Commit a89b3b4

Browse files
committed
Merge branch 'shared/triton-gfx950-launch' into shared/triton-gfx950-launch-update-rebase
2 parents c5ceb64 + 77c00fa commit a89b3b4

File tree

15 files changed

+534
-197
lines changed

15 files changed

+534
-197
lines changed

fa/flash-attention.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,11 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
243243
BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr,
244244
PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr,
245245
RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr,
246-
QK_SCALE: tl.constexpr, INT8_GEMM: tl.constexpr, USE_P_SCALE: tl.constexpr, INT8_KV: tl.constexpr):
246+
QK_SCALE: tl.constexpr, INT8_GEMM: tl.constexpr, USE_P_SCALE: tl.constexpr, INT8_KV: tl.constexpr,
247+
ENABLE_PIPELINING: tl.constexpr):
247248
# loop over k, v, and update accumulator
248-
for start_n in range(block_min, block_max, BLOCK_N):
249+
num_stages: tl.constexpr = None if ENABLE_PIPELINING else 1 # Set num_stages==1 if we want to disable pipelining
250+
for start_n in tl.range(block_min, block_max, BLOCK_N, num_stages=num_stages):
249251
# For padded blocks, we will overrun the tensor size if
250252
# we load all BLOCK_N. For others, the blocks are all within range.
251253
if MASK_STEPS:
@@ -674,7 +676,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh
674676
# _, MASK_STEPS, ...
675677
PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX,
676678
PADDED_HEAD, ACTUAL_BLOCK_DMODEL, QK_SCALE, INT8_GEMM, USE_P_SCALE,
677-
INT8_KV)
679+
INT8_KV, True)
678680
block_min = block_max
679681
block_max = n_blocks * BLOCK_N
680682

@@ -698,7 +700,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh
698700
p_scale, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
699701
# _, MASK_STEPS, ...
700702
PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, ACTUAL_BLOCK_DMODEL,
701-
QK_SCALE, INT8_GEMM, USE_P_SCALE, INT8_KV)
703+
QK_SCALE, INT8_GEMM, USE_P_SCALE, INT8_KV, False)
702704

703705
if INT8 and not INT8_KV:
704706
if USE_P_SCALE:

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -698,13 +698,15 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
698698
Type elemLlvmTy, std::optional<int32_t> maxVecElems,
699699
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
700700
const TargetInfoBase &target,
701-
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
701+
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback,
702+
bool forceLane0 = false);
702703

703704
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
704705
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
705706
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
706707
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
707-
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
708+
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback,
709+
bool forceLane0 = false);
708710

709711
SmallVector<Value> loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp,
710712
Type elemLlvmTy,

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,7 @@ LinearLayout chooseScaledMfmaScaleLayout(
287287
// 8 elements. This layout is useful for emitting the widest 128-bit global
288288
// store instructions. Since it closely resembles mfmaLayout, conversion between
289289
// the two can be done using transferWithinWarp, without involving LDS
290-
LinearLayout chooseMfmaLikeStoreLayout(AMDMfmaEncodingAttr mfmaLayout,
291-
ArrayRef<int64_t> shape);
290+
std::optional<LinearLayout> chooseMfmaLikeStoreLayout(RankedTensorType valType);
292291

293292
} // namespace mlir::triton::gpu
294293
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3535
"TRITON_HIP_LOCAL_PREFETCH",
3636
"TRITON_HIP_USE_ASYNC_COPY",
3737
"TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE",
38+
"TRITON_HIP_ASYNC_COPY_OVERLAP",
3839
"TRITON_HIP_ENABLE_F16_ASYNC_PINGPONG",
3940
"TRITON_HIP_USE_BLOCK_PINGPONG",
4041
"TRITON_HIP_USE_IN_THREAD_TRANSPOSE",
42+
"TRITON_HIP_ASYNC_FAST_SWIZZLE",
4143
"TRITON_LLVM_DEBUG_ONLY",
4244
"TRITON_ENABLE_ASAN",
4345
"TRITON_OVERRIDE_ARCH",

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,8 @@ bool emitTransferBetweenRegistersAndShared(
409409
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
410410
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
411411
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
412-
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
412+
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback,
413+
bool forceLane0) {
413414
MLIRContext *ctx = rewriter.getContext();
414415
auto b = TritonLLVMOpBuilder(loc, rewriter);
415416

@@ -452,6 +453,17 @@ bool emitTransferBetweenRegistersAndShared(
452453

453454
auto withCTAOffset = triton::gpu::getNumCTAs(sharedTy.getEncoding()) > 1;
454455
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
456+
if (forceLane0) {
457+
laneId = b.i32_val(0);
458+
// NFC it's copied from getLaneAndWarpId but we add a shuffleIdx(0) to the
459+
// tid so LLVM sees that warpId is a scalar
460+
// This is not optimal as it adds a readlane which is not necessary but
461+
// better than getting readfirstlanes for every direct-to-lds load
462+
Value tid = target.shuffleIdx(rewriter, loc, getThreadId(rewriter, loc), 0);
463+
int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(rewriter);
464+
Value warpSizeVal = b.i32_val(threadsPerWarp);
465+
warpId = b.udiv(tid, warpSizeVal);
466+
}
455467
Value blockId =
456468
withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0);
457469

@@ -486,12 +498,13 @@ bool emitTransferBetweenRegistersAndShared(
486498
Type elemLlvmTy, std::optional<int32_t> maxVecElems,
487499
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
488500
const TargetInfoBase &target,
489-
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
501+
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback,
502+
bool forceLane0) {
490503
auto regLayout = triton::gpu::toLinearLayout(registerTy.getShape(),
491504
registerTy.getEncoding());
492505
return emitTransferBetweenRegistersAndShared(
493506
regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter,
494-
target, perVectorCallback);
507+
target, perVectorCallback, forceLane0);
495508
}
496509

497510
SmallVector<Value> loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp,

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,37 +1533,39 @@ LinearLayout chooseScaledMfmaScaleLayout(
15331533
return newLL;
15341534
}
15351535

1536-
LinearLayout chooseMfmaLikeStoreLayout(AMDMfmaEncodingAttr mfmaLayout,
1537-
ArrayRef<int64_t> shape) {
1538-
assert(shape.size() == 2 && mfmaLayout.getMDim() == 32 &&
1539-
mfmaLayout.getNDim() == 32 && mfmaLayout.getIsTransposed());
1540-
1541-
MLIRContext *ctx = mfmaLayout.getContext();
1542-
StringAttr kRegister = S("register");
1543-
StringAttr kLane = S("lane");
1544-
StringAttr kWarp = S("warp");
1545-
StringAttr kBlock = S("block");
1546-
1547-
SmallVector<unsigned> order = getDefaultMmaOrder(mfmaLayout);
1548-
auto standardOutDims = standardOutDimNames(ctx, 2);
1549-
// We make each thread handle 8 consecutive elements to enable 128-bit
1550-
// global stores for [b]f16 types and keep the thread pattern in each lane
1551-
// similar to the canonical mfmaLayout.
1552-
LinearLayout mfma8Layout = LinearLayout::empty();
1553-
mfma8Layout =
1554-
LinearLayout({{kRegister, {{1, 0}, {2, 0}, {4, 0}}},
1555-
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}},
1556-
{kWarp, {}},
1557-
{kBlock, {}}},
1558-
{standardOutDims[order[0]], standardOutDims[order[1]]});
1559-
1560-
LinearLayout warpLayout =
1561-
identityStandardND(kWarp, mfmaLayout.getWarpsPerCTA(), order);
1562-
LinearLayout ctaLayout = mfma8Layout.transposeOuts(standardOutDims) *
1563-
warpLayout.transposeOuts(standardOutDims);
1564-
mfma8Layout =
1565-
combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape);
1566-
return mfma8Layout;
1536+
std::optional<LinearLayout>
1537+
chooseMfmaLikeStoreLayout(RankedTensorType valType) {
1538+
auto mfmaLayout = cast<AMDMfmaEncodingAttr>(valType.getEncoding());
1539+
1540+
// We currently only support transposed [B]F16 MFMA32x32 on CDNA4.
1541+
bool isMfma32 = mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32;
1542+
Type elemType = valType.getElementType();
1543+
if (!(valType.getRank() == 2 && (elemType.isF16() || elemType.isBF16()) &&
1544+
mfmaLayout.getVersionMajor() == 4 && mfmaLayout.getIsTransposed() &&
1545+
isMfma32))
1546+
return {};
1547+
1548+
auto valShape = valType.getShape();
1549+
LinearLayout mfmaLL = mfmaLayout.toLinearLayout(valShape);
1550+
auto mfmaOutDims = llvm::to_vector(mfmaLL.getOutDimNames());
1551+
StringAttr dimM = mfmaOutDims[0];
1552+
StringAttr dimN = mfmaOutDims[1];
1553+
1554+
auto swapLL = LinearLayout::empty();
1555+
// The rows are kept as is with an identity linear layout.
1556+
swapLL *= LinearLayout::identity1D(valShape[0], dimM, dimM);
1557+
// In transposed mfma32 layout, each thread holds 4 consecutive values along N
1558+
// dim. We want to exchange column 4-7 (owned by thread 32-63) and column 8-11
1559+
// (owned by thread 0-31) every 16 columns to make each thread holds 8
1560+
// elements. This would mean exchange the 2nd and 3rd basis vector from an
1561+
// identity linear layout.
1562+
std::vector<std::vector<int32_t>> dimNBases(mfmaLL.getOutDimSizeLog2(dimN));
1563+
std::generate(dimNBases.begin(), dimNBases.end(),
1564+
[i = 0]() mutable { return std::vector<int32_t>{1 << i++}; });
1565+
std::swap(dimNBases[2], dimNBases[3]);
1566+
swapLL *= LinearLayout({{dimN, dimNBases}}, {dimN});
1567+
1568+
return mfmaLL.compose(swapLL);
15671569
}
15681570

15691571
LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType,

test/Conversion/amd/buffer_load_to_local_to_llvm.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
145145
%arg2: !ttg.memdesc<64xf32, #shared, #smem, mutable>) {
146146
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked>
147147
// The first constant 0 skips the LDS offset which is also 0
148-
// COMMON: llvm.getelementptr
148+
// COMMON: rocdl.make.buffer.rsrc
149+
// COMMON: llvm.select
149150
// COMMON: llvm.mlir.constant(0 : i32) : i32
150151
// COMMON: %[[aux_ca:.*]] = llvm.mlir.constant(0 : i32) : i32
151-
// COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]]
152+
// COMMON: llvm.mlir.constant(0 : i32) : i32
153+
// COMMON-: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]]
152154
%1 = amdgpu.buffer_load_to_local %arg0[%0] cacheModifier = ca into %arg2: <f32>[tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable>
153155
// COMMON: llvm.getelementptr
154156
// COMMON: %[[aux_cg:.*]] = llvm.mlir.constant(3 : i32) : i32

0 commit comments

Comments
 (0)