@@ -47,15 +47,67 @@ inline SmallVector<unsigned> extractDimTypeIdx(ArrayRef<DimType> tyList,
4747 return idxList;
4848}
4949
50+ inline void getDimTypeFromIterators (linalg::LinalgOp linalgOp,
51+ SmallVectorImpl<DimType> &dimTypes) {
52+ SmallVector<mlir::utils::IteratorType> iteratorTypes =
53+ linalgOp.getIteratorTypesArray ();
54+
55+ for (const auto &&[idx, iterType] : llvm::enumerate (iteratorTypes)) {
56+ if (iterType == mlir::utils::IteratorType::parallel) {
57+ SmallVector<std::pair<Value, unsigned >> operandDimPairs;
58+ linalgOp.mapIterationSpaceDimToAllOperandDims (idx, operandDimPairs);
59+ if (operandDimPairs.size () == 3 ) {
60+ dimTypes.push_back (DimType::Batch);
61+ } else if (llvm::any_of (operandDimPairs,
62+ [&](std::pair<Value, unsigned > pair) {
63+ return pair.first ==
64+ dyn_cast<linalg::ContractionOpInterface>(
65+ linalgOp.getOperation ())
66+ .lhs ();
67+ })) {
68+ dimTypes.push_back (DimType::M);
69+ } else {
70+ dimTypes.push_back (DimType::N);
71+ }
72+ } else if (iterType == mlir::utils::IteratorType::reduction) {
73+ dimTypes.push_back (DimType::K);
74+ }
75+ }
76+ }
77+
78+ inline SmallVector<DimType>
79+ matchOperandToDimTypes (linalg::LinalgOp linalgOp, OpOperand *operand,
80+ ArrayRef<DimType> allDimTypes) {
81+ ArrayRef<AffineExpr> map =
82+ linalgOp.getMatchingIndexingMap (operand).getResults ();
83+ SmallVector<DimType> res;
84+ for (const AffineExpr &dim : map) {
85+ AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(dim);
86+ res.push_back (allDimTypes[dimExpr.getPosition ()]);
87+ }
88+ return res;
89+ }
90+
91+ inline SmallVector<SmallVector<DimType>>
92+ getContractionOpOperandDimType (linalg::LinalgOp linalgOp) {
93+ SmallVector<DimType> dimTypes;
94+ getDimTypeFromIterators (linalgOp, dimTypes);
95+ SmallVector<DimType> ADimTypes = matchOperandToDimTypes (
96+ linalgOp, linalgOp.getDpsInputOperand (0 ), dimTypes);
97+ SmallVector<DimType> BDimTypes = matchOperandToDimTypes (
98+ linalgOp, linalgOp.getDpsInputOperand (1 ), dimTypes);
99+ SmallVector<DimType> CDimTypes =
100+ matchOperandToDimTypes (linalgOp, linalgOp.getDpsInitOperand (0 ), dimTypes);
101+
102+ return SmallVector<SmallVector<DimType>>{ADimTypes, BDimTypes, CDimTypes};
103+ }
104+
50105// Get the operand dim type for every operand for the given linalg op
51106inline FailureOr<SmallVector<SmallVector<DimType>>>
52107getOprandDimType (linalg::LinalgOp &linalgOp) {
53108 // TODO: replace the linalgx op with generic op
54- if (llvm::isa<linalg::MatmulOp>(linalgOp)) {
55- return SmallVector<SmallVector<DimType>>{
56- SmallVector<DimType>{DimType::M, DimType::K},
57- SmallVector<DimType>{DimType::K, DimType::N},
58- SmallVector<DimType>{DimType::M, DimType::N}};
109+ if (llvm::isa<linalg::ContractionOpInterface>(linalgOp.getOperation ())) {
110+ return getContractionOpOperandDimType (linalgOp);
59111 } else if (linalgx::isGenericPackedMatmulOp (
60112 linalgOp.getOperation (), linalgx::PackingType::VNNI_MM2D) ||
61113 llvm::isa<linalgx::Mm2DVnniOp>(linalgOp)) {
@@ -72,31 +124,6 @@ getOprandDimType(linalg::LinalgOp &linalgOp) {
72124 SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
73125 DimType::K},
74126 SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
75- } else if (llvm::isa<linalg::BatchMatmulOp>(linalgOp)) {
76- return SmallVector<SmallVector<DimType>>{
77- SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
78- SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
79- SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
80- } else if (llvm::isa<linalg::MatmulTransposeAOp>(linalgOp)) {
81- return SmallVector<SmallVector<DimType>>{
82- SmallVector<DimType>{DimType::K, DimType::M},
83- SmallVector<DimType>{DimType::K, DimType::N},
84- SmallVector<DimType>{DimType::M, DimType::N}};
85- } else if (llvm::isa<linalg::MatmulTransposeBOp>(linalgOp)) {
86- return SmallVector<SmallVector<DimType>>{
87- SmallVector<DimType>{DimType::M, DimType::K},
88- SmallVector<DimType>{DimType::N, DimType::K},
89- SmallVector<DimType>{DimType::M, DimType::N}};
90- } else if (llvm::isa<linalg::BatchMatmulTransposeAOp>(linalgOp)) {
91- return SmallVector<SmallVector<DimType>>{
92- SmallVector<DimType>{DimType::Batch, DimType::K, DimType::M},
93- SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
94- SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
95- } else if (llvm::isa<linalg::BatchMatmulTransposeBOp>(linalgOp)) {
96- return SmallVector<SmallVector<DimType>>{
97- SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
98- SmallVector<DimType>{DimType::Batch, DimType::N, DimType::K},
99- SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
100127 } else if (linalgx::isGenericPackedMatmulOp (linalgOp.getOperation (),
101128 linalgx::PackingType::MM4D)) {
102129 return SmallVector<SmallVector<DimType>>{
0 commit comments