From 0f26443ea3c92f80801d80708dfe8c9d040df74a Mon Sep 17 00:00:00 2001 From: Grigory Vartanyan Date: Fri, 22 Aug 2025 09:09:02 -0600 Subject: [PATCH] fix: add bf16 to munge type detection --- lib/RefBackend/RefBackend.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index d40d02d43ffc..6470e25ea30b 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -26,6 +26,7 @@ #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" @@ -56,7 +57,7 @@ void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); } static bool isArgMemRefTypeValid(Type type) { if (auto memRefType = dyn_cast(type)) { Type elemTy = memRefType.getElementType(); - if (isa(elemTy)) { + if (isa(elemTy)) { return true; } else if (auto integerTy = dyn_cast(elemTy)) { if (integerTy.isSignlessInteger(64)) @@ -90,6 +91,8 @@ static Type getAbiTypeForMemRef(Type type) { static std::string getTypeToken(Type type) { if (type.isSignlessInteger()) return ("i" + Twine(type.getIntOrFloatBitWidth())).str(); + else if (isa(type)) // only the 16 size exists in mlir + return "bf16"; else if (isa(type)) return ("f" + Twine(type.getIntOrFloatBitWidth())).str(); else if (auto complexTy = dyn_cast(type)) @@ -150,10 +153,10 @@ static LogicalResult mungeFunction( auto type = arg.getType(); if (!isArgMemRefTypeValid(type)) { return emitError(arg.getLoc()) - .append("argument must be a memref of f32, f64, i32, i64, i8, i1, " - "c32, c64, but " - "got ", - type); + .append( + "argument must be a memref of f32, f64, bf16, i32, i64, i8, i1, " + "c32, c64, but got ", + type); } auto cast = b.create(arg.getLoc(), type, arg); arg.replaceAllUsesExcept(cast, cast);