|
9 | 9 | #include "PatternTritonGPUOpToLLVM.h" |
10 | 10 | #include "TargetInfo.h" |
11 | 11 | #include "Utility.h" |
| 12 | +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" |
12 | 13 |
|
13 | 14 | #include "intel/include/Dialect/TritonIntelGPU/IR/Attributes.h" |
14 | 15 | #include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" |
@@ -66,35 +67,6 @@ unsigned getCanonicalIndex(unsigned index, unsigned freeVarMask) { |
66 | 67 | return index & ~freeVarMask; |
67 | 68 | } |
68 | 69 |
|
69 | | -inline LLVM::AtomicBinOp matchAtomicOp(RMWOp atomicOp) { |
70 | | - switch (atomicOp) { |
71 | | - case RMWOp::AND: |
72 | | - return LLVM::AtomicBinOp::_and; |
73 | | - case RMWOp::OR: |
74 | | - return LLVM::AtomicBinOp::_or; |
75 | | - case RMWOp::XOR: |
76 | | - return LLVM::AtomicBinOp::_xor; |
77 | | - case RMWOp::ADD: |
78 | | - return LLVM::AtomicBinOp::add; |
79 | | - case RMWOp::FADD: |
80 | | - return LLVM::AtomicBinOp::fadd; |
81 | | - case RMWOp::MAX: |
82 | | - return LLVM::AtomicBinOp::max; |
83 | | - case RMWOp::MIN: |
84 | | - return LLVM::AtomicBinOp::min; |
85 | | - case RMWOp::UMAX: |
86 | | - return LLVM::AtomicBinOp::umax; |
87 | | - case RMWOp::UMIN: |
88 | | - return LLVM::AtomicBinOp::umin; |
89 | | - case RMWOp::XCHG: |
90 | | - return LLVM::AtomicBinOp::xchg; |
91 | | - } |
92 | | - // Note that we should never hit this because all cases are covered above. |
93 | | - // However, something is necessary after the switch in the function body to |
94 | | - // avoid a compiler error. |
95 | | - llvm_unreachable("Unhandled RMWOp in case statement"); |
96 | | -} |
97 | | - |
98 | 70 | /// Holds the values related to a block pointer. |
99 | 71 | /// It includes the base pointer, base width and height, row and column |
100 | 72 | /// stride, and offset base for X and Y. |
@@ -2679,21 +2651,6 @@ void createBarrier(ConversionPatternRewriter &rewriter, Location loc, |
2679 | 2651 | b.barrier(); |
2680 | 2652 | } |
2681 | 2653 |
|
2682 | | -static LLVM::AtomicOrdering getMemoryOrdering(MemSemantic memOrdering) { |
2683 | | - switch (memOrdering) { |
2684 | | - case MemSemantic::RELAXED: |
2685 | | - return LLVM::AtomicOrdering::monotonic; |
2686 | | - case MemSemantic::ACQUIRE: |
2687 | | - return LLVM::AtomicOrdering::acquire; |
2688 | | - case MemSemantic::RELEASE: |
2689 | | - return LLVM::AtomicOrdering::release; |
2690 | | - case MemSemantic::ACQUIRE_RELEASE: |
2691 | | - return LLVM::AtomicOrdering::acq_rel; |
2692 | | - default: |
2693 | | - return LLVM::AtomicOrdering::acq_rel; |
2694 | | - } |
2695 | | -} |
2696 | | - |
2697 | 2654 | struct AtomicCASOpConversion |
2698 | 2655 | : public ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>, |
2699 | 2656 | public LoadStoreConversionBase { |
@@ -2749,7 +2706,9 @@ struct AtomicCASOpConversion |
2749 | 2706 | SmallVector<Value> resultVals(elemsPerThread); |
2750 | 2707 |
|
2751 | 2708 | MemSemantic memSem = op.getSem(); |
2752 | | - LLVM::AtomicOrdering successOrdering = getMemoryOrdering(memSem); |
| 2709 | + LLVM::AtomicOrdering successOrdering = getMemoryOrdering(memSem) |
| 2710 | + ? *getMemoryOrdering(memSem) |
| 2711 | + : LLVM::AtomicOrdering::acq_rel; |
2753 | 2712 | LLVM::AtomicOrdering failureOrdering = LLVM::AtomicOrdering::monotonic; |
2754 | 2713 | for (size_t i = 0; i < elemsPerThread; i += vec) { |
2755 | 2714 | Value casVal = b.undef(vecTy); |
@@ -2851,7 +2810,9 @@ struct AtomicRMWOpConversion |
2851 | 2810 |
|
2852 | 2811 | auto atomicRmwAttr = op.getAtomicRmwOp(); |
2853 | 2812 | MemSemantic memSem = op.getSem(); |
2854 | | - LLVM::AtomicOrdering llvmMemOrdering = getMemoryOrdering(memSem); |
| 2813 | + LLVM::AtomicOrdering llvmMemOrdering = getMemoryOrdering(memSem) |
| 2814 | + ? *getMemoryOrdering(memSem) |
| 2815 | + : LLVM::AtomicOrdering::acq_rel; |
2855 | 2816 |
|
2856 | 2817 | Value val = op.getVal(); |
2857 | 2818 | Value ptr = op.getPtr(); |
@@ -2937,11 +2898,14 @@ struct AtomicRMWOpConversion |
2937 | 2898 | TritonGEN::MemFence::GLOBAL); |
2938 | 2899 |
|
2939 | 2900 | auto createAtomicBinOpInstruction = [&]() -> SmallVector<Value, 1> { |
2940 | | - mlir::LLVM::AtomicBinOp rmwKind = matchAtomicOp(atomicRmwAttr); |
| 2901 | + std::optional<mlir::LLVM::AtomicBinOp> rmwKind = |
| 2902 | + matchAtomicOp(atomicRmwAttr); |
| 2903 | + if (!rmwKind) |
| 2904 | + llvm_unreachable("Unhandled RMWOp in case statement"); |
2941 | 2905 |
|
2942 | 2906 | rmwVal = b.bitcast(rmwVal, valueElemTy); |
2943 | 2907 | auto atomRMW = rewriter.create<LLVM::AtomicRMWOp>( |
2944 | | - loc, rmwKind, rmwPtr, rmwVal, llvmMemOrdering); |
| 2908 | + loc, *rmwKind, rmwPtr, rmwVal, llvmMemOrdering); |
2945 | 2909 | return {atomRMW.getRes()}; |
2946 | 2910 | }; |
2947 | 2911 |
|
|
0 commit comments