diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index b94c5fce64f83..3cd25c3cb2fc2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -264,109 +264,172 @@ struct CombineContractResultTranspose final /// iterator_types = ["parallel", "parallel", "reduction"], /// kind = add} %arg0, %arg1, %cst_f0 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> -/// ``` -struct CombineContractBroadcast - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ContractionOp contractOp, - PatternRewriter &rewriter) const override { - SmallVector maps = - llvm::to_vector<4>(contractOp.getIndexingMapsArray()); - Value lhs = contractOp.getLhs(); - Value rhs = contractOp.getRhs(); - size_t index = 0; - bool changed = false; - for (Value *operand : {&lhs, &rhs}) { - AffineMap &map = maps[index++]; - auto broadcast = operand->getDefiningOp(); - if (!broadcast) - continue; - // contractionOp can only take vector as operands. - auto srcType = dyn_cast(broadcast.getSourceType()); - if (!srcType || - srcType.getRank() == broadcast.getResultVectorType().getRank()) - continue; - int64_t rankDiff = - broadcast.getResultVectorType().getRank() - srcType.getRank(); - bool innerDimBroadcast = false; - SmallVector originalDims; - for (const auto &dim : llvm::enumerate(srcType.getShape())) { - if (dim.value() != broadcast.getResultVectorType().getDimSize( - rankDiff + dim.index())) { - innerDimBroadcast = true; - break; - } - originalDims.push_back( - rewriter.getAffineDimExpr(dim.index() + rankDiff)); +/// ``` +/// +/// For masked vector.contract, the mask requires updating when a dimension is +/// dropped. In such cases, the dropped dimensions must correspond to the mask's +/// leading unit dimensions. Supporting more generic cases (e.g. non-unit dims) +/// is not supported. +FailureOr combineContractAndBroadcast(vector::ContractionOp contractOp, + MaskingOpInterface maskingOp, + PatternRewriter &rewriter) { + SmallVector maps = + llvm::to_vector<4>(contractOp.getIndexingMapsArray()); + Value lhs = contractOp.getLhs(); + Value rhs = contractOp.getRhs(); + size_t index = 0; + bool changed = false; + for (Value *operand : {&lhs, &rhs}) { + AffineMap &map = maps[index++]; + auto broadcast = operand->getDefiningOp(); + if (!broadcast) + continue; + // contractionOp can only take vector as operands. + auto srcType = dyn_cast(broadcast.getSourceType()); + if (!srcType || + srcType.getRank() == broadcast.getResultVectorType().getRank()) + continue; + int64_t rankDiff = + broadcast.getResultVectorType().getRank() - srcType.getRank(); + bool innerDimBroadcast = false; + SmallVector originalDims; + for (const auto &dim : llvm::enumerate(srcType.getShape())) { + if (dim.value() != + broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) { + innerDimBroadcast = true; + break; } - // Contract doesn't support inner dimension broadcast. Once this is - // relaxed we can remove this case. - if (innerDimBroadcast) - continue; + originalDims.push_back(rewriter.getAffineDimExpr(dim.index() + rankDiff)); + } + // Contract doesn't support inner dimension broadcast. Once this is + // relaxed we can remove this case. + if (innerDimBroadcast) + continue; - // It would be incorrect to fold a broadcast onto a reduction dimension - // of non-unit size. - bool nonUnitDimReductionBroadcast = false; - for (int64_t i = 0; i < rankDiff; ++i) { - if (broadcast.getResultVectorType().getDimSize(i) != 1 && - isReductionIterator(contractOp.getIteratorTypes() - .getValue()[map.getDimPosition(i)])) { - nonUnitDimReductionBroadcast = true; - break; - } + // It would be incorrect to fold a broadcast onto a reduction dimension + // of non-unit size. + bool nonUnitDimReductionBroadcast = false; + for (int64_t i = 0; i < rankDiff; ++i) { + if (broadcast.getResultVectorType().getDimSize(i) != 1 && + isReductionIterator(contractOp.getIteratorTypes() + .getValue()[map.getDimPosition(i)])) { + nonUnitDimReductionBroadcast = true; + break; } - if (nonUnitDimReductionBroadcast) - continue; - - AffineMap broadcastMap = - AffineMap::get(broadcast.getResultVectorType().getRank(), 0, - originalDims, contractOp.getContext()); - map = broadcastMap.compose(map); - *operand = broadcast.getSource(); - changed = true; } + if (nonUnitDimReductionBroadcast) + continue; - if (!changed) - return failure(); + AffineMap broadcastMap = + AffineMap::get(broadcast.getResultVectorType().getRank(), 0, + originalDims, contractOp.getContext()); + map = broadcastMap.compose(map); + *operand = broadcast.getSource(); + changed = true; + } - // Determine which dims are usused, now that the maps have been composed - // with the broadcast maps. - llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps); - // Compress unused dims. - for (auto &m : maps) - m = compressDims(m, unusedDimsBitVector); - // Compute the combined iterators. - SmallVector iterators; - for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) { - if (!unusedDimsBitVector.test(i)) - iterators.push_back(contractOp.getIteratorTypes().getValue()[i]); - } - // Check that compressing unused dims isn't removing all reduction dimension - // pairs. For example, if the vector.contract had only one reduction - // iterator and that was a unit-dimension created by a broadcast, - // then we should bail here, otherwise we would create a contract without - // a reduction dimension pair. - bool hasReductionIteratorApplyingOnBothSides = false; - for (unsigned i = 0; i < iterators.size(); ++i) { - if (!isReductionIterator(iterators[i])) - continue; - if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) { - hasReductionIteratorApplyingOnBothSides = true; + if (!changed) + return failure(); + + // Determine which dims are usused, now that the maps have been composed + // with the broadcast maps. + llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps); + // Compress unused dims. + for (auto &m : maps) + m = compressDims(m, unusedDimsBitVector); + // Compute the combined iterators. + SmallVector iterators; + for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) { + if (!unusedDimsBitVector.test(i)) + iterators.push_back(contractOp.getIteratorTypes().getValue()[i]); + } + + // Check whether any of the unused dims is non-unit, e.g.: + // * vector.broadcast %arg0 : vector<8x4xi32> to vector<2x8x4xi32> + // This is only required when collapsing a mask. If there is no mask, skip. + VectorType oldMaskType; + bool isAnyUnusedDimNonUnit = false; + if (maskingOp) { + oldMaskType = cast(maskingOp.getMask().getType()); + for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) { + if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) { + isAnyUnusedDimNonUnit = true; break; } } - if (!hasReductionIteratorApplyingOnBothSides) - return failure(); + } - // If the compressed maps have a dimension that is not used by either LHS or - // RHS then the ContractionOp verifier would fail. - if (getUnusedDimsBitVector({maps[0], maps[1]}).any()) - return failure(); - rewriter.replaceOpWithNewOp( - contractOp, lhs, rhs, contractOp.getAcc(), - rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators)); - return success(); + // Check that compressing unused dims isn't removing all reduction dimension + // pairs. For example, if the vector.contract had only one reduction + // iterator and that was a unit-dimension created by a broadcast, + // then we should bail here, otherwise we would create a contract without + // a reduction dimension pair. + bool hasReductionIteratorApplyingOnBothSides = false; + for (unsigned i = 0; i < iterators.size(); ++i) { + if (!isReductionIterator(iterators[i])) + continue; + if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) { + hasReductionIteratorApplyingOnBothSides = true; + break; + } + } + if (!hasReductionIteratorApplyingOnBothSides) + return failure(); + + // If the compressed maps have a dimension that is not used by either LHS or + // RHS then the ContractionOp verifier would fail. + if (getUnusedDimsBitVector({maps[0], maps[1]}).any()) + return failure(); + + Operation *newOp = rewriter.create( + contractOp.getLoc(), lhs, rhs, contractOp.getAcc(), + rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators)); + + // Handle the mask. + if (maskingOp) { + if (isAnyUnusedDimNonUnit) + return rewriter.notifyMatchFailure(contractOp, + "Cannont drop non-unit mask dim."); + assert(unusedDimsBitVector.size() == + static_cast(oldMaskType.getRank()) && + "The mask rank is incorrect!"); + + // If a dimension has been dropped, update the mask accordingly. Otherwise, + // keep it as is. + Value mask = maskingOp.getMask(); + if (unusedDimsBitVector.count() != 0) { + // At this point, two assumptions are made: + // * The unused dimensions are the leading mask dimensions + // (vector.contract does not support inner dim broadcasting). + // * The unused dimensions are all unit. + // These conditions are effectively verified in the blocks preceeding this + // one. + auto newShape = + oldMaskType.getShape().drop_front(unusedDimsBitVector.count()); + auto newShapeScalableDims = + oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count()); + VectorType maskOpType = + VectorType::get(newShape, rewriter.getI1Type(), newShapeScalableDims); + mask = rewriter + .create(contractOp.getLoc(), maskOpType, + maskingOp.getMask()) + .getResult(); + } + + newOp = mlir::vector::maskOperation(rewriter, newOp, mask); + } + return newOp->getResult(0); +} + +struct CombineContractBroadcastMask + : public MaskableOpRewritePattern { + using MaskableOpRewritePattern::MaskableOpRewritePattern; + FailureOr + + matchAndRewriteMaskableOp(vector::ContractionOp contractOp, + MaskingOpInterface maskingOp, + PatternRewriter &rewriter) const override { + return combineContractAndBroadcast(contractOp, maskingOp, rewriter); } }; @@ -2237,7 +2300,7 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT( void mlir::vector::populateVectorReductionToContractPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add( patterns.getContext(), benefit); } diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir index 24070dbf017a5..0bf38ba5947c0 100644 --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -1,11 +1,15 @@ // RUN: mlir-opt %s -test-vector-reduction-to-contract-patterns -split-input-file | FileCheck %s -// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// TODO: Seperate tests for vector.multi_reduction -> vector.contract and +// * pre-op + vector.contract -> vector.contract, +// * vector.contract + post-op -> vector.contract. + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK-LABEL: multidimreduction_contract // CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xf32>, %[[ARG1:.*]]: vector<8x32x16xf32>, %[[ARG2:.*]]: vector<8x16xf32>) -// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]], +// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]], // CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32> // CHECK-NEXT: return %[[R]] : vector<8x16xf32> @@ -13,17 +17,16 @@ func.func @multidimreduction_contract( %arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>, %acc: vector<8x16xf32>) -> vector<8x16xf32> { %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> %1 = vector.multi_reduction , %0, %acc [1] : vector<8x32x16xf32> to vector<8x16xf32> - return %1 : vector<8x16xf32> -} + return %1 : vector<8x16xf32> } // ----- -// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK-LABEL: multidimreduction_contract_int // CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xi32>, %[[ARG1:.*]]: vector<8x32x16xi32>, %[[ARG2:.*]]: vector<8x16xi32>) -// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]], +// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]], // CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32> // CHECK-NEXT: return %[[R]] : vector<8x16xi32> @@ -36,17 +39,21 @@ func.func @multidimreduction_contract_int( // ----- +//----------------------------------------------------------------------------- +// [Pattern: CombineContractABTranspose] +//----------------------------------------------------------------------------- + #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> #map1 = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)> -// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: contract_transpose // CHECK-SAME: (%[[ARG0:.+]]: vector<32x16x8xf32>, // CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32> -// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], +// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} // CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16x8xf32>, vector<8x32x16xf32> into vector<8x32xf32> // CHECK-NEXT: return %[[R]] : vector<8x32xf32> @@ -62,17 +69,21 @@ func.func @contract_transpose( // ----- +//----------------------------------------------------------------------------- +// [Pattern: CombineContractBroadcast] +//----------------------------------------------------------------------------- + #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> #map1 = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: contract_broadcast // CHECK-SAME: (%[[ARG0:.+]]: vector<32x16xf32>, // CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32> -// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], +// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} // CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> // CHECK-NEXT: return %[[R]] : vector<8x32xf32> @@ -87,6 +98,79 @@ func.func @contract_broadcast( } // ----- + +// Same as above, but with a mask. + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: contract_broadcast_masked +// CHECK-SAME: %[[ARG0:.*]]: vector<32x16xf32>, +// CHECK-SAME: %[[ARG1:.*]]: vector<8x32x16xf32>, +// CHECK-SAME: %[[MASK:.*]]: vector<8x32x16xi1>) -> vector<8x32xf32> { +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<8x32xf32> +// CHECK: %[[R:.*]] = vector.mask %[[MASK]] { +// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], +// CHECK-SAME: kind = #vector.kind} +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[C0]] : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> +// CHECK-SAME } : vector<8x32x16xi1> -> vector<8x32xf32> +// CHECK: return %[[R]] : vector<8x32xf32> +func.func @contract_broadcast_masked( + %arg0: vector<32x16xf32>, %arg1: vector<8x32x16xf32>, %mask: vector<8x32x16xi1>) -> vector<8x32xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<8x32xf32> + %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> + %1 = vector.mask %mask { + vector.contract {indexing_maps = [#map0, #map0, #map1], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind + } %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> + } : vector<8x32x16xi1> -> vector<8x32xf32> + return %1 : vector<8x32xf32> +} + +// ----- + +// Same as above, but with a scalable dim. + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: contract_broadcast_masked_scalable +// CHECK-SAME: %[[ARG0:.*]]: vector<[32]x16xf32>, +// CHECK-SAME: %[[ARG1:.*]]: vector<8x[32]x16xf32>, +// CHECK-SAME: %[[MASK:.*]]: vector<8x[32]x16xi1>) -> vector<8x32xf32> { +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<8x32xf32> +// CHECK: %[[R:.*]] = vector.mask %[[MASK]] { +// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], +// CHECK-SAME: kind = #vector.kind} +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[C0]] : vector<[32]x16xf32>, vector<8x[32]x16xf32> into vector<8x32xf32> +// CHECK-SAME } : vector<8x[32]x16xi1> -> vector<8x32xf32> +// CHECK: return %[[R]] : vector<8x32xf32> +func.func @contract_broadcast_masked_scalable( + %arg0: vector<[32]x16xf32>, %arg1: vector<8x[32]x16xf32>, %mask: vector<8x[32]x16xi1>) -> vector<8x32xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<8x32xf32> + %0 = vector.broadcast %arg0 : vector<[32]x16xf32> to vector<8x[32]x16xf32> + %1 = vector.mask %mask { + vector.contract {indexing_maps = [#map0, #map0, #map1], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind + } %0, %arg1, %cst : vector<8x[32]x16xf32>, vector<8x[32]x16xf32> into vector<8x32xf32> + } : vector<8x[32]x16xi1> -> vector<8x32xf32> + return %1 : vector<8x32xf32> +} + +// ----- + // Test that CombineContractBroadcast is able to combine a broadcast that // creates a unit dim that is consumed by a reduction iterator, dropping that // reduction iterator, as long as there is another reduction iterator left. @@ -95,14 +179,14 @@ func.func @contract_broadcast( #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> -// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: contract_broadcast_unit_dim_reduction // CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>) // CHECK: vector.contract -// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]] +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32> func.func @contract_broadcast_unit_dim_reduction(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>) -> vector<8x8xi32> { @@ -116,6 +200,72 @@ func.func @contract_broadcast_unit_dim_reduction(%arg0 : vector<8x4xi32>, %arg1 return %result : vector<8x8xi32> } +// ----- + +// Same as above, but with a mask. + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked +// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>, %[[MASK:.+]]: vector<1x8x8x4xi1>) +// CHECK: %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x8x4xi1> to vector<8x8x4xi1> +// CHECK: %[[R:.*]] = vector.mask %[[MASK_SC]] { +// CHECK-SAME: vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32> +func.func @contract_broadcast_unit_dim_reduction_masked(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>, %mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> { + %0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32> + %1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32> + %result = vector.mask %mask { + vector.contract { + indexing_maps = [#map0, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind + } %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32> + } : vector<1x8x8x4xi1> -> vector<8x8xi32> + return %result : vector<8x8xi32> +} + +// ----- + +// Same as above, but with a scalable dim. + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked_scalable +// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<[8]x4xi32>, %[[ARG2:.+]]: vector<8x[8]xi32>, %[[MASK:.+]]: vector<1x8x[8]x4xi1>) +// CHECK: %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x[8]x4xi1> to vector<8x[8]x4xi1> +// CHECK: %[[R:.*]] = vector.mask %[[MASK_SC]] { +// CHECK-SAME: vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<[8]x4xi32> into vector<8x[8]xi32> +func.func @contract_broadcast_unit_dim_reduction_masked_scalable(%arg0 : vector<8x4xi32>, %arg1 : vector<[8]x4xi32>, %arg2 : vector<8x[8]xi32>, %mask: vector<1x8x[8]x4xi1>) -> vector<8x[8]xi32> { + %0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32> + %1 = vector.broadcast %arg1 : vector<[8]x4xi32> to vector<1x[8]x4xi32> + %result = vector.mask %mask { + vector.contract { + indexing_maps = [#map0, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind + } %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x[8]x4xi32> into vector<8x[8]xi32> + } : vector<1x8x[8]x4xi1> -> vector<8x[8]xi32> + return %result : vector<8x[8]xi32> +} + // ----- // Test that CombineContractBroadcast will not combine a broadcast that creates // a non-unit dim that is consumed by a reduction iterator. @@ -127,16 +277,16 @@ func.func @contract_broadcast_unit_dim_reduction(%arg0 : vector<8x4xi32>, %arg1 #map1 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2)> -// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)> -// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> -// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)> // CHECK-LABEL: contract_broadcast_non_unit_dim_reduction_with_permutation // CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>) // CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8x4xi32> to vector<2x8x4xi32> // CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<8x4xi32> to vector<2x8x4xi32> // CHECK: vector.contract -// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]] +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel", "reduction"] // CHECK-SAME: %[[BROADCAST0]], %[[BROADCAST1]], %[[ARG2]] : vector<2x8x4xi32>, vector<2x8x4xi32> into vector<8x8xi32> func.func @contract_broadcast_non_unit_dim_reduction_with_permutation(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>) -> vector<8x8xi32> { @@ -159,16 +309,16 @@ func.func @contract_broadcast_non_unit_dim_reduction_with_permutation(%arg0 : ve #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> #map2 = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> // CHECK-LABEL: contract_broadcast_unit_dim_reduction_as_only_reduction // CHECK-SAME: (%[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>, %[[ARG2:.+]]: vector<8x8xi32>) // CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8xi32> to vector<1x8xi32> // CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<8xi32> to vector<1x8xi32> // CHECK: vector.contract -// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]] +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel"] // CHECK-SAME: %[[BROADCAST0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x8xi32>, vector<1x8xi32> into vector<8x8xi32> func.func @contract_broadcast_unit_dim_reduction_as_only_reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : vector<8x8xi32>) -> vector<8x8xi32> { @@ -191,15 +341,15 @@ func.func @contract_broadcast_unit_dim_reduction_as_only_reduction(%arg0 : vecto #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> #map2 = affine_map<(d0, d1, d2) -> (d1)> -// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d1)> // CHECK-LABEL: contract_broadcast_dimension_would_go_unused_in_lhs_rhs // CHECK-SAME: (%[[ARG0:.+]]: vector<1x2xi32>, %[[ARG1:.+]]: vector<2xi32>, %[[ARG2:.+]]: vector<1xi32>) // CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<2xi32> to vector<1x1x2xi32> // CHECK: vector.contract -// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]] +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"] // CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32> @@ -230,7 +380,7 @@ func.func @contract_broadcast_dimension_would_go_unused_in_lhs_rhs(%arg0 : vecto // CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>, %[[ARG2:.+]]: vector<1xf32>) // CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<1xf32> to vector<1x1xf32> // CHECK: vector.contract -// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]] +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: iterator_types = ["parallel", "reduction"] // CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1xf32>, vector<1x1xf32> into vector<1xf32> @@ -247,6 +397,10 @@ func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1x // ----- +//----------------------------------------------------------------------------- +// [Pattern: CombineContractResultTranspose] +//----------------------------------------------------------------------------- + // CHECK-DAG: #[[$LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)> // CHECK-DAG: #[[$RHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> // CHECK-DAG: #[[$ACC_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>