diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp index 95965872f4098..1e8e1265affa0 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp @@ -724,59 +724,6 @@ struct LiftIllegalVectorTransposeToMemory } }; -/// A rewrite to turn unit dim transpose-like vector.shape_casts into -/// vector.transposes. The shape_cast has to be from an illegal vector type to a -/// legal one (as defined by isLegalVectorType). -/// -/// The reasoning for this is if we've got to this pass and we still have -/// shape_casts of illegal types, then they likely will not cancel out. Turning -/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to -/// eliminate them. -/// -/// Example: -/// -/// BEFORE: -/// ```mlir -/// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32> -/// ``` -/// -/// AFTER: -/// ```mlir -/// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> -/// ``` -struct ConvertIllegalShapeCastOpsToTransposes - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, - PatternRewriter &rewriter) const override { - auto sourceType = shapeCastOp.getSourceVectorType(); - auto resultType = shapeCastOp.getResultVectorType(); - if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType)) - return rewriter.notifyMatchFailure(shapeCastOp, - kMatchFailureNotIllegalToLegal); - - // Note: If we know that `sourceType` is an illegal vector type (and 2D) - // then dim 0 is scalable and dim 1 is fixed. - if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1) - return rewriter.notifyMatchFailure( - shapeCastOp, "expected source to be a 2D scalable vector with a " - "trailing unit dim"); - - auto loc = shapeCastOp.getLoc(); - auto transpose = rewriter.create( - loc, shapeCastOp.getSource(), ArrayRef{1, 0}); - - if (resultType.getRank() == 1) - rewriter.replaceOpWithNewOp(shapeCastOp, resultType, - transpose); - else - rewriter.replaceOp(shapeCastOp, transpose); - - return success(); - } -}; - /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use /// the ZA state. This workaround rewrite to support these transposes when ZA is /// available. @@ -920,6 +867,116 @@ struct LowerIllegalTransposeStoreViaZA } }; +/// Lower `vector.transfer_read` of a scalable column to `scf::for` +/// +/// Lowers a "read" of a scalable column from a MemRef for which there is no +/// hardware pperation that we could use to a loop over the rows to read and +/// loads one element at a time. +/// +/// BEFORE: +/// ``` +/// %res = vector.transfer_read %mem[%a, %b] (...) +/// : memref, vector<[4]x1xf32> +/// ``` +/// +/// AFTER: +/// ``` +/// %cst = arith.constant (...) : vector<[4]xf32> +/// %vscale = vector.vscale +/// %c4_vscale = arith.muli %vscale, %c4 : index +/// %scf = scf.for %lb = %c0 to %c4_vscale step %c1 iter_args(%arg4 = %cst) +/// -> (vector<[4]xf32>) { +/// +/// %load = memref.load %mem[%arg3 + %a, %b] : memref +/// %vec = vector.insert %load, %cst [%arg3] : f32 into vector<[4]xf32> +/// scf.yield %vec : vector<[4]xf32> +/// } +/// %res = vector.shape_cast %scf : vector<[4]xf32> to vector<[4]x1xf32> +/// ``` +/// +/// TODO: This transformation isn't specific to SME - move it to the SVE +/// dialect. +/// TODO: Check the in_bounds attribute and generate vector.maskedload if +/// required. +struct LowerColumnTransferReadToLoops + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, + PatternRewriter &rewriter) const override { + // NOTE: This is a fairly low-level transformation, so we shouldn't be + // adding support for Tensors without good rationale. + if (readOp.hasPureTensorSemantics()) + return rewriter.notifyMatchFailure( + readOp, "Tensor semantics are unsupported (either bufferize or " + "extend this pattern)"); + + auto resType = readOp.getVectorType(); + + if (resType.getRank() != 2) + return rewriter.notifyMatchFailure(readOp, + "Only 2D vectors are supported!"); + + if (resType.getShape()[1] != 1) + return rewriter.notifyMatchFailure( + readOp, "The trailing output dim is != 1 (not supported ATM)"); + + if (!resType.getScalableDims()[0] || resType.getScalableDims()[1]) + return rewriter.notifyMatchFailure( + readOp, "Expected the leading dim to be scalable and the trailing " + "dim to be fixed."); + + // Create new result type - similar to the original vector with the + // trailing unit dim collapsed. + int64_t numRows = resType.getShape()[0]; + VectorType newResType = VectorType::get(numRows, resType.getElementType(), + /*scalableDims=*/{true}); + + // Create a loop over all rows and load one element at a time. + auto loc = readOp.getLoc(); + auto lowerBound = rewriter.create(loc, 0); + auto createVscaleMultiple = + vector::makeVscaleConstantBuilder(rewriter, loc); + auto upperBound = createVscaleMultiple(numRows); + auto step = rewriter.create(loc, 1); + Value init = rewriter.create( + loc, newResType, DenseElementsAttr::get(newResType, 0.0f)); + + scf::ForOp loadLoop; + { + OpBuilder::InsertionGuard g(rewriter); + loadLoop = rewriter.create(loc, lowerBound, upperBound, step, + ValueRange{init}); + rewriter.setInsertionPointToStart(loadLoop.getBody()); + + auto tileSliceIndex = loadLoop.getInductionVar(); + + auto idx0 = rewriter.create(loc, tileSliceIndex, + readOp.getIndices()[0]); + auto idx1 = readOp.getIndices()[1]; + + Value scalar = rewriter.create( + loc, readOp.getBase(), SmallVector({idx0, idx1})); + + Operation *updateInit = rewriter.create( + loc, scalar, loadLoop.getRegionIterArg(0), tileSliceIndex); + + rewriter.create(loc, updateInit->getResult(0)); + } + + // The read operation has been "legalized", but since the original result + // type was a 2D vector, we need to cast before returning the result. This + // ShapeCast should cancel-out with some other ShapeCast (i.e. it's a + // no-op). + auto sc = rewriter.create( + loc, readOp.getResult().getType(), loadLoop.getResult(0)); + + rewriter.replaceOp(readOp, sc); + + return success(); + } +}; + struct VectorLegalizationPass : public arm_sme::impl::VectorLegalizationBase { void runOnOperation() override { @@ -941,10 +998,10 @@ struct VectorLegalizationPass // Apply preprocessing patterns. RewritePatternSet rewritePatterns(context); - rewritePatterns.add(context); + rewritePatterns + .add(context); if (failed( applyPatternsGreedily(getOperation(), std::move(rewritePatterns)))) return signalPassFailure(); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 2a2357319bd23..887773172339f 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5758,18 +5758,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { // shape_cast(transpose(x)) -> shape_cast(x) if (auto transpose = getSource().getDefiningOp()) { - // This folder does - // shape_cast(transpose) -> shape_cast - // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does - // shape_cast -> shape_cast(transpose) - // i.e. the complete opposite. When paired, these 2 patterns can cause - // infinite cycles in pattern rewriting. - // ConvertIllegalShapeCastOpsToTransposes only matches on scalable - // vectors, so by disabling this folder for scalable vectors the - // cycle is avoided. - // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is - // still needed. If it's not, then we can fold here. - if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) { + if (isOrderPreserving(transpose)) { setOperand(transpose.getVector()); return getResult(); } diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir index d56df9814f173..6cdf576272ebc 100644 --- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir +++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir @@ -491,51 +491,6 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v // ----- -// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d( -// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>) -func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> { - // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> - %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32> - return %0 : vector<1x[4]xf32> -} - -// ----- - -// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d( -// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>) -func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> { - // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> - // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32> - %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32> - return %0 : vector<[4]xf32> -} - -// ----- - -// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory -func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref) -> vector<1x[4]xf32> { - // CHECK: vector.transfer_read {{.*}} : memref, vector<1x[4]xf32> - // CHECK-NOT: vector.shape_cast - %pad = arith.constant 0.0 : f32 - %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref, vector<[4]x1xf32> - %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32> - return %cast : vector<1x[4]xf32> -} - -// ----- - -// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory -func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref) -> vector<[4]xf32> { - // CHECK: vector.transfer_read {{.*}} : memref, vector<1x[4]xf32> - // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32> - %pad = arith.constant 0.0 : f32 - %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref, vector<[4]x1xf32> - %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32> - return %cast : vector<[4]xf32> -} - -// ----- - // CHECK-LABEL: @multi_tile_splat func.func @multi_tile_splat() -> vector<[8]x[8]xi32> { @@ -656,3 +611,59 @@ func.func @vector_mask_without_maskable_op(%mask: vector<16x2xi1>, %vec: vector< %0 = vector.mask %mask { vector.yield %vec : vector<16x16xf32> } : vector<16x2xi1> -> vector<16x16xf32> return %0 : vector<16x16xf32> } + +// ----- + +//============================================================================= +// 1D examples - to be moved to the SVE dialect +//============================================================================= + +/// TODO: Handle in_bounds + +// CHECK-LABEL: func.func @xfer_read_scalable_column( +// CHECK-SAME: %[[IDX_0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[PAD:.*]]: f32, +// CHECK-SAME: %[[SRC:.*]]: memref) -> vector<[4]x1xf32> { +func.func @xfer_read_scalable_column(%a: index, %b: index, %pad: f32, %src: memref) -> (vector<[4]x1xf32>) { + // CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32> + // CHECK: %[[STEP:.*]] = arith.constant 1 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index + // CHECK: %[[LB:.*]] = arith.constant 0 : index + // CHECK: %[[VSCALE:.*]] = vector.vscale + // CHECK: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index + + // + // CHECK: %[[SCF:.*]] = scf.for %[[IND_VAR:.*]] = %[[LB]] to %[[C4_VSCALE]] step %[[STEP]] iter_args(%[[SCF_RES:.*]] = %[[INIT]]) -> (vector<[4]xf32>) { + // CHECK: %[[IDX_0_UPDATED:.*]] = arith.addi %[[IND_VAR]], %[[IDX_0]] : index + // CHECK: %[[VAL_10:.*]] = memref.load %[[SRC]][%[[IDX_0_UPDATED]], %[[IDX_1]]] : memref + // CHECK: %[[RES_UPDATED:.*]] = vector.insert %[[VAL_10]], %[[SCF_RES]] [%[[IND_VAR]]] : f32 into vector<[4]xf32> + // CHECK: scf.yield %[[RES_UPDATED]] : vector<[4]xf32> + // CHECK: } + + // + // CHECK: %[[SC:.*]] = vector.shape_cast %[[SCF]] : vector<[4]xf32> to vector<[4]x1xf32> + // CHECK: return %[[SC]] + %read = vector.transfer_read %src[%a, %b], %pad : memref, vector<[4]x1xf32> + return %read : vector<[4]x1xf32> +} + +// ----- + +// CHECK-LABEL: func.func @negative_xfer_read_scalable_column_x2 +func.func @negative_xfer_read_scalable_column_x2(%a: index, %b: index, %pad: f32, %src: memref) -> (vector<[4]x2xf32>) { + // CHECK-NOT: scf.for + // CHECK-NOT: memref.load + %read = vector.transfer_read %src[%a, %b], %pad : memref, vector<[4]x2xf32> + return %read : vector<[4]x2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @negative_xfer_read_scalable_column_scalable_trailing_dim +func.func @negative_xfer_read_scalable_column_scalable_trailing_dim(%a: index, %b: index, %pad: f32, %src: memref) -> (vector<4x[1]xf32>) { + // CHECK-NOT: scf.for + // CHECK-NOT: memref.load + %read = vector.transfer_read %src[%a, %b], %pad : memref, vector<4x[1]xf32> + return %read : vector<4x[1]xf32> +} diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir index c84aea6609665..f1e1c5e896c66 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir @@ -165,6 +165,25 @@ func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8 // ----- +// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows: +// 1 -> 0 +// 2 -> 4 +// Because 0 < 4, this permutation is order preserving and effectively a shape_cast. +// (same as the example above, but one of the dims is scalable) +// CHECK-LABEL: @shape_cast_of_transpose_scalable +// CHECK-SAME: %[[ARG:.*]]: vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> { +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] : +// CHECK-SAME: vector<1x[4]x4x1x1xi8> to vector<[4]x4xi8> +// CHECK: return %[[SHAPE_CAST]] : vector<[4]x4xi8> +func.func @shape_cast_of_transpose_scalable(%arg : vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> { + %0 = vector.transpose %arg, [1, 0, 3, 4, 2] + : vector<1x[4]x4x1x1xi8> to vector<[4]x1x1x1x4xi8> + %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4xi8> to vector<[4]x4xi8> + return %1 : vector<[4]x4xi8> +} + +// ----- + // In this test, the mapping of non-unit dimensions (1 and 2) is as follows: // 1 -> 2 // 2 -> 1 @@ -184,36 +203,10 @@ func.func @negative_shape_cast_of_transpose(%arg : vector<1x4x4x1xi8>) -> vector // ----- -// Currently the conversion shape_cast(transpose) -> shape_cast is disabled for -// scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes -// CHECK-LABEL: @negative_shape_cast_of_transpose_scalable -// CHECK: vector.transpose -// CHECK: vector.shape_cast -func.func @negative_shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> { - %0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8> - %1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8> - return %1 : vector<[4]xi8> -} - -// ----- - /// +-------------------------------------------------------------------------- /// Tests of FoldTransposeShapeCast: transpose(shape_cast) -> shape_cast /// +-------------------------------------------------------------------------- -// The conversion transpose(shape_cast) -> shape_cast is not disabled for scalable -// vectors. -// CHECK-LABEL: @transpose_of_shape_cast_scalable -// CHECK: vector.shape_cast -// CHECK-SAME: vector<[4]xi8> to vector<[4]x1xi8> -func.func @transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> { - %0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8> - %1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8> - return %1 : vector<[4]x1xi8> -} - -// ----- - // A transpose that is 'order preserving' can be treated like a shape_cast. // CHECK-LABEL: @transpose_of_shape_cast // CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> { @@ -229,11 +222,26 @@ func.func @transpose_of_shape_cast(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi // ----- -// Scalable dimensions should be treated as non-unit dimensions. // CHECK-LABEL: @transpose_of_shape_cast_scalable +// CHECK-SAME: %[[ARG:.*]]: vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> { +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] : +// CHECK-SAME: vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8> +// CHECK: return %[[SHAPE_CAST]] : vector<[6]x1x1xi8> +func.func @transpose_of_shape_cast_scalable(%arg : vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> { + %0 = vector.shape_cast %arg : vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8> + %1 = vector.transpose %0, [0, 2, 1] + : vector<[6]x1x1xi8> to vector<[6]x1x1xi8> + return %1 : vector<[6]x1x1xi8> +} + +// ----- + +// Scalable 1 dimensions (i.e. [1]) should be treated as non-unit dimensions +// (hence no folding). +// CHECK-LABEL: @negative_transpose_of_shape_cast_scalable_unit // CHECK: vector.shape_cast // CHECK: vector.transpose -func.func @transpose_of_shape_cast_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> { +func.func @negative_transpose_of_shape_cast_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> { %0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8> %1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8> return %1 : vector<4x[1]xi8>