Skip to content

Commit 797258f

Browse files
committed
enable printing bf16 - TBD
Bf16 seems not well supported. On X86 is a promotion to f32. See: https://reviews.llvm.org/rGfb34d531af953119593be74753b89baf99fbc194
1 parent f0c82ea commit 797258f

File tree

3 files changed

+12
-1
lines changed

3 files changed

+12
-1
lines changed

mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class LLVMFuncOp;
3434
/// of the libc).
3535
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
3636
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
37+
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp);
3738
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
3839
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
3940
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1082,7 +1082,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
10821082
VectorType vectorType = printType.dyn_cast<VectorType>();
10831083
Type eltType = vectorType ? vectorType.getElementType() : printType;
10841084
Operation *printer;
1085-
if (eltType.isF32()) {
1085+
if (eltType.isBF16()) {
1086+
printer =
1087+
LLVM::lookupOrCreatePrintBF16Fn(printOp->getParentOfType<ModuleOp>());
1088+
} else if (eltType.isF32()) {
10861089
printer =
10871090
LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
10881091
} else if (eltType.isF64()) {

mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ using namespace mlir::LLVM;
2626
/// part of the libc).
2727
static constexpr llvm::StringRef kPrintI64 = "printI64";
2828
static constexpr llvm::StringRef kPrintU64 = "printU64";
29+
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
2930
static constexpr llvm::StringRef kPrintF32 = "printF32";
3031
static constexpr llvm::StringRef kPrintF64 = "printF64";
3132
static constexpr llvm::StringRef kPrintOpen = "printOpen";
@@ -66,6 +67,12 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
6667
LLVM::LLVMVoidType::get(moduleOp->getContext()));
6768
}
6869

70+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) {
71+
return lookupOrCreateFn(moduleOp, kPrintBF16,
72+
FloatType::getBF16(moduleOp->getContext()),
73+
LLVM::LLVMVoidType::get(moduleOp->getContext()));
74+
}
75+
6976
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
7077
return lookupOrCreateFn(moduleOp, kPrintF32,
7178
Float32Type::get(moduleOp->getContext()),

0 commit comments

Comments
 (0)