Skip to content

Commit 23f4117

Browse files
authored
Lower Triton Atomic memory semantics to llvm memory ordering (#3374)
Triton Atomic ops accept memory semantics values, including “acquire”, “release”, “acq_rel”, and “relaxed”. Currently we ignore the `memSemantic` from the frontend and only use "acq_rel" by default. This PR enable passing `memSemantic` to the backend.
1 parent 488474f commit 23f4117

File tree

2 files changed

+68
-10
lines changed

2 files changed

+68
-10
lines changed

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,7 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp
10741074
// CHECK-NEXT: llvm.br ^bb2([[CMPXCHG_RES]] : i32)
10751075
// CHECK-NEXT: ^bb2([[RES:%.*]]: i32):
10761076
// CHECK-NEXT: [[RES_CAST:%.*]] = llvm.bitcast [[RES]] : i32 to f32
1077-
%0 = "tt.atomic_cas" (%ptr, %cmp, %val) {sem = 1 : i32, scope = 1 : i32} : (!tt.ptr<f32>, f32, f32) -> f32
1077+
%0 = "tt.atomic_cas" (%ptr, %cmp, %val) {sem = 4 : i32, scope = 1 : i32} : (!tt.ptr<f32>, f32, f32) -> f32
10781078
tt.return
10791079
}
10801080
}
@@ -1109,7 +1109,7 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp
11091109
// CHECK-NEXT: ^bb4:
11101110
// CHECK-NEXT: [[ONE:%.*]] = llvm.mlir.constant(1 : i32) : i32
11111111
// CHECK-NEXT llvm.call spir_funccc @_Z7barrierj([[ONE]]) {{.*}} : (i32) -> ()
1112-
%0 = "tt.atomic_cas" (%ptr, %cmp, %val) {sem = 1 : i32, scope = 1 : i32} : (!tt.ptr<f32>, f32, f32) -> f32
1112+
%0 = "tt.atomic_cas" (%ptr, %cmp, %val) {sem = 4 : i32, scope = 1 : i32} : (!tt.ptr<f32>, f32, f32) -> f32
11131113
tt.store %ptr, %0 : !tt.ptr<f32>
11141114
tt.return
11151115
}
@@ -1133,7 +1133,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11331133
// CHECK: llvm.cond_br [[PRED1]], ^bb1, ^bb2([[ZERO1]] : f32)
11341134
// CHECK-NEXT: ^bb1:
11351135
// CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast [[IE1]] : vector<1xf32> to f32
1136-
// CHECK-NEXT: [[RMW_RES1:%.*]] = llvm.atomicrmw fadd [[EV0_ARG0]], [[BCAST2]] acq_rel : !llvm.ptr<1>, f32
1136+
// CHECK-NEXT: [[RMW_RES1:%.*]] = llvm.atomicrmw fadd [[EV0_ARG0]], [[BCAST2]] monotonic : !llvm.ptr<1>, f32
11371137
// CHECK-NEXT: llvm.br ^bb2([[RMW_RES1]] : f32)
11381138
// CHECK-NEXT: ^bb2([[RMW_PHI1:%.*]]: f32):
11391139
// CHECK-NEXT: [[RMW_CAST:%.*]] = llvm.bitcast [[RMW_PHI1]] : f32 to f32
@@ -1148,7 +1148,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11481148
// CHECK-NEXT: llvm.cond_br [[PRED2]], ^bb3, ^bb4([[ZERO2]] : f32)
11491149
// CHECK-NEXT: ^bb3:
11501150
// CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast [[IE2]] : vector<1xf32> to f32
1151-
// CHECK-NEXT: [[RMW_RES2:%.*]] = llvm.atomicrmw fadd [[EV1_ARG0]], [[BCAST2]] acq_rel : !llvm.ptr<1>, f32
1151+
// CHECK-NEXT: [[RMW_RES2:%.*]] = llvm.atomicrmw fadd [[EV1_ARG0]], [[BCAST2]] monotonic : !llvm.ptr<1>, f32
11521152
// CHECK-NEXT: llvm.br ^bb4([[RMW_RES2]] : f32)
11531153
// CHECK-NEXT: ^bb4([[RMW_PHI2:%.*]]: f32):
11541154
%0 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
@@ -1177,7 +1177,7 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp
11771177
// CHECK-NEXT: llvm.cond_br [[PRED]], ^bb1, ^bb2([[ZERO]] : f32)
11781178
// CHECK-NEXT: ^bb1:
11791179
// CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast [[IE1]] : vector<1xf32> to f32
1180-
// CHECK-NEXT: [[RMW_RES:%.*]] = llvm.atomicrmw fadd %arg0, [[BCAST2]] acq_rel : !llvm.ptr<1>, f32
1180+
// CHECK-NEXT: [[RMW_RES:%.*]] = llvm.atomicrmw fadd %arg0, [[BCAST2]] monotonic : !llvm.ptr<1>, f32
11811181
// CHECK-NEXT: llvm.br ^bb2([[RMW_RES]] : f32)
11821182
// CHECK-NEXT: ^bb2([[RMW_PHI:%.*]]: f32):
11831183
// CHECK-NEXT: [[RMW_CAST:%.*]] = llvm.bitcast [[RMW_PHI]] : f32 to f32
@@ -1204,7 +1204,7 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp
12041204
// CHECK-NEXT: llvm.cond_br [[PRED]], ^bb1, ^bb2([[ZERO]] : f32)
12051205
// CHECK-NEXT: ^bb1:
12061206
// CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast [[IE1]] : vector<1xf32> to f32
1207-
// CHECK-NEXT: [[RMW_RES:%.*]] = llvm.atomicrmw fadd %arg0, [[BCAST2]] acq_rel : !llvm.ptr<1>, f32
1207+
// CHECK-NEXT: [[RMW_RES:%.*]] = llvm.atomicrmw fadd %arg0, [[BCAST2]] monotonic : !llvm.ptr<1>, f32
12081208
// CHECK-NEXT: llvm.br ^bb2([[RMW_RES]] : f32)
12091209
// CHECK-NEXT: ^bb2([[RMW_PHI:%.*]]: f32):
12101210
// CHECK-NEXT: [[RMW_CAST:%.*]] = llvm.bitcast [[RMW_PHI]] : f32 to f32
@@ -1225,13 +1225,52 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp
12251225

