Skip to content

Commit 690f690

Browse files
authored
[GLUON] Add atomic operators for gluon (#7737)
1 parent 7a83ab7 commit 690f690

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

python/test/gluon/test_frontend.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,6 +1343,66 @@ def test_auto_layout_broadcast():
13431343
_ = y * x
13441344

13451345

1346+
@filecheck_test
1347+
@gluon.jit
1348+
def test_atomic_rmw():
1349+
x0 = ttgl.full([1], 1, ttgl.int64, layout=ttgl.AutoLayout())
1350+
ptr0 = x0.cast(ttgl.pointer_type(ttgl.int32), bitcast=True).item()
1351+
# CHECK: [[c1:%.*]] = arith.constant 1 : i32
1352+
# CHECK: {{.*}} = tt.atomic_rmw exch, acq_rel, gpu, %{{.*}}, [[c1]], %true : (!tt.ptr<i32>, i32, i1) -> i32
1353+
ttgl.atomic_xchg(ptr0, 1)
1354+
1355+
BLOCK: ttgl.constexpr = 128
1356+
x = ttgl.full([BLOCK], 0, ttgl.int64, layout=ttgl.AutoLayout())
1357+
ptr = x.cast(ttgl.pointer_type(ttgl.int32), bitcast=True)
1358+
val = ttgl.full([BLOCK], 1, ttgl.int32, layout=ttgl.AutoLayout())
1359+
mask = ttgl.full([BLOCK], True, ttgl.int1, layout=ttgl.AutoLayout())
1360+
offset = ttgl.arange(0, BLOCK, layout=ttgl.AutoLayout())
1361+
# CHECK: [[val:%.*]] = arith.constant dense<1> : tensor<128xi32, #gluon.auto_encoding>
1362+
# CHECK: {{.*}} = tt.atomic_rmw min, acq_rel, gpu, %{{.*}}, [[val]], %{{.*}} : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi1, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1363+
# CHECK: {{.*}} = tt.atomic_rmw max, acq_rel, gpu, %{{.*}}, [[val]], %{{.*}} : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi1, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1364+
# CHECK: {{.*}} = tt.atomic_rmw add, acq_rel, gpu, %{{.*}}, [[val]], %{{.*}} : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi1, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1365+
# CHECK: {{.*}} = tt.atomic_rmw and, acq_rel, gpu, %{{.*}}, [[val]], %{{.*}} : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi1, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1366+
# CHECK: {{.*}} = tt.atomic_rmw or, acq_rel, gpu, %{{.*}}, [[val]], %{{.*}} : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi1, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1367+
# CHECK: {{.*}} = tt.atomic_rmw xor, acq_rel, gpu, %{{.*}}, [[val]], %{{.*}} : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi1, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1368+
# CHECK: {{.*}} = tt.atomic_rmw max, acq_rel, gpu, %{{.*}}, [[val]], %{{.*}} : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi1, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1369+
# CHECK: {{.*}} = tt.atomic_rmw add, relaxed, gpu, %{{.*}}, [[val]], %{{.*}} : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi1, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1370+
ttgl.atomic_min(offset + ptr, val)
1371+
ttgl.atomic_max(offset + ptr, val)
1372+
ttgl.atomic_add(offset + ptr, val)
1373+
ttgl.atomic_and(offset + ptr, val)
1374+
ttgl.atomic_or(offset + ptr, val)
1375+
ttgl.atomic_xor(offset + ptr, val)
1376+
ttgl.atomic_max(offset + ptr, val, mask=mask)
1377+
ttgl.atomic_add(offset + ptr, val, mask=mask, sem="relaxed")
1378+
1379+
1380+
@filecheck_test
1381+
@gluon.jit
1382+
def test_atomic_cas():
1383+
# CHECK: {{.*}} = arith.constant dense<1> : tensor<1xi64, #gluon.auto_encoding>
1384+
x0 = ttgl.full([1], 1, ttgl.int64, layout=ttgl.AutoLayout())
1385+
ptr0 = x0.cast(ttgl.pointer_type(ttgl.int32), bitcast=True).item()
1386+
# CHECK: [[c0:%.*]] = arith.constant 0 : i32
1387+
# CHECK: [[c1:%.*]] = arith.constant 1 : i32
1388+
# CHECK: {{.*}} = tt.atomic_cas acq_rel, gpu, %{{.*}}, [[c0]], [[c1]] : (!tt.ptr<i32>, i32, i32) -> i32
1389+
ttgl.atomic_cas(ptr0, 0, 1)
1390+
1391+
BLOCK: ttgl.constexpr = 128
1392+
x = ttgl.full([BLOCK], 0, ttgl.int64, layout=ttgl.AutoLayout())
1393+
ptr = x.cast(ttgl.pointer_type(ttgl.int32), bitcast=True)
1394+
# CHECK: {{.*}} = arith.constant dense<0> : tensor<128xi64, #gluon.auto_encoding>
1395+
offset = ttgl.arange(0, BLOCK, layout=ttgl.AutoLayout())
1396+
old = ttgl.full([BLOCK], 0, ttgl.int32, layout=ttgl.AutoLayout())
1397+
new = ttgl.full([BLOCK], 1, ttgl.int32, layout=ttgl.AutoLayout())
1398+
# CHECK: [[old:%.*]] = arith.constant dense<0> : tensor<128xi32, #gluon.auto_encoding>
1399+
# CHECK: [[new:%.*]] = arith.constant dense<1> : tensor<128xi32, #gluon.auto_encoding>
1400+
# CHECK: {{.*}} = tt.atomic_cas relaxed, gpu, %{{.*}}, [[old]], [[new]] : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1401+
# CHECK: {{.*}} = tt.atomic_cas acq_rel, gpu, %{{.*}}, [[old]], [[new]] : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1402+
ttgl.atomic_cas(offset + ptr, old, new, sem="relaxed")
1403+
ttgl.atomic_cas(offset + ptr, old, new)
1404+
1405+
13461406
@gluon.jit
13471407
def amd_mfma_layout_kernel():
13481408
mfma_layout_fp32: ttgl.constexpr = amd_layouts.AMDMFMALayout(version=3, instr_shape=[32, 32], transposed=True,

python/triton/experimental/gluon/language/_core.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,23 @@
4646
)
4747

4848
_IMPORT_FROM_TRITON: List[str] = [
49+
"atomic_add",
50+
"atomic_and",
51+
"atomic_cas",
52+
"atomic_max",
53+
"atomic_min",
54+
"atomic_or",
55+
"atomic_xchg",
56+
"atomic_xor",
4957
"broadcast",
5058
"expand_dims",
5159
"inline_asm_elementwise",
5260
"join",
5361
"load",
5462
"map_elementwise",
55-
"maximum",
5663
"max_constancy",
5764
"max_contiguous",
65+
"maximum",
5866
"minimum",
5967
"multiple_of",
6068
"permute",

0 commit comments

Comments
 (0)