|
26 | 26 | #include "mlir/Dialect/Math/Transforms/Passes.h" |
27 | 27 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
28 | 28 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
| 29 | +#include "mlir/IR/BuiltinTypes.h" |
29 | 30 | #include "mlir/Transforms/DialectConversion.h" |
30 | 31 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
31 | 32 | #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" |
@@ -56,7 +57,7 @@ void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); } |
56 | 57 | static bool isArgMemRefTypeValid(Type type) { |
57 | 58 | if (auto memRefType = dyn_cast<MemRefType>(type)) { |
58 | 59 | Type elemTy = memRefType.getElementType(); |
59 | | - if (isa<Float16Type, Float32Type, Float64Type>(elemTy)) { |
| 60 | + if (isa<Float16Type, Float32Type, Float64Type, BFloat16Type>(elemTy)) { |
60 | 61 | return true; |
61 | 62 | } else if (auto integerTy = dyn_cast<IntegerType>(elemTy)) { |
62 | 63 | if (integerTy.isSignlessInteger(64)) |
@@ -90,6 +91,8 @@ static Type getAbiTypeForMemRef(Type type) { |
90 | 91 | static std::string getTypeToken(Type type) { |
91 | 92 | if (type.isSignlessInteger()) |
92 | 93 | return ("i" + Twine(type.getIntOrFloatBitWidth())).str(); |
| 94 | + else if (isa<mlir::BFloat16Type>(type)) // only the 16 size exists in mlir |
| 95 | + return "bf16"; |
93 | 96 | else if (isa<mlir::FloatType>(type)) |
94 | 97 | return ("f" + Twine(type.getIntOrFloatBitWidth())).str(); |
95 | 98 | else if (auto complexTy = dyn_cast<mlir::ComplexType>(type)) |
@@ -150,10 +153,10 @@ static LogicalResult mungeFunction( |
150 | 153 | auto type = arg.getType(); |
151 | 154 | if (!isArgMemRefTypeValid(type)) { |
152 | 155 | return emitError(arg.getLoc()) |
153 | | - .append("argument must be a memref of f32, f64, i32, i64, i8, i1, " |
154 | | - "c32, c64, but " |
155 | | - "got ", |
156 | | - type); |
| 156 | + .append( |
| 157 | + "argument must be a memref of f32, f64, bf16, i32, i64, i8, i1, " |
| 158 | + "c32, c64, but got ", |
| 159 | + type); |
157 | 160 | } |
158 | 161 | auto cast = b.create<memref::CastOp>(arg.getLoc(), type, arg); |
159 | 162 | arg.replaceAllUsesExcept(cast, cast); |
|
0 commit comments