Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
Expand All @@ -42,7 +41,6 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"

#include <cassert>
#include <cstdint>
Expand Down Expand Up @@ -2684,25 +2682,45 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
if (!v1Attr || !v2Attr)
return {};

// Fold shuffle poison, poison -> poison.
bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
if (isV1Poison && isV2Poison)
return ub::PoisonAttr::get(getContext());

// Only support 1-D for now to avoid complicated n-D DenseElementsAttr
// manipulation.
if (v1Type.getRank() != 1)
return {};

int64_t v1Size = v1Type.getDimSize(0);
// Poison input attributes need special handling as they are not
// DenseElementsAttr. If an index is poison, we select the first element of
// the first non-poison input.
Comment on lines +2696 to +2698
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] To me this is a fairly significant (and not immediately intuitive) part of the design. Perhaps move above the signature?

Also, is this based on some prior-art? Just curious, this does make sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow the prior-art part. Do you mean why we pick the first element of the first non-poison input? Poison is basically UB so given that we can't represent a partially poison vector we just make a random decision, which is ok as part of the UB behavior.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's valid to substitute poison with an arbitrary value

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's valid to substitute poison with an arbitrary value

Sure, but we are selecting a specific "arbitrary value" :)

Not sure I follow the prior-art part.

I was just curious whether there's any rationale behind this specific option. For example, something else in LLVM or MLIR makes similar choice?

Basically, what I'm missing is "why would we select the first element"? Something along the lines would be helpful:

I doesn't matter what we select, but we need to make a choice. We choose the first element.

Copy link
Member

@kuhar kuhar Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but we are selecting a specific "arbitrary value" :)

??? In this context, arbitrary is synonymous to non-deterministics, as in: absolutely any value will do and the choice doesn't have to be fair by any definition of fair.

SmallVector<Attribute> v1Elements, v2Elements;
Attribute poisonElement;
if (!isV2Poison) {
v2Elements =
to_vector(cast<DenseElementsAttr>(v2Attr).getValues<Attribute>());
poisonElement = v2Elements[0];
}
if (!isV1Poison) {
v1Elements =
to_vector(cast<DenseElementsAttr>(v1Attr).getValues<Attribute>());
poisonElement = v1Elements[0];
}

SmallVector<Attribute> results;
auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues<Attribute>();
auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues<Attribute>();
int64_t v1Size = v1Type.getDimSize(0);
for (int64_t maskIdx : mask) {
Attribute indexedElm;
// Select v1[0] for poison indices.
// TODO: Return a partial poison vector when supported by the UB dialect.
if (maskIdx == ShuffleOp::kPoisonIndex) {
indexedElm = v1Elements[0];
indexedElm = poisonElement;
} else {
indexedElm =
maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
if (maskIdx < v1Size)
indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
else
indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
}

results.push_back(indexedElm);
Expand Down Expand Up @@ -3319,13 +3337,15 @@ class InsertStridedSliceConstantFolder final
!destVector.hasOneUse())
return failure();

auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);

TypedValue<VectorType> sourceValue = op.getSource();
Attribute sourceCst;
if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
return failure();

// TODO: Support poison.
if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
return failure();

// TODO: Handle non-unit strides when they become available.
if (op.hasNonUnitStrides())
return failure();
Expand All @@ -3342,6 +3362,7 @@ class InsertStridedSliceConstantFolder final
// increasing linearized position indices.
// Because the destination may have higher dimensionality then the slice,
// we keep track of two overlapping sets of positions and offsets.
auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
auto sliceValuesIt = denseSlice.value_begin<Attribute>();
auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
Expand Down
39 changes: 39 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2023,6 +2023,45 @@ func.func @shuffle_1d_poison_idx() -> vector<4xi32> {

// -----

// CHECK-LABEL: func @shuffle_1d_rhs_lhs_poison
// CHECK-NOT: vector.shuffle
// CHECK: %[[V:.+]] = ub.poison : vector<4xi32>
// CHECK: return %[[V]]
func.func @shuffle_1d_rhs_lhs_poison() -> vector<4xi32> {
%v0 = ub.poison : vector<3xi32>
%v1 = ub.poison : vector<3xi32>
%shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
return %shuffle : vector<4xi32>
}

// -----

// CHECK-LABEL: func @shuffle_1d_lhs_poison
// CHECK-NOT: vector.shuffle
// CHECK: %[[V:.+]] = arith.constant dense<[5, 4, 5, 5]> : vector<4xi32>
// CHECK: return %[[V]]
func.func @shuffle_1d_lhs_poison() -> vector<4xi32> {
%v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32>
%v1 = ub.poison : vector<3xi32>
%shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] There's a value and index == 5, so it's not obvious that the first element of %v0 is in any way significant. Perhaps use some more distinct number? (e.g. 123).

Suggested change
%v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32>
%v1 = ub.poison : vector<3xi32>
%shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
%v0 = arith.constant dense<[123, 4, 3]> : vector<3xi32>
%v1 = ub.poison : vector<3xi32>
%shuffle = vector.shuffle %v0, %v1 [3, 1, 123, 4] : vector<3xi32>, vector<3xi32>

I appreciate that this is obvious right now, but lets also cater for our future selves :)

return %shuffle : vector<4xi32>
}

// -----

// CHECK-LABEL: func @shuffle_1d_rhs_poison
// CHECK-NOT: vector.shuffle
// CHECK: %[[V:.+]] = arith.constant dense<[2, 2, 0, 1]> : vector<4xi32>
// CHECK: return %[[V]]
func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
%v0 = ub.poison : vector<3xi32>
%v1 = arith.constant dense<[2, 1, 0]> : vector<3xi32>
%shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
return %shuffle : vector<4xi32>
}

// -----

// CHECK-LABEL: func @shuffle_canonicalize_0d
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
// CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
Expand Down