Skip to content

Commit a372a80

Browse files
authored
[BACKEND] Disable vectorization for atomic_cas on all backends (#7711)
1 parent 8d3f09e commit a372a80

File tree

3 files changed

+33
-75
lines changed

3 files changed

+33
-75
lines changed

python/test/unit/language/test_core.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1908,24 +1908,33 @@ def serialized_add(data, Lock, SEM: tl.constexpr):
19081908

19091909

19101910
@pytest.mark.interpreter
1911-
@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed'])
1911+
@pytest.mark.parametrize("sem", [None, "acquire", "release", "acq_rel", "relaxed"])
19121912
@pytest.mark.parametrize("num_ctas", num_ctas_list)
1913-
def test_tensor_atomic_cas(sem, num_ctas, device):
1913+
@pytest.mark.parametrize("size", [4, 128, 512])
1914+
@pytest.mark.parametrize("dtype_str", ['bfloat16', 'float16', 'float32', 'uint64', 'int64', 'float64'])
1915+
def test_tensor_atomic_cas(sem, size, dtype_str, num_ctas, device):
1916+
check_type_supported(dtype_str, device)
1917+
if "float" in dtype_str and is_hip():
1918+
pytest.skip("HIP does not support atomic cas with float types")
19141919

19151920
@triton.jit
1916-
def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr):
1921+
def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr, dtype: tl.constexpr):
19171922
pid = tl.program_id(axis=0)
19181923
block_start = pid * BLOCK_SIZE
19191924
offsets = block_start + tl.arange(0, BLOCK_SIZE)
1920-
t1 = tl.full((BLOCK_SIZE, ), 0, dtype=tl.int64)
1921-
t2 = tl.full((BLOCK_SIZE, ), 2, dtype=tl.int64)
1925+
t1 = tl.full((BLOCK_SIZE, ), 0, dtype=dtype)
1926+
t2 = tl.full((BLOCK_SIZE, ), 2, dtype=dtype)
19221927
tl.atomic_cas(X + offsets, t1, t2, sem=sem)
19231928

1924-
X = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], device=device, dtype=torch.int64)
1925-
Y = torch.tensor([2, 1, 2, 1, 2, 1, 2, 1], device=device, dtype=torch.int64)
1929+
torch_dtype = getattr(torch, dtype_str)
1930+
X = torch.zeros((size, ), device=device, dtype=torch_dtype)
1931+
X[1::2] = 1
1932+
Y = X.clone()
1933+
Y[0::2] = 2
19261934

1927-
change_value[(2, )](X, 4, sem)
1928-
assert (torch.equal(X, Y))
1935+
tl_dtype = getattr(tl, dtype_str)
1936+
change_value[(2, )](X, BLOCK_SIZE=size // 2, sem=sem, dtype=tl_dtype)
1937+
assert torch.equal(X, Y)
19291938

19301939

19311940
@pytest.mark.interpreter

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,33 +1361,16 @@ struct AtomicCASOpConversion
13611361
: valueTy;
13621362
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
13631363
auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType());
1364-
// vec = 1 for scalar
1365-
auto vec = getVectorSize(op.getPtr(), axisAnalysisPass);
1366-
// tensor
1367-
if (tensorTy) {
1368-
auto valTy = cast<RankedTensorType>(op.getVal().getType());
1369-
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
1370-
}
1371-
1372-
auto vecTy = vec_ty(valueElemTy, vec);
13731364
SmallVector<Value> resultVals(elemsPerThread);
13741365

