diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index d5f3634377e4c..7eaf34a55cb4a 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1660,10 +1660,6 @@ static bool hasZeroDimVectors(Operation *op) { /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. static Value foldExtractFromBroadcast(ExtractOp extractOp) { - // TODO: Canonicalization for dynamic position not implemented yet. - if (extractOp.hasDynamicPosition()) - return Value(); - Operation *defOp = extractOp.getVector().getDefiningOp(); if (!defOp || !isa(defOp)) return Value(); @@ -1700,20 +1696,22 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) { // extract position to `0` when extracting from the source operand. llvm::SetVector broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims(); - SmallVector extractPos(extractOp.getStaticPosition()); + SmallVector extractPos(extractOp.getMixedPosition()); + OpBuilder b(extractOp.getContext()); int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank; for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i) if (broadcastedUnitDims.contains(i)) - extractPos[i] = 0; + extractPos[i] = b.getIndexAttr(0); // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the // matching extract position when extracting from the source operand. int64_t rankDiff = broadcastSrcRank - extractResultRank; extractPos.erase(extractPos.begin(), std::next(extractPos.begin(), extractPos.size() - rankDiff)); // OpBuilder is only used as a helper to build an I64ArrayAttr. - OpBuilder b(extractOp.getContext()); - extractOp.setOperand(0, source); - extractOp.setStaticPosition(extractPos); + auto [staticPos, dynPos] = decomposeMixedValues(extractPos); + extractOp->setOperands( + llvm::to_vector(llvm::concat(ValueRange(source), dynPos))); + extractOp.setStaticPosition(staticPos); return extractOp.getResult(); } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index f17d917ca521e..bf755b466c7eb 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -710,24 +710,38 @@ func.func @fold_extract_transpose( // ----- -// CHECK-LABEL: fold_extract_broadcast +// CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar // CHECK-SAME: %[[A:.*]]: f32 // CHECK: return %[[A]] : f32 -func.func @fold_extract_broadcast(%a : f32) -> f32 { +func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32, + %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { %b = vector.broadcast %a : f32 to vector<1x2x4xf32> - %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32> return %r : f32 } // ----- -// CHECK-LABEL: fold_extract_broadcast_0dvec +// CHECK-LABEL: fold_extract_broadcast_same_input_output_vec +// CHECK-SAME: %[[A:.*]]: vector<4xf32> +// CHECK: return %[[A]] : vector<4xf32> +func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>, + %idx0 : index, %idx1 : index) -> vector<4xf32> { + %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32> + return %r : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: fold_extract_broadcast_0dvec_input_scalar_output // CHECK-SAME: %[[A:.*]]: vector // CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector // CHECK: return %[[B]] : f32 -func.func @fold_extract_broadcast_0dvec(%a : vector) -> f32 { +func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector, + %idx0 : index, %idx1 : index, %idx2: index) -> f32 { %b = vector.broadcast %a : vector to vector<1x2x4xf32> - %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32> return %r : f32 } @@ -747,57 +761,68 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32 // CHECK-LABEL: fold_extract_splat // CHECK-SAME: %[[A:.*]]: f32 // CHECK: return %[[A]] : f32 -func.func @fold_extract_splat(%a : f32) -> f32 { +func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { %b = vector.splat %a : vector<1x2x4xf32> - %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32> return %r : f32 } // ----- -// CHECK-LABEL: fold_extract_broadcast_vector -// CHECK-SAME: %[[A:.*]]: vector<4xf32> -// CHECK: return %[[A]] : vector<4xf32> -func.func @fold_extract_broadcast_vector(%a : vector<4xf32>) -> vector<4xf32> { - %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32> - %r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32> - return %r : vector<4xf32> +// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting +// CHECK-SAME: %[[A:.*]]: vector<2x1xf32> +// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index +// CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32> +// CHECK: return %[[R]] : f32 +func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>, + %idx : index, %idx1 : index, %idx2 : index) -> f32 { + %b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32> + %r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32> + return %r : f32 } // ----- -// CHECK-LABEL: fold_extract_broadcast -// CHECK-SAME: %[[A:.*]]: vector<4xf32> -// CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32> -// CHECK: return %[[R]] : f32 -func.func @fold_extract_broadcast(%a : vector<4xf32>) -> f32 { - %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32> - %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32> - return %r : f32 +// CHECK-LABEL: fold_extract_broadcast_to_lower_rank +// CHECK-SAME: %[[A:.*]]: vector<2x4xf32> +// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index +// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32> +// CHECK: return %[[B]] : vector<4xf32> +// rank(extract_output) < rank(broadcast_input) +func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>, + %idx0 : index, %idx1 : index) -> vector<4xf32> { + %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32> + return %r : vector<4xf32> } // ----- -// CHECK-LABEL: fold_extract_broadcast +// CHECK-LABEL: fold_extract_broadcast_to_higher_rank // CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32> // CHECK: return %[[B]] : vector<4xf32> -func.func @fold_extract_broadcast(%a : f32) -> vector<4xf32> { +// rank(extract_output) > rank(broadcast_input) +func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index) + -> vector<4xf32> { %b = vector.broadcast %a : f32 to vector<1x2x4xf32> - %r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32> return %r : vector<4xf32> } // ----- -// CHECK-LABEL: fold_extract_broadcast +// CHECK-LABEL: fold_extract_broadcast_to_equal_rank // CHECK-SAME: %[[A:.*]]: vector<1xf32> // CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32> // CHECK: return %[[R]] : vector<8xf32> -func.func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> { +// rank(extract_output) == rank(broadcast_input) +func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index) + -> vector<8xf32> { %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32> - %r = vector.extract %b[0] : vector<8xf32> from vector<1x8xf32> + %r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32> return %r : vector<8xf32> } + // ----- // CHECK-LABEL: @fold_extract_shuffle