Skip to content

Commit 097a106

Browse files
Revert "[BACKEND] Disable vectorization for atomic_cas on all backends (#7711)"
This reverts commit a372a80.
1 parent b7eda0b commit 097a106

File tree

3 files changed

+75
-33
lines changed

3 files changed

+75
-33
lines changed

python/test/unit/language/test_core.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1969,33 +1969,24 @@ def serialized_add(data, Lock, SEM: tl.constexpr):
19691969

19701970

19711971
@pytest.mark.interpreter
1972-
@pytest.mark.parametrize("sem", [None, "acquire", "release", "acq_rel", "relaxed"])
1972+
@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed'])
19731973
@pytest.mark.parametrize("num_ctas", num_ctas_list)
1974-
@pytest.mark.parametrize("size", [4, 128, 512])
1975-
@pytest.mark.parametrize("dtype_str", ['bfloat16', 'float16', 'float32', 'uint64', 'int64', 'float64'])
1976-
def test_tensor_atomic_cas(sem, size, dtype_str, num_ctas, device):
1977-
check_type_supported(dtype_str, device)
1978-
if "float" in dtype_str and is_hip():
1979-
pytest.skip("HIP does not support atomic cas with float types")
1974+
def test_tensor_atomic_cas(sem, num_ctas, device):
19801975

19811976
@triton.jit
1982-
def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr, dtype: tl.constexpr):
1977+
def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr):
19831978
pid = tl.program_id(axis=0)
19841979
block_start = pid * BLOCK_SIZE
19851980
offsets = block_start + tl.arange(0, BLOCK_SIZE)
1986-
t1 = tl.full((BLOCK_SIZE, ), 0, dtype=dtype)
1987-
t2 = tl.full((BLOCK_SIZE, ), 2, dtype=dtype)
1981+
t1 = tl.full((BLOCK_SIZE, ), 0, dtype=tl.int64)
1982+
t2 = tl.full((BLOCK_SIZE, ), 2, dtype=tl.int64)
19881983
tl.atomic_cas(X + offsets, t1, t2, sem=sem)
19891984

1990-
torch_dtype = getattr(torch, dtype_str)
1991-
X = torch.zeros((size, ), device=device, dtype=torch_dtype)
1992-
X[1::2] = 1
1993-
Y = X.clone()
1994-
Y[0::2] = 2
1985+
X = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], device=device, dtype=torch.int64)
1986+
Y = torch.tensor([2, 1, 2, 1, 2, 1, 2, 1], device=device, dtype=torch.int64)
19951987

1996-
tl_dtype = getattr(tl, dtype_str)
1997-
change_value[(2, )](X, BLOCK_SIZE=size // 2, sem=sem, dtype=tl_dtype)
1998-
assert torch.equal(X, Y)
1988+
change_value[(2, )](X, 4, sem)
1989+
assert (torch.equal(X, Y))
19991990

20001991

20011992
@pytest.mark.interpreter

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,16 +1361,33 @@ struct AtomicCASOpConversion
13611361
: valueTy;
13621362
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
13631363
auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType());
1364+
// vec = 1 for scalar
1365+
auto vec = getVectorSize(op.getPtr(), axisAnalysisPass);
1366+
// tensor
1367+
if (tensorTy) {
1368+
auto valTy = cast<RankedTensorType>(op.getVal().getType());
1369+
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
1370+
}
1371+
1372+
auto vecTy = vec_ty(valueElemTy, vec);
13641373
SmallVector<Value> resultVals(elemsPerThread);
13651374

