diff --git a/include/gc/Analysis/MatmulConfigAnalysis.h b/include/gc/Analysis/MatmulConfigAnalysis.h index 4cb19ac23..2b275f246 100644 --- a/include/gc/Analysis/MatmulConfigAnalysis.h +++ b/include/gc/Analysis/MatmulConfigAnalysis.h @@ -109,16 +109,14 @@ getOprandDimType(linalg::LinalgOp &linalgOp) { if (llvm::isa(linalgOp.getOperation())) { return getContractionOpOperandDimType(linalgOp); } else if (linalgx::isGenericPackedMatmulOp( - linalgOp.getOperation(), linalgx::PackingType::VNNI_MM2D) || - llvm::isa(linalgOp)) { + linalgOp.getOperation(), linalgx::PackingType::VNNI_MM2D)) { return SmallVector>{ SmallVector{DimType::M, DimType::K}, SmallVector{DimType::N, DimType::K, DimType::K, DimType::N, DimType::K}, SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; } else if (linalgx::isGenericPackedMatmulOp( - linalgOp.getOperation(), linalgx::PackingType::VNNI_MM4D) || - llvm::isa(linalgOp)) { + linalgOp.getOperation(), linalgx::PackingType::VNNI_MM4D)) { return SmallVector>{ SmallVector{DimType::M, DimType::K, DimType::M, DimType::K}, SmallVector{DimType::N, DimType::K, DimType::K, DimType::N, diff --git a/include/gc/Dialect/Linalgx/LinalgxStructuredOps.td b/include/gc/Dialect/Linalgx/LinalgxStructuredOps.td index dee5eef74..5e24a05b6 100644 --- a/include/gc/Dialect/Linalgx/LinalgxStructuredOps.td +++ b/include/gc/Dialect/Linalgx/LinalgxStructuredOps.td @@ -104,212 +104,4 @@ def Linalgx_SigmoidOp : LinalgxStructuredBase_Op<"sigmoid", }]; } -def Linalgx_Mm2DVnniOp - : LinalgxStructuredBase_Op<"mm2d_vnni", [AttrSizedOperandSegments]> { - let summary = "Transposed matmul with 2d input and vnni packed weights"; - let description = [{ - Supported format: A[M, K] * B[N0, K0, k, n, v] -> C[M, N], with: - N = N0 * n - K = K0 * k * v; v = (2, 4) - }]; - let arguments = (ins - Variadic:$inputs, - Variadic:$outputs); - let results = (outs Variadic:$results); - let regions = (region AnyRegion:$region); - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder< - (ins - "TypeRange":$resultTensorTypes, - "ValueRange":$inputs, - "ValueRange":$outputs, - CArg<"ArrayRef", "{}">:$attributes), - [{ - buildStructuredOp($_builder, $_state, resultTensorTypes, - inputs, outputs, attributes, Mm2DVnniOp::getRegionBuilder()); - }]> - ]; - - let hasCustomAssemblyFormat = 1; - let hasFolder = 1; - let hasVerifier = 1; - - let extraClassDeclaration = structuredOpsBaseDecls # [{ - // Declare functions necessary for LinalgStructuredInterface. - SmallVector getIteratorTypesArray(); - ArrayAttr getIndexingMaps(); - static unsigned getNumRegionArgs() { return 3; } - std::string getLibraryCallName() { - return "op_has_no_registered_library_name"; - } - - // Implement functions necessary for DestinationStyleOpInterface. - MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } - - static void regionBuilder(ImplicitLocOpBuilder &b, - Block &block, ArrayRef attrs); - static std::function)> - getRegionBuilder() { - return regionBuilder; - } - }]; -} - -def Linalgx_Mm4DVnniOp - : LinalgxStructuredBase_Op<"mm4d_vnni", [AttrSizedOperandSegments]> { - let summary = "Transposed matmul with 4d blocking input and vnni packed weights"; - let description = [{ - Supported format: A[M, K, m, k] * B[N, K, k0, n, v] -> C[M, N, m, n], with: - k = k0 * v; v = (2, 4) - }]; - let arguments = (ins - Variadic:$inputs, - Variadic:$outputs); - let results = (outs Variadic:$results); - let regions = (region AnyRegion:$region); - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder< - (ins - "TypeRange":$resultTensorTypes, - "ValueRange":$inputs, - "ValueRange":$outputs, - CArg<"ArrayRef", "{}">:$attributes), - [{ - buildStructuredOp($_builder, $_state, resultTensorTypes, - inputs, outputs, attributes, Mm4DVnniOp::getRegionBuilder()); - }]> - ]; - - let hasCustomAssemblyFormat = 1; - let hasFolder = 1; - let hasVerifier = 1; - - let extraClassDeclaration = structuredOpsBaseDecls # [{ - // Declare functions necessary for LinalgStructuredInterface. - SmallVector getIteratorTypesArray(); - ArrayAttr getIndexingMaps(); - static unsigned getNumRegionArgs() { return 3; } - std::string getLibraryCallName() { - return "op_has_no_registered_library_name"; - } - - // Implement functions necessary for DestinationStyleOpInterface. - MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } - - static void regionBuilder(ImplicitLocOpBuilder &b, - Block &block, ArrayRef attrs); - static std::function)> - getRegionBuilder() { - return regionBuilder; - } - }]; -} - -def Linalgx_BatchReduceMatmulVnniOp - : LinalgxStructuredBase_Op<"batch_reduce_matmul_vnni", [AttrSizedOperandSegments]> { - let summary = "Batch reduced matmul with 3d batch input and vnni packed weights"; - let description = [{ - Supported format: A[B, M, K] * B[B, k, N, v] -> C[M, N], with: - K = k * v; v = (2, 4) - }]; - let arguments = (ins - Variadic:$inputs, - Variadic:$outputs); - let results = (outs Variadic:$results); - let regions = (region AnyRegion:$region); - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder< - (ins - "TypeRange":$resultTensorTypes, - "ValueRange":$inputs, - "ValueRange":$outputs, - CArg<"ArrayRef", "{}">:$attributes), - [{ - buildStructuredOp($_builder, $_state, resultTensorTypes, - inputs, outputs, attributes, BatchReduceMatmulVnniOp::getRegionBuilder()); - }]> - ]; - - let hasCustomAssemblyFormat = 1; - let hasFolder = 1; - let hasVerifier = 1; - - let extraClassDeclaration = structuredOpsBaseDecls # [{ - // Declare functions necessary for LinalgStructuredInterface. - SmallVector getIteratorTypesArray(); - ArrayAttr getIndexingMaps(); - static unsigned getNumRegionArgs() { return 3; } - std::string getLibraryCallName() { - return "op_has_no_registered_library_name"; - } - - // Implement functions necessary for DestinationStyleOpInterface. - MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } - - static void regionBuilder(ImplicitLocOpBuilder &b, - Block &block, ArrayRef attrs); - static std::function)> - getRegionBuilder() { - return regionBuilder; - } - }]; -} - -def Linalgx_MultiBatchMatmulOp : LinalgxStructuredBase_Op<"multi_batch_matmul", - [AttrSizedOperandSegments, LinalgContractionOpInterface]> { - let summary = "Batch matmul with variable batch dims"; - let arguments = (ins - Variadic:$inputs, - Variadic:$outputs); - let results = (outs Variadic:$results); - let regions = (region AnyRegion:$region); - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder< - (ins - "TypeRange":$resultTensorTypes, - "ValueRange":$inputs, - "ValueRange":$outputs, - CArg<"ArrayRef", "{}">:$attributes), - [{ - buildStructuredOp($_builder, $_state, resultTensorTypes, - inputs, outputs, attributes, MultiBatchMatmulOp::getRegionBuilder()); - }]> - ]; - - let hasCustomAssemblyFormat = 1; - let hasFolder = 1; - - let extraClassDeclaration = structuredOpsBaseDecls # [{ - // Declare functions necessary for LinalgStructuredInterface. - SmallVector getIteratorTypesArray(); - ArrayAttr getIndexingMaps(); - static unsigned getNumRegionArgs() { return 3; } - std::string getLibraryCallName() { - return "op_has_no_registered_library_name"; - } - - // Implement functions necessary for DestinationStyleOpInterface. - MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } - - static void regionBuilder(ImplicitLocOpBuilder &b, - Block &block, ArrayRef attrs); - static std::function)> - getRegionBuilder() { - return regionBuilder; - } - }]; -} - #endif // LINALGX_STRUCTURED_OPS \ No newline at end of file diff --git a/include/gc/Dialect/Linalgx/Utils.h b/include/gc/Dialect/Linalgx/Utils.h index b9c681864..5bc83b449 100644 --- a/include/gc/Dialect/Linalgx/Utils.h +++ b/include/gc/Dialect/Linalgx/Utils.h @@ -18,7 +18,7 @@ namespace mlir { namespace linalgx { /// @brief enum of type of matmul packing -enum class PackingType { +enum class PackingType : int { MM4D = 0, // MKmk x NKkn VNNI_MM2D, // MK x NKknV VNNI_MM4D, // MKmk x NKknV @@ -43,6 +43,30 @@ makeGenericPackedMatmulOp(OpBuilder &builder, Location loc, PackingType opType, /// @return true if op is a generic packed matmul Op bool isGenericPackedMatmulOp(Operation *op, PackingType opType); +template +inline bool isGenericPackedMatmulOp(Operation *op, PackingType first, + Args... args) { + return isGenericPackedMatmulOp(op, first) || + isGenericPackedMatmulOp(op, args...); +} + +/// @brief identify a generic packed matmul Op based on any PackingType +/// @param op the op +/// @return true if op is a generic packed matmul Op +template inline bool isAnyGenericPackedMatmulOp(Operation *op) { + return isGenericPackedMatmulOp(op, (PackingType)N) || + isAnyGenericPackedMatmulOp(op); +} +constexpr int NUM_ALL_TYPES = (int)PackingType::NUM_TYPES; +template <> +inline bool +isAnyGenericPackedMatmulOp(Operation *op) { + return false; +} +inline bool isAnyGenericPackedMatmulOp(Operation *op) { + return isAnyGenericPackedMatmulOp<0, NUM_ALL_TYPES>(op); +} + /// @brief identify a matmul Op based on ContractionOp and PackingType /// @param op the op /// @return true if op is a matmul Op diff --git a/lib/gc/Dialect/Linalgx/LinalgxOps.cpp b/lib/gc/Dialect/Linalgx/LinalgxOps.cpp index 04eae3657..1f7fc72b5 100644 --- a/lib/gc/Dialect/Linalgx/LinalgxOps.cpp +++ b/lib/gc/Dialect/Linalgx/LinalgxOps.cpp @@ -81,533 +81,6 @@ void SigmoidOp::getEffects( getGenericEffectsImpl(effects, cast(getOperation())); } -//===----------------------------------------------------------------------===// -// Mm2DVnniOp -//===----------------------------------------------------------------------===// - -SmallVector Mm2DVnniOp::getIteratorTypesArray() { - return SmallVector{ - utils::IteratorType::parallel, utils::IteratorType::parallel, - utils::IteratorType::parallel, utils::IteratorType::reduction, - utils::IteratorType::reduction, utils::IteratorType::reduction}; -} - -static SmallVector getSymbolBindings(Mm2DVnniOp self) { - MLIRContext *context = self.getContext(); - - auto vnniShape = ShapeAdaptor(self.getInputs()[1].getType()); - - SmallVector exprs; - exprs.push_back(getAffineSymbolExpr(0, context)); - exprs.push_back(getAffineSymbolExpr(1, context)); - - int64_t cst2 = vnniShape.getDimSize(3); - exprs.push_back(getAffineConstantExpr(cst2, context)); - - exprs.push_back(getAffineSymbolExpr(3, context)); - - int64_t cst4 = vnniShape.getDimSize(2); - exprs.push_back(getAffineConstantExpr(cst4, context)); - - int64_t cst5 = vnniShape.getDimSize(4); - exprs.push_back(getAffineConstantExpr(cst5, context)); - return exprs; -} - -ArrayAttr Mm2DVnniOp::getIndexingMaps() { - static const char memoizeAttr[] = "linalg.memoized_indexing_maps"; - ArrayAttr cached = getOperation()->getAttrOfType(memoizeAttr); - if (cached) - return cached; - constexpr const char *mapA = - "affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5]" - " -> (d0, (d3 * s4 + d4) * s5 + d5)>"; - constexpr const char *mapB = - "affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5]" - " -> (d1, d3, d4, d2, d5)>"; - constexpr const char *mapC = - "affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5]" - " -> (d0, d1 * s2 + d2)>"; - MLIRContext *context = getContext(); - auto symbolBindings = getSymbolBindings(*this); - SmallVector maps; - maps.push_back(llvm::cast(mlir::parseAttribute(mapA, context)) - .getValue()); - maps.back() = simplifyAffineMap( - maps.back().replaceDimsAndSymbols({}, symbolBindings, 6, 0)); - maps.push_back(llvm::cast(mlir::parseAttribute(mapB, context)) - .getValue()); - maps.back() = simplifyAffineMap( - maps.back().replaceDimsAndSymbols({}, symbolBindings, 6, 0)); - maps.push_back(llvm::cast(mlir::parseAttribute(mapC, context)) - .getValue()); - maps.back() = simplifyAffineMap( - maps.back().replaceDimsAndSymbols({}, symbolBindings, 6, 0)); - cached = Builder(context).getAffineMapArrayAttr(maps); - getOperation()->setAttr(memoizeAttr, cached); - return cached; -} - -void Mm2DVnniOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, - ArrayRef attrs) { - assert(3 > 0 && block.getNumArguments() == 3 && - "Mm2DVnniOp regionBuilder expects 3 (>=0) args"); - RegionBuilderHelper helper(b, block); - SmallVector yields; - - Value value1 = - helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), - block.getArgument(0)); - Value value2 = - helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), - block.getArgument(1)); - Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2); - Value value4 = - helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3); - yields.push_back(value4); - helper.yieldOutputs(yields); -} - -ParseResult Mm2DVnniOp::parse(OpAsmParser &parser, OperationState &result) { - return ::parseNamedStructuredOp(parser, result, - Mm2DVnniOp::getNumRegionArgs(), - Mm2DVnniOp::getRegionBuilder()); -} - -void Mm2DVnniOp::print(OpAsmPrinter &p) { - ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs()); -} - -LogicalResult Mm2DVnniOp::fold(FoldAdaptor, SmallVectorImpl &) { - return memref::foldMemRefCast(*this); -} - -void Mm2DVnniOp::getEffects( - SmallVectorImpl> - &effects) { - if (hasPureTensorSemantics()) - return; - getGenericEffectsImpl(effects, cast(getOperation())); -} - -LogicalResult Mm2DVnniOp::verify() { - // A[M, K] - // B[N0, K0, K1, N1, K2] - // C[M, N] - auto shapeA = ShapeAdaptor(getInputs()[0].getType()); - auto shapeB = ShapeAdaptor(getInputs()[1].getType()); - auto shapeC = ShapeAdaptor(getOutputs()[0].getType()); - // check rank - auto hasRank = shapeA.hasRank() && shapeB.hasRank() && shapeC.hasRank(); - if (!hasRank) - return emitOpError() << "input/output must have rank."; - auto checkRank = (shapeA.getRank() == 2) && (shapeB.getRank() == 5) && - (shapeC.getRank() == 2); - if (!checkRank) - return emitOpError() << "not supported input/output shape."; - // match M, N, K dims - bool matchM = shapeA.getDimSize(0) == shapeC.getDimSize(0); - bool matchN = - (shapeB.getDimSize(0) * shapeB.getDimSize(3)) == shapeC.getDimSize(1); - bool matchK = - shapeA.getDimSize(1) == - (shapeB.getDimSize(1) * shapeB.getDimSize(2) * shapeB.getDimSize(4)); - bool result = matchM && matchN && matchK; - if (!result) - return emitOpError() << "input/output dims packing not match."; - // match vnni dim: bf16 == 2; i8 == 4 - auto dataSize = DataLayout().getTypeSizeInBits(shapeB.getElementType()); - bool matchVnni = (shapeB.getDimSize(4) == 2 && dataSize == 16) || - (shapeB.getDimSize(4) == 4 && dataSize == 8); - if (!matchVnni) - return emitOpError() << "input b vnni dim not valid."; - return success(); -} - -//===----------------------------------------------------------------------===// -// Mm4DVnniOp -//===----------------------------------------------------------------------===// - -SmallVector Mm4DVnniOp::getIteratorTypesArray() { - return SmallVector{ - utils::IteratorType::parallel, utils::IteratorType::parallel, - utils::IteratorType::parallel, utils::IteratorType::parallel, - utils::IteratorType::reduction, utils::IteratorType::reduction, - utils::IteratorType::reduction}; -} - -static SmallVector getSymbolBindings(Mm4DVnniOp self) { - MLIRContext *context = self.getContext(); - - auto vnniShape = ShapeAdaptor(self.getInputs()[1].getType()); - - SmallVector exprs; - exprs.push_back(getAffineSymbolExpr(0, context)); - exprs.push_back(getAffineSymbolExpr(1, context)); - exprs.push_back(getAffineSymbolExpr(2, context)); - exprs.push_back(getAffineSymbolExpr(3, context)); - exprs.push_back(getAffineSymbolExpr(4, context)); - exprs.push_back(getAffineSymbolExpr(5, context)); - - int64_t cst6 = vnniShape.getDimSize(4); - exprs.push_back(getAffineConstantExpr(cst6, context)); - return exprs; -} - -ArrayAttr Mm4DVnniOp::getIndexingMaps() { - static const char memoizeAttr[] = "linalg.memoized_indexing_maps"; - ArrayAttr cached = getOperation()->getAttrOfType(memoizeAttr); - if (cached) - return cached; - constexpr const char *mapA = - "affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6]" - " -> (d0, d4, d2, d5 * s6 + d6)>"; - constexpr const char *mapB = - "affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6]" - " -> (d1, d4, d5, d3, d6)>"; - constexpr const char *mapC = - "affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6]" - " -> (d0, d1, d2, d3)>"; - MLIRContext *context = getContext(); - auto symbolBindings = getSymbolBindings(*this); - SmallVector maps; - maps.push_back(llvm::cast(mlir::parseAttribute(mapA, context)) - .getValue()); - maps.back() = simplifyAffineMap( - maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0)); - maps.push_back(llvm::cast(mlir::parseAttribute(mapB, context)) - .getValue()); - maps.back() = simplifyAffineMap( - maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0)); - maps.push_back(llvm::cast(mlir::parseAttribute(mapC, context)) - .getValue()); - maps.back() = simplifyAffineMap( - maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0)); - cached = Builder(context).getAffineMapArrayAttr(maps); - getOperation()->setAttr(memoizeAttr, cached); - return cached; -} - -void Mm4DVnniOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, - ArrayRef attrs) { - assert(3 > 0 && block.getNumArguments() == 3 && - "Mm4DVnniOp regionBuilder expects 3 (>=0) args"); - RegionBuilderHelper helper(b, block); - SmallVector yields; - - Value value1 = - helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), - block.getArgument(0)); - Value value2 = - helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), - block.getArgument(1)); - Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2); - Value value4 = - helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3); - yields.push_back(value4); - helper.yieldOutputs(yields); -} - -ParseResult Mm4DVnniOp::parse(OpAsmParser &parser, OperationState &result) { - return ::parseNamedStructuredOp(parser, result, - Mm4DVnniOp::getNumRegionArgs(), - Mm4DVnniOp::getRegionBuilder()); -} - -void Mm4DVnniOp::print(OpAsmPrinter &p) { - ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs()); -} - -LogicalResult Mm4DVnniOp::fold(FoldAdaptor, SmallVectorImpl &) { - return memref::foldMemRefCast(*this); -} - -void Mm4DVnniOp::getEffects( - SmallVectorImpl> - &effects) { - if (hasPureTensorSemantics()) - return; - getGenericEffectsImpl(effects, cast(getOperation())); -} - -LogicalResult Mm4DVnniOp::verify() { - // A[M0, K0, M1, K] - // B[N0, K0, K1, N1, K2] - // C[M0, N0, M1, N1] - auto shapeA = ShapeAdaptor(getInputs()[0].getType()); - auto shapeB = ShapeAdaptor(getInputs()[1].getType()); - auto shapeC = ShapeAdaptor(getOutputs()[0].getType()); - // check rank - auto hasRank = shapeA.hasRank() && shapeB.hasRank() && shapeC.hasRank(); - if (!hasRank) - return emitOpError() << "input/output must have rank."; - auto checkRank = (shapeA.getRank() == 4) && (shapeB.getRank() == 5) && - (shapeC.getRank() == 4); - if (!checkRank) - return emitOpError() << "not supported input/output shape."; - // match M0, M1, N0, N1, K0, K dims - bool matchM0 = shapeA.getDimSize(0) == shapeC.getDimSize(0); - bool matchM1 = shapeA.getDimSize(2) == shapeC.getDimSize(2); - bool matchN0 = shapeB.getDimSize(0) == shapeC.getDimSize(1); - bool matchN1 = shapeB.getDimSize(3) == shapeC.getDimSize(3); - bool matchK0 = shapeA.getDimSize(1) == shapeB.getDimSize(1); - bool matchK = - shapeA.getDimSize(3) == (shapeB.getDimSize(2) * shapeB.getDimSize(4)); - bool result = matchM0 && matchM1 && matchN0 && matchN1 && matchK0 && matchK; - if (!result) - return emitOpError() << "input/output dims packing not match."; - // match vnni dim: bf16 == 2; i8 == 4 - auto dataSize = DataLayout().getTypeSizeInBits(shapeB.getElementType()); - bool matchVnni = (shapeB.getDimSize(4) == 2 && dataSize == 16) || - (shapeB.getDimSize(4) == 4 && dataSize == 8); - if (!matchVnni) - return emitOpError() << "input b vnni dim not valid."; - return success(); -} - -//===----------------------------------------------------------------------===// -// BatchReduceMatmulVnniOp -//===----------------------------------------------------------------------===// - -SmallVector -BatchReduceMatmulVnniOp::getIteratorTypesArray() { - return SmallVector{ - utils::IteratorType::parallel, utils::IteratorType::parallel, - utils::IteratorType::reduction, utils::IteratorType::reduction, - utils::IteratorType::reduction}; -} - -static SmallVector getSymbolBindings(BatchReduceMatmulVnniOp self) { - MLIRContext *context = self.getContext(); - - auto vnniShape = ShapeAdaptor(self.getInputs()[1].getType()); - - SmallVector exprs; - exprs.push_back(getAffineSymbolExpr(0, context)); - exprs.push_back(getAffineSymbolExpr(1, context)); - exprs.push_back(getAffineSymbolExpr(2, context)); - exprs.push_back(getAffineSymbolExpr(3, context)); - - int64_t cst4 = vnniShape.getDimSize(3); - exprs.push_back(getAffineConstantExpr(cst4, context)); - return exprs; -} - -ArrayAttr BatchReduceMatmulVnniOp::getIndexingMaps() { - static const char memoizeAttr[] = "linalg.memoized_indexing_maps"; - ArrayAttr cached = getOperation()->getAttrOfType(memoizeAttr); - if (cached) - return cached; - constexpr const char *mapA = - "affine_map<(d0, d1, d2, d3, d4)[s0, s1, s2, s3, s4]" - " -> (d2, d0, d3 * s4 + d4)>"; - constexpr const char *mapB = - "affine_map<(d0, d1, d2, d3, d4)[s0, s1, s2, s3, s4]" - " -> (d2, d3, d1, d4)>"; - constexpr const char *mapC = - "affine_map<(d0, d1, d2, d3, d4)[s0, s1, s2, s3, s4]" - " -> (d0, d1)>"; - MLIRContext *context = getContext(); - auto symbolBindings = getSymbolBindings(*this); - SmallVector maps; - maps.push_back(llvm::cast(mlir::parseAttribute(mapA, context)) - .getValue()); - maps.back() = simplifyAffineMap( - maps.back().replaceDimsAndSymbols({}, symbolBindings, 5, 0)); - maps.push_back(llvm::cast(mlir::parseAttribute(mapB, context)) - .getValue()); - maps.back() = simplifyAffineMap( - maps.back().replaceDimsAndSymbols({}, symbolBindings, 5, 0)); - maps.push_back(llvm::cast(mlir::parseAttribute(mapC, context)) - .getValue()); - maps.back() = simplifyAffineMap( - maps.back().replaceDimsAndSymbols({}, symbolBindings, 5, 0)); - cached = Builder(context).getAffineMapArrayAttr(maps); - getOperation()->setAttr(memoizeAttr, cached); - return cached; -} - -void BatchReduceMatmulVnniOp::regionBuilder(ImplicitLocOpBuilder &b, - Block &block, - ArrayRef attrs) { - assert(3 > 0 && block.getNumArguments() == 3 && - "BatchReduceMatmulVnniOp regionBuilder expects 3 (>=0) args"); - RegionBuilderHelper helper(b, block); - SmallVector yields; - - Value value1 = - helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), - block.getArgument(0)); - Value value2 = - helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), - block.getArgument(1)); - Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2); - Value value4 = - helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3); - yields.push_back(value4); - helper.yieldOutputs(yields); -} - -ParseResult BatchReduceMatmulVnniOp::parse(OpAsmParser &parser, - OperationState &result) { - return ::parseNamedStructuredOp(parser, result, - BatchReduceMatmulVnniOp::getNumRegionArgs(), - BatchReduceMatmulVnniOp::getRegionBuilder()); -} - -void BatchReduceMatmulVnniOp::print(OpAsmPrinter &p) { - ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs()); -} - -LogicalResult BatchReduceMatmulVnniOp::fold(FoldAdaptor, - SmallVectorImpl &) { - return memref::foldMemRefCast(*this); -} - -void BatchReduceMatmulVnniOp::getEffects( - SmallVectorImpl> - &effects) { - if (hasPureTensorSemantics()) - return; - getGenericEffectsImpl(effects, cast(getOperation())); -} - -LogicalResult BatchReduceMatmulVnniOp::verify() { - // A[B, M, K] - // B[B, K0, N, K1] - // C[M, N] - auto shapeA = ShapeAdaptor(getInputs()[0].getType()); - auto shapeB = ShapeAdaptor(getInputs()[1].getType()); - auto shapeC = ShapeAdaptor(getOutputs()[0].getType()); - // check rank - auto hasRank = shapeA.hasRank() && shapeB.hasRank() && shapeC.hasRank(); - if (!hasRank) - return emitOpError() << "input/output must have rank."; - auto checkRank = (shapeA.getRank() == 3) && (shapeB.getRank() == 4) && - (shapeC.getRank() == 2); - if (!checkRank) - return emitOpError() << "not supported input/output shape."; - // match B, M, N, K dims - bool matchB = shapeA.getDimSize(0) == shapeB.getDimSize(0); - bool matchM = shapeA.getDimSize(1) == shapeC.getDimSize(0); - bool matchN = shapeB.getDimSize(2) == shapeC.getDimSize(1); - bool matchK = - shapeA.getDimSize(2) == (shapeB.getDimSize(1) * shapeB.getDimSize(3)); - bool result = matchB && matchM && matchN && matchK; - if (!result) - return emitOpError() << "input/output dims packing not match."; - // match vnni dim: bf16 == 2; i8 == 4 - auto dataSize = DataLayout().getTypeSizeInBits(shapeB.getElementType()); - bool matchVnni = (shapeB.getDimSize(3) == 2 && dataSize == 16) || - (shapeB.getDimSize(3) == 4 && dataSize == 8); - if (!matchVnni) - return emitOpError() << "input b vnni dim not valid."; - return success(); -} - -//===----------------------------------------------------------------------===// -// MultiBatchMatmulOp -//===----------------------------------------------------------------------===// - -SmallVector MultiBatchMatmulOp::getIteratorTypesArray() { - int64_t rank = getRank(getDpsInitOperand(0)); - SmallVector iteratorTypes(rank, - utils::IteratorType::parallel); - iteratorTypes.push_back(utils::IteratorType::reduction); - return iteratorTypes; -} - -static SmallVector getSymbolBindings(MultiBatchMatmulOp self) { - MLIRContext *context = self.getContext(); - int64_t symbols = self.getRank(self.getDpsInitOperand(0)) + 1; - SmallVector exprs; - for (auto dim : llvm::seq(0, symbols)) { - exprs.push_back(getAffineSymbolExpr(dim, context)); - } - return exprs; -} - -ArrayAttr MultiBatchMatmulOp::getIndexingMaps() { - static const char memoizeAttr[] = "linalg.memoized_indexing_maps"; - ArrayAttr cached = getOperation()->getAttrOfType(memoizeAttr); - if (cached) - return cached; - int64_t symbols = getRank(getDpsInitOperand(0)) + 1; - int64_t batches = getRank(getDpsInitOperand(0)) - 2; - MLIRContext *context = getContext(); - // Get affine_map with specified mat dims - auto getBatchMMAffineMap = [&](int64_t mat1, int64_t mat2) { - SmallVector exprs; - // batch dims - for (auto dim : llvm::seq(0, batches)) { - auto expr = getAffineDimExpr(dim, context); - exprs.push_back(expr); - } - // mat dims - exprs.push_back(getAffineDimExpr(mat1, context)); - exprs.push_back(getAffineDimExpr(mat2, context)); - return AffineMap::get(symbols, symbols, exprs, context); - }; - auto symbolBindings = getSymbolBindings(*this); - SmallVector maps; - maps.push_back(getBatchMMAffineMap(batches + 0, batches + 2)); - maps.back() = simplifyAffineMap( - maps.back().replaceDimsAndSymbols({}, symbolBindings, symbols, 0)); - maps.push_back(getBatchMMAffineMap(batches + 2, batches + 1)); - maps.back() = simplifyAffineMap( - maps.back().replaceDimsAndSymbols({}, symbolBindings, symbols, 0)); - maps.push_back(getBatchMMAffineMap(batches + 0, batches + 1)); - maps.back() = simplifyAffineMap( - maps.back().replaceDimsAndSymbols({}, symbolBindings, symbols, 0)); - cached = Builder(context).getAffineMapArrayAttr(maps); - getOperation()->setAttr(memoizeAttr, cached); - return cached; -} - -void MultiBatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, - ArrayRef attrs) { - assert(3 > 0 && block.getNumArguments() == 3 && - "MultiBatchMatmulOp regionBuilder expects 3 (>=0) args"); - RegionBuilderHelper helper(b, block); - SmallVector yields; - - Value value1 = - helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), - block.getArgument(0)); - Value value2 = - helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), - block.getArgument(1)); - Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2); - Value value4 = - helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3); - yields.push_back(value4); - helper.yieldOutputs(yields); -} - -ParseResult MultiBatchMatmulOp::parse(OpAsmParser &parser, - OperationState &result) { - return ::parseNamedStructuredOp(parser, result, - MultiBatchMatmulOp::getNumRegionArgs(), - MultiBatchMatmulOp::getRegionBuilder()); -} - -void MultiBatchMatmulOp::print(OpAsmPrinter &p) { - ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs()); -} - -LogicalResult MultiBatchMatmulOp::fold(FoldAdaptor, - SmallVectorImpl &) { - return memref::foldMemRefCast(*this); -} - -void MultiBatchMatmulOp::getEffects( - SmallVectorImpl> - &effects) { - if (hasPureTensorSemantics()) - return; - getGenericEffectsImpl(effects, cast(getOperation())); -} - /////// Operations corresponding to library calls defined with Tablegen //////// #define GET_OP_CLASSES diff --git a/lib/gc/Transforms/DeepTileContractionOp.cpp b/lib/gc/Transforms/DeepTileContractionOp.cpp index 769bb5321..21de7b778 100644 --- a/lib/gc/Transforms/DeepTileContractionOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionOp.cpp @@ -949,14 +949,10 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { } bool checkLinalgMatmulType(linalg::LinalgOp linalgOp) const { - return llvm::isa(linalgOp) || - linalgx::isGenericPackedMatmulOp(linalgOp.getOperation(), - linalgx::PackingType::VNNI_MM2D) || - linalgx::isGenericPackedMatmulOp(linalgOp.getOperation(), - linalgx::PackingType::VNNI_MM4D) || - linalgx::isGenericPackedMatmulOp(linalgOp.getOperation(), - linalgx::PackingType::MM4D); + return llvm::isa(linalgOp) || + linalgx::isGenericPackedMatmulOp( + linalgOp.getOperation(), linalgx::PackingType::VNNI_MM2D, + linalgx::PackingType::VNNI_MM4D, linalgx::PackingType::MM4D); } LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, diff --git a/lib/gc/Transforms/IterativeTilingAndFusion.cpp b/lib/gc/Transforms/IterativeTilingAndFusion.cpp index cfbc6d9e2..294342f20 100644 --- a/lib/gc/Transforms/IterativeTilingAndFusion.cpp +++ b/lib/gc/Transforms/IterativeTilingAndFusion.cpp @@ -610,12 +610,6 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp( return success(numTiledOps > 1); } -/// This is a workaround to deal with LinalgXOp -static bool isTilableLinalgXOp(Operation *op) { - return isa(op); -} - /// Check if tiled op inside a loop? /// E.g. /// %1 = scf.for(){ @@ -628,7 +622,7 @@ static bool isTilableLinalgXOp(Operation *op) { /// } static LogicalResult isTiledOpInLoop(Operation *targetOp) { // 1. check tilable - if (!isa(targetOp) && !isTilableLinalgXOp(targetOp)) + if (!isa(targetOp)) return failure(); // 2. check parentOp auto forOp = targetOp->getParentOfType(); diff --git a/lib/gc/Transforms/LowerToTileVector.cpp b/lib/gc/Transforms/LowerToTileVector.cpp index c2e0c895d..d105eaeb8 100644 --- a/lib/gc/Transforms/LowerToTileVector.cpp +++ b/lib/gc/Transforms/LowerToTileVector.cpp @@ -5,7 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -#include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "gc/Dialect/Linalgx/Utils.h" #include "gc/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -34,14 +34,6 @@ namespace { #define SAFE_EXPAND(X) X #define LDBG(X) LLVM_DEBUG(DBGS() << SAFE_EXPAND(X) << "\n") -#define IMPLEMENTED_MATMUL \ - linalgx::BatchReduceMatmulVnniOp, linalgx::MultiBatchMatmulOp, \ - linalg::BatchReduceMatmulOp, linalgx::Mm2DVnniOp, linalgx::Mm4DVnniOp, \ - linalg::MatmulOp, linalg::BatchMatmulOp, \ - linalg::BatchMatmulTransposeAOp, linalg::BatchMatmulTransposeBOp, \ - linalg::MatmulTransposeAOp, linalg::MatmulTransposeBOp, \ - linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp - #define SUPPORT_TENSOR_OP \ tensor::ExpandShapeOp, tensor::CollapseShapeOp, tensor::ConcatOp @@ -54,7 +46,7 @@ static inline bool isRequiredTensorOp(Operation *operation) { /// matmul operation or fill + matmul operation static bool isMatchedOperationSet(Operation *op) { - if (isa(op)) + if (linalgx::isMatmulOp(op)) return true; // Operation produce for matmul can't lower. @@ -63,7 +55,7 @@ static bool isMatchedOperationSet(Operation *op) { return false; return llvm::any_of(op->getUsers(), - [](Operation *x) { return isa(x); }); + [](Operation *x) { return linalgx::isMatmulOp(x); }); } static bool isContainsDynamicSize(ArrayRef sizes) { diff --git a/test/mlir/test/gc/Dialect/Linlagx/generalize-named-ops.mlir b/test/mlir/test/gc/Dialect/Linlagx/generalize-named-ops.mlir index 51ca3ecce..562e020e7 100644 --- a/test/mlir/test/gc/Dialect/Linlagx/generalize-named-ops.mlir +++ b/test/mlir/test/gc/Dialect/Linlagx/generalize-named-ops.mlir @@ -22,117 +22,3 @@ func.func @generalize_sigmoid(%arg0: tensor<4x256x64xbf16>, %arg1: tensor<4x256x // CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[EXP]], %[[CST]] : bf16 // CHECK-NEXT: %[[DIV:.+]] = arith.divf %[[CST]], %[[ADD]] : bf16 // CHECK-NEXT: linalg.yield %[[DIV]] : bf16 - -// ----- - -func.func @generalize_mm2d_vnni(%arg0: tensor<256x64xi8>, %arg1: tensor<16x2x8x32x4xi8>, - %arg2: tensor<256x512xi32>) -> tensor<256x512xi32> { - %0 = linalgx.mm2d_vnni ins(%arg0, %arg1 : tensor<256x64xi8>, tensor<16x2x8x32x4xi8>) - outs(%arg2 : tensor<256x512xi32>) -> tensor<256x512xi32> - return %0 : tensor<256x512xi32> -} - -// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3 * 32 + d4 * 4 + d5)> -// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d4, d2, d5)> -// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 32 + d2)> - -// CHECK: func @generalize_mm2d_vnni - -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "reduction"] -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<256x64xi8>, tensor<16x2x8x32x4xi8>) -// CHECK-SAME: outs(%{{.+}} : tensor<256x512xi32>) - -// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32) -// CHECK-NEXT: %[[A_EXT:.+]] = arith.extsi %[[A_ARG]] : i8 to i32 -// CHECK-NEXT: %[[B_EXT:.+]] = arith.extsi %[[B_ARG]] : i8 to i32 -// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[A_EXT]], %[[B_EXT]] : i32 -// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32 -// CHECK-NEXT: linalg.yield %[[ADD]] : i32 -// CHECK-NEXT: -> tensor<256x512xi32> - -// ----- - -func.func @generalize_mm4d_vnni(%arg0: tensor<2x8x32x32xbf16>, %arg1: tensor<4x8x16x32x2xbf16>, - %arg2: tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> { - %0 = linalgx.mm4d_vnni ins(%arg0, %arg1 : tensor<2x8x32x32xbf16>, tensor<4x8x16x32x2xbf16>) - outs(%arg2 : tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> - return %0 : tensor<2x4x32x32xbf16> -} - -// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2, d5 * 2 + d6)> -// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d3, d6)> -// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - -// CHECK: func @generalize_mm4d_vnni - -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"] -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<2x8x32x32xbf16>, tensor<4x8x16x32x2xbf16>) -// CHECK-SAME: outs(%{{.+}} : tensor<2x4x32x32xbf16>) - -// CHECK: ^{{.*}}(%[[A_ARG:.+]]: bf16, %[[B_ARG:.+]]: bf16, %[[C_ARG:.+]]: bf16) -// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_ARG]], %[[B_ARG]] : bf16 -// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : bf16 -// CHECK-NEXT: linalg.yield %[[ADD]] : bf16 -// CHECK-NEXT: -> tensor<2x4x32x32xbf16> - -// ----- - -func.func @generalize_batch_reduce_matmul_vnni(%arg0: tensor<512x32x64xbf16>, %arg1: tensor<512x32x128x2xbf16>, - %arg2: tensor<32x128xf32>) -> tensor<32x128xf32> { - %0 = linalgx.batch_reduce_matmul_vnni ins(%arg0, %arg1 : tensor<512x32x64xbf16>, tensor<512x32x128x2xbf16>) - outs(%arg2 : tensor<32x128xf32>) -> tensor<32x128xf32> - return %0 : tensor<32x128xf32> -} - -// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d0, d3 * 2 + d4)> -// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1, d4)> -// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)> - -// CHECK: func @generalize_batch_reduce_matmul_vnni - -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "reduction", "reduction"] -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<512x32x64xbf16>, tensor<512x32x128x2xbf16>) -// CHECK-SAME: outs(%{{.+}} : tensor<32x128xf32>) - -// CHECK: ^{{.*}}(%[[A_ARG:.+]]: bf16, %[[B_ARG:.+]]: bf16, %[[C_ARG:.+]]: f32) -// CHECK-NEXT: %[[A_EXT:.+]] = arith.extf %[[A_ARG]] : bf16 to f32 -// CHECK-NEXT: %[[B_EXT:.+]] = arith.extf %[[B_ARG]] : bf16 to f32 -// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_EXT]], %[[B_EXT]] : f32 -// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32 -// CHECK-NEXT: linalg.yield %[[ADD]] : f32 -// CHECK-NEXT: -> tensor<32x128xf32> - -// ----- - -func.func @generalize_multi_batch_matmul(%arg0: tensor<13x5x6x128x512xbf16>, %arg1: tensor<13x5x6x512x256xbf16>, - %arg2: tensor<13x5x6x128x256xbf16>) -> tensor<13x5x6x128x256xbf16> { - %0 = linalgx.multi_batch_matmul ins(%arg0, %arg1 : tensor<13x5x6x128x512xbf16>, tensor<13x5x6x512x256xbf16>) - outs(%arg2 : tensor<13x5x6x128x256xbf16>) -> tensor<13x5x6x128x256xbf16> - return %0 : tensor<13x5x6x128x256xbf16> -} - -// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d5)> -// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5, d4)> -// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4)> - -// CHECK: func @generalize_multi_batch_matmul - -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction"] -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<13x5x6x128x512xbf16>, tensor<13x5x6x512x256xbf16>) -// CHECK-SAME: outs(%{{.+}} : tensor<13x5x6x128x256xbf16>) - -// CHECK: ^{{.*}}(%[[A_ARG:.+]]: bf16, %[[B_ARG:.+]]: bf16, %[[C_ARG:.+]]: bf16) -// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_ARG]], %[[B_ARG]] : bf16 -// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : bf16 -// CHECK-NEXT: linalg.yield %[[ADD]] : bf16 -// CHECK-NEXT: -> tensor<13x5x6x128x256xbf16> - -// ----- diff --git a/test/mlir/test/gc/Dialect/Linlagx/linalgx-named-ops.mlir b/test/mlir/test/gc/Dialect/Linlagx/linalgx-named-ops.mlir index c87ca2259..fc552aa0a 100644 --- a/test/mlir/test/gc/Dialect/Linlagx/linalgx-named-ops.mlir +++ b/test/mlir/test/gc/Dialect/Linlagx/linalgx-named-ops.mlir @@ -6,39 +6,3 @@ func.func @sigmoid(%arg0: tensor<4x256x64xbf16>, %arg1: tensor<4x256x64xbf16>) - %0 = linalgx.sigmoid ins(%arg0 : tensor<4x256x64xbf16>) outs(%arg1 : tensor<4x256x64xbf16>) -> tensor<4x256x64xbf16> return %0 : tensor<4x256x64xbf16> } - -// CHECK-LABEL: @mm2d_vnni -func.func @mm2d_vnni(%arg0: tensor<256x64xi8>, %arg1: tensor<16x2x8x32x4xi8>, - %arg2: tensor<256x512xi32>) -> tensor<256x512xi32> { - // CHECK: linalgx.mm2d_vnni - %0 = linalgx.mm2d_vnni ins(%arg0, %arg1 : tensor<256x64xi8>, tensor<16x2x8x32x4xi8>) - outs(%arg2 : tensor<256x512xi32>) -> tensor<256x512xi32> - return %0 : tensor<256x512xi32> -} - -// CHECK-LABEL: @mm4d_vnni -func.func @mm4d_vnni(%arg0: tensor<2x8x32x32xbf16>, %arg1: tensor<4x8x16x32x2xbf16>, - %arg2: tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> { - // CHECK: linalgx.mm4d_vnni - %0 = linalgx.mm4d_vnni ins(%arg0, %arg1 : tensor<2x8x32x32xbf16>, tensor<4x8x16x32x2xbf16>) - outs(%arg2 : tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> - return %0 : tensor<2x4x32x32xbf16> -} - -// CHECK-LABEL: @batch_reduce_matmul_vnni -func.func @batch_reduce_matmul_vnni(%arg0: tensor<512x32x64xbf16>, %arg1: tensor<512x32x128x2xbf16>, - %arg2: tensor<32x128xf32>) -> tensor<32x128xf32> { - // CHECK: linalgx.batch_reduce_matmul_vnni - %0 = linalgx.batch_reduce_matmul_vnni ins(%arg0, %arg1 : tensor<512x32x64xbf16>, tensor<512x32x128x2xbf16>) - outs(%arg2 : tensor<32x128xf32>) -> tensor<32x128xf32> - return %0 : tensor<32x128xf32> -} - -// CHECK-LABEL: @multi_batch_matmul -func.func @multi_batch_matmul(%arg0: tensor<13x5x6x128x512xbf16>, %arg1: tensor<13x5x6x512x256xbf16>, - %arg2: tensor<13x5x6x128x256xbf16>) -> tensor<13x5x6x128x256xbf16> { - // CHECK: linalgx.multi_batch_matmul - %0 = linalgx.multi_batch_matmul ins(%arg0, %arg1 : tensor<13x5x6x128x512xbf16>, tensor<13x5x6x512x256xbf16>) - outs(%arg2 : tensor<13x5x6x128x256xbf16>) -> tensor<13x5x6x128x256xbf16> - return %0 : tensor<13x5x6x128x256xbf16> -} diff --git a/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir index 0b024d026..61848dcb7 100644 --- a/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir +++ b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir @@ -73,7 +73,21 @@ func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<12 // CHECK: scf.if // CHECK: linalg.copy // CHECK: else - %2 = linalgx.mm4d_vnni {MThreads = 16 : i32, NThreads = 2 : i32, KThreads = 1 : i32, MBlock = 256 : i32, NBlock = 256 : i32, KBlock = 256 : i32,innermostMBlock = 32 : i32, innermostNBlock = 32 : i32, innermostKBlock = 32 : i32 } ins(%arg0, %arg1 : tensor<128x128x32x32xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> + %2 = linalg.generic { + MThreads = 16 : i32, NThreads = 2 : i32, KThreads = 1 : i32, MBlock = 256 : i32, NBlock = 256 : i32, KBlock = 256 : i32,innermostMBlock = 32 : i32, innermostNBlock = 32 : i32, innermostKBlock = 32 : i32, + indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d1, d5 * 2 + d6)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d5, d3, d6)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d1, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"] + } + ins(%arg0, %arg1 : tensor<128x128x32x32xbf16>, tensor<128x128x16x32x2xbf16>) + outs(%1 : tensor<128x128x32x32xbf16>) { + ^bb0(%in: bf16, %in_0: bf16, %out: bf16): + %b0 = arith.mulf %in, %in_0 : bf16 + %b1 = arith.addf %out, %b0 : bf16 + linalg.yield %b1 : bf16 + } -> tensor<128x128x32x32xbf16> + return %2 : tensor<128x128x32x32xbf16> } @@ -114,7 +128,21 @@ func.func @matmul_2Dx4D_bf16(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x12 // CHECK: scf.forall.in_parallel // CHECK: linalg.reduce {{.*}} dimensions = [0, 1, 2] // CHECK: linalg.copy - %2 = linalgx.mm2d_vnni {MThreads = 32 : i32, NThreads = 2 : i32, KThreads = 2 : i32, MBlock = 256 : i32, NBlock = 256 : i32, KBlock = 256 : i32,innermostMBlock = 32 : i32, innermostNBlock = 32 : i32, innermostKBlock = 32 : i32 } ins(%arg0, %arg1 : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> + %2 = linalg.generic { + MThreads = 32 : i32, NThreads = 2 : i32, KThreads = 2 : i32, MBlock = 256 : i32, NBlock = 256 : i32, KBlock = 256 : i32,innermostMBlock = 32 : i32, innermostNBlock = 32 : i32, innermostKBlock = 32 : i32, + indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3 * 32 + d4 * 2 + d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d4, d2, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 32 + d2)>], + iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "reduction"] + } + ins(%arg0, %arg1 : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) + outs(%1 : tensor<4096x4096xbf16>) { + ^bb0(%in: bf16, %in_0: bf16, %out: bf16): + %b0 = arith.mulf %in, %in_0 : bf16 + %b1 = arith.addf %out, %b0 : bf16 + linalg.yield %b1 : bf16 + } -> tensor<4096x4096xbf16> + return %2 : tensor<4096x4096xbf16> } @@ -154,7 +182,19 @@ func.func @matmul_2Dx4D_bf16_with_dlti(%arg0: tensor<4096x4096xbf16>, %arg1: ten // CHECK: else // CHECK: linalg.generic {indexing_maps = [#[[mapA]], #[[mapB]], #[[mapC]]], iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} // CHECK: scf.forall.in_parallel - %2 = linalgx.mm2d_vnni ins(%arg0, %arg1 : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> + %2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3 * 32 + d4 * 2 + d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d4, d2, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 32 + d2)>], + iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "reduction"] + } + ins(%arg0, %arg1 : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) + outs(%1 : tensor<4096x4096xbf16>) { + ^bb0(%in: bf16, %in_0: bf16, %out: bf16): + %b0 = arith.mulf %in, %in_0 : bf16 + %b1 = arith.addf %out, %b0 : bf16 + linalg.yield %b1 : bf16 + } -> tensor<4096x4096xbf16> return %2 : tensor<4096x4096xbf16> } diff --git a/test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir b/test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir index d9efaca66..fdd3ce1bc 100644 --- a/test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir +++ b/test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir @@ -134,8 +134,22 @@ module { %15 = tensor.empty() : tensor<2x128x16xbf16> /// CHECK: %[[TRANSPOSE_OUT:.*]] = linalg.transpose ins(%[[EXPAND_OUT]] : %transposed = linalg.transpose ins(%expanded : tensor<128x2x16xbf16>) outs(%15 : tensor<2x128x16xbf16>) permutation = [1, 0, 2] - /// CHECK: %[[MATMUL_OUT:.*]] = linalgx.batch_reduce_matmul_vnni ins(%[[TRANSPOSE_OUT]], %[[COLLAPSE_OUT]] : - %16 = linalgx.batch_reduce_matmul_vnni ins(%transposed, %collapsed : tensor<2x128x16xbf16>, tensor<2x8x32x2xbf16>) outs(%arg6 : tensor<128x32xf32>) -> tensor<128x32xf32> + /// CHECK: %[[MATMUL_OUT:.*]] = linalg.generic + /// CHECK-SAME: ins(%[[TRANSPOSE_OUT]], %[[COLLAPSE_OUT]] : + %16 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3 * 2 + d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>], + iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} + ins(%transposed, %collapsed : tensor<2x128x16xbf16>, tensor<2x8x32x2xbf16>) + outs(%arg6 : tensor<128x32xf32>) { + ^bb0(%in: bf16, %in_0: bf16, %out: f32): + %b1 = arith.extf %in : bf16 to f32 + %b2 = arith.extf %in_0 : bf16 to f32 + %b3 = arith.mulf %b1, %b2 : f32 + %b4 = arith.addf %out, %b3 : f32 + linalg.yield %b4 : f32 + } -> tensor<128x32xf32> %17 = arith.addi %arg5, %c2 : index %18 = arith.cmpi sge, %17, %c64 : index /// CHECK: %[[IF_RESULT:.*]] = scf.if @@ -339,16 +353,44 @@ module { %collapse_6 = tensor.collapse_shape %extracted_slice_6 [[0, 1, 2], [3]] : tensor<1x1x32x32xbf16> into tensor<32x32xbf16> %13 = arith.cmpi eq, %arg5, %c0 : index /// CHECK: %[[IF_RESULT_1:.*]] = scf.if - /// CHECK: linalgx.batch_reduce_matmul_vnni ins(%[[COLLAPSE_OUT_1]], %[[COLLAPSE_OUT_2]] : + /// CHECK: linalg.generic + /// CHECK-SAME: ins(%[[COLLAPSE_OUT_1]], %[[COLLAPSE_OUT_2]] : /// CHECK: } else { - /// CHECK: linalgx.batch_reduce_matmul_vnni ins(%[[COLLAPSE_OUT_1]], %[[COLLAPSE_OUT_2]] : + /// CHECK: linalg.generic + /// CHECK-SAME: ins(%[[COLLAPSE_OUT_1]], %[[COLLAPSE_OUT_2]] : /// CHECK: } %14 = scf.if %13 -> (tensor<32x32xf32>) { %18 = linalg.fill ins(%cst : bf16) outs(%collapse_5 : tensor<32x32xf32>) -> tensor<32x32xf32> - %19 = linalgx.batch_reduce_matmul_vnni ins(%collapse_3, %collapse_4 : tensor<32x32x32xbf16>, tensor<32x16x32x2xbf16>) outs(%18 : tensor<32x32xf32>) -> tensor<32x32xf32> + %19 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3 * 2 + d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>], + iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} + ins(%collapse_3, %collapse_4 : tensor<32x32x32xbf16>, tensor<32x16x32x2xbf16>) + outs(%18 : tensor<32x32xf32>) { + ^bb0(%in: bf16, %in_0: bf16, %out: f32): + %b1 = arith.extf %in : bf16 to f32 + %b2 = arith.extf %in_0 : bf16 to f32 + %b3 = arith.mulf %b1, %b2 : f32 + %b4 = arith.addf %out, %b3 : f32 + linalg.yield %b4 : f32 + } -> tensor<32x32xf32> scf.yield %19 : tensor<32x32xf32> } else { - %18 = linalgx.batch_reduce_matmul_vnni ins(%collapse_3, %collapse_4 : tensor<32x32x32xbf16>, tensor<32x16x32x2xbf16>) outs(%collapse_5 : tensor<32x32xf32>) -> tensor<32x32xf32> + %18 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3 * 2 + d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>], + iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} + ins(%collapse_3, %collapse_4 : tensor<32x32x32xbf16>, tensor<32x16x32x2xbf16>) + outs(%collapse_5 : tensor<32x32xf32>) { + ^bb0(%in: bf16, %in_0: bf16, %out: f32): + %b1 = arith.extf %in : bf16 to f32 + %b2 = arith.extf %in_0 : bf16 to f32 + %b3 = arith.mulf %b1, %b2 : f32 + %b4 = arith.addf %out, %b3 : f32 + linalg.yield %b4 : f32 + } -> tensor<32x32xf32> scf.yield %18 : tensor<32x32xf32> } %15 = arith.addi %arg5, %c32 : index