Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 64 additions & 33 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
Location loc, Value mask,
int origElements, int scale,
int intraDataOffset = 0) {
assert(intraDataOffset < scale && "intraDataOffset must be less than scale");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not clear from the method name ...

What are origElements, scale and intraDataOffset?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • origElements is the number of elements of the subbyte vector
  • scale is byte-emulated element type size / original element type size. For example, if the original elem type is i2, then the scale is sizeof(i8)/sizeof(i2) = 4.
  • intraDataOffset is the element offset into the emulated byte. For example, to extract the second slice of vector<3xi2> out from a vector<3x3xi2> (here we assume the subbyte type elements are stored in memory packed), we would need to load 2 bytes (the first and second byte), and extract bit [7, 14) out from it. so the first 3 elements are irrelevant in this case, hence intraDataOffset == 3 in such case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, this was not clear to me at all :)

I was trying to understand all of this a bit better and am just thinking that this logic needs TLC. The comment for this method needs updating to capture the info that you shared above. I think that it would also be good to provide more descriptive argument names.

Now, I appreciate that it wasn't you who wrote this to begin with and updating this shouldn't be a blocker for this PR. Some help would be appreciated. Also, I want to help:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's my attempt to improve the comments and input variable names:

Please let me know whether that makes sense to you, and any feedback is welcome.

Note, I've also created these two:

(again, your feedback would be appreciated). Last, but not least, this example seems off. In particular:

/// [Comment from Andrzej] 6 elements
///  %mask = [1, 1, 0, 0, 0, 0]
///
/// will first be padded with number of `intraDataOffset` zeros:
/// [Comment from Andrzej] 8 elements != 6 + 1
///   %mask = [0, 1, 1, 0, 0, 0, 0, 0]

Shouldn't the padded mask be: %mask = [0, 1, 1, 0, 0, 0, 0] (7 elements)?

Btw, thanks so much for working on this - your efforts are truly appreciated! Please don’t let my comments (and appetite to improve things overall) give you any other impression 😅.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right! Here I just exposed some intermediate calculating details to the comment, as in this case scale == 2 so making the padded mask a multiple of scale in the intermediary result is easier.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slightly updated the comment part. can you take a look at it again?

auto numElements = (intraDataOffset + origElements + scale - 1) / scale;

Operation *maskOp = mask.getDefiningOp();
Expand Down Expand Up @@ -182,6 +183,25 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
return dest;
}

/// Inserts a 1-D subvector into a 1-D `dest` vector at index `offset`.
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
TypedValue<VectorType> source,
Value dest, OpFoldResult destOffsetVar,
int64_t length) {
assert(length > 0 && "length must be greater than 0");
for (int i = 0; i < length; ++i) {
Value insertLoc =
i == 0
? destOffsetVar.dyn_cast<Value>()
: rewriter.create<arith::AddIOp>(
loc, rewriter.getIndexType(), destOffsetVar.dyn_cast<Value>(),
rewriter.create<arith::ConstantIndexOp>(loc, i));
auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
}
return dest;
}

/// Returns the op sequence for an emulated sub-byte data type vector load.
/// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
/// The load location is given by `base` and `linearizedIndices`, and the
Expand All @@ -199,7 +219,7 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
return rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
newLoad);
};
}

