Skip to content

Commit ece1c6d

Browse files
committed
[mlir][Vector] Add vector.shuffle fold for poison inputs
We recently added folding support for poison indices to `vector.shuffle`. This PR adds support for folding poison inputs.
1 parent 79e804b commit ece1c6d

File tree

2 files changed

+71
-11
lines changed

2 files changed

+71
-11
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
#include "mlir/IR/AffineMap.h"
2727
#include "mlir/IR/Builders.h"
2828
#include "mlir/IR/BuiltinAttributes.h"
29-
#include "mlir/IR/BuiltinOps.h"
3029
#include "mlir/IR/BuiltinTypes.h"
3130
#include "mlir/IR/DialectImplementation.h"
3231
#include "mlir/IR/IRMapping.h"
@@ -42,7 +41,6 @@
4241
#include "llvm/ADT/SmallVector.h"
4342
#include "llvm/ADT/StringSet.h"
4443
#include "llvm/ADT/TypeSwitch.h"
45-
#include "llvm/ADT/bit.h"
4644

4745
#include <cassert>
4846
#include <cstdint>
@@ -2696,25 +2694,45 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
26962694
if (!v1Attr || !v2Attr)
26972695
return {};
26982696

2697+
// Fold shuffle poison, poison -> poison.
2698+
bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
2699+
bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
2700+
if (isV1Poison && isV2Poison)
2701+
return ub::PoisonAttr::get(getContext());
2702+
26992703
// Only support 1-D for now to avoid complicated n-D DenseElementsAttr
27002704
// manipulation.
27012705
if (v1Type.getRank() != 1)
27022706
return {};
27032707

2704-
int64_t v1Size = v1Type.getDimSize(0);
2708+
// Poison input attributes need special handling as they are not
2709+
// DenseElementsAttr. If an index is poison, we select the first element of
2710+
// the first non-poison input.
2711+
SmallVector<Attribute> v1Elements, v2Elements;
2712+
Attribute poisonElement;
2713+
if (!isV2Poison) {
2714+
v2Elements =
2715+
to_vector(cast<DenseElementsAttr>(v2Attr).getValues<Attribute>());
2716+
poisonElement = v2Elements[0];
2717+
}
2718+
if (!isV1Poison) {
2719+
v1Elements =
2720+
to_vector(cast<DenseElementsAttr>(v1Attr).getValues<Attribute>());
2721+
poisonElement = v1Elements[0];
2722+
}
27052723

27062724
SmallVector<Attribute> results;
2707-
auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues<Attribute>();
2708-
auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues<Attribute>();
2725+
int64_t v1Size = v1Type.getDimSize(0);
27092726
for (int64_t maskIdx : mask) {
27102727
Attribute indexedElm;
2711-
// Select v1[0] for poison indices.
27122728
// TODO: Return a partial poison vector when supported by the UB dialect.
27132729
if (maskIdx == ShuffleOp::kPoisonIndex) {
2714-
indexedElm = v1Elements[0];
2730+
indexedElm = poisonElement;
27152731
} else {
2716-
indexedElm =
2717-
maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
2732+
if (maskIdx < v1Size)
2733+
indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
2734+
else
2735+
indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
27182736
}
27192737

27202738
results.push_back(indexedElm);
@@ -3332,13 +3350,15 @@ class InsertStridedSliceConstantFolder final
33323350
!destVector.hasOneUse())
33333351
return failure();
33343352

3335-
auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3336-
33373353
TypedValue<VectorType> sourceValue = op.getSource();
33383354
Attribute sourceCst;
33393355
if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
33403356
return failure();
33413357

3358+
// TODO: Support poison.
3359+
if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3360+
return failure();
3361+
33423362
// TODO: Handle non-unit strides when they become available.
33433363
if (op.hasNonUnitStrides())
33443364
return failure();
@@ -3355,6 +3375,7 @@ class InsertStridedSliceConstantFolder final
33553375
// increasing linearized position indices.
33563376
// Because the destination may have higher dimensionality then the slice,
33573377
// we keep track of two overlapping sets of positions and offsets.
3378+
auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
33583379
auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
33593380
auto sliceValuesIt = denseSlice.value_begin<Attribute>();
33603381
auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2023,6 +2023,45 @@ func.func @shuffle_1d_poison_idx() -> vector<4xi32> {
20232023

20242024
// -----
20252025

2026+
// CHECK-LABEL: func @shuffle_1d_rhs_lhs_poison
2027+
// CHECK-NOT: vector.shuffle
2028+
// CHECK: %[[V:.+]] = ub.poison : vector<4xi32>
2029+
// CHECK: return %[[V]]
2030+
func.func @shuffle_1d_rhs_lhs_poison() -> vector<4xi32> {
2031+
%v0 = ub.poison : vector<3xi32>
2032+
%v1 = ub.poison : vector<3xi32>
2033+
%shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
2034+
return %shuffle : vector<4xi32>
2035+
}
2036+
2037+
// -----
2038+
2039+
// CHECK-LABEL: func @shuffle_1d_lhs_poison
2040+
// CHECK-NOT: vector.shuffle
2041+
// CHECK: %[[V:.+]] = arith.constant dense<[5, 4, 5, 5]> : vector<4xi32>
2042+
// CHECK: return %[[V]]
2043+
func.func @shuffle_1d_lhs_poison() -> vector<4xi32> {
2044+
%v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32>
2045+
%v1 = ub.poison : vector<3xi32>
2046+
%shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
2047+
return %shuffle : vector<4xi32>
2048+
}
2049+
2050+
// -----
2051+
2052+
// CHECK-LABEL: func @shuffle_1d_rhs_poison
2053+
// CHECK-NOT: vector.shuffle
2054+
// CHECK: %[[V:.+]] = arith.constant dense<[2, 2, 0, 1]> : vector<4xi32>
2055+
// CHECK: return %[[V]]
2056+
func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
2057+
%v0 = ub.poison : vector<3xi32>
2058+
%v1 = arith.constant dense<[2, 1, 0]> : vector<3xi32>
2059+
%shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
2060+
return %shuffle : vector<4xi32>
2061+
}
2062+
2063+
// -----
2064+
20262065
// CHECK-LABEL: func @shuffle_canonicalize_0d
20272066
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
20282067
// CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>

0 commit comments

Comments
 (0)