Skip to content

Commit ebad1d9

Browse files
wenqinyJokeren
andauthored
[BACKEND] broadcast the result for atomic rmw op when necessary (#7460)
## Summary Fix #7402 This PR try to broadcast the result of `tl.atomic_add`, the SASS code looks like: ``` @p1 BRA LBB0; <-- Only P1 of thread 0 is True ... @p1 ATOMG.E.ADD.STRONG.GPU PT, R3, [R2.64], R7; <-- atomic add here ... SHFL.IDX PT, R0, R3, R4, 0x1f; <-- shfl only works for thread 0, so it didn't do anything. LBB0: <-- Other threads came here to wait thread 0. BSYNC B0; ``` ## Potential solution We could try to manually broadcast the result for `tl.atomic_add` at `LBB0:`, just like add some broadcast logic at here: https://github.com/triton-lang/triton/blob/1ab4bb4a96b3561504110549d21398ba58e42a76/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp#L1093 ## <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: Jokeren <[email protected]> Co-authored-by: Keren Zhou <[email protected]>
1 parent 8cb3a83 commit ebad1d9

File tree

11 files changed

+228
-109
lines changed

11 files changed

+228
-109
lines changed

include/triton/Analysis/Utility.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,6 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy);
252252
// warps, and possibly blocks.
253253
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);
254254

255-
bool atomicNeedsSharedMemory(Value result);
256-
257255
// Check if MFMA layout can be converted to the dot operand
258256
// layout using warp shuffle.
259257
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,13 @@ SmallVector<Value> inlineRegion(RewriterBase &rewriter, Region &region,
662662
mlir::TypeID::get<TerminatorOp>(), loc);
663663
}
664664

665+
void finalizeTensorAtomicResults(Operation *op, RankedTensorType tensorTy,
666+
ConversionPatternRewriter &rewriter,
667+
SmallVector<Value> &resultVals,
668+
Type valueElemTy, TritonLLVMOpBuilder &b,
669+
Value threadPred,
670+
const TargetInfoBase &targetInfo,
671+
const LLVMTypeConverter *typeConverter);
665672
} // namespace mlir
666673

667674
#endif

lib/Analysis/Allocation.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,26 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
9292
return repShape;
9393
}
9494

95-
// Both `atomic_cas` and `atomic_rmw need a single scratch element if returning
96-
// a scalar value because Triton's block-based programming model ensures that
97-
// all threads in each block see the same return value, even those threads that
98-
// do not participate in the atomic operation
95+
// Both `atomic_cas` and `atomic_rmw` may need scratch memory to store values
96+
// because Triton's block-based programming model ensures that
97+
// all threads sharing the same partition of the tensor see the same values,
98+
// even for threads that do not participate in the atomic operation
9999
static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
100100
SmallVector<unsigned> smemShape;
101-
if (atomicNeedsSharedMemory(result)) {
102-
smemShape.push_back(1);
101+
if (!result.use_empty()) {
102+
if (auto tensorTy = dyn_cast<RankedTensorType>(result.getType())) {
103+
auto freeVariableMasks =
104+
gpu::toLinearLayout(tensorTy).getFreeVariableMasks();
105+
if (llvm::any_of(freeVariableMasks, [](auto variableMask) {
106+
return variableMask.second != 0;
107+
})) {
108+
// The tensor has broadcasted dimensions
109+
smemShape = gpu::getShapePerCTATile(tensorTy);
110+
}
111+
} else {
112+
// If the result is a scalar, we need to allocate a single element.
113+
smemShape.push_back(1);
114+
}
103115
}
104116
return smemShape;
105117
}
@@ -211,15 +223,11 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
211223
}
212224
if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
213225
auto value = op->getOperand(0);
214-
// only scalar requires scratch memory
215-
// make it explicit for readability
216-
if (dyn_cast<RankedTensorType>(value.getType())) {
217-
return 0;
218-
}
219226
auto smemShape = getRepShapeForAtomic(op->getResult(0));
220227
auto elems = getNumScratchElements(smemShape);
221-
auto elemTy = cast<PointerType>(value.getType()).getPointeeType();
222-
assert(!isa<PointerType>(elemTy) && "unexpected pointer type");
228+
if (elems == 0)
229+
return 0;
230+
auto elemTy = getElementTypeOrSelf(getPointeeType(value.getType()));
223231
return elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
224232
}
225233
if (isa<ttng::TensormapCreateOp>(op)) {

lib/Analysis/Utility.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -797,13 +797,6 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
797797
!matchMFMAAndDotOperandShuffleCase(srcTy, dstTy);
798798
}
799799

