diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h index ac86e8461d277..4c6dcbebb38df 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -34,6 +34,7 @@ class LLVMFuncOp; /// of the libc). LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index 33583d5b1f01c..7e2a0942f7a90 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -3791,3 +3791,71 @@ structured_op: !LinalgStructuredOpConfig scalar_const: '2.3283063999999999E-10 : f64' - !ScalarExpression scalar_arg: min +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: reduce_batch_matmul + cpp_class_name: ReduceBatchMatmulOp + doc: |- + Performs a batched matrix multiplication of two 3D inputs. + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + implements: + - LinalgContractionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: A + kind: input_tensor + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> + - !LinalgOperandDefConfig + name: B + kind: input_tensor + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)> + - !LinalgOperandDefConfig + name: C + kind: output_tensor + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)> + iterator_types: + - reduction + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: C + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: add + operands: + - !ScalarExpression + scalar_arg: C + - !ScalarExpression + scalar_fn: + kind: binary + fn_name: mul + operands: + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: A + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: B diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index a56b6b44e4657..12969e0e3e306 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -110,6 +110,11 @@ struct SCFTileAndFuseResult { SmallVector tiledAndFusedOps; SmallVector loops; }; + +using checkProducerFn = + std::function rootIterationDomain, + Operation *producer, OpBuilder &builder)>; + struct TileConsumerAndFuseProducersUsingSCFForOp : public OpInterfaceRewritePattern { @@ -127,7 +132,8 @@ struct TileConsumerAndFuseProducersUsingSCFForOp /// `matchAndRewrite` implementation that returns the significant transformed /// pieces of IR. FailureOr - returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const; + returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter, + checkProducerFn = nullptr) const; LogicalResult matchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const override { diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 9fa683a92e47d..76feacebbca44 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -60,6 +60,7 @@ class Builder { // Types. FloatType getBF16Type(); + FloatType getPackedBF16Type(); FloatType getF16Type(); FloatType getF32Type(); FloatType getF64Type(); diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index 96c7804d96495..12509298f4b39 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -41,6 +41,7 @@ class FloatType : public Type { // Convenience factories. static FloatType getBF16(MLIRContext *ctx); + static FloatType getPackedBF16(MLIRContext *ctx); static FloatType getF16(MLIRContext *ctx); static FloatType getF32(MLIRContext *ctx); static FloatType getF64(MLIRContext *ctx); @@ -374,13 +375,17 @@ inline bool BaseMemRefType::isValidElementType(Type type) { inline bool FloatType::classof(Type type) { return type.isa(); + Float80Type, Float128Type, PackedBF16Type>(); } inline FloatType FloatType::getBF16(MLIRContext *ctx) { return BFloat16Type::get(ctx); } +inline FloatType FloatType::getPackedBF16(MLIRContext *ctx) { + return PackedBF16Type::get(ctx); +} + inline FloatType FloatType::getF16(MLIRContext *ctx) { return Float16Type::get(ctx); } diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 55dc24134b451..676d9480e20d7 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -118,6 +118,14 @@ def Builtin_Float128 : Builtin_FloatType<"Float128"> { let summary = "128-bit floating-point type"; } +//===----------------------------------------------------------------------===// +// PackedBF16Type + +def Builtin_PackedBF16 : Builtin_FloatType<"PackedBF16"> { + let summary = "Packed BF16 format"; +} + + //===----------------------------------------------------------------------===// // FunctionType //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h index 5cac1e240d653..94546f29c240c 100644 --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -127,6 +127,7 @@ class Type { bool isF64() const; bool isF80() const; bool isF128() const; + bool isPackedBF16() const; /// Return true if this is an integer type with the specified width. bool isInteger(unsigned width) const; diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def index 94d2fd3687fc3..02c42b6109041 100644 --- a/mlir/lib/AsmParser/TokenKinds.def +++ b/mlir/lib/AsmParser/TokenKinds.def @@ -85,6 +85,7 @@ TOK_KEYWORD(affine_set) TOK_KEYWORD(array) TOK_KEYWORD(attributes) TOK_KEYWORD(bf16) +TOK_KEYWORD(pbf16) TOK_KEYWORD(ceildiv) TOK_KEYWORD(complex) TOK_KEYWORD(dense) diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp index 16da006809d29..88a51adec0542 100644 --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -31,6 +31,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) { case Token::kw_vector: case Token::inttype: case Token::kw_bf16: + case Token::kw_pbf16: case Token::kw_f16: case Token::kw_f32: case Token::kw_f64: @@ -249,7 +250,7 @@ Type Parser::parseMemRefType() { /// | none-type /// /// index-type ::= `index` -/// float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128` +/// float-type ::= `f16` | `bf16` | `pbf16` | `f32` | `f64` | `f80` | `f128` /// none-type ::= `none` /// Type Parser::parseNonFunctionType() { @@ -289,6 +290,9 @@ Type Parser::parseNonFunctionType() { case Token::kw_bf16: consumeToken(Token::kw_bf16); return builder.getBF16Type(); + case Token::kw_pbf16: + consumeToken(Token::kw_pbf16); + return builder.getPackedBF16Type(); case Token::kw_f16: consumeToken(Token::kw_f16); return builder.getF16Type(); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 26ea7f3631707..cdc3dc6ce330c 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1082,7 +1082,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { VectorType vectorType = printType.dyn_cast(); Type eltType = vectorType ? vectorType.getElementType() : printType; Operation *printer; - if (eltType.isF32()) { + if (eltType.isBF16()) { + printer = + LLVM::lookupOrCreatePrintBF16Fn(printOp->getParentOfType()); + } else if (eltType.isF32()) { printer = LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType()); } else if (eltType.isF64()) { diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index 409c513f69ddb..bc888eca23c71 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -26,6 +26,7 @@ using namespace mlir::LLVM; /// part of the libc). static constexpr llvm::StringRef kPrintI64 = "printI64"; static constexpr llvm::StringRef kPrintU64 = "printU64"; +static constexpr llvm::StringRef kPrintBF16 = "printBF16"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; static constexpr llvm::StringRef kPrintOpen = "printOpen"; @@ -66,6 +67,12 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) { LLVM::LLVMVoidType::get(moduleOp->getContext())); } +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, kPrintBF16, + FloatType::getBF16(moduleOp->getContext()), + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) { return lookupOrCreateFn(moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()), diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index 3f39a6e07748b..fe5a2164f012f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -765,6 +765,7 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) { Float64Type, Float80Type, Float128Type, + PackedBF16Type, LLVMArrayType, LLVMFunctionType, LLVMLabelType, @@ -865,8 +866,9 @@ bool mlir::LLVM::isCompatibleType(Type type) { } bool mlir::LLVM::isCompatibleFloatingPointType(Type type) { - return type.isa(); + return type + .isa(); } bool mlir::LLVM::isCompatibleVectorType(Type type) { @@ -880,7 +882,7 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) { if (auto intType = elementType.dyn_cast()) return intType.isSignless(); return elementType.isa(); + Float80Type, Float128Type, PackedBF16Type>(); } return false; } @@ -965,7 +967,7 @@ llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) { "expected a type compatible with the LLVM dialect"); return llvm::TypeSwitch(type) - .Case( + .Case( [](Type) { return llvm::TypeSize::Fixed(16); }) .Case([](Type) { return llvm::TypeSize::Fixed(32); }) .Case( diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 8e55ce9613baf..dd1d1ab080324 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -397,13 +397,14 @@ static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor, FailureOr scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite( - TilingInterface op, PatternRewriter &rewriter) const { + TilingInterface op, PatternRewriter &rewriter, checkProducerFn fn) const { // This transformation is only valid for ops that return values (i.e. not // valid to use with operations that have memref operands). if (!op->getNumResults()) { return rewriter.notifyMatchFailure( op, "invalid pattern for op with no results"); } + SmallVector iterationDomain = op.getIterationDomain(rewriter); // 1. First tile the consumer. SCFTileAndFuseResult tileAndFuseResult; @@ -446,6 +447,10 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite( if (!fusableProducer) continue; + if (fn && + failed(fn(iterationDomain, fusableProducer->getDefiningOp(), rewriter))) + continue; + // 2c. Generate the tiled implementation of the producer of the source rewriter.setInsertionPoint(candidateSliceOp); FailureOr fusedProducerValue = diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 9cf9501850a6b..409918c44ac89 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2078,6 +2078,7 @@ void AsmPrinter::Impl::printType(Type type) { }) .Case([&](Type) { os << "index"; }) .Case([&](Type) { os << "bf16"; }) + .Case([&](Type) { os << "pbf16"; }) .Case([&](Type) { os << "f16"; }) .Case([&](Type) { os << "f32"; }) .Case([&](Type) { os << "f64"; }) diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 50f84a6998222..37b1876ea9fcb 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -35,6 +35,10 @@ Location Builder::getFusedLoc(ArrayRef locs, Attribute metadata) { FloatType Builder::getBF16Type() { return FloatType::getBF16(context); } +FloatType Builder::getPackedBF16Type() { + return FloatType::getPackedBF16(context); +} + FloatType Builder::getF16Type() { return FloatType::getF16(context); } FloatType Builder::getF32Type() { return FloatType::getF32(context); } diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp index ce6120c77ffbf..c04a774a6865e 100644 --- a/mlir/lib/IR/BuiltinDialectBytecode.cpp +++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp @@ -146,36 +146,41 @@ enum TypeCode { /// kBFloat16Type = 3, + /// PackedBF16Type { + /// } + /// + kPackedBF16Type = 4, + /// Float16Type { /// } /// - kFloat16Type = 4, + kFloat16Type = 5, /// Float32Type { /// } /// - kFloat32Type = 5, + kFloat32Type = 6, /// Float64Type { /// } /// - kFloat64Type = 6, + kFloat64Type = 7, /// Float80Type { /// } /// - kFloat80Type = 7, + kFloat80Type = 8, /// Float128Type { /// } /// - kFloat128Type = 8, + kFloat128Type = 9, /// ComplexType { /// elementType: Type /// } /// - kComplexType = 9, + kComplexType = 10, /// MemRefType { /// shape: svarint[], @@ -183,7 +188,7 @@ enum TypeCode { /// layout: Attribute /// } /// - kMemRefType = 10, + kMemRefType = 11, /// MemRefTypeWithMemSpace { /// memorySpace: Attribute, @@ -192,19 +197,19 @@ enum TypeCode { /// layout: Attribute /// } /// Variant of MemRefType with non-default memory space. - kMemRefTypeWithMemSpace = 11, + kMemRefTypeWithMemSpace = 12, /// NoneType { /// } /// - kNoneType = 12, + kNoneType = 13, /// RankedTensorType { /// shape: svarint[], /// elementType: Type, /// } /// - kRankedTensorType = 13, + kRankedTensorType = 14, /// RankedTensorTypeWithEncoding { /// encoding: Attribute, @@ -212,38 +217,38 @@ enum TypeCode { /// elementType: Type /// } /// Variant of RankedTensorType with an encoding. - kRankedTensorTypeWithEncoding = 14, + kRankedTensorTypeWithEncoding = 15, /// TupleType { /// elementTypes: Type[] /// } - kTupleType = 15, + kTupleType = 16, /// UnrankedMemRefType { /// shape: svarint[] /// } /// - kUnrankedMemRefType = 16, + kUnrankedMemRefType = 17, /// UnrankedMemRefTypeWithMemSpace { /// memorySpace: Attribute, /// shape: svarint[] /// } /// Variant of UnrankedMemRefType with non-default memory space. - kUnrankedMemRefTypeWithMemSpace = 17, + kUnrankedMemRefTypeWithMemSpace = 18, /// UnrankedTensorType { /// elementType: Type /// } /// - kUnrankedTensorType = 18, + kUnrankedTensorType = 19, /// VectorType { /// shape: svarint[], /// elementType: Type /// } /// - kVectorType = 19, + kVectorType = 20, /// VectorTypeWithScalableDims { /// numScalableDims: varint, @@ -251,7 +256,7 @@ enum TypeCode { /// elementType: Type /// } /// Variant of VectorType with scalable dimensions. - kVectorTypeWithScalableDims = 20, + kVectorTypeWithScalableDims = 21, }; } // namespace builtin_encoding @@ -711,6 +716,8 @@ Type BuiltinDialectBytecodeInterface::readType( return readFunctionType(reader); case builtin_encoding::kBFloat16Type: return BFloat16Type::get(getContext()); + case builtin_encoding::kPackedBF16Type: + return PackedBF16Type::get(getContext()); case builtin_encoding::kFloat16Type: return Float16Type::get(getContext()); case builtin_encoding::kFloat32Type: @@ -767,6 +774,9 @@ LogicalResult BuiltinDialectBytecodeInterface::writeType( .Case([&](BFloat16Type) { return writer.writeVarInt(builtin_encoding::kBFloat16Type), success(); }) + .Case([&](PackedBF16Type) { + return writer.writeVarInt(builtin_encoding::kPackedBF16Type), success(); + }) .Case([&](Float16Type) { return writer.writeVarInt(builtin_encoding::kFloat16Type), success(); }) diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index f53a94146efcb..67410497eb383 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -98,6 +98,8 @@ unsigned FloatType::getWidth() { return 80; if (isa()) return 128; + if (isa()) + return 16; llvm_unreachable("unexpected float type"); } @@ -115,6 +117,8 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() { return APFloat::x87DoubleExtended(); if (isa()) return APFloat::IEEEquad(); + if (isa()) + return APFloat::BFloat(); llvm_unreachable("non-floating point type used"); } @@ -122,7 +126,7 @@ FloatType FloatType::scaleElementBitwidth(unsigned scale) { if (!scale) return FloatType(); MLIRContext *ctx = getContext(); - if (isF16() || isBF16()) { + if (isF16() || isBF16() || isPackedBF16()) { if (scale == 2) return FloatType::getF32(ctx); if (scale == 4) diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 2e00bd4778c30..7e7e50c1d64c9 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -207,6 +207,7 @@ class MLIRContextImpl { /// Cached Type Instances. BFloat16Type bf16Ty; + PackedBF16Type packedBf16Ty; Float16Type f16Ty; Float32Type f32Ty; Float64Type f64Ty; @@ -277,6 +278,7 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) //// Types. /// Floating-point Types. impl->bf16Ty = TypeUniquer::get(this); + impl->packedBf16Ty = TypeUniquer::get(this); impl->f16Ty = TypeUniquer::get(this); impl->f32Ty = TypeUniquer::get(this); impl->f64Ty = TypeUniquer::get(this); @@ -808,6 +810,11 @@ StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } BFloat16Type BFloat16Type::get(MLIRContext *context) { return context->getImpl().bf16Ty; } + +PackedBF16Type PackedBF16Type::get(MLIRContext *context) { + return context->getImpl().packedBf16Ty; +} + Float16Type Float16Type::get(MLIRContext *context) { return context->getImpl().f16Ty; } diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp index defe2dacfac29..700536baf0209 100644 --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -24,6 +24,7 @@ bool Type::isF32() const { return isa(); } bool Type::isF64() const { return isa(); } bool Type::isF80() const { return isa(); } bool Type::isF128() const { return isa(); } +bool Type::isPackedBF16() const { return isa(); } bool Type::isIndex() const { return isa(); } diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp index a65757a1bbd5f..63508c482a5ea 100644 --- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp +++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp @@ -44,6 +44,9 @@ class TypeToLLVMIRTranslatorImpl { .Case([this](BFloat16Type) { return llvm::Type::getBFloatTy(context); }) + .Case([this](PackedBF16Type) { + return llvm::Type::getBFloatTy(context); + }) .Case( [this](Float32Type) { return llvm::Type::getFloatTy(context); }) .Case([this](Float64Type) { diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir index 6512847c8530d..ad0c828511f46 100644 --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -762,3 +762,11 @@ func.func @conv_interface_wrong_num_operands( }) {dilations = dense<1> : tensor<2xi64>, linalg.memoized_indexing_maps = [#map0, #map1, #map2], operand_segment_sizes = array, strides = dense<1> : tensor<2xi64>} : (tensor, tensor, tensor) -> tensor return %0 : tensor } + +// ----- + +func.func @brgemm_test(%arg0: tensor<8x128x256xf32>, %arg1: tensor<8x256x512xf32>, %arg2: tensor<128x512xf32>) -> tensor<128x512xf32> { + // CHECK: linalg.reduce_batch_matmul + %0 = linalg.reduce_batch_matmul ins(%arg0, %arg1 : tensor<8x128x256xf32>, tensor<8x256x512xf32>) outs(%arg2: tensor<128x512xf32>) -> tensor<128x512xf32> + return %0: tensor<128x512xf32> +}