@@ -1343,6 +1343,66 @@ def test_auto_layout_broadcast():
1343
1343
_ = y * x
1344
1344
1345
1345
1346
+ @filecheck_test
1347
+ @gluon .jit
1348
+ def test_atomic_rmw ():
1349
+ x0 = ttgl .full ([1 ], 1 , ttgl .int64 , layout = ttgl .AutoLayout ())
1350
+ ptr0 = x0 .cast (ttgl .pointer_type (ttgl .int32 ), bitcast = True ).item ()
1351
+ # CHECK: [[c1:%.*]] = arith.constant 1 : i32
1352
+ # CHECK: {{.*}} = tt.atomic_rmw exch, acq_rel, gpu, %{{.*}}, [[c1]], %true : (!tt.ptr<i32>, i32, i1) -> i32
1353
+ ttgl .atomic_xchg (ptr0 , 1 )
1354
+
1355
+ BLOCK : ttgl .constexpr = 128
1356
+ x = ttgl .full ([BLOCK ], 0 , ttgl .int64 , layout = ttgl .AutoLayout ())
1357
+ ptr = x .cast (ttgl .pointer_type (ttgl .int32 ), bitcast = True )
1358
+ val = ttgl .full ([BLOCK ], 1 , ttgl .int32 , layout = ttgl .AutoLayout ())
1359
+ mask = ttgl .full ([BLOCK ], True , ttgl .int1 , layout = ttgl .AutoLayout ())
1360
+ offset = ttgl .arange (0 , BLOCK , layout = ttgl .AutoLayout ())
1361
+ # CHECK: [[val:%.*]] = arith.constant dense<1> : tensor<128xi32, #gluon.auto_encoding>
1362
+ # CHECK: {{.*}} = tt.atomic_rmw min, acq_rel, gpu, %{{.*}}, [[val]], %{{.*}} : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi1, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1363
+ # CHECK: {{.*}} = tt.atomic_rmw max, acq_rel, gpu, %{{.*}}, [[val]], %{{.*}} : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi1, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1364
+ # CHECK: {{.*}} = tt.atomic_rmw add, acq_rel, gpu, %{{.*}}, [[val]], %{{.*}} : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi1, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1365
+ # CHECK: {{.*}} = tt.atomic_rmw and, acq_rel, gpu, %{{.*}}, [[val]], %{{.*}} : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi1, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1366
+ # CHECK: {{.*}} = tt.atomic_rmw or, acq_rel, gpu, %{{.*}}, [[val]], %{{.*}} : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi1, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1367
+ # CHECK: {{.*}} = tt.atomic_rmw xor, acq_rel, gpu, %{{.*}}, [[val]], %{{.*}} : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi1, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1368
+ # CHECK: {{.*}} = tt.atomic_rmw max, acq_rel, gpu, %{{.*}}, [[val]], %{{.*}} : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi1, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1369
+ # CHECK: {{.*}} = tt.atomic_rmw add, relaxed, gpu, %{{.*}}, [[val]], %{{.*}} : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi1, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1370
+ ttgl .atomic_min (offset + ptr , val )
1371
+ ttgl .atomic_max (offset + ptr , val )
1372
+ ttgl .atomic_add (offset + ptr , val )
1373
+ ttgl .atomic_and (offset + ptr , val )
1374
+ ttgl .atomic_or (offset + ptr , val )
1375
+ ttgl .atomic_xor (offset + ptr , val )
1376
+ ttgl .atomic_max (offset + ptr , val , mask = mask )
1377
+ ttgl .atomic_add (offset + ptr , val , mask = mask , sem = "relaxed" )
1378
+
1379
+
1380
+ @filecheck_test
1381
+ @gluon .jit
1382
+ def test_atomic_cas ():
1383
+ # CHECK: {{.*}} = arith.constant dense<1> : tensor<1xi64, #gluon.auto_encoding>
1384
+ x0 = ttgl .full ([1 ], 1 , ttgl .int64 , layout = ttgl .AutoLayout ())
1385
+ ptr0 = x0 .cast (ttgl .pointer_type (ttgl .int32 ), bitcast = True ).item ()
1386
+ # CHECK: [[c0:%.*]] = arith.constant 0 : i32
1387
+ # CHECK: [[c1:%.*]] = arith.constant 1 : i32
1388
+ # CHECK: {{.*}} = tt.atomic_cas acq_rel, gpu, %{{.*}}, [[c0]], [[c1]] : (!tt.ptr<i32>, i32, i32) -> i32
1389
+ ttgl .atomic_cas (ptr0 , 0 , 1 )
1390
+
1391
+ BLOCK : ttgl .constexpr = 128
1392
+ x = ttgl .full ([BLOCK ], 0 , ttgl .int64 , layout = ttgl .AutoLayout ())
1393
+ ptr = x .cast (ttgl .pointer_type (ttgl .int32 ), bitcast = True )
1394
+ # CHECK: {{.*}} = arith.constant dense<0> : tensor<128xi64, #gluon.auto_encoding>
1395
+ offset = ttgl .arange (0 , BLOCK , layout = ttgl .AutoLayout ())
1396
+ old = ttgl .full ([BLOCK ], 0 , ttgl .int32 , layout = ttgl .AutoLayout ())
1397
+ new = ttgl .full ([BLOCK ], 1 , ttgl .int32 , layout = ttgl .AutoLayout ())
1398
+ # CHECK: [[old:%.*]] = arith.constant dense<0> : tensor<128xi32, #gluon.auto_encoding>
1399
+ # CHECK: [[new:%.*]] = arith.constant dense<1> : tensor<128xi32, #gluon.auto_encoding>
1400
+ # CHECK: {{.*}} = tt.atomic_cas relaxed, gpu, %{{.*}}, [[old]], [[new]] : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1401
+ # CHECK: {{.*}} = tt.atomic_cas acq_rel, gpu, %{{.*}}, [[old]], [[new]] : (tensor<128x!tt.ptr<i32>, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>, tensor<128xi32, #gluon.auto_encoding>) -> tensor<128xi32, #gluon.auto_encoding>
1402
+ ttgl .atomic_cas (offset + ptr , old , new , sem = "relaxed" )
1403
+ ttgl .atomic_cas (offset + ptr , old , new )
1404
+
1405
+
1346
1406
@gluon .jit
1347
1407
def amd_mfma_layout_kernel ():
1348
1408
mfma_layout_fp32 : ttgl .constexpr = amd_layouts .AMDMFMALayout (version = 3 , instr_shape = [32 , 32 ], transposed = True ,
0 commit comments