13661375
// atomic ops
1367-
for (size_t i = 0; i < elemsPerThread; i += 1) {
1368-
Value casVal = valElements[i];
1369-
Value casCmp = cmpElements[i];
1376+
for (size_t i = 0; i < elemsPerThread; i += vec) {
1377+
Value casVal = b.undef(vecTy);
1378+
for (int ii = 0; ii < vec; ++ii) {
1379+
Value iiVal = createIndexAttrConstant(
1380+
rewriter, loc, getTypeConverter()->getIndexType(), ii);
1381+
casVal = b.insert_element(vecTy, casVal, valElements[i + ii], iiVal);
1382+
}
1383+
13701384
Value casPtr = ptrElements[i];
1385+
Value casCmp = cmpElements[i];
1386+
casVal = valElements[i];
1387+
13711388
// use op
13721389
if (tensorTy) { // for tensor
1373-
auto retType = valueElemTy;
1390+
auto retType = vec == 1 ? valueElemTy : vecTy;
13741391
// TODO: USE ATOMIC CAS OP on Tensor
13751392
auto successOrdering = *atomicMemOrdering;
13761393
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
@@ -1380,7 +1397,12 @@ struct AtomicCASOpConversion
13801397

13811398
// Extract the new_loaded value from the pair.
13821399
Value ret = b.extract_val(valueElemTy, cmpxchg, i);
1383-
resultVals[i] = ret;
1400+
1401+
for (int ii = 0; ii < vec; ++ii) {
1402+
resultVals[i + ii] =
1403+
vec == 1 ? ret
1404+
: b.extract_element(valueElemTy, ret, b.i32_val(ii));
1405+
}
13841406
} else { // for scalar
13851407
// Build blocks to bypass the atomic instruction for ~rmwMask.
13861408
auto *curBlock = rewriter.getInsertionBlock();

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -611,27 +611,52 @@ struct AtomicCASOpConversion
611611
: valueTy;
612612
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
613613
auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType());
614+
// vec = 1 for scalar
615+
auto vec = getVectorSize(op.getPtr());
616+
auto vecOrig = vec;
617+
// tensor
618+
if (tensorTy) {
619+
auto valTy = cast<RankedTensorType>(op.getVal().getType());
620+
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
621+
}
622+
623+
if (vec == 1 && elemsPerThread > 1)
624+
op->emitRemark() << "Warning: vectorization fails vec = " << vec
625+
<< " origin vec = " << vecOrig
626+
<< " elemsPerThread = " << elemsPerThread << "\n";
627+
614628
auto freeVarMasks = getFreeVariableMasks(op.getPtr().getType());
615629
Value threadPred =
616630
emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo);
617631
uint32_t regMask = freeVarMasks[str_attr("reg")];
618632

633+
auto vecTy = vec_ty(valueElemTy, vec);
619634
SmallVector<Value> resultVals(elemsPerThread);
620635

621-
for (size_t i = 0; i < elemsPerThread; i += 1) {
622-
if (auto canonicalStart = getCanonicalIndex(i, regMask);
623-
canonicalStart != i) {
636+
for (size_t i = 0; i < elemsPerThread; i += vec) {
637+
if (auto canonicalVecStart = getCanonicalIndex(i, regMask);
638+
canonicalVecStart != i) {
624639
// For redundant registers, refer back to the canonical result
625-
resultVals[i] = resultVals[canonicalStart];
640+
for (auto iVec = 0; iVec < vec; ++iVec) {
641+
resultVals[i + iVec] = resultVals[canonicalVecStart + iVec];
642+
}
626643
continue;
627644
}
628645

629-
Value casVal = valElements[i];
630-
Value casCmp = cmpElements[i];
646+
Value casVal = b.undef(vecTy);
647+
for (int ii = 0; ii < vec; ++ii) {
648+
Value iiVal = createIndexAttrConstant(
649+
rewriter, loc, getTypeConverter()->getIndexType(), ii);
650+
casVal = b.insert_element(vecTy, casVal, valElements[i + ii], iiVal);
651+
}
652+
631653
Value casPtr = ptrElements[i];
654+
Value casCmp = cmpElements[i];
655+
casVal = valElements[i];
632656
PTXBuilder ptxBuilderAtomicCAS;
633-
std::string tyId =
634-
valueElemNBits == 64 ? "l" : (valueElemNBits == 32 ? "r" : "h");
657+
std::string tyId = valueElemNBits * vec == 64
658+
? "l"
659+
: (valueElemNBits * vec == 32 ? "r" : "h");
635660
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=" + tyId, /*init=*/true);
636661
auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l");
637662
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, tyId);
@@ -646,9 +671,13 @@ struct AtomicCASOpConversion
646671
atom(dstOpr, ptrOpr, cmpOpr, valOpr).maybePredicate(threadPred);
647672

648673
if (tensorTy) {
649-
auto retType = valueElemTy;
674+
auto retType = vec == 1 ? valueElemTy : vecTy;
650675
auto ret = ptxBuilderAtomicCAS.launch(rewriter, loc, retType);
651-
resultVals[i] = ret;
676+
for (int ii = 0; ii < vec; ++ii) {
677+
resultVals[i + ii] =
678+
vec == 1 ? ret
679+
: b.extract_element(valueElemTy, ret, b.i32_val(ii));
680+
}
652681
} else {
653682
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
654683
if (op.getResult().use_empty()) {

0 commit comments

Comments
 (0)