13751366
// atomic ops
1376-
for (size_t i = 0; i < elemsPerThread; i += vec) {
1377-
Value casVal = b.undef(vecTy);
1378-
for (int ii = 0; ii < vec; ++ii) {
1379-
Value iiVal = createIndexAttrConstant(
1380-
rewriter, loc, getTypeConverter()->getIndexType(), ii);
1381-
casVal = b.insert_element(vecTy, casVal, valElements[i + ii], iiVal);
1382-
}
1383-
1384-
Value casPtr = ptrElements[i];
1367+
for (size_t i = 0; i < elemsPerThread; i += 1) {
1368+
Value casVal = valElements[i];
13851369
Value casCmp = cmpElements[i];
1386-
casVal = valElements[i];
1387-
1370+
Value casPtr = ptrElements[i];
13881371
// use op
13891372
if (tensorTy) { // for tensor
1390-
auto retType = vec == 1 ? valueElemTy : vecTy;
1373+
auto retType = valueElemTy;
13911374
// TODO: USE ATOMIC CAS OP on Tensor
13921375
auto successOrdering = *atomicMemOrdering;
13931376
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
@@ -1397,12 +1380,7 @@ struct AtomicCASOpConversion
13971380

13981381
// Extract the new_loaded value from the pair.
13991382
Value ret = b.extract_val(valueElemTy, cmpxchg, i);
1400-
1401-
for (int ii = 0; ii < vec; ++ii) {
1402-
resultVals[i + ii] =
1403-
vec == 1 ? ret
1404-
: b.extract_element(valueElemTy, ret, b.i32_val(ii));
1405-
}
1383+
resultVals[i] = ret;
14061384
} else { // for scalar
14071385
// Build blocks to bypass the atomic instruction for ~rmwMask.
14081386
auto *curBlock = rewriter.getInsertionBlock();

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -611,52 +611,27 @@ struct AtomicCASOpConversion
611611
: valueTy;
612612
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
613613
auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType());
614-
// vec = 1 for scalar
615-
auto vec = getVectorSize(op.getPtr());
616-
auto vecOrig = vec;
617-
// tensor
618-
if (tensorTy) {
619-
auto valTy = cast<RankedTensorType>(op.getVal().getType());
620-
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
621-
}
622-
623-
if (vec == 1 && elemsPerThread > 1)
624-
op->emitRemark() << "Warning: vectorization fails vec = " << vec
625-
<< " origin vec = " << vecOrig
626-
<< " elemsPerThread = " << elemsPerThread << "\n";
627-
628614
auto freeVarMasks = getFreeVariableMasks(op.getPtr().getType());
629615
Value threadPred =
630616
emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo);
631617
uint32_t regMask = freeVarMasks[str_attr("reg")];
632618

633-
auto vecTy = vec_ty(valueElemTy, vec);
634619
SmallVector<Value> resultVals(elemsPerThread);
635620

636-
for (size_t i = 0; i < elemsPerThread; i += vec) {
637-
if (auto canonicalVecStart = getCanonicalIndex(i, regMask);
638-
canonicalVecStart != i) {
621+
for (size_t i = 0; i < elemsPerThread; i += 1) {
622+
if (auto canonicalStart = getCanonicalIndex(i, regMask);
623+
canonicalStart != i) {
639624
// For redundant registers, refer back to the canonical result
640-
for (auto iVec = 0; iVec < vec; ++iVec) {
641-
resultVals[i + iVec] = resultVals[canonicalVecStart + iVec];
642-
}
625+
resultVals[i] = resultVals[canonicalStart];
643626
continue;
644627
}
645628

646-
Value casVal = b.undef(vecTy);
647-
for (int ii = 0; ii < vec; ++ii) {
648-
Value iiVal = createIndexAttrConstant(
649-
rewriter, loc, getTypeConverter()->getIndexType(), ii);
650-
casVal = b.insert_element(vecTy, casVal, valElements[i + ii], iiVal);
651-
}
652-
653-
Value casPtr = ptrElements[i];
629+
Value casVal = valElements[i];
654630
Value casCmp = cmpElements[i];
655-
casVal = valElements[i];
631+
Value casPtr = ptrElements[i];
656632
PTXBuilder ptxBuilderAtomicCAS;
657-
std::string tyId = valueElemNBits * vec == 64
658-
? "l"
659-
: (valueElemNBits * vec == 32 ? "r" : "h");
633+
std::string tyId =
634+
valueElemNBits == 64 ? "l" : (valueElemNBits == 32 ? "r" : "h");
660635
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=" + tyId, /*init=*/true);
661636
auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l");
662637
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, tyId);
@@ -671,13 +646,9 @@ struct AtomicCASOpConversion
671646
atom(dstOpr, ptrOpr, cmpOpr, valOpr).maybePredicate(threadPred);
672647

673648
if (tensorTy) {
674-
auto retType = vec == 1 ? valueElemTy : vecTy;
649+
auto retType = valueElemTy;
675650
auto ret = ptxBuilderAtomicCAS.launch(rewriter, loc, retType);
676-
for (int ii = 0; ii < vec; ++ii) {
677-
resultVals[i + ii] =
678-
vec == 1 ? ret
679-
: b.extract_element(valueElemTy, ret, b.i32_val(ii));
680-
}
651+
resultVals[i] = ret;
681652
} else {
682653
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
683654
if (op.getResult().use_empty()) {

0 commit comments

Comments
 (0)