Skip to content

Commit 0f26443

Browse files
committed
fix: add bf16 to munge type detection
1 parent 1615cd9 commit 0f26443

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

lib/RefBackend/RefBackend.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/Dialect/Math/Transforms/Passes.h"
2727
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2828
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
29+
#include "mlir/IR/BuiltinTypes.h"
2930
#include "mlir/Transforms/DialectConversion.h"
3031
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3132
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
@@ -56,7 +57,7 @@ void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); }
5657
static bool isArgMemRefTypeValid(Type type) {
5758
if (auto memRefType = dyn_cast<MemRefType>(type)) {
5859
Type elemTy = memRefType.getElementType();
59-
if (isa<Float16Type, Float32Type, Float64Type>(elemTy)) {
60+
if (isa<Float16Type, Float32Type, Float64Type, BFloat16Type>(elemTy)) {
6061
return true;
6162
} else if (auto integerTy = dyn_cast<IntegerType>(elemTy)) {
6263
if (integerTy.isSignlessInteger(64))
@@ -90,6 +91,8 @@ static Type getAbiTypeForMemRef(Type type) {
9091
static std::string getTypeToken(Type type) {
9192
if (type.isSignlessInteger())
9293
return ("i" + Twine(type.getIntOrFloatBitWidth())).str();
94+
else if (isa<mlir::BFloat16Type>(type)) // only the 16 size exists in mlir
95+
return "bf16";
9396
else if (isa<mlir::FloatType>(type))
9497
return ("f" + Twine(type.getIntOrFloatBitWidth())).str();
9598
else if (auto complexTy = dyn_cast<mlir::ComplexType>(type))
@@ -150,10 +153,10 @@ static LogicalResult mungeFunction(
150153
auto type = arg.getType();
151154
if (!isArgMemRefTypeValid(type)) {
152155
return emitError(arg.getLoc())
153-
.append("argument must be a memref of f32, f64, i32, i64, i8, i1, "
154-
"c32, c64, but "
155-
"got ",
156-
type);
156+
.append(
157+
"argument must be a memref of f32, f64, bf16, i32, i64, i8, i1, "
158+
"c32, c64, but got ",
159+
type);
157160
}
158161
auto cast = b.create<memref::CastOp>(arg.getLoc(), type, arg);
159162
arg.replaceAllUsesExcept(cast, cast);

0 commit comments

Comments
 (0)