diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 2ec1b97f2f241..7a10d2f2c0dfc 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -26,7 +26,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IRMapping.h" @@ -42,7 +41,6 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/ADT/bit.h" #include #include @@ -2684,25 +2682,45 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) { if (!v1Attr || !v2Attr) return {}; + // Fold shuffle poison, poison -> poison. + bool isV1Poison = isa(v1Attr); + bool isV2Poison = isa(v2Attr); + if (isV1Poison && isV2Poison) + return ub::PoisonAttr::get(getContext()); + // Only support 1-D for now to avoid complicated n-D DenseElementsAttr // manipulation. if (v1Type.getRank() != 1) return {}; - int64_t v1Size = v1Type.getDimSize(0); + // Poison input attributes need special handling as they are not + // DenseElementsAttr. If an index is poison, we select the first element of + // the first non-poison input. + SmallVector v1Elements, v2Elements; + Attribute poisonElement; + if (!isV2Poison) { + v2Elements = + to_vector(cast(v2Attr).getValues()); + poisonElement = v2Elements[0]; + } + if (!isV1Poison) { + v1Elements = + to_vector(cast(v1Attr).getValues()); + poisonElement = v1Elements[0]; + } SmallVector results; - auto v1Elements = cast(v1Attr).getValues(); - auto v2Elements = cast(v2Attr).getValues(); + int64_t v1Size = v1Type.getDimSize(0); for (int64_t maskIdx : mask) { Attribute indexedElm; - // Select v1[0] for poison indices. // TODO: Return a partial poison vector when supported by the UB dialect. if (maskIdx == ShuffleOp::kPoisonIndex) { - indexedElm = v1Elements[0]; + indexedElm = poisonElement; } else { - indexedElm = - maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size]; + if (maskIdx < v1Size) + indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx]; + else + indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size]; } results.push_back(indexedElm); @@ -3319,13 +3337,15 @@ class InsertStridedSliceConstantFolder final !destVector.hasOneUse()) return failure(); - auto denseDest = llvm::cast(vectorDestCst); - TypedValue sourceValue = op.getSource(); Attribute sourceCst; if (!matchPattern(sourceValue, m_Constant(&sourceCst))) return failure(); + // TODO: Support poison. + if (isa(vectorDestCst) || isa(sourceCst)) + return failure(); + // TODO: Handle non-unit strides when they become available. if (op.hasNonUnitStrides()) return failure(); @@ -3342,6 +3362,7 @@ class InsertStridedSliceConstantFolder final // increasing linearized position indices. // Because the destination may have higher dimensionality then the slice, // we keep track of two overlapping sets of positions and offsets. + auto denseDest = llvm::cast(vectorDestCst); auto denseSlice = llvm::cast(sourceCst); auto sliceValuesIt = denseSlice.value_begin(); auto newValues = llvm::to_vector(denseDest.getValues()); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 6858f0d56e641..61e858f5f226a 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2012,17 +2012,56 @@ func.func @shuffle_1d() -> vector<4xi32> { // input vector. That is, %v[0] (i.e., 5) in this test. // CHECK-LABEL: func @shuffle_1d_poison_idx -// CHECK: %[[V:.+]] = arith.constant dense<[2, 5, 0, 5]> : vector<4xi32> +// CHECK: %[[V:.+]] = arith.constant dense<[13, 10, 15, 10]> : vector<4xi32> // CHECK: return %[[V]] func.func @shuffle_1d_poison_idx() -> vector<4xi32> { - %v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32> - %v1 = arith.constant dense<[2, 1, 0]> : vector<3xi32> + %v0 = arith.constant dense<[10, 11, 12]> : vector<3xi32> + %v1 = arith.constant dense<[13, 14, 15]> : vector<3xi32> %shuffle = vector.shuffle %v0, %v1 [3, -1, 5, -1] : vector<3xi32>, vector<3xi32> return %shuffle : vector<4xi32> } // ----- +// CHECK-LABEL: func @shuffle_1d_rhs_lhs_poison +// CHECK-NOT: vector.shuffle +// CHECK: %[[V:.+]] = ub.poison : vector<4xi32> +// CHECK: return %[[V]] +func.func @shuffle_1d_rhs_lhs_poison() -> vector<4xi32> { + %v0 = ub.poison : vector<3xi32> + %v1 = ub.poison : vector<3xi32> + %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32> + return %shuffle : vector<4xi32> +} + +// ----- + +// CHECK-LABEL: func @shuffle_1d_lhs_poison +// CHECK-NOT: vector.shuffle +// CHECK: %[[V:.+]] = arith.constant dense<[11, 12, 11, 11]> : vector<4xi32> +// CHECK: return %[[V]] +func.func @shuffle_1d_lhs_poison() -> vector<4xi32> { + %v0 = arith.constant dense<[11, 12, 13]> : vector<3xi32> + %v1 = ub.poison : vector<3xi32> + %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32> + return %shuffle : vector<4xi32> +} + +// ----- + +// CHECK-LABEL: func @shuffle_1d_rhs_poison +// CHECK-NOT: vector.shuffle +// CHECK: %[[V:.+]] = arith.constant dense<[11, 11, 13, 12]> : vector<4xi32> +// CHECK: return %[[V]] +func.func @shuffle_1d_rhs_poison() -> vector<4xi32> { + %v0 = ub.poison : vector<3xi32> + %v1 = arith.constant dense<[11, 12, 13]> : vector<3xi32> + %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32> + return %shuffle : vector<4xi32> +} + +// ----- + // CHECK-LABEL: func @shuffle_canonicalize_0d func.func @shuffle_canonicalize_0d(%v0 : vector, %v1 : vector) -> vector<1xi32> { // CHECK: vector.broadcast %{{.*}} : vector to vector<1xi32>