Skip to content

Commit 67de05c

Browse files
Jokerenwdziurdz
authored andcommitted
[BACKEND] Disable vectorization for atomic_cas on all backends (#7711)
1 parent 7ec882d commit 67de05c

File tree

3 files changed

+33
-75
lines changed

3 files changed

+33
-75
lines changed

python/test/unit/language/test_core.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,24 +1911,33 @@ def serialized_add(data, Lock, SEM: tl.constexpr):
19111911

19121912

19131913
@pytest.mark.interpreter
1914-
@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed'])
1914+
@pytest.mark.parametrize("sem", [None, "acquire", "release", "acq_rel", "relaxed"])
19151915
@pytest.mark.parametrize("num_ctas", num_ctas_list)
1916-
def test_tensor_atomic_cas(sem, num_ctas, device):
1916+
@pytest.mark.parametrize("size", [4, 128, 512])
1917+
@pytest.mark.parametrize("dtype_str", ['bfloat16', 'float16', 'float32', 'uint64', 'int64', 'float64'])
1918+
def test_tensor_atomic_cas(sem, size, dtype_str, num_ctas, device):
1919+
check_type_supported(dtype_str, device)
1920+
if "float" in dtype_str and is_hip():
1921+
pytest.skip("HIP does not support atomic cas with float types")
19171922

19181923
@triton.jit
1919-
def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr):
1924+
def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr, dtype: tl.constexpr):
19201925
pid = tl.program_id(axis=0)
19211926
block_start = pid * BLOCK_SIZE
19221927
offsets = block_start + tl.arange(0, BLOCK_SIZE)
1923-
t1 = tl.full((BLOCK_SIZE, ), 0, dtype=tl.int64)
1924-
t2 = tl.full((BLOCK_SIZE, ), 2, dtype=tl.int64)
1928+
t1 = tl.full((BLOCK_SIZE, ), 0, dtype=dtype)
1929+
t2 = tl.full((BLOCK_SIZE, ), 2, dtype=dtype)
19251930
tl.atomic_cas(X + offsets, t1, t2, sem=sem)
19261931

1927-
X = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], device=device, dtype=torch.int64)
1928-
Y = torch.tensor([2, 1, 2, 1, 2, 1, 2, 1], device=device, dtype=torch.int64)
1932+
torch_dtype = getattr(torch, dtype_str)
1933+
X = torch.zeros((size, ), device=device, dtype=torch_dtype)
1934+
X[1::2] = 1
1935+
Y = X.clone()
1936+
Y[0::2] = 2
19291937

1930-
change_value[(2, )](X, 4, sem)
1931-
assert (torch.equal(X, Y))
1938+
tl_dtype = getattr(tl, dtype_str)
1939+
change_value[(2, )](X, BLOCK_SIZE=size // 2, sem=sem, dtype=tl_dtype)
1940+
assert torch.equal(X, Y)
19321941

19331942

19341943
@pytest.mark.interpreter

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,33 +1419,16 @@ struct AtomicCASOpConversion
14191419
: valueTy;
14201420
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
14211421
auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType());
1422-
// vec = 1 for scalar
1423-
auto vec = getVectorSize(op.getPtr(), axisAnalysisPass);
1424-
// tensor
1425-
if (tensorTy) {
1426-
auto valTy = cast<RankedTensorType>(op.getVal().getType());
1427-
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
1428-
}
1429-
1430-
auto vecTy = vec_ty(valueElemTy, vec);
14311422
SmallVector<Value> resultVals(elemsPerThread);
14321423

