|
2 | 2 |
|
3 | 3 | #include "Dialect/TritonIntelGPU/Transforms/Utility.h" |
4 | 4 | #include "Utils/LLVMIntr.h" |
5 | | -#include "Utils/Mangling.h" |
6 | 5 |
|
7 | 6 | #include "mlir/Dialect/Arith/IR/Arith.h" |
8 | 7 | #include "mlir/Support/LLVM.h" |
|
13 | 12 | #include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h" |
14 | 13 |
|
15 | 14 | using namespace mlir; |
| 15 | +using namespace triton::gpu::intel; |
16 | 16 |
|
17 | 17 | namespace { |
18 | 18 | static bool isBF16OrTensorOf(Type type) { |
@@ -79,74 +79,35 @@ struct TruncBF16 : ConvertOpToLLVMPattern<arith::TruncFOp> { |
79 | 79 | namespace mlir::triton::intel { |
80 | 80 | Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter, |
81 | 81 | Value v) { |
82 | | - auto b = TritonLLVMOpBuilder(loc, rewriter); |
83 | | - if (auto definingOp = v.getDefiningOp()) { |
84 | | - auto moduleOp = definingOp->getParentWithTrait<OpTrait::SymbolTable>(); |
85 | | - if (moduleOp->hasAttr(triton::gpu::intel::TritonIntelGPUDialect:: |
86 | | - getSupportBF16ConversionAttrName())) { |
87 | | - // For SPIRV target, use specialized intrinsic call for conversion. |
88 | | - // Otherwise, use fpext operation. |
89 | | - if (gpu::intel::hasSpirvTargetArch(moduleOp)) { |
90 | | - constexpr StringLiteral baseName = "__spirv_ConvertBF16ToFINTEL"; |
91 | | - Type inTy = getTypeWithSameShape(v.getType(), i16_ty); |
92 | | - Type outTy = getTypeWithSameShape(inTy, f32_ty); |
93 | | - std::string funcName = mlir::triton::gpu::intel::mangle(baseName, inTy); |
94 | | - |
95 | | - auto bitcastValue = b.bitcast(v, inTy).getResult(); |
96 | | - |
97 | | - auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( |
98 | | - /*other=*/LLVM::ModRefInfo::NoModRef, |
99 | | - /*argMem=*/LLVM::ModRefInfo::NoModRef, |
100 | | - /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); |
101 | | - auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs; |
102 | | - funcAttrs.memEffectsAttr = memAttr; |
103 | | - |
104 | | - auto call = gpu::intel::createDeviceFunctionCall( |
105 | | - rewriter, funcName, outTy, {inTy}, {bitcastValue}, {}, funcAttrs); |
106 | | - return call.getResult(); |
107 | | - } |
108 | | - |
109 | | - return LLVM::FPExtOp::create(rewriter, loc, f32_ty, v); |
110 | | - } |
111 | | - } |
| 82 | + TritonLLVMIRRewriter b(loc, rewriter); |
| 83 | + auto as_int16 = b.bitcast(v, i16_ty).getResult(); |
| 84 | + auto result = convertWithFunctionCall( |
| 85 | + b, as_int16, "__spirv_ConvertBF16ToFINTEL", i16_ty, f32_ty, |
| 86 | + TritonIntelGPUDialect::getSupportBF16ConversionAttrName()); |
| 87 | + if (result) |
| 88 | + return result; |
112 | 89 |
|
113 | | - auto as_int16 = b.bitcast(v, i16_ty); |
114 | 90 | auto as_int32 = b.zext(i32_ty, as_int16); |
115 | 91 | auto shifted = b.shl(i32_ty, as_int32, b.i32_val(16)); |
116 | 92 | return (b.bitcast(shifted, f32_ty)); |
117 | 93 | } |
118 | 94 |
|
119 | 95 | Value convertFp32ToBf16(Location loc, ConversionPatternRewriter &rewriter, |
120 | 96 | Value v, RoundingMode rounding) { |
121 | | - auto b = TritonLLVMOpBuilder(loc, rewriter); |
122 | | - if (auto definingOp = v.getDefiningOp()) { |
123 | | - auto moduleOp = definingOp->getParentWithTrait<OpTrait::SymbolTable>(); |
124 | | - if (moduleOp->hasAttr(triton::gpu::intel::TritonIntelGPUDialect:: |
125 | | - getSupportBF16ConversionAttrName()) && |
126 | | - rounding == RoundingMode::RTNE) { |
127 | | - // Intel SPIR-V extension only supports round-to-nearest-even |
128 | | - // LLVM fptrunc operation also assumes round-to-nearest mode |
129 | | - if (gpu::intel::hasSpirvTargetArch(moduleOp)) { |
130 | | - constexpr StringLiteral baseName = "__spirv_ConvertFToBF16INTEL"; |
131 | | - Type inTy = v.getType(); |
132 | | - Type funcOutTy = getTypeWithSameShape(inTy, i16_ty); |
133 | | - Type outTy = getTypeWithSameShape(inTy, bf16_ty); |
134 | | - std::string funcName = mlir::triton::gpu::intel::mangle(baseName, inTy); |
135 | | - |
136 | | - auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( |
137 | | - /*other=*/LLVM::ModRefInfo::NoModRef, |
138 | | - /*argMem=*/LLVM::ModRefInfo::NoModRef, |
139 | | - /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); |
140 | | - auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs; |
141 | | - funcAttrs.memEffectsAttr = memAttr; |
142 | | - |
143 | | - auto call = gpu::intel::createDeviceFunctionCall( |
144 | | - rewriter, funcName, funcOutTy, {inTy}, {v}, {}, funcAttrs); |
145 | | - return b.bitcast(call.getResult(), outTy); |
146 | | - } |
147 | | - |
| 97 | + TritonLLVMIRRewriter b(loc, rewriter); |
| 98 | + // Intel SPIR-V extension only supports round-to-nearest-even |
| 99 | + // LLVM fptrunc operation also assumes round-to-nearest mode |
| 100 | + if (rounding == RoundingMode::RTNE) { |
| 101 | + std::string attrName = "__spirv_ConvertFToBF16INTEL"; |
| 102 | + auto result = convertWithFunctionCall( |
| 103 | + b, v, attrName, f32_ty, i16_ty, |
| 104 | + TritonIntelGPUDialect::getSupportBF16ConversionAttrName()); |
| 105 | + if (result) |
| 106 | + return b.bitcast(result, bf16_ty); |
| 107 | + |
| 108 | + auto op = v.getDefiningOp(); |
| 109 | + if (mlir::LLVM::intel::hasModuleAttr(op, attrName)) |
148 | 110 | return LLVM::FPTruncOp::create(rewriter, loc, bf16_ty, v); |
149 | | - } |
150 | 111 | } |
151 | 112 |
|
152 | 113 | assert(!isa<VectorType>(v.getType()) && "Not yet supported"); |
|
0 commit comments