12261226
// -----
12271227

1228+
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
1229+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1230+
// CHECK-LABEL: atomic_add_f32
1231+
tt.func @atomic_add_f32_sys_scope(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
1232+
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} monotonic : !llvm.ptr<1>, f32
1233+
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} monotonic : !llvm.ptr<1>, f32
1234+
%0 = tt.atomic_rmw fadd, relaxed, sys, %arg0, %arg2, %arg1 : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
1235+
tt.return
1236+
}
1237+
}
1238+
1239+
// -----
1240+
1241+
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
1242+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1243+
// CHECK-LABEL: atomic_add_f32
1244+
tt.func @atomic_add_f32_sys_scope(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
1245+
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acquire : !llvm.ptr<1>, f32
1246+
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acquire : !llvm.ptr<1>, f32
1247+
%0 = tt.atomic_rmw fadd, acquire, sys, %arg0, %arg2, %arg1 : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
1248+
tt.return
1249+
}
1250+
}
1251+
1252+
// -----
1253+
1254+
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
1255+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1256+
// CHECK-LABEL: atomic_add_f32
1257+
tt.func @atomic_add_f32_sys_scope(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
1258+
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} release : !llvm.ptr<1>, f32
1259+
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} release : !llvm.ptr<1>, f32
1260+
%0 = tt.atomic_rmw fadd, release, sys, %arg0, %arg2, %arg1 : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
1261+
tt.return
1262+
}
1263+
}
1264+
1265+
// -----
1266+
12281267
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
12291268
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
12301269
// CHECK-LABEL: atomic_add_f32
12311270
tt.func @atomic_add_f32_sys_scope(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
12321271
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel : !llvm.ptr<1>, f32
12331272
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel : !llvm.ptr<1>, f32
1234-
%0 = tt.atomic_rmw fadd, relaxed, sys, %arg0, %arg2, %arg1 : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
1273+
%0 = tt.atomic_rmw fadd, acq_rel, sys, %arg0, %arg2, %arg1 : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
12351274
tt.return
12361275
}
12371276
}

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,6 +1416,21 @@ void createBarrier(ConversionPatternRewriter &rewriter, Location loc,
14161416
b.barrier();
14171417
}
14181418