namespace {

Expand Down Expand Up @@ -546,29 +566,30 @@ struct ConvertVectorMaskedLoad final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;

if (!foldedIntraVectorOffset) {
// unimplemented case for dynamic intra vector offset
return failure();
}

FailureOr<Operation *> newMask =
getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale,
*foldedIntraVectorOffset);
auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
FailureOr<Operation *> newMask = getCompressedMaskOp(
rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
if (failed(newMask))
return failure();

Value passthru = op.getPassThru();

auto numElements =
llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
auto loadType = VectorType::get(numElements, newElementType);
auto newBitcastType = VectorType::get(numElements * scale, oldElementType);

Value passthru = op.getPassThru();
if (isUnalignedEmulation) {
// create an empty vector of the new type
auto emptyVector = rewriter.create<arith::ConstantOp>(
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
*foldedIntraVectorOffset);
auto emptyVector = rewriter.create<arith::ConstantOp>(
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
if (foldedIntraVectorOffset) {
if (isUnalignedEmulation) {
passthru = staticallyInsertSubvector(
rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset);
}
} else {
passthru = dynamicallyInsertSubVector(
rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
emptyVector, linearizedInfo.intraDataOffset, origElements);
}
auto newPassThru =
rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
Expand All @@ -585,23 +606,34 @@ struct ConvertVectorMaskedLoad final
rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);

Value mask = op.getMask();
if (isUnalignedEmulation) {
auto newSelectMaskType =
VectorType::get(numElements * scale, rewriter.getI1Type());
// TODO: can fold if op's mask is constant
auto emptyVector = rewriter.create<arith::ConstantOp>(
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyVector,
*foldedIntraVectorOffset);
auto newSelectMaskType =
VectorType::get(numElements * scale, rewriter.getI1Type());
// TODO: try to fold if op's mask is constant
auto emptyMask = rewriter.create<arith::ConstantOp>(
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
if (foldedIntraVectorOffset) {
if (isUnalignedEmulation) {
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
*foldedIntraVectorOffset);
}
} else {
mask = dynamicallyInsertSubVector(
rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
linearizedInfo.intraDataOffset, origElements);
}

Value result =
rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);

if (isUnalignedEmulation) {
result =
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
if (foldedIntraVectorOffset) {
if (isUnalignedEmulation) {
result =
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
}
} else {
result = dynamicallyExtractSubVector(
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
}
rewriter.replaceOp(op, result);

Expand Down Expand Up @@ -659,10 +691,9 @@ struct ConvertVectorTransferRead final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;

auto maxIntraVectorOffset =
foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
auto numElements =
llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
llvm::divideCeil(maxIntraDataOffset + origElements, scale);

auto newRead = rewriter.create<vector::TransferReadOp>(
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
Expand Down
65 changes: 65 additions & 0 deletions mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,68 @@ func.func @vector_transfer_read_i2_dynamic_indexing_mixed(%idx1: index) -> vecto
// CHECK: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
// -----

func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, %idx: index) -> vector<3xi2> {
%0 = memref.alloc() : memref<3x3xi2>
%cst = arith.constant dense<0> : vector<3x3xi2>
%c2 = arith.constant 2 : index
%mask = vector.constant_mask [3] : vector<3xi1>
%1 = vector.maskedload %0[%idx, %c2], %mask, %passthru :
memref<3x3xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
return %1 : vector<3xi2>
}

// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 * 3 + 2) floordiv 4)>
// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3 - ((s0 * 3 + 2) floordiv 4) * 4 + 2)>
// CHECK: func @vector_maskedload_i2_dynamic_indexing_mixed(
// CHECK-SAME: %[[PTH:.+]]: vector<3xi2>, %[[IDX:.+]]: index) -> vector<3xi2>
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>

// extract passthru vector, and insert into zero vector, this is for constructing a new passthru
// CHECK: %[[EX1:.+]] = vector.extract %[[PTH]][0] : i2 from vector<3xi2>
// CHECK: %[[IN1:.+]] = vector.insert %[[EX1]], %[[ZERO]] [%[[LINEAR2]]] : i2 into vector<8xi2>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[INCIDX:.+]] = arith.addi %[[LINEAR2]], %[[C1]] : index
// CHECK: %[[EX2:.+]] = vector.extract %[[PTH]][1] : i2 from vector<3xi2>
// CHECK: %[[IN2:.+]] = vector.insert %[[EX2]], %[[IN1]] [%[[INCIDX]]] : i2 into vector<8xi2>
// CHECK: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[INCIDX2:.+]] = arith.addi %[[LINEAR2]], %[[C2]] : index
// CHECK: %[[EX3:.+]] = vector.extract %[[PTH]][2] : i2 from vector<3xi2>
// CHECK: %[[IN3:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2>

// bitcast the new passthru vector to emulated i8 vector
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[IN3]] : vector<8xi2> to vector<2xi8>

// use the emulated i8 vector to masked load from the memory
// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BITCAST]]
// CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>

// bitcast back to i2 vector
// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>

// CHECK: %[[CST1:.+]] = arith.constant dense<false> : vector<8xi1>

// create a mask vector and select passthru part from the loaded vector.
// note that if indices are known then we can fold the part generating mask.
// CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1>
// CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1>
// CHECK: %[[EX5:.+]] = vector.extract %[[MASK]][1] : i1 from vector<3xi1>
// CHECK: %[[IN5:.+]] = vector.insert %[[EX5]], %[[IN4]] [%[[INCIDX]]] : i1 into vector<8xi1>
// CHECK: %[[EX6:.+]] = vector.extract %[[MASK]][2] : i1 from vector<3xi1>
// CHECK: %[[IN6:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1>

// CHECK: %[[SELECT:.+]] = arith.select %[[IN6]], %[[BITCAST2]], %[[IN3]] : vector<8xi1>, vector<8xi2>

// finally, insert the selected parts into actual passthru vector.
// CHECK: %[[EX7:.+]] = vector.extract %[[SELECT]][%[[LINEAR2]]] : i2 from vector<8xi2>
// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[PTH]] [0] : i2 into vector<3xi2>
// CHECK: %[[EX8:.+]] = vector.extract %[[SELECT]][%[[INCIDX]]] : i2 from vector<8xi2>
// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>
Loading