14331424
// atomic ops
1434-
for (size_t i = 0; i < elemsPerThread; i += vec) {
1435-
Value casVal = b.undef(vecTy);
1436-
for (int ii = 0; ii < vec; ++ii) {
1437-
Value iiVal = createIndexAttrConstant(
1438-
rewriter, loc, getTypeConverter()->getIndexType(), ii);
1439-
casVal = b.insert_element(vecTy, casVal, valElements[i + ii], iiVal);
1440-
}
1441-
1442-
Value casPtr = ptrElements[i];
1425+
for (size_t i = 0; i < elemsPerThread; i += 1) {
1426+
Value casVal = valElements[i];
14431427
Value casCmp = cmpElements[i];
1444-
casVal = valElements[i];
1445-
1428+
Value casPtr = ptrElements[i];
14461429
// use op
14471430
if (tensorTy) { // for tensor
1448-
auto retType = vec == 1 ? valueElemTy : vecTy;
1431+
auto retType = valueElemTy;
14491432
// TODO: USE ATOMIC CAS OP on Tensor
14501433
auto successOrdering = *atomicMemOrdering;
14511434
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
@@ -1455,12 +1438,7 @@ struct AtomicCASOpConversion
14551438

14561439
// Extract the new_loaded value from the pair.
14571440
Value ret = b.extract_val(valueElemTy, cmpxchg, i);
1458-
1459-
for (int ii = 0; ii < vec; ++ii) {
1460-
resultVals[i + ii] =
1461-
vec == 1 ? ret
1462-
: b.extract_element(valueElemTy, ret, b.i32_val(ii));
1463-
}
1441+
resultVals[i] = ret;
14641442
} else { // for scalar
14651443
// Build blocks to bypass the atomic instruction for ~rmwMask.
14661444
auto *curBlock = rewriter.getInsertionBlock();

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -611,52 +611,27 @@ struct AtomicCASOpConversion
611611
: valueTy;
612612
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
613613
auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType());
614-
// vec = 1 for scalar
615-
auto vec = getVectorSize(op.getPtr());
616-
auto vecOrig = vec;
617-
// tensor
618-
if (tensorTy) {
619-
auto valTy = cast<RankedTensorType>(op.getVal().getType());
620-
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
621-
}
622-
623-
if (vec == 1 && elemsPerThread > 1)
624-
op->emitRemark() << "Warning: vectorization fails vec = " << vec
625-
<< " origin vec = " << vecOrig
626-
<< " elemsPerThread = " << elemsPerThread << "\n";
627-
628614
auto freeVarMasks = getFreeVariableMasks(op.getPtr().getType());
629615
Value threadPred =
630616
emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo);
631617
uint32_t regMask = freeVarMasks[str_attr("reg")];
632618

633-
auto vecTy = vec_ty(valueElemTy, vec);
634619
SmallVector<Value> resultVals(elemsPerThread);
635620

636-
for (size_t i = 0; i < elemsPerThread; i += vec) {
637-
if (auto canonicalVecStart = getCanonicalIndex(i, regMask);
638-
canonicalVecStart != i) {
621+
for (size_t i = 0; i < elemsPerThread; i += 1) {
622+
if (auto canonicalStart = getCanonicalIndex(i, regMask);
623+
canonicalStart != i) {
639624
// For redundant registers, refer back to the canonical result
640-
for (auto iVec = 0; iVec < vec; ++iVec) {
641-
resultVals[i + iVec] = resultVals[canonicalVecStart + iVec];
642-
}
625+
resultVals[i] = resultVals[canonicalStart];
643626
continue;
644627
}
645628

646-
Value casVal = b.undef(vecTy);
647-
for (int ii = 0; ii < vec; ++ii) {
648-
Value iiVal = createIndexAttrConstant(
649-
rewriter, loc, getTypeConverter()->getIndexType(), ii);
650-
casVal = b.insert_element(vecTy, casVal, valElements[i + ii], iiVal);
651-
}
652-
653-
Value casPtr = ptrElements[i];
629+
Value casVal = valElements[i];
654630
Value casCmp = cmpElements[i];
655-
casVal = valElements[i];
631+
Value casPtr = ptrElements[i];
656632
PTXBuilder ptxBuilderAtomicCAS;
657-
std::string tyId = valueElemNBits * vec == 64
658-
? "l"
659-
: (valueElemNBits * vec == 32 ? "r" : "h");
633+
std::string tyId =
634+
valueElemNBits == 64 ? "l" : (valueElemNBits == 32 ? "r" : "h");
660635
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=" + tyId, /*init=*/true);
661636
auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l");
662637
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, tyId);
@@ -671,13 +646,9 @@ struct AtomicCASOpConversion
671646
atom(dstOpr, ptrOpr, cmpOpr, valOpr).maybePredicate(threadPred);
672647

673648
if (tensorTy) {
674-
auto retType = vec == 1 ? valueElemTy : vecTy;
649+
auto retType = valueElemTy;
675650
auto ret = ptxBuilderAtomicCAS.launch(rewriter, loc, retType);
676-
for (int ii = 0; ii < vec; ++ii) {
677-
resultVals[i + ii] =
678-
vec == 1 ? ret
679-
: b.extract_element(valueElemTy, ret, b.i32_val(ii));
680-
}
651+
resultVals[i] = ret;
681652
} else {
682653
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
683654
if (op.getResult().use_empty()) {

0 commit comments

Comments
 (0)