Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 21 additions & 60 deletions third_party/intel/lib/TritonIntelGPUToLLVM/BF16Casts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include "Dialect/TritonIntelGPU/Transforms/Utility.h"
#include "Utils/LLVMIntr.h"
#include "Utils/Mangling.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Support/LLVM.h"
Expand All @@ -13,6 +12,7 @@
#include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h"

using namespace mlir;
using namespace triton::gpu::intel;

namespace {
static bool isBF16OrTensorOf(Type type) {
Expand Down Expand Up @@ -79,74 +79,35 @@ struct TruncBF16 : ConvertOpToLLVMPattern<arith::TruncFOp> {
namespace mlir::triton::intel {
Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
Value v) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
if (auto definingOp = v.getDefiningOp()) {
auto moduleOp = definingOp->getParentWithTrait<OpTrait::SymbolTable>();
if (moduleOp->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
getSupportBF16ConversionAttrName())) {
// For SPIRV target, use specialized intrinsic call for conversion.
// Otherwise, use fpext operation.
if (gpu::intel::hasSpirvTargetArch(moduleOp)) {
constexpr StringLiteral baseName = "__spirv_ConvertBF16ToFINTEL";
Type inTy = getTypeWithSameShape(v.getType(), i16_ty);
Type outTy = getTypeWithSameShape(inTy, f32_ty);
std::string funcName = mlir::triton::gpu::intel::mangle(baseName, inTy);

auto bitcastValue = b.bitcast(v, inTy).getResult();

auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::NoModRef,
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs;
funcAttrs.memEffectsAttr = memAttr;

auto call = gpu::intel::createDeviceFunctionCall(
rewriter, funcName, outTy, {inTy}, {bitcastValue}, {}, funcAttrs);
return call.getResult();
}

return LLVM::FPExtOp::create(rewriter, loc, f32_ty, v);
}
}
TritonLLVMIRRewriter b(loc, rewriter);
auto as_int16 = b.bitcast(v, i16_ty).getResult();
auto result = convertWithFunctionCall(
b, as_int16, "__spirv_ConvertBF16ToFINTEL", i16_ty, f32_ty,
TritonIntelGPUDialect::getSupportBF16ConversionAttrName());
if (result)
return result;

auto as_int16 = b.bitcast(v, i16_ty);
auto as_int32 = b.zext(i32_ty, as_int16);
auto shifted = b.shl(i32_ty, as_int32, b.i32_val(16));
return (b.bitcast(shifted, f32_ty));
}

Value convertFp32ToBf16(Location loc, ConversionPatternRewriter &rewriter,
Value v, RoundingMode rounding) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
if (auto definingOp = v.getDefiningOp()) {
auto moduleOp = definingOp->getParentWithTrait<OpTrait::SymbolTable>();
if (moduleOp->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
getSupportBF16ConversionAttrName()) &&
rounding == RoundingMode::RTNE) {
// Intel SPIR-V extension only supports round-to-nearest-even
// LLVM fptrunc operation also assumes round-to-nearest mode
if (gpu::intel::hasSpirvTargetArch(moduleOp)) {
constexpr StringLiteral baseName = "__spirv_ConvertFToBF16INTEL";
Type inTy = v.getType();
Type funcOutTy = getTypeWithSameShape(inTy, i16_ty);
Type outTy = getTypeWithSameShape(inTy, bf16_ty);
std::string funcName = mlir::triton::gpu::intel::mangle(baseName, inTy);

auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::NoModRef,
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs;
funcAttrs.memEffectsAttr = memAttr;

auto call = gpu::intel::createDeviceFunctionCall(
rewriter, funcName, funcOutTy, {inTy}, {v}, {}, funcAttrs);
return b.bitcast(call.getResult(), outTy);
}

TritonLLVMIRRewriter b(loc, rewriter);
// Intel SPIR-V extension only supports round-to-nearest-even
// LLVM fptrunc operation also assumes round-to-nearest mode
if (rounding == RoundingMode::RTNE) {
std::string attrName = "__spirv_ConvertFToBF16INTEL";
auto result = convertWithFunctionCall(
b, v, attrName, f32_ty, i16_ty,
TritonIntelGPUDialect::getSupportBF16ConversionAttrName());
if (result)
return b.bitcast(result, bf16_ty);

auto op = v.getDefiningOp();
if (mlir::LLVM::intel::hasModuleAttr(op, attrName))
return LLVM::FPTruncOp::create(rewriter, loc, bf16_ty, v);
}
}

assert(!isa<VectorType>(v.getType()) && "Not yet supported");
Expand Down
47 changes: 47 additions & 0 deletions third_party/intel/lib/TritonIntelGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
//===----------------------------------------------------------------------===//

#include "Utility.h"
#include "Utils/LLVMIntr.h"
#include "Utils/Mangling.h"

#include "llvm/ADT/TypeSwitch.h"

#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"

#include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h"

using namespace mlir;
using namespace mlir::triton;

Expand Down Expand Up @@ -168,4 +174,45 @@ convertTritonRoundingModeToLLVM(const triton::RoundingMode rounding) {
}
}

Type getTypeWithSameShape(Type type, Type elementType) {
return TypeSwitch<Type, Type>(type)
.Case([elementType](VectorType type) {
return VectorType::get(type.getShape(), elementType,
type.getScalableDims());
})
.Default(elementType);
}

bool hasModuleAttr(Operation *op, StringRef attrName) {
auto mod = op->getParentOfType<ModuleOp>();
return mod && mod->hasAttr(attrName);
}

} // namespace mlir::LLVM::intel

namespace mlir::triton::intel {
Value convertWithFunctionCall(TritonLLVMIRRewriter &rewriter, Value value,
StringRef baseName, Type inType, Type outType,
StringRef hasAttrName) {
auto op = value.getDefiningOp();
if (!gpu::intel::hasSpirvTargetArch(op))
return {};
if (!hasAttrName.empty() &&
!mlir::LLVM::intel::hasModuleAttr(op, hasAttrName))
return {};

auto valueType = value.getType();
Type inTy = mlir::LLVM::intel::getTypeWithSameShape(valueType, inType);
Type outTy = mlir::LLVM::intel::getTypeWithSameShape(valueType, outType);
std::string funcName = mlir::triton::gpu::intel::mangle(baseName, inTy);
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::NoModRef,
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs;
funcAttrs.memEffectsAttr = memAttr;
return gpu::intel::createDeviceFunctionCall(rewriter, funcName, outTy, {inTy},
{value}, {}, funcAttrs)
.getResult();
}
} // namespace mlir::triton::intel
8 changes: 8 additions & 0 deletions third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ Block &createPredicatedBlock(RewriterBase &rewriter, Location loc, Value cond,
LLVM::RoundingMode
convertTritonRoundingModeToLLVM(const triton::RoundingMode rounding);

Type getTypeWithSameShape(Type type, Type elementType);

bool hasModuleAttr(Operation *op, StringRef attrName);

} // namespace mlir::LLVM::intel

namespace mlir::triton::intel {
Expand All @@ -88,6 +92,10 @@ Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
Value convertFp32ToBf16(Location loc, ConversionPatternRewriter &rewriter,
Value v, RoundingMode rounding);

Value convertWithFunctionCall(TritonLLVMIRRewriter &rewriter, Value value,
StringRef baseName, Type inType, Type outType,
StringRef hasAttrName = {});

} // namespace mlir::triton::intel

#endif // TRITON_CONVERSION_TRITONINTELGPU_TO_LLVM_UTILITY_H
Loading