1419+
static LLVM::AtomicOrdering getMemoryOrdering(MemSemantic memOrdering) {
1420+
switch (memOrdering) {
1421+
case MemSemantic::RELAXED:
1422+
return LLVM::AtomicOrdering::monotonic;
1423+
case MemSemantic::ACQUIRE:
1424+
return LLVM::AtomicOrdering::acquire;
1425+
case MemSemantic::RELEASE:
1426+
return LLVM::AtomicOrdering::release;
1427+
case MemSemantic::ACQUIRE_RELEASE:
1428+
return LLVM::AtomicOrdering::acq_rel;
1429+
default:
1430+
return LLVM::AtomicOrdering::acq_rel;
1431+
}
1432+
}
1433+
14191434
struct AtomicCASOpConversion
14201435
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>,
14211436
public LoadStoreConversionBase {
@@ -1469,6 +1484,9 @@ struct AtomicCASOpConversion
14691484
auto vecTy = vec_ty(valueElemTy, vec);
14701485
SmallVector<Value> resultVals(elemsPerThread);
14711486

1487+
MemSemantic memSem = op.getSem();
1488+
LLVM::AtomicOrdering successOrdering = getMemoryOrdering(memSem);
1489+
LLVM::AtomicOrdering failureOrdering = LLVM::AtomicOrdering::monotonic;
14721490
for (size_t i = 0; i < elemsPerThread; i += vec) {
14731491
Value casVal = b.undef(vecTy);
14741492
for (int ii = 0; ii < vec; ++ii) {
@@ -1497,8 +1515,7 @@ struct AtomicCASOpConversion
14971515
casVal = b.bitcast(casVal, zero.getType());
14981516

14991517
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
1500-
loc, casPtr, casCmp, casVal, LLVM::AtomicOrdering::acq_rel,
1501-
LLVM::AtomicOrdering::monotonic);
1518+
loc, casPtr, casCmp, casVal, successOrdering, failureOrdering);
15021519
Value newLoaded =
15031520
rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
15041521
return SmallVector<Value, 1>{newLoaded};
@@ -1566,6 +1583,8 @@ struct AtomicRMWOpConversion
15661583
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
15671584

15681585
auto atomicRmwAttr = op.getAtomicRmwOp();
1586+
MemSemantic memSem = op.getSem();
1587+
LLVM::AtomicOrdering llvmMemOrdering = getMemoryOrdering(memSem);
15691588

15701589
Value val = op.getVal();
15711590
Value ptr = op.getPtr();
@@ -1682,7 +1701,7 @@ struct AtomicRMWOpConversion
16821701

16831702
rmwVal = b.bitcast(rmwVal, valueElemTy);
16841703
auto atomRMW = rewriter.create<LLVM::AtomicRMWOp>(
1685-
loc, rmwKind, rmwPtr, rmwVal, LLVM::AtomicOrdering::acq_rel);
1704+
loc, rmwKind, rmwPtr, rmwVal, llvmMemOrdering);
16861705
return SmallVector<Value, 1>{atomRMW.getRes()};
16871706
});
16881707
}

0 commit comments

Comments
 (0)