Skip to content

Commit de5ec1b

Browse files
Merge commit 'ebad1d975ae31bdb3d50a786bd89a21a6a4f24d8'
2 parents 77c9289 + ebad1d9 commit de5ec1b

File tree

11 files changed

+228
-109
lines changed

11 files changed

+228
-109
lines changed

include/triton/Analysis/Utility.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,6 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy);
252252
// warps, and possibly blocks.
253253
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);
254254

255-
bool atomicNeedsSharedMemory(Value result);
256-
257255
// Check if MFMA layout can be converted to the dot operand
258256
// layout using warp shuffle.
259257
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,13 @@ SmallVector<Value> inlineRegion(RewriterBase &rewriter, Region &region,
655655
mlir::TypeID::get<TerminatorOp>(), loc);
656656
}
657657

658+
void finalizeTensorAtomicResults(Operation *op, RankedTensorType tensorTy,
659+
ConversionPatternRewriter &rewriter,
660+
SmallVector<Value> &resultVals,
661+
Type valueElemTy, TritonLLVMOpBuilder &b,
662+
Value threadPred,
663+
const TargetInfoBase &targetInfo,
664+
const LLVMTypeConverter *typeConverter);
658665
} // namespace mlir
659666

660667
#endif

lib/Analysis/Allocation.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,26 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
9292
return repShape;
9393
}
9494

95-
// Both `atomic_cas` and `atomic_rmw need a single scratch element if returning
96-
// a scalar value because Triton's block-based programming model ensures that
97-
// all threads in each block see the same return value, even those threads that
98-
// do not participate in the atomic operation
95+
// Both `atomic_cas` and `atomic_rmw` may need scratch memory to store values
96+
// because Triton's block-based programming model ensures that
97+
// all threads sharing the same partition of the tensor see the same values,
98+
// even for threads that do not participate in the atomic operation
9999
static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
100100
SmallVector<unsigned> smemShape;
101-
if (atomicNeedsSharedMemory(result)) {
102-
smemShape.push_back(1);
101+
if (!result.use_empty()) {
102+
if (auto tensorTy = dyn_cast<RankedTensorType>(result.getType())) {
103+
auto freeVariableMasks =
104+
gpu::toLinearLayout(tensorTy).getFreeVariableMasks();
105+
if (llvm::any_of(freeVariableMasks, [](auto variableMask) {
106+
return variableMask.second != 0;
107+
})) {
108+
// The tensor has broadcasted dimensions
109+
smemShape = gpu::getShapePerCTATile(tensorTy);
110+
}
111+
} else {
112+
// If the result is a scalar, we need to allocate a single element.
113+
smemShape.push_back(1);
114+
}
103115
}
104116
return smemShape;
105117
}
@@ -211,15 +223,11 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
211223
}
212224
if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
213225
auto value = op->getOperand(0);
214-
// only scalar requires scratch memory
215-
// make it explicit for readability
216-
if (dyn_cast<RankedTensorType>(value.getType())) {
217-
return 0;
218-
}
219226
auto smemShape = getRepShapeForAtomic(op->getResult(0));
220227
auto elems = getNumScratchElements(smemShape);
221-
auto elemTy = cast<PointerType>(value.getType()).getPointeeType();
222-
assert(!isa<PointerType>(elemTy) && "unexpected pointer type");
228+
if (elems == 0)
229+
return 0;
230+
auto elemTy = getElementTypeOrSelf(getPointeeType(value.getType()));
223231
return elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
224232
}
225233
if (isa<ttng::TensormapCreateOp>(op)) {

lib/Analysis/Utility.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -804,13 +804,6 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
804804
!matchMFMAAndDotOperandShuffleCase(srcTy, dstTy);
805805
}
806806

807-
bool atomicNeedsSharedMemory(Value value) {
808-
auto type = value.getType();
809-
if (isa<RankedTensorType>(type) || value.use_empty())
810-
return false;
811-
return true;
812-
}
813-
814807
namespace {
815808

816809
/// A data structure similar to SetVector but maintains

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ largestVectorisation(MLIRContext *ctx, const LinearLayout &cvt, int bitwidth,
490490
}
491491
return {v, permutation};
492492
}
493-
llvm_unreachable("No vectorisation found");
493+
llvm_unreachable("Vectorization < 1 is not valid");
494494
}
495495
} // namespace
496496

@@ -558,8 +558,9 @@ SmallVector<Value> lowerLdSt(
558558
}
559559

560560
auto tile = LinearLayout::identity1D(elemsPerVec, kReg, kOffset);
561-
auto quot = *divideLeft(cvt, tile);
562-
LinearLayout reps = zerosLike(tile) * quot;
561+
auto quot = divideLeft(cvt, tile);
562+
assert(quot.has_value() && "cvt must be divisible by tile");
563+
LinearLayout reps = zerosLike(tile) * *quot;
563564

