diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/BF16Casts.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/BF16Casts.cpp index 225c60c1e3..25f91ca97a 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/BF16Casts.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/BF16Casts.cpp @@ -2,7 +2,6 @@ #include "Dialect/TritonIntelGPU/Transforms/Utility.h" #include "Utils/LLVMIntr.h" -#include "Utils/Mangling.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Support/LLVM.h" @@ -13,6 +12,7 @@ #include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h" using namespace mlir; +using namespace triton::gpu::intel; namespace { static bool isBF16OrTensorOf(Type type) { @@ -79,38 +79,14 @@ struct TruncBF16 : ConvertOpToLLVMPattern { namespace mlir::triton::intel { Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter, Value v) { - auto b = TritonLLVMOpBuilder(loc, rewriter); - if (auto definingOp = v.getDefiningOp()) { - auto moduleOp = definingOp->getParentWithTrait(); - if (moduleOp->hasAttr(triton::gpu::intel::TritonIntelGPUDialect:: - getSupportBF16ConversionAttrName())) { - // For SPIRV target, use specialized intrinsic call for conversion. - // Otherwise, use fpext operation. - if (gpu::intel::hasSpirvTargetArch(moduleOp)) { - constexpr StringLiteral baseName = "__spirv_ConvertBF16ToFINTEL"; - Type inTy = getTypeWithSameShape(v.getType(), i16_ty); - Type outTy = getTypeWithSameShape(inTy, f32_ty); - std::string funcName = mlir::triton::gpu::intel::mangle(baseName, inTy); - - auto bitcastValue = b.bitcast(v, inTy).getResult(); - - auto memAttr = rewriter.getAttr( - /*other=*/LLVM::ModRefInfo::NoModRef, - /*argMem=*/LLVM::ModRefInfo::NoModRef, - /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); - auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs; - funcAttrs.memEffectsAttr = memAttr; - - auto call = gpu::intel::createDeviceFunctionCall( - rewriter, funcName, outTy, {inTy}, {bitcastValue}, {}, funcAttrs); - return call.getResult(); - } - - return LLVM::FPExtOp::create(rewriter, loc, f32_ty, v); - } - } + TritonLLVMIRRewriter b(loc, rewriter); + auto as_int16 = b.bitcast(v, i16_ty).getResult(); + auto result = convertWithFunctionCall( + b, as_int16, "__spirv_ConvertBF16ToFINTEL", i16_ty, f32_ty, + TritonIntelGPUDialect::getSupportBF16ConversionAttrName()); + if (result) + return result; - auto as_int16 = b.bitcast(v, i16_ty); auto as_int32 = b.zext(i32_ty, as_int16); auto shifted = b.shl(i32_ty, as_int32, b.i32_val(16)); return (b.bitcast(shifted, f32_ty)); @@ -118,35 +94,20 @@ Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter, Value convertFp32ToBf16(Location loc, ConversionPatternRewriter &rewriter, Value v, RoundingMode rounding) { - auto b = TritonLLVMOpBuilder(loc, rewriter); - if (auto definingOp = v.getDefiningOp()) { - auto moduleOp = definingOp->getParentWithTrait(); - if (moduleOp->hasAttr(triton::gpu::intel::TritonIntelGPUDialect:: - getSupportBF16ConversionAttrName()) && - rounding == RoundingMode::RTNE) { - // Intel SPIR-V extension only supports round-to-nearest-even - // LLVM fptrunc operation also assumes round-to-nearest mode - if (gpu::intel::hasSpirvTargetArch(moduleOp)) { - constexpr StringLiteral baseName = "__spirv_ConvertFToBF16INTEL"; - Type inTy = v.getType(); - Type funcOutTy = getTypeWithSameShape(inTy, i16_ty); - Type outTy = getTypeWithSameShape(inTy, bf16_ty); - std::string funcName = mlir::triton::gpu::intel::mangle(baseName, inTy); - - auto memAttr = rewriter.getAttr( - /*other=*/LLVM::ModRefInfo::NoModRef, - /*argMem=*/LLVM::ModRefInfo::NoModRef, - /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); - auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs; - funcAttrs.memEffectsAttr = memAttr; - - auto call = gpu::intel::createDeviceFunctionCall( - rewriter, funcName, funcOutTy, {inTy}, {v}, {}, funcAttrs); - return b.bitcast(call.getResult(), outTy); - } - + TritonLLVMIRRewriter b(loc, rewriter); + // Intel SPIR-V extension only supports round-to-nearest-even + // LLVM fptrunc operation also assumes round-to-nearest mode + if (rounding == RoundingMode::RTNE) { + std::string attrName = "__spirv_ConvertFToBF16INTEL"; + auto result = convertWithFunctionCall( + b, v, attrName, f32_ty, i16_ty, + TritonIntelGPUDialect::getSupportBF16ConversionAttrName()); + if (result) + return b.bitcast(result, bf16_ty); + + auto op = v.getDefiningOp(); + if (mlir::LLVM::intel::hasModuleAttr(op, attrName)) return LLVM::FPTruncOp::create(rewriter, loc, bf16_ty, v); - } } assert(!isa(v.getType()) && "Not yet supported"); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.cpp index d9bfc52023..69527a897f 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.cpp @@ -7,9 +7,15 @@ //===----------------------------------------------------------------------===// #include "Utility.h" +#include "Utils/LLVMIntr.h" +#include "Utils/Mangling.h" + +#include "llvm/ADT/TypeSwitch.h" #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" +#include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h" + using namespace mlir; using namespace mlir::triton; @@ -168,4 +174,45 @@ convertTritonRoundingModeToLLVM(const triton::RoundingMode rounding) { } } +Type getTypeWithSameShape(Type type, Type elementType) { + return TypeSwitch(type) + .Case([elementType](VectorType type) { + return VectorType::get(type.getShape(), elementType, + type.getScalableDims()); + }) + .Default(elementType); +} + +bool hasModuleAttr(Operation *op, StringRef attrName) { + auto mod = op->getParentOfType(); + return mod && mod->hasAttr(attrName); +} + } // namespace mlir::LLVM::intel + +namespace mlir::triton::intel { +Value convertWithFunctionCall(TritonLLVMIRRewriter &rewriter, Value value, + StringRef baseName, Type inType, Type outType, + StringRef hasAttrName) { + auto op = value.getDefiningOp(); + if (!gpu::intel::hasSpirvTargetArch(op)) + return {}; + if (!hasAttrName.empty() && + !mlir::LLVM::intel::hasModuleAttr(op, hasAttrName)) + return {}; + + auto valueType = value.getType(); + Type inTy = mlir::LLVM::intel::getTypeWithSameShape(valueType, inType); + Type outTy = mlir::LLVM::intel::getTypeWithSameShape(valueType, outType); + std::string funcName = mlir::triton::gpu::intel::mangle(baseName, inTy); + auto memAttr = rewriter.getAttr( + /*other=*/LLVM::ModRefInfo::NoModRef, + /*argMem=*/LLVM::ModRefInfo::NoModRef, + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); + auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs; + funcAttrs.memEffectsAttr = memAttr; + return gpu::intel::createDeviceFunctionCall(rewriter, funcName, outTy, {inTy}, + {value}, {}, funcAttrs) + .getResult(); +} +} // namespace mlir::triton::intel diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h index 0445e459b5..0f73964ea7 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h @@ -79,6 +79,10 @@ Block &createPredicatedBlock(RewriterBase &rewriter, Location loc, Value cond, LLVM::RoundingMode convertTritonRoundingModeToLLVM(const triton::RoundingMode rounding); +Type getTypeWithSameShape(Type type, Type elementType); + +bool hasModuleAttr(Operation *op, StringRef attrName); + } // namespace mlir::LLVM::intel namespace mlir::triton::intel { @@ -88,6 +92,10 @@ Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter, Value convertFp32ToBf16(Location loc, ConversionPatternRewriter &rewriter, Value v, RoundingMode rounding); +Value convertWithFunctionCall(TritonLLVMIRRewriter &rewriter, Value value, + StringRef baseName, Type inType, Type outType, + StringRef hasAttrName = {}); + } // namespace mlir::triton::intel #endif // TRITON_CONVERSION_TRITONINTELGPU_TO_LLVM_UTILITY_H