diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index d40d02d43ffc..6470e25ea30b 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -26,6 +26,7 @@ #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" @@ -56,7 +57,7 @@ void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); } static bool isArgMemRefTypeValid(Type type) { if (auto memRefType = dyn_cast(type)) { Type elemTy = memRefType.getElementType(); - if (isa(elemTy)) { + if (isa(elemTy)) { return true; } else if (auto integerTy = dyn_cast(elemTy)) { if (integerTy.isSignlessInteger(64)) @@ -90,6 +91,8 @@ static Type getAbiTypeForMemRef(Type type) { static std::string getTypeToken(Type type) { if (type.isSignlessInteger()) return ("i" + Twine(type.getIntOrFloatBitWidth())).str(); + else if (isa(type)) // only the 16 size exists in mlir + return "bf16"; else if (isa(type)) return ("f" + Twine(type.getIntOrFloatBitWidth())).str(); else if (auto complexTy = dyn_cast(type)) @@ -150,10 +153,10 @@ static LogicalResult mungeFunction( auto type = arg.getType(); if (!isArgMemRefTypeValid(type)) { return emitError(arg.getLoc()) - .append("argument must be a memref of f32, f64, i32, i64, i8, i1, " - "c32, c64, but " - "got ", - type); + .append( + "argument must be a memref of f32, f64, bf16, i32, i64, i8, i1, " + "c32, c64, but got ", + type); } auto cast = b.create(arg.getLoc(), type, arg); arg.replaceAllUsesExcept(cast, cast);