564565
auto [nAdditive, permStrides] =
565566
actionAdditiveStrides(reps, maskSpanAffineOffset);
@@ -1892,4 +1893,69 @@ SmallVector<Value> inlineRegionImpl(RewriterBase &rewriter, Region &region,
18921893
return vals;
18931894
}
18941895

1896+
void finalizeTensorAtomicResults(Operation *op, RankedTensorType tensorTy,
1897+
ConversionPatternRewriter &rewriter,
1898+
SmallVector<Value> &resultVals,
1899+
Type valueElemTy, TritonLLVMOpBuilder &b,
1900+
Value threadPred,
1901+
const TargetInfoBase &targetInfo,
1902+
const LLVMTypeConverter *typeConverter) {
1903+
auto *ctx = rewriter.getContext();
1904+
auto loc = op->getLoc();
1905+
Type structTy = typeConverter->convertType(tensorTy);
1906+
if (!op->hasAttr("allocation.offset")) {
1907+
// No broadcasting, just pack the values into a struct
1908+
Value resultStruct =
1909+
packLLElements(loc, typeConverter, resultVals, rewriter, structTy);
1910+
rewriter.replaceOp(op, {resultStruct});
1911+
return;
1912+
}
1913+
1914+
auto dstLayout = triton::gpu::toLinearLayout(tensorTy);
1915+
auto kReg = str_attr("register");
1916+
auto kLane = str_attr("lane");
1917+
auto kWarp = str_attr("warp");
1918+
dstLayout = dstLayout.sublayout({kReg, kLane, kWarp},
1919+
llvm::to_vector(dstLayout.getOutDimNames()));
1920+
dstLayout = dstLayout.reshapeOuts(
1921+
{{str_attr("offset"), dstLayout.getTotalOutDimSize()}});
1922+
auto smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op);
1923+
1924+
auto emitSt = [&](ConversionPatternRewriter &rewriter, Location loc,
1925+
ArrayRef<Value> vals, Value shmemAddr, int idx,
1926+
VectorType vecTy) -> SmallVector<Value> {
1927+
auto length = vecTy.getNumElements();
1928+
Value valsVec =
1929+
packLLVector(loc, ArrayRef<Value>(vals).slice(idx, length), rewriter);
1930+
targetInfo.storeDShared(rewriter, loc, shmemAddr, std::nullopt, valsVec,
1931+
threadPred);
1932+
return {};
1933+
};
1934+
1935+
auto emitLd = [&](ConversionPatternRewriter &rewriter, Location loc,
1936+
ArrayRef<Value> vals, Value shmemAddr, int idx,
1937+
VectorType vecTy) -> SmallVector<Value> {
1938+
Value loadedVec = targetInfo.loadDShared(rewriter, loc, shmemAddr,
1939+
std::nullopt, vecTy, b.true_val());
1940+
return unpackLLVector(loc, loadedVec, rewriter);
1941+
};
1942+
1943+
auto noPaddingOffset = [](Value v) { return v; };
1944+
lowerLdSt(loc, ctx, dstLayout, resultVals, valueElemTy, smemBase,
1945+
/*calcPaddedOffset=*/noPaddingOffset, /*affineOffset=*/b.i32_val(0),
1946+
/*maskSpanAffineOffset=*/0, rewriter, targetInfo,
1947+
/*maybeMaxVecElems=*/{}, emitSt);
1948+
b.barrier();
1949+
resultVals = lowerLdSt(loc, ctx, dstLayout, resultVals, valueElemTy, smemBase,
1950+
/*calcPaddedOffset=*/noPaddingOffset,
1951+
/*affineOffset=*/b.i32_val(0),
1952+
/*maskSpanAffineOffset=*/0, rewriter, targetInfo,
1953+
/*maybeMaxVecElems=*/{}, emitLd);
1954+
1955+
// Create the result struct and replace the operation
1956+
Value resultStruct =
1957+
packLLElements(loc, typeConverter, resultVals, rewriter, structTy);
1958+
rewriter.replaceOp(op, {resultStruct});
1959+
}
1960+
18951961
} // namespace mlir

python/test/unit/language/test_core.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2064,6 +2064,36 @@ def kernel(I, O):
20642064
kernel[(1, )](I, O)
20652065

20662066