800-
bool atomicNeedsSharedMemory(Value value) {
801-
auto type = value.getType();
802-
if (isa<RankedTensorType>(type) || value.use_empty())
803-
return false;
804-
return true;
805-
}
806-
807800
namespace {
808801

809802
/// A data structure similar to SetVector but maintains

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ largestVectorisation(MLIRContext *ctx, const LinearLayout &cvt, int bitwidth,
470470
}
471471
return {v, permutation};
472472
}
473-
llvm_unreachable("No vectorisation found");
473+
llvm_unreachable("Vectorization < 1 is not valid");
474474
}
475475
} // namespace
476476

@@ -538,8 +538,9 @@ SmallVector<Value> lowerLdSt(
538538
}
539539

540540
auto tile = LinearLayout::identity1D(elemsPerVec, kReg, kOffset);
541-
auto quot = *divideLeft(cvt, tile);
542-
LinearLayout reps = zerosLike(tile) * quot;
541+
auto quot = divideLeft(cvt, tile);
542+
assert(quot.has_value() && "cvt must be divisible by tile");
543+
LinearLayout reps = zerosLike(tile) * *quot;
543544

544545
auto [nAdditive, permStrides] =
545546
actionAdditiveStrides(reps, maskSpanAffineOffset);
@@ -2020,4 +2021,69 @@ SmallVector<Value> inlineRegionImpl(RewriterBase &rewriter, Region &region,
20202021
return vals;
20212022
}
20222023

