diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index a295bf1eb4d95..2a2357319bd23 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4668,12 +4668,15 @@ struct TransferReadAfterWriteToBroadcast LogicalResult matchAndRewrite(TransferReadOp readOp, PatternRewriter &rewriter) const override { - if (readOp.hasOutOfBoundsDim() || - !llvm::isa(readOp.getShapedType())) - return failure(); auto defWrite = readOp.getBase().getDefiningOp(); if (!defWrite) return failure(); + // Bail if we need an alias analysis. + if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics()) + return failure(); + // Bail if we need a bounds analysis. + if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim()) + return failure(); // TODO: If the written transfer chunk is a superset of the read transfer // chunk we could do an extract_strided_slice. if (readOp.getTransferChunkAccessed() != @@ -4684,15 +4687,28 @@ struct TransferReadAfterWriteToBroadcast if (getUnusedDimsBitVector({readOp.getPermutationMap()}) != getUnusedDimsBitVector({defWrite.getPermutationMap()})) return failure(); - if (readOp.getIndices() != defWrite.getIndices() || - readOp.getMask() != defWrite.getMask()) + // This pattern should only catch the broadcast case, the non-broadcast case + // should be done separately to keep application conditions clean and + // separate. + AffineMap readMap = compressUnusedDims(readOp.getPermutationMap()); + AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap()); + bool bcast = !readMap.getBroadcastDims().empty() || + !writeMap.getBroadcastDims().empty(); + if (!bcast) + return failure(); + // At this point, we know we have a bcast. + // Bail in the masked case (too complex atm and needed to properly account + // for padding). + if (readOp.getMask() || defWrite.getMask()) + return failure(); + // If indices are not the same a shift may be required, bail. + if (readOp.getIndices() != defWrite.getIndices()) return failure(); + Value vec = defWrite.getVector(); // TODO: loop through the chain of transfer_write if we can prove that they // don't overlap with the transfer_read. This requires improving // `isDisjointTransferIndices` helper. - AffineMap readMap = compressUnusedDims(readOp.getPermutationMap()); - AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap()); AffineMap map = readMap.compose(writeMap); if (map.getNumResults() == 0) return failure(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index a06a9f67d54dc..6691cb52acdc0 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -408,7 +408,7 @@ func.func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32> // ----- // Negative test where the extract is not a subset of the element inserted. -// CHECK-LABEL: extract_strided_fold_negative +// CHECK-LABEL: negative_extract_strided_fold // CHECK-SAME: (%[[ARG0:.*]]: vector<4x4xf32>, %[[ARG1:.*]]: vector<8x16xf32> // CHECK: %[[INS:.*]] = vector.insert_strided_slice %[[ARG0]], %[[ARG1]] // CHECK-SAME: {offsets = [2, 2], strides = [1, 1]} @@ -417,7 +417,7 @@ func.func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32> // CHECK-SAME: {offsets = [2, 2], sizes = [6, 4], strides = [1, 1]} // CHECK-SAME: : vector<8x16xf32> to vector<6x4xf32> // CHECK-NEXT: return %[[EXT]] : vector<6x4xf32> -func.func @extract_strided_fold_negative(%a: vector<4x4xf32>, %b: vector<8x16xf32>) +func.func @negative_extract_strided_fold(%a: vector<4x4xf32>, %b: vector<8x16xf32>) -> (vector<6x4xf32>) { %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<8x16xf32> @@ -753,10 +753,10 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector, // ----- -// CHECK-LABEL: fold_extract_broadcast_negative +// CHECK-LABEL: negative_fold_extract_broadcast // CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32> // CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32> -func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32> { +func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> { %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32> %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32> return %r : vector<4xf32> @@ -895,11 +895,11 @@ func.func @fold_extract_shapecast_0d_source(%arg0 : vector) -> f32 { // ----- -// CHECK-LABEL: fold_extract_shapecast_negative +// CHECK-LABEL: negative_fold_extract_shapecast // CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32> // CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<4x2xf32> from vector<2x4x2xf32> // CHECK: return %[[R]] : vector<4x2xf32> -func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>) -> vector<4x2xf32> { +func.func @negative_fold_extract_shapecast(%arg0 : vector<16xf32>) -> vector<4x2xf32> { %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32> %r = vector.extract %0[1] : vector<4x2xf32> from vector<2x4x2xf32> return %r : vector<4x2xf32> @@ -1460,11 +1460,11 @@ func.func @store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { // ----- -// CHECK-LABEL: func @store_after_load_tensor_negative +// CHECK-LABEL: func @negative_store_after_load_tensor // CHECK: vector.transfer_read // CHECK: vector.transfer_write // CHECK: return -func.func @store_after_load_tensor_negative(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { +func.func @negative_store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %cf0 = arith.constant 0.0 : f32 @@ -1499,12 +1499,12 @@ func.func @store_to_load_tensor(%arg0 : tensor<4x4xf32>, // ----- -// CHECK-LABEL: func @store_to_load_negative_tensor +// CHECK-LABEL: func @negative_store_to_load_tensor // CHECK: vector.transfer_write // CHECK: vector.transfer_write // CHECK: %[[V:.*]] = vector.transfer_read // CHECK: return %[[V]] : vector<1x4xf32> -func.func @store_to_load_negative_tensor(%arg0 : tensor<4x4xf32>, +func.func @negative_store_to_load_tensor(%arg0 : tensor<4x4xf32>, %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> vector<1x4xf32> { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index @@ -1540,6 +1540,86 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>, // ----- +// CHECK-LABEL: func @negative_store_to_load_tensor_memref +// CHECK-NOT: vector.broadcast +// CHECK-NOT: vector.transpose +// CHECK: vector.transfer_write +// CHECK: vector.transfer_read +func.func @negative_store_to_load_tensor_memref( + %arg0 : tensor, + %arg1 : memref, + %v0 : vector<4x2xf32> + ) -> vector<4x2xf32> +{ + %c0 = arith.constant 0 : index + %cf0 = arith.constant 0.0 : f32 + vector.transfer_write %v0, %arg1[%c0, %c0] {in_bounds = [true, true]} : + vector<4x2xf32>, memref + %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {in_bounds = [true, true]} : + tensor, vector<4x2xf32> + return %0 : vector<4x2xf32> +} + +// ----- + +// CHECK-LABEL: func @negative_store_to_load_tensor_no_actual_broadcast +// CHECK-NOT: vector.broadcast +// CHECK-NOT: vector.transpose +// CHECK: vector.transfer_write +// CHECK: vector.transfer_read +func.func @negative_store_to_load_tensor_no_actual_broadcast(%arg0 : tensor, + %v0 : vector<4x2xf32>) -> vector<4x2xf32> { + %c0 = arith.constant 0 : index + %cf0 = arith.constant 0.0 : f32 + %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] : + vector<4x2xf32>, tensor + %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true]} : + tensor, vector<4x2xf32> + return %0 : vector<4x2xf32> +} + +// ----- + +// CHECK-LABEL: func @negative_store_to_load_tensor_broadcast_out_of_bounds +// CHECK-NOT: vector.broadcast +// CHECK-NOT: vector.transpose +// CHECK: vector.transfer_write +// CHECK: vector.transfer_read +func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor, + %v0 : vector<4x2xf32>) -> vector<4x2x6xf32> { + %c0 = arith.constant 0 : index + %cf0 = arith.constant 0.0 : f32 + %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] : + vector<4x2xf32>, tensor + %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true], + permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} : + tensor, vector<4x2x6xf32> + return %0 : vector<4x2x6xf32> +} + +// ----- + +// CHECK-LABEL: func @negative_store_to_load_tensor_broadcast_masked +// CHECK-NOT: vector.broadcast +// CHECK-NOT: vector.transpose +// CHECK: vector.transfer_write +// CHECK: vector.transfer_read +func.func @negative_store_to_load_tensor_broadcast_masked( + %arg0 : tensor, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>) + -> vector<4x2x6xf32> +{ + %c0 = arith.constant 0 : index + %cf0 = arith.constant 0.0 : f32 + %w0 = vector.transfer_write %v0, %arg0[%c0, %c0], %mask {in_bounds = [true, true]} : + vector<4x2xf32>, tensor + %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true], + permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} : + tensor, vector<4x2x6xf32> + return %0 : vector<4x2x6xf32> +} + +// ----- + // CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable // CHECK-SAME: (%[[ARG:.*]]: tensor, %[[V0:.*]]: vector<[4]xf32>) // CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32> @@ -1604,7 +1684,7 @@ func.func @dead_store_tensor(%arg0 : tensor<4x4xf32>, // ----- -// CHECK-LABEL: func @dead_store_tensor_negative +// CHECK-LABEL: func @negative_dead_store_tensor // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK: vector.transfer_write @@ -1612,7 +1692,7 @@ func.func @dead_store_tensor(%arg0 : tensor<4x4xf32>, // CHECK: vector.transfer_read // CHECK: %[[VTW:.*]] = vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]]] // CHECK: return %[[VTW]] : tensor<4x4xf32> -func.func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>, +func.func @negative_dead_store_tensor(%arg0 : tensor<4x4xf32>, %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index @@ -2063,10 +2143,10 @@ func.func @extract_insert_rank_reduce(%a: vector<4xf32>, %b: vector<8x16xf32>) // ----- -// CHECK-LABEL: extract_insert_negative +// CHECK-LABEL: negative_extract_insert // CHECK: vector.insert_strided_slice // CHECK: vector.extract -func.func @extract_insert_negative(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>) +func.func @negative_extract_insert(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>) -> vector<16xf32> { %0 = vector.insert_strided_slice %a, %b {offsets = [4, 2, 0], strides = [1, 1]} : vector<2x15xf32> into vector<12x8x16xf32>