2067+
@pytest.mark.interpreter
2068+
@pytest.mark.parametrize("dtype_str", ["int32", "float16"])
2069+
@pytest.mark.parametrize("size", [1, 4, 16])
2070+
@pytest.mark.parametrize("op", ["add", "cas"])
2071+
def test_tensor_atomic_use_result(dtype_str, size, op, device):
2072+
if is_hip():
2073+
pytest.skip(
2074+
"HIP is broken because (1) it doesn't support thread predicate in atomic cas, and (2) it doesn't support"
2075+
" atomic rmw with float16")
2076+
2077+
@triton.jit
2078+
def kernel(index_ptr, out_ptr, size: tl.constexpr, op: tl.constexpr):
2079+
if op == "add":
2080+
write_index = tl.atomic_add(index_ptr + tl.arange(0, size)[:, None], val=tl.arange(0, size)[:, None],
2081+
sem="relaxed")
2082+
elif op == "cas":
2083+
write_index = tl.atomic_cas(
2084+
index_ptr + tl.arange(0, size)[:, None],
2085+
cmp=tl.zeros((size, ), dtype=index_ptr.dtype.element_ty)[:, None],
2086+
val=tl.arange(0, size).to(index_ptr.dtype.element_ty)[:, None],
2087+
sem="relaxed",
2088+
)
2089+
tl.store(out_ptr + write_index.to(tl.uint32) * size + tl.arange(0, size)[None, :], 5)
2090+
2091+
index = torch.arange(0, size, device=device).to(dtype=getattr(torch, dtype_str))
2092+
out = torch.zeros((size, size), device=device, dtype=getattr(torch, dtype_str))
2093+
kernel[(1, )](index, out, size, op)
2094+
assert (out == 5).all()
2095+
2096+
20672097
# ---------------
20682098
# test cast
20692099
# ---------------

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
260260
// CHECK-LABEL: atomic_runtime_lds_reduction
261261
tt.func @atomic_runtime_lds_reduction(%arg0 : tensor<64x!tt.ptr<f32>, #blocked5>, %arg2 : tensor<64xf32, #blocked5>) {
262262

263-
// CHECK: llvm.zext
264263
// CHECK-COUNT-7: rocdl.update.dpp
265264
// CHECK: llvm.bitcast
266265
// CHECK-COUNT: llvm.amdgcqn.ds.permute

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,9 +1449,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
14491449
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} {
14501450
// CHECK-LABEL: atomic_add_f32
14511451
tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
1452-
// CHECK: llvm.inline_asm
1452+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;
14531453
// CHECK-SAME: @$3 atom.global.gpu.relaxed.add.f32
1454-
// CHECK: llvm.inline_asm
1454+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;
14551455
// CHECK-SAME: @$3 atom.global.gpu.relaxed.add.f32
14561456
%0 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
14571457
tt.return
@@ -1488,6 +1488,36 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.tar
14881488

14891489
// -----
14901490

1491+
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
1492+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} {
1493+
// CHECK-LABEL: atomic_add_use_result_broadcasting
1494+
tt.func @atomic_add_use_result_broadcasting(%arg0 : tensor<16x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<16xi1, #blocked0>, %arg2 : tensor<16xf32, #blocked0>) {
1495+
%0 = tt.atomic_rmw fadd, relaxed, sys, %arg0, %arg2, %arg1 : (tensor<16x!tt.ptr<f32>, #blocked0>, tensor<16xf32, #blocked0>, tensor<16xi1, #blocked0>) -> tensor<16xf32, #blocked0>
1496+
// CHECK: st.shared
1497+
// CHECK: nvvm.barrier0
1498+
// CHECK: llvm.load
1499+
tt.store %arg0, %0 : tensor<16x!tt.ptr<f32>, #blocked0>
1500+
tt.return
1501+
}
1502+
}
1503+
1504+
// -----
1505+
1506+
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
1507+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} {
1508+
// CHECK-LABEL: atomic_add_use_result_no_broadcasting
1509+
tt.func @atomic_add_use_result_no_broadcasting(%arg0 : tensor<128x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<128xi1, #blocked0>, %arg2 : tensor<128xf32, #blocked0>) {
1510+
%0 = tt.atomic_rmw fadd, relaxed, sys, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0>
1511+
// CHECK-NOT: st.shared
1512+
// CHECK-NOT: nvvm.barrier0
1513+
// CHECK-NOT: llvm.load
1514+
tt.store %arg0, %0 : tensor<128x!tt.ptr<f32>, #blocked0>
1515+
tt.return
1516+
}
1517+
}
1518+
1519+
// -----
1520+
14911521
#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
14921522
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {
14931523
tt.func public @atomic_add_f16_nomask(%dest_ptrs: tensor<256x!tt.ptr<f16>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>) attributes {noinline = false} {

test/TritonGPU/atomic-cas.mlir

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
1-
// RUN: triton-opt %s -convert-triton-to-tritongpu=target=cuda:80 2>&1 | FileCheck %s --check-prefix=GPU
2-
// RUN: triton-opt %s -convert-triton-to-tritongpu=target=cuda:80 -convert-triton-gpu-to-llvm 2>&1 | FileCheck %s --check-prefix=LLVM
1+
// RUN: triton-opt %s -convert-triton-gpu-to-llvm 2>&1 | FileCheck %s
32

4-
// GPU: %9 = tt.atomic_cas acq_rel, cta, %8, %cst_0, %cst : (tensor<2x!tt.ptr<i64>, #blocked>, tensor<2xi64, #blocked>, tensor<2xi64, #blocked>) -> tensor<2xi64, #blocked>
5-
// LLVM: llvm.inline_asm {{.*}} "mov.u64 $0, 0x0;\0A\09@$4 atom.global.acq_rel.cta.cas.b64 $0, [ $1 + 0 ], $2, $3;", "=l,l,l,l,b"
3+
// CHECK: llvm.inline_asm {{.*}} "mov.u64 $0, 0x0;\0A\09@$4 atom.global.acq_rel.cta.cas.b64 $0, [ $1 + 0 ], $2, $3;", "=l,l,l,l,b"
4+
// CHECK: st.shared
5+
// CHECK: nvvm.barrier0
6+
// CHECK: llvm.load
67

7-
module {
8+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
9+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {
810
tt.func public @atomic_cas_kernel_0d1d2e(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
9-
%cst = arith.constant dense<2> : tensor<2xi64>
10-
%cst_0 = arith.constant dense<1> : tensor<2xi64>
11+
%cst = arith.constant dense<2> : tensor<2xi64, #blocked>
12+
%cst_0 = arith.constant dense<1> : tensor<2xi64, #blocked>
1113
%c2_i32 = arith.constant 2 : i32
1214
%0 = tt.get_program_id x : i32
1315
%1 = arith.muli %0, %c2_i32 : i32
14-
%2 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32>
15-
%3 = tt.splat %1 : i32 -> tensor<2xi32>
16-
%4 = arith.addi %3, %2 : tensor<2xi32>
17-
%5 = tt.splat %arg2 : i32 -> tensor<2xi32>
18-
%6 = arith.cmpi slt, %4, %5 : tensor<2xi32>
19-
%7 = tt.splat %arg0 : !tt.ptr<i64> -> tensor<2x!tt.ptr<i64>>
20-
%8 = tt.addptr %7, %4 : tensor<2x!tt.ptr<i64>>, tensor<2xi32>
21-
%9 = tt.atomic_cas acq_rel, cta, %8, %cst_0, %cst : (tensor<2x!tt.ptr<i64>>, tensor<2xi64>, tensor<2xi64>) -> tensor<2xi64>
22-
%10 = tt.splat %arg1 : !tt.ptr<i64> -> tensor<2x!tt.ptr<i64>>
23-
%11 = tt.addptr %10, %4 : tensor<2x!tt.ptr<i64>>, tensor<2xi32>
24-
tt.store %11, %9, %6 : tensor<2x!tt.ptr<i64>>
16+
%2 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #blocked>
17+
%3 = tt.splat %1 : i32 -> tensor<2xi32, #blocked>
18+
%4 = arith.addi %3, %2 : tensor<2xi32, #blocked>
19+
%5 = tt.splat %arg2 : i32 -> tensor<2xi32, #blocked>
20+
%6 = arith.cmpi slt, %4, %5 : tensor<2xi32, #blocked>
21+
%7 = tt.splat %arg0 : !tt.ptr<i64> -> tensor<2x!tt.ptr<i64>, #blocked>
22+
%8 = tt.addptr %7, %4 : tensor<2x!tt.ptr<i64>, #blocked>, tensor<2xi32, #blocked>
23+
%9 = tt.atomic_cas acq_rel, cta, %8, %cst_0, %cst {allocation.offset = 0 : i32} : (tensor<2x!tt.ptr<i64>, #blocked>, tensor<2xi64, #blocked>, tensor<2xi64, #blocked>) -> tensor<2xi64, #blocked>
24+
%10 = tt.splat %arg1 : !tt.ptr<i64> -> tensor<2x!tt.ptr<i64>, #blocked>
25+
%11 = tt.addptr %10, %4 : tensor<2x!tt.ptr<i64>, #blocked>, tensor<2xi32, #blocked>
26+
tt.store %11, %9, %6 : tensor<2x!tt.ptr<i64>, #blocked>
2527
tt.return
2628
}
2729
}

0 commit comments

Comments
 (0)