Skip to content

Commit a90ac86

Browse files
authored
[BACKEND] Fix 64 bit atomic_cas (triton-lang#8105)
1 parent adf3999 commit a90ac86

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

python/test/unit/language/test_core.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,22 +1615,30 @@ def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
16151615
@pytest.mark.interpreter
16161616
@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed'])
16171617
@pytest.mark.parametrize("num_ctas", num_ctas_list)
1618-
def test_atomic_cas(sem, num_ctas, device):
1618+
@pytest.mark.parametrize("dtype_str", ["int32", "int64"])
1619+
def test_atomic_cas(sem, num_ctas, dtype_str, device):
16191620
# 1. make sure that atomic_cas changes the original value (Lock)
16201621
@triton.jit
1621-
def change_value(Lock):
1622-
tl.atomic_cas(Lock, 0, 1)
1622+
def change_value(Lock, triton_dtype: tl.constexpr):
1623+
num0 = tl.full((1, ), 0, dtype=triton_dtype).item()
1624+
num1 = tl.full((1, ), 1, dtype=triton_dtype).item()
1625+
tl.atomic_cas(Lock, num0, num1)
16231626

1624-
Lock = torch.zeros((1, ), device=device, dtype=torch.int32)
1625-
change_value[(1, )](Lock)
1627+
torch_dtype = getattr(torch, dtype_str)
1628+
triton_dtype = getattr(tl, dtype_str)
1629+
Lock = torch.zeros((1, ), device=device, dtype=torch_dtype)
1630+
change_value[(1, )](Lock, triton_dtype)
16261631

16271632
assert (Lock[0] == 1)
16281633

16291634
# 2. only one block enters the critical section
16301635
@triton.jit
1631-
def serialized_add(data, Lock, SEM: tl.constexpr):
1636+
def serialized_add(data, Lock, triton_dtype: tl.constexpr, SEM: tl.constexpr):
1637+
num0 = tl.full((1, ), 0, dtype=triton_dtype).item()
1638+
num1 = tl.full((1, ), 1, dtype=triton_dtype).item()
1639+
16321640
ptrs = data + tl.arange(0, 128)
1633-
while tl.atomic_cas(Lock, 0, 1, SEM) == 1:
1641+
while tl.atomic_cas(Lock, num0, num1, SEM) == 1:
16341642
pass
16351643

16361644
tl.store(ptrs, tl.load(ptrs) + 1.0)
@@ -1640,12 +1648,12 @@ def serialized_add(data, Lock, SEM: tl.constexpr):
16401648
tl.debug_barrier()
16411649

16421650
# release lock
1643-
tl.atomic_xchg(Lock, 0)
1651+
tl.atomic_xchg(Lock, num0)
16441652

1645-
Lock = torch.zeros((1, ), device=device, dtype=torch.int32)
1653+
Lock = torch.zeros((1, ), device=device, dtype=torch_dtype)
16461654
data = torch.zeros((128, ), device=device, dtype=torch.float32)
16471655
ref = torch.full((128, ), 2000.0)
1648-
h = serialized_add[(2000, )](data, Lock, SEM=sem, num_ctas=num_ctas)
1656+
h = serialized_add[(2000, )](data, Lock, triton_dtype=triton_dtype, SEM=sem, num_ctas=num_ctas)
16491657
sem_str = "acq_rel" if sem is None else sem
16501658
np.testing.assert_allclose(to_numpy(data), to_numpy(ref))
16511659
if not is_cuda():

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ struct AtomicCASOpConversion
661661
// Only threads with mask = True store the result
662662
PTXBuilder ptxBuilderStore;
663663
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r");
664-
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
664+
auto *valOprStore = ptxBuilderStore.newOperand(old, tyId);
665665
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
666666
st.shared().o(sTy);
667667
st(dstOprStore, valOprStore).maybePredicate(threadPred);

0 commit comments

Comments
 (0)