2024+
void finalizeTensorAtomicResults(Operation *op, RankedTensorType tensorTy,
2025+
ConversionPatternRewriter &rewriter,
2026+
SmallVector<Value> &resultVals,
2027+
Type valueElemTy, TritonLLVMOpBuilder &b,
2028+
Value threadPred,
2029+
const TargetInfoBase &targetInfo,
2030+
const LLVMTypeConverter *typeConverter) {
2031+
auto *ctx = rewriter.getContext();
2032+
auto loc = op->getLoc();
2033+
Type structTy = typeConverter->convertType(tensorTy);
2034+
if (!op->hasAttr("allocation.offset")) {
2035+
// No broadcasting, just pack the values into a struct
2036+
Value resultStruct =
2037+
packLLElements(loc, typeConverter, resultVals, rewriter, structTy);
2038+
rewriter.replaceOp(op, {resultStruct});
2039+
return;
2040+
}
2041+
2042+
auto dstLayout = triton::gpu::toLinearLayout(tensorTy);
2043+
auto kReg = str_attr("register");
2044+
auto kLane = str_attr("lane");
2045+
auto kWarp = str_attr("warp");
2046+
dstLayout = dstLayout.sublayout({kReg, kLane, kWarp},
2047+
llvm::to_vector(dstLayout.getOutDimNames()));
2048+
dstLayout = dstLayout.reshapeOuts(
2049+
{{str_attr("offset"), dstLayout.getTotalOutDimSize()}});
2050+
auto smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op);
2051+
2052+
auto emitSt = [&](ConversionPatternRewriter &rewriter, Location loc,
2053+
ArrayRef<Value> vals, Value shmemAddr, int idx,
2054+
VectorType vecTy) -> SmallVector<Value> {
2055+
auto length = vecTy.getNumElements();
2056+
Value valsVec =
2057+
packLLVector(loc, ArrayRef<Value>(vals).slice(idx, length), rewriter);
2058+
targetInfo.storeDShared(rewriter, loc, shmemAddr, std::nullopt, valsVec,
2059+
threadPred);
2060+
return {};
2061+
};
2062+
2063+
auto emitLd = [&](ConversionPatternRewriter &rewriter, Location loc,
2064+
ArrayRef<Value> vals, Value shmemAddr, int idx,
2065+
VectorType vecTy) -> SmallVector<Value> {
2066+
Value loadedVec = targetInfo.loadDShared(rewriter, loc, shmemAddr,
2067+
std::nullopt, vecTy, b.true_val());
2068+
return unpackLLVector(loc, loadedVec, rewriter);
2069+
};
2070+
2071+
auto noPaddingOffset = [](Value v) { return v; };
2072+
lowerLdSt(loc, ctx, dstLayout, resultVals, valueElemTy, smemBase,
2073+
/*calcPaddedOffset=*/noPaddingOffset, /*affineOffset=*/b.i32_val(0),
2074+
/*maskSpanAffineOffset=*/0, rewriter, targetInfo,
2075+
/*maybeMaxVecElems=*/{}, emitSt);
2076+
b.barrier();
2077+
resultVals = lowerLdSt(loc, ctx, dstLayout, resultVals, valueElemTy, smemBase,
2078+
/*calcPaddedOffset=*/noPaddingOffset,
2079+
/*affineOffset=*/b.i32_val(0),
2080+
/*maskSpanAffineOffset=*/0, rewriter, targetInfo,
2081+
/*maybeMaxVecElems=*/{}, emitLd);
2082+
2083+
// Create the result struct and replace the operation
2084+
Value resultStruct =
2085+
packLLElements(loc, typeConverter, resultVals, rewriter, structTy);
2086+
rewriter.replaceOp(op, {resultStruct});
2087+
}
2088+
20232089
} // namespace mlir

python/test/unit/language/test_core.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,6 +2003,36 @@ def kernel(I, O):
20032003
kernel[(1, )](I, O)
20042004

20052005

