diff --git a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt index 0fe01824b8248..bbe8e4eb892dd 100644 --- a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt @@ -3,3 +3,5 @@ add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector) add_mlir_interface(X86VectorInterfaces) add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen) + +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..6f377e10fa8f8 --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td) +mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls) +mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs) +add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen) diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h new file mode 100644 index 0000000000000..e1d8b8762e799 --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h @@ -0,0 +1,31 @@ +//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H +#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H + +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +//===----------------------------------------------------------------------===// +// X86Vector Transform Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace x86vector { +void registerTransformDialectExtension(DialectRegistry ®istry); + +} // namespace x86vector +} // namespace mlir + +#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td new file mode 100644 index 0000000000000..4009a140bb097 --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td @@ -0,0 +1,42 @@ +//===- X86VectorTransformOps.td - X86Vector transform ops ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef X86VECTOR_TRANSFORM_OPS +#define X86VECTOR_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/IR/RegionKindInterface.td" + +def ApplyVectorContractToFMAPatternsOp : Op]> { + let description = [{ + Indicates that vector contract operation can be lowered to a FMA. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op]> { + let description = [{ + Indicates that vector contract operation can be lowered to a BF16/Int8 dot-product. + }]; + + let assemblyFormat = "attr-dict"; +} + + +#endif // X86VECTOR_TRANSFORM_OPS + diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h index d54111ca41e69..943d7182d1960 100644 --- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h +++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h @@ -11,6 +11,10 @@ #include "mlir/IR/Value.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" + namespace mlir { class ImplicitLocOpBuilder; @@ -79,6 +83,13 @@ struct MaskHelper { } }; +//===----------------------------------------------------------------------===// + +void populateVectorContractToFMAPatterns(RewritePatternSet &patterns); + +void populateVectorContractToPackedTypeDotProductPatterns( + RewritePatternSet &patterns); + //===----------------------------------------------------------------------===// /// Helpers extracted from: /// - clang/lib/Headers/avxintrin.h diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt index 9f57627c321fb..cb1e9d01821a2 100644 --- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..f4c9f8a05acbc --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRX86VectorTransformOps + X86VectorTransformOps.cpp + + DEPENDS + MLIRX86VectorTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRVectorDialect + MLIRSideEffectInterfaces + MLIRTransformDialect + MLIRTransformDialectUtils + MLIRX86VectorDialect + MLIRX86VectorTransforms + ) diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp new file mode 100644 index 0000000000000..68d577326a308 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp @@ -0,0 +1,65 @@ +//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops +//--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/RegionKindInterface.h" + +using namespace mlir; +using namespace mlir::x86vector; +using namespace mlir::transform; + +void mlir::transform::ApplyVectorContractToFMAPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + x86vector::populateVectorContractToFMAPatterns(patterns); +} + +void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp:: + populatePatterns(RewritePatternSet &patterns) { + x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class X86VectorTransformDialectExtension + : public transform::TransformDialectExtension< + X86VectorTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + X86VectorTransformDialectExtension) + + X86VectorTransformDialectExtension() { + declareGeneratedDialect(); + declareGeneratedDialect(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc" + +void mlir::x86vector::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt index c51266afe9e8f..3d2288049e49e 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt @@ -1,6 +1,8 @@ add_mlir_dialect_library(MLIRX86VectorTransforms AVXTranspose.cpp LegalizeForLLVMExport.cpp + VectorContractToFMA.cpp + VectorContractToPackedTypeDotProduct.cpp LINK_LIBS PUBLIC MLIRArithDialect diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp new file mode 100644 index 0000000000000..764ec46681094 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp @@ -0,0 +1,99 @@ +//===- VectorContractToFMA.cpp --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::x86vector; + +// Implements outer product contraction as a sequence of broadcast and +// FMA operations. +// +// For example - for F32 type: +// ``` +// vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32> +// ``` +// to +// ``` +// vector.broadcast %lhs to <16xf32> +// vector.fma vector<16xf32> +// ``` +struct VectorContractToFMA : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + + if (contractOp.getKind() != vector::CombiningKind::ADD) { + return rewriter.notifyMatchFailure(contractOp, + "Expects add combining kind"); + } + + VectorType lhsTy = contractOp.getLhsType(); + if (!lhsTy.getElementType().isF32()) + return rewriter.notifyMatchFailure(contractOp, + "Only F32 lowering is supported."); + if (llvm::any_of(lhsTy.getShape(), [](int64_t dim) { return dim != 1; })) + return rewriter.notifyMatchFailure( + contractOp, "Expects one for all dimensions of LHS"); + + VectorType rhsTy = contractOp.getRhsType(); + ArrayRef rhsShape = rhsTy.getShape(); + llvm::SmallVector dimsRhs; + llvm::copy_if(rhsShape, std::back_inserter(dimsRhs), + [](int64_t dim) { return dim != 1; }); + if (dimsRhs.size() != 1) + return rewriter.notifyMatchFailure(contractOp, "Irregular RHS shape"); + + VectorType accTy = dyn_cast(contractOp.getAccType()); + assert(accTy && "Invalid accumulator"); + ArrayRef accShape = accTy.getShape(); + llvm::SmallVector dimsAcc; + llvm::copy_if(accShape, std::back_inserter(dimsAcc), + [](int64_t dim) { return dim != 1; }); + if (dimsAcc.size() != 1) + return rewriter.notifyMatchFailure(contractOp, "Irregular ACC shape"); + + // Lowers vector.contract into a broadcast+FMA sequence. + auto loc = contractOp.getLoc(); + auto castLhs = vector::ShapeCastOp::create( + rewriter, loc, VectorType::get(1, lhsTy.getElementType()), + contractOp.getLhs()); + auto castRhs = vector::ShapeCastOp::create( + rewriter, loc, VectorType::get(dimsRhs.front(), rhsTy.getElementType()), + contractOp.getRhs()); + auto castAcc = vector::ShapeCastOp::create( + rewriter, loc, VectorType::get(dimsAcc.front(), accTy.getElementType()), + contractOp.getAcc()); + auto broadcastLhs = vector::BroadcastOp::create( + rewriter, loc, castRhs.getResult().getType(), castLhs); + auto fma = + vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc); + auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma); + + rewriter.replaceOp(contractOp, castFma); + + return success(); + } +}; + +void x86vector::populateVectorContractToFMAPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp new file mode 100644 index 0000000000000..1dabbddbebb7e --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp @@ -0,0 +1,148 @@ +//===- VectorContractToPackedTypeDotProduct.cpp ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::x86vector; + +// Implements packed type outer product contraction as a sequence +// of broadcast and packed dot-product operations. +// +// For example - for F32 type: +// ``` +// vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32> +// ``` +// to +// ``` +// vector.broadcast %lhs to <32xbf16> +// x86vector.avx512.dot vector<32xbf16> -> vector<16xf32> +// ``` +struct VectorContractToPackedTypeDotProduct + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + + if (contractOp.getKind() != vector::CombiningKind::ADD) { + return rewriter.notifyMatchFailure(contractOp, + "Expects add combining kind"); + } + + VectorType lhsTy = contractOp.getLhsType(); + if (!lhsTy.getElementType().isBF16() && + !lhsTy.getElementType().isSignlessInteger(8)) + return rewriter.notifyMatchFailure( + contractOp, "Only BF16/Int8 lowering is supported."); + ArrayRef lhsShape = lhsTy.getShape(); + if (lhsTy.getElementType().isBF16() && lhsShape.back() != 2) + return rewriter.notifyMatchFailure( + contractOp, "The LHS vnni dim should be 2 for BF16."); + + if (lhsTy.getElementType().isSignlessInteger(8) && lhsShape.back() != 4) + return rewriter.notifyMatchFailure( + contractOp, "The LHS vnni dim should be 4 for Int8."); + llvm::SmallVector dimsLhs; + llvm::copy_if(lhsShape, std::back_inserter(dimsLhs), + [](int64_t dim) { return dim != 1; }); + if (dimsLhs.size() != 1) + return rewriter.notifyMatchFailure(contractOp, "Irregular LHS shape"); + + VectorType rhsTy = contractOp.getRhsType(); + ArrayRef rhsShape = rhsTy.getShape(); + if (lhsTy.getElementType().isBF16() && rhsShape.back() != 2) + return rewriter.notifyMatchFailure( + contractOp, "The RHS vnni dim should be 2 for BF16."); + if (lhsTy.getElementType().isSignlessInteger(8) && rhsShape.back() != 4) + return rewriter.notifyMatchFailure( + contractOp, "The RHS vnni dim should be 4 for Int8."); + llvm::SmallVector dimsRhs; + llvm::copy_if(rhsShape, std::back_inserter(dimsRhs), + [](int64_t dim) { return dim != 1; }); + if (dimsRhs.size() != 2) + return rewriter.notifyMatchFailure(contractOp, "Irregular RHS shape"); + + VectorType accTy = dyn_cast(contractOp.getAccType()); + assert(accTy && "Invalid accumulator"); + if (!accTy.getElementType().isF32() && + !accTy.getElementType().isSignlessInteger(32)) + return rewriter.notifyMatchFailure( + contractOp, "Only F32/Int32 accumulation is supported."); + ArrayRef accShape = accTy.getShape(); + llvm::SmallVector dimsAcc; + llvm::copy_if(accShape, std::back_inserter(dimsAcc), + [](int64_t dim) { return dim != 1; }); + if (dimsAcc.size() != 1) + return rewriter.notifyMatchFailure(contractOp, "Irregular ACC shape"); + + auto loc = contractOp.getLoc(); + auto castRhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(dimsRhs.front() * dimsRhs.back(), + rhsTy.getElementType()), + contractOp.getRhs()); + + auto castAcc = vector::ShapeCastOp::create( + rewriter, loc, VectorType::get(dimsAcc.front(), accTy.getElementType()), + contractOp.getAcc()); + + auto castLhs = vector::ShapeCastOp::create( + rewriter, loc, VectorType::get(dimsLhs.front(), lhsTy.getElementType()), + contractOp.getLhs()); + auto bitcastLhs = vector::BitCastOp::create( + rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)), + castLhs); + auto broadcastLhs = vector::BroadcastOp::create( + rewriter, loc, + VectorType::get({dimsRhs.front()}, rewriter.getIntegerType(32)), + bitcastLhs); + auto bitcastLhsPkType = vector::BitCastOp::create( + rewriter, loc, castRhs.getResult().getType(), broadcastLhs); + + Value dp; + + if (lhsTy.getElementType().isBF16()) { + dp = x86vector::DotBF16Op::create( + rewriter, loc, + VectorType::get(dimsRhs.front(), rewriter.getF32Type()), castAcc, + bitcastLhsPkType, castRhs); + } + + if (lhsTy.getElementType().isSignlessInteger(8)) { + dp = x86vector::DotInt8Op::create( + rewriter, loc, + VectorType::get(dimsRhs.front(), rewriter.getIntegerType(32)), + castAcc, bitcastLhsPkType, castRhs); + } + + if (dp) { + auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp); + rewriter.replaceOp(contractOp, castDp); + return success(); + } + + return failure(); + } +}; + +void x86vector::populateVectorContractToPackedTypeDotProductPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp index c857c38df717c..4312100a0c0b0 100644 --- a/mlir/lib/RegisterAllExtensions.cpp +++ b/mlir/lib/RegisterAllExtensions.cpp @@ -56,6 +56,7 @@ #include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h" #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" @@ -113,6 +114,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) { transform::registerSMTExtension(registry); transform::registerTuneExtension(registry); vector::registerTransformDialectExtension(registry); + x86vector::registerTransformDialectExtension(registry); xegpu::registerTransformDialectExtension(registry); arm_neon::registerTransformDialectExtension(registry); arm_sve::registerTransformDialectExtension(registry); diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir new file mode 100644 index 0000000000000..3a79037ca37c2 --- /dev/null +++ b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir @@ -0,0 +1,210 @@ +// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s + +!vecA = vector<1x1xf32> +!vecB = vector<1x64xf32> +!vecC = vector<1x64xf32> +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @matmul_outer_product_to_fma( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @matmul_outer_product_to_fma +// CHECK-COUNT-1: vector.shape_cast{{.*}}to vector<1xf32> +// CHECK-COUNT-2: vector.shape_cast{{.*}}to vector<64xf32> +// CHECK: vector.broadcast{{.*}}vector<1xf32> to vector<64xf32> +// CHECK: vector.fma{{.*}}vector<64xf32> +// CHECK: vector.shape_cast{{.*}}vector<64xf32> to vector<1x64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1xf32> +!vecB = vector<1x1x64xf32> +!vecC = vector<1x1x64xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +func.func @batch_matmul_to_fma( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @batch_matmul_to_fma +// CHECK: vector.broadcast +// CHECK: vector.fma{{.*}}vector<64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1xf32> +!vecB = vector<1x1x64xf32> +!vecC = vector<1x64xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +func.func @brgemm_to_fma( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @brgemm_to_fma +// CHECK: vector.broadcast +// CHECK: vector.fma{{.*}}vector<64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<3x1x1xf32> +!vecB = vector<3x1x64xf32> +!vecC = vector<3x1x64xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_non_unit_batch_dim( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// Batch dimension should've been simplified earlier. + +// CHECK-LABEL: @negative_non_unit_batch_dim +// CHECK-NOT: vector.fma +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<3x1x1xf32> +!vecB = vector<3x1x64xf32> +!vecC = vector<1x64xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +func.func @negative_non_unit_batch_reduce_dim( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// Batch-reduce dimension should've been simplified earlier. + +// CHECK-LABEL: @negative_non_unit_batch_reduce_dim +// CHECK-NOT: vector.fma +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1xf32> +!vecB = vector<1x64xf32> +!vecC = vector<1x64xf32> +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @negative_invalid_kind( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_invalid_kind +// CHECK-NOT: vector.fma +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir new file mode 100644 index 0000000000000..551f1f95ed9c0 --- /dev/null +++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir @@ -0,0 +1,374 @@ +// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s + +!vecA = vector<1x1x1x2xbf16> +!vecB = vector<1x1x16x2xbf16> +!vecC = vector<1x16xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @brgemm_to_bf16dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @brgemm_to_bf16dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx512.dot + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xi8> +!vecB = vector<1x1x8x4xi8> +!vecC = vector<1x8xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @brgemm_to_int8dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @brgemm_to_int8dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx.dot.i8 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x2xbf16> +!vecB = vector<1x1x16x2xbf16> +!vecC = vector<1x1x16xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @batch_matmul_bf16dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @batch_matmul_bf16dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx512.dot + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xi8> +!vecB = vector<1x1x8x4xi8> +!vecC = vector<1x1x8xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @batch_matmul_int8dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + + +// CHECK-LABEL: @batch_matmul_int8dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx.dot.i8 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x2xbf16> +!vecB = vector<1x16x2xbf16> +!vecC = vector<1x16xf32> +#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)> +#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)> +#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)> +func.func @matmul_outer_product_to_bf16dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @matmul_outer_product_to_bf16dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx512.dot + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x4xi8> +!vecB = vector<1x8x4xi8> +!vecC = vector<1x8xi32> +#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)> +#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)> +#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)> +func.func @matmul_outer_product_to_int8dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @matmul_outer_product_to_int8dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx.dot.i8 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x2xbf16> +!vecB = vector<1x16x2xbf16> +!vecC = vector<1x16xf32> +#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)> +#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)> +#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)> +func.func @negative_invalid_vc_kind( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_invalid_vc_kind +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xbf16> +!vecB = vector<1x1x16x4xbf16> +!vecC = vector<1x16xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @negative_false_vnni_bf16( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_false_vnni_bf16 +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x2xi8> +!vecB = vector<1x1x8x2xi8> +!vecC = vector<1x8xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @negative_false_vnni_int8( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_false_vnni_int8 +// CHECK-NOT: x86vector.avx.dot.i8 +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<3x1x1x2xbf16> +!vecB = vector<3x1x16x2xbf16> +!vecC = vector<3x1x16xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_batch_dimension( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_batch_dimension +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<2x1x1x4xi8> +!vecB = vector<2x1x8x4xi8> +!vecC = vector<1x8xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @negative_brgemm_dimension( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_brgemm_dimension +// CHECK-NOT: x86vector.avx.dot.i8 +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +}