Skip to content

Commit 052c792

Browse files
[NFC][LoadStoreOpToLLVM] Reuse common utilities (#4393)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 24caa77 commit 052c792

File tree

1 file changed

+12
-48
lines changed

1 file changed

+12
-48
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 12 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "PatternTritonGPUOpToLLVM.h"
1010
#include "TargetInfo.h"
1111
#include "Utility.h"
12+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
1213

1314
#include "intel/include/Dialect/TritonIntelGPU/IR/Attributes.h"
1415
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
@@ -66,35 +67,6 @@ unsigned getCanonicalIndex(unsigned index, unsigned freeVarMask) {
6667
return index & ~freeVarMask;
6768
}
6869

69-
inline LLVM::AtomicBinOp matchAtomicOp(RMWOp atomicOp) {
70-
switch (atomicOp) {
71-
case RMWOp::AND:
72-
return LLVM::AtomicBinOp::_and;
73-
case RMWOp::OR:
74-
return LLVM::AtomicBinOp::_or;
75-
case RMWOp::XOR:
76-
return LLVM::AtomicBinOp::_xor;
77-
case RMWOp::ADD:
78-
return LLVM::AtomicBinOp::add;
79-
case RMWOp::FADD:
80-
return LLVM::AtomicBinOp::fadd;
81-
case RMWOp::MAX:
82-
return LLVM::AtomicBinOp::max;
83-
case RMWOp::MIN:
84-
return LLVM::AtomicBinOp::min;
85-
case RMWOp::UMAX:
86-
return LLVM::AtomicBinOp::umax;
87-
case RMWOp::UMIN:
88-
return LLVM::AtomicBinOp::umin;
89-
case RMWOp::XCHG:
90-
return LLVM::AtomicBinOp::xchg;
91-
}
92-
// Note that we should never hit this because all cases are covered above.
93-
// However, something is necessary after the switch in the function body to
94-
// avoid a compiler error.
95-
llvm_unreachable("Unhandled RMWOp in case statement");
96-
}
97-
9870
/// Holds the values related to a block pointer.
9971
/// It includes the base pointer, base width and height, row and column
10072
/// stride, and offset base for X and Y.
@@ -2679,21 +2651,6 @@ void createBarrier(ConversionPatternRewriter &rewriter, Location loc,
26792651
b.barrier();
26802652
}
26812653

2682-
static LLVM::AtomicOrdering getMemoryOrdering(MemSemantic memOrdering) {
2683-
switch (memOrdering) {
2684-
case MemSemantic::RELAXED:
2685-
return LLVM::AtomicOrdering::monotonic;
2686-
case MemSemantic::ACQUIRE:
2687-
return LLVM::AtomicOrdering::acquire;
2688-
case MemSemantic::RELEASE:
2689-
return LLVM::AtomicOrdering::release;
2690-
case MemSemantic::ACQUIRE_RELEASE:
2691-
return LLVM::AtomicOrdering::acq_rel;
2692-
default:
2693-
return LLVM::AtomicOrdering::acq_rel;
2694-
}
2695-
}
2696-
26972654
struct AtomicCASOpConversion
26982655
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>,
26992656
public LoadStoreConversionBase {
@@ -2749,7 +2706,9 @@ struct AtomicCASOpConversion
27492706
SmallVector<Value> resultVals(elemsPerThread);
27502707

27512708
MemSemantic memSem = op.getSem();
2752-
LLVM::AtomicOrdering successOrdering = getMemoryOrdering(memSem);
2709+
LLVM::AtomicOrdering successOrdering = getMemoryOrdering(memSem)
2710+
? *getMemoryOrdering(memSem)
2711+
: LLVM::AtomicOrdering::acq_rel;
27532712
LLVM::AtomicOrdering failureOrdering = LLVM::AtomicOrdering::monotonic;
27542713
for (size_t i = 0; i < elemsPerThread; i += vec) {
27552714
Value casVal = b.undef(vecTy);
@@ -2851,7 +2810,9 @@ struct AtomicRMWOpConversion
28512810

28522811
auto atomicRmwAttr = op.getAtomicRmwOp();
28532812
MemSemantic memSem = op.getSem();
2854-
LLVM::AtomicOrdering llvmMemOrdering = getMemoryOrdering(memSem);
2813+
LLVM::AtomicOrdering llvmMemOrdering = getMemoryOrdering(memSem)
2814+
? *getMemoryOrdering(memSem)
2815+
: LLVM::AtomicOrdering::acq_rel;
28552816

28562817
Value val = op.getVal();
28572818
Value ptr = op.getPtr();
@@ -2937,11 +2898,14 @@ struct AtomicRMWOpConversion
29372898
TritonGEN::MemFence::GLOBAL);
29382899

29392900
auto createAtomicBinOpInstruction = [&]() -> SmallVector<Value, 1> {
2940-
mlir::LLVM::AtomicBinOp rmwKind = matchAtomicOp(atomicRmwAttr);
2901+
std::optional<mlir::LLVM::AtomicBinOp> rmwKind =
2902+
matchAtomicOp(atomicRmwAttr);
2903+
if (!rmwKind)
2904+
llvm_unreachable("Unhandled RMWOp in case statement");
29412905

29422906
rmwVal = b.bitcast(rmwVal, valueElemTy);
29432907
auto atomRMW = rewriter.create<LLVM::AtomicRMWOp>(
2944-
loc, rmwKind, rmwPtr, rmwVal, llvmMemOrdering);
2908+
loc, *rmwKind, rmwPtr, rmwVal, llvmMemOrdering);
29452909
return {atomRMW.getRes()};
29462910
};
29472911

0 commit comments

Comments
 (0)