2006+
@pytest.mark.interpreter
2007+
@pytest.mark.parametrize("dtype_str", ["int32", "float16"])
2008+
@pytest.mark.parametrize("size", [1, 4, 16])
2009+
@pytest.mark.parametrize("op", ["add", "cas"])
2010+
def test_tensor_atomic_use_result(dtype_str, size, op, device):
2011+
if is_hip():
2012+
pytest.skip(
2013+
"HIP is broken because (1) it doesn't support thread predicate in atomic cas, and (2) it doesn't support"
2014+
" atomic rmw with float16")
2015+
2016+
@triton.jit
2017+
def kernel(index_ptr, out_ptr, size: tl.constexpr, op: tl.constexpr):
2018+
if op == "add":
2019+
write_index = tl.atomic_add(index_ptr + tl.arange(0, size)[:, None], val=tl.arange(0, size)[:, None],
2020+
sem="relaxed")
2021+
elif op == "cas":
2022+
write_index = tl.atomic_cas(
2023+
index_ptr + tl.arange(0, size)[:, None],
2024+
cmp=tl.zeros((size, ), dtype=index_ptr.dtype.element_ty)[:, None],
2025+
val=tl.arange(0, size).to(index_ptr.dtype.element_ty)[:, None],
2026+
sem="relaxed",
2027+
)
2028+
tl.store(out_ptr + write_index.to(tl.uint32) * size + tl.arange(0, size)[None, :], 5)
2029+
2030+
index = torch.arange(0, size, device=device).to(dtype=getattr(torch, dtype_str))
2031+
out = torch.zeros((size, size), device=device, dtype=getattr(torch, dtype_str))
2032+
kernel[(1, )](index, out, size, op)
2033+
assert (out == 5).all()
2034+
2035+
20062036
# ---------------
20072037
# test cast
20082038
# ---------------

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
260260
// CHECK-LABEL: atomic_runtime_lds_reduction
261261
tt.func @atomic_runtime_lds_reduction(%arg0 : tensor<64x!tt.ptr<f32>, #blocked5>, %arg2 : tensor<64xf32, #blocked5>) {
262262

263-
// CHECK: llvm.zext
264263
// CHECK-COUNT-7: rocdl.update.dpp
265264
// CHECK: llvm.bitcast
266265
// CHECK-COUNT: llvm.amdgcqn.ds.permute

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,9 +1449,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
14491449
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} {
14501450
// CHECK-LABEL: atomic_add_f32
14511451
tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
1452-
// CHECK: llvm.inline_asm
1452+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;
14531453
// CHECK-SAME: @$3 atom.global.gpu.relaxed.add.f32
1454-
// CHECK: llvm.inline_asm
1454+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;
14551455
// CHECK-SAME: @$3 atom.global.gpu.relaxed.add.f32
14561456
%0 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
14571457
tt.return
@@ -1488,6 +1488,36 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.tar
14881488

14891489
// -----
14901490

1491+
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
1492+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} {
1493+
// CHECK-LABEL: atomic_add_use_result_broadcasting
1494+
tt.func @atomic_add_use_result_broadcasting(%arg0 : tensor<16x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<16xi1, #blocked0>, %arg2 : tensor<16xf32, #blocked0>) {
1495+
%0 = tt.atomic_rmw fadd, relaxed, sys, %arg0, %arg2, %arg1 : (tensor<16x!tt.ptr<f32>, #blocked0>, tensor<16xf32, #blocked0>, tensor<16xi1, #blocked0>) -> tensor<16xf32, #blocked0>
1496+
// CHECK: st.shared
1497+
// CHECK: nvvm.barrier0
1498+
// CHECK: llvm.load
1499+
tt.store %arg0, %0 : tensor<16x!tt.ptr<f32>, #blocked0>
1500+
tt.return
1501+
}
1502+
}
1503+
1504+
// -----
1505+
1506+
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
1507+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} {
1508+
// CHECK-LABEL: atomic_add_use_result_no_broadcasting
1509+
tt.func @atomic_add_use_result_no_broadcasting(%arg0 : tensor<128x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<128xi1, #blocked0>, %arg2 : tensor<128xf32, #blocked0>) {
1510+
%0 = tt.atomic_rmw fadd, relaxed, sys, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0>
1511+
// CHECK-NOT: st.shared
1512+
// CHECK-NOT: nvvm.barrier0
1513+
// CHECK-NOT: llvm.load
1514+
tt.store %arg0, %0 : tensor<128x!tt.ptr<f32>, #blocked0>
1515+
tt.return
1516+
}
1517+
}
1518+
1519+
// -----
1520+
14911521
#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
14921522
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {
14931523
tt.func public @atomic_add_f16_nomask(%dest_ptrs: tensor<256x!tt.ptr<f16>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>) attributes {noinline = false} {

test/TritonGPU/atomic-cas.mlir

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
1-
// RUN: triton-opt %s -convert-triton-to-tritongpu=target=cuda:80 2>&1 | FileCheck %s --check-prefix=GPU
2-
// RUN: triton-opt %s -convert-triton-to-tritongpu=target=cuda:80 -convert-triton-gpu-to-llvm 2>&1 | FileCheck %s --check-prefix=LLVM
1+
// RUN: triton-opt %s -convert-triton-gpu-to-llvm 2>&1 | FileCheck %s
32

4-
// GPU: %9 = tt.atomic_cas acq_rel, cta, %8, %cst_0, %cst : (tensor<2x!tt.ptr<i64>, #blocked>, tensor<2xi64, #blocked>, tensor<2xi64, #blocked>) -> tensor<2xi64, #blocked>
5-
// LLVM: llvm.inline_asm {{.*}} "mov.u64 $0, 0x0;\0A\09@$4 atom.global.acq_rel.cta.cas.b64 $0, [ $1 + 0 ], $2, $3;", "=l,l,l,l,b"
3+
// CHECK: llvm.inline_asm {{.*}} "mov.u64 $0, 0x0;\0A\09@$4 atom.global.acq_rel.cta.cas.b64 $0, [ $1 + 0 ], $2, $3;", "=l,l,l,l,b"
4+
// CHECK: st.shared
5+
// CHECK: nvvm.barrier0
6+
// CHECK: llvm.load
67

7-
module {
8+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
9+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {
810
tt.func public @atomic_cas_kernel_0d1d2e(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
9-
%cst = arith.constant dense<2> : tensor<2xi64>
10-
%cst_0 = arith.constant dense<1> : tensor<2xi64>
11+
%cst = arith.constant dense<2> : tensor<2xi64, #blocked>
12+
%cst_0 = arith.constant dense<1> : tensor<2xi64, #blocked>
1113
%c2_i32 = arith.constant 2 : i32
1214
%0 = tt.get_program_id x : i32
1315
%1 = arith.muli %0, %c2_i32 : i32
14-
%2 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32>
15-
%3 = tt.splat %1 : i32 -> tensor<2xi32>
16-
%4 = arith.addi %3, %2 : tensor<2xi32>
17-
%5 = tt.splat %arg2 : i32 -> tensor<2xi32>
18-
%6 = arith.cmpi slt, %4, %5 : tensor<2xi32>
19-
%7 = tt.splat %arg0 : !tt.ptr<i64> -> tensor<2x!tt.ptr<i64>>
20-
%8 = tt.addptr %7, %4 : tensor<2x!tt.ptr<i64>>, tensor<2xi32>
21-
%9 = tt.atomic_cas acq_rel, cta, %8, %cst_0, %cst : (tensor<2x!tt.ptr<i64>>, tensor<2xi64>, tensor<2xi64>) -> tensor<2xi64>
22-
%10 = tt.splat %arg1 : !tt.ptr<i64> -> tensor<2x!tt.ptr<i64>>
23-
%11 = tt.addptr %10, %4 : tensor<2x!tt.ptr<i64>>, tensor<2xi32>
24-
tt.store %11, %9, %6 : tensor<2x!tt.ptr<i64>>
16+
%2 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #blocked>
17+
%3 = tt.splat %1 : i32 -> tensor<2xi32, #blocked>
18+
%4 = arith.addi %3, %2 : tensor<2xi32, #blocked>
19+
%5 = tt.splat %arg2 : i32 -> tensor<2xi32, #blocked>
20+
%6 = arith.cmpi slt, %4, %5 : tensor<2xi32, #blocked>
21+
%7 = tt.splat %arg0 : !tt.ptr<i64> -> tensor<2x!tt.ptr<i64>, #blocked>
22+
%8 = tt.addptr %7, %4 : tensor<2x!tt.ptr<i64>, #blocked>, tensor<2xi32, #blocked>
23+
%9 = tt.atomic_cas acq_rel, cta, %8, %cst_0, %cst {allocation.offset = 0 : i32} : (tensor<2x!tt.ptr<i64>, #blocked>, tensor<2xi64, #blocked>, tensor<2xi64, #blocked>) -> tensor<2xi64, #blocked>
24+
%10 = tt.splat %arg1 : !tt.ptr<i64> -> tensor<2x!tt.ptr<i64>, #blocked>
25+
%11 = tt.addptr %10, %4 : tensor<2x!tt.ptr<i64>, #blocked>, tensor<2xi32, #blocked>
26+
tt.store %11, %9, %6 : tensor<2x!tt.ptr<i64>, #blocked>
2527
tt.return
2628
}
2729
}

0 commit comments

Comments
 (0)