Skip to content

Commit 1430e84

Browse files
committed
Implement vector stores
1 parent 4610e5c commit 1430e84

File tree

2 files changed

+313
-32
lines changed

2 files changed

+313
-32
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 180 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "mlir/Transforms/DialectConversion.h"
3434
#include "llvm/ADT/SmallVector.h"
3535
#include "llvm/Support/Debug.h"
36+
#include "llvm/Support/LogicalResult.h"
3637
#include "llvm/Support/MathExtras.h"
3738
#include "llvm/Support/raw_ostream.h"
3839
#include <cstdint>
@@ -157,13 +158,10 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
157158
/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
158159
/// emitting `vector.extract_strided_slice`.
159160
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
160-
VectorType extractType, Value source,
161-
int64_t frontOffset,
161+
Value source, int64_t frontOffset,
162162
int64_t subvecSize) {
163163
auto vectorType = cast<VectorType>(source.getType());
164-
assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
165-
"expected 1-D source and destination types");
166-
(void)vectorType;
164+
assert(vectorType.getRank() == 1 && "expected 1-D source types");
167165
assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
168166
"subvector out of bounds");
169167

@@ -174,9 +172,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
174172
auto offsets = rewriter.getI64ArrayAttr({frontOffset});
175173
auto sizes = rewriter.getI64ArrayAttr({subvecSize});
176174
auto strides = rewriter.getI64ArrayAttr({1});
175+
176+
auto resultVectorType =
177+
VectorType::get({subvecSize}, vectorType.getElementType());
177178
return rewriter
178-
.create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
179-
sizes, strides)
179+
.create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
180+
offsets, sizes, strides)
180181
->getResult(0);
181182
}
182183

@@ -185,12 +186,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
185186
/// `vector.insert_strided_slice`.
186187
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
187188
Value src, Value dest, int64_t offset) {
188-
auto srcType = cast<VectorType>(src.getType());
189-
auto destType = cast<VectorType>(dest.getType());
189+
[[maybe_unused]] auto srcType = cast<VectorType>(src.getType());
190+
[[maybe_unused]] auto destType = cast<VectorType>(dest.getType());
190191
assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
191192
"expected source and dest to be vector type");
192-
(void)srcType;
193-
(void)destType;
194193
auto offsets = rewriter.getI64ArrayAttr({offset});
195194
auto strides = rewriter.getI64ArrayAttr({1});
196195
return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
@@ -257,6 +256,63 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
257256
newLoad);
258257
}
259258

259+
static void nonAtomicStore(ConversionPatternRewriter &rewriter, Location loc,
260+
Value memref, Value index, Value value) {
261+
auto originType = dyn_cast<VectorType>(value.getType());
262+
auto memrefElemType = dyn_cast<MemRefType>(memref.getType()).getElementType();
263+
auto scale = memrefElemType.getIntOrFloatBitWidth() /
264+
originType.getElementType().getIntOrFloatBitWidth();
265+
auto storeType =
266+
VectorType::get({originType.getNumElements() / scale}, memrefElemType);
267+
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType, value);
268+
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memref, index);
269+
}
270+
271+
/// atomically store a subbyte-sized value to memory, with a mask.
272+
static Value atomicStore(OpBuilder &rewriter, Location loc,
273+
Value emulatedMemref, Value emulatedIndex,
274+
TypedValue<VectorType> value, Value mask,
275+
int64_t scale) {
276+
auto atomicOp = rewriter.create<memref::GenericAtomicRMWOp>(
277+
loc, emulatedMemref, ValueRange{emulatedIndex});
278+
OpBuilder builder =
279+
OpBuilder::atBlockEnd(atomicOp.getBody(), rewriter.getListener());
280+
Value origValue = atomicOp.getCurrentValue();
281+
282+
// i8 -> vector type <1xi8> then <1xi8> -> <scale x i.>
283+
auto oneVectorType = VectorType::get({1}, origValue.getType());
284+
auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
285+
ValueRange{origValue});
286+
auto vectorBitCast =
287+
builder.create<vector::BitCastOp>(loc, value.getType(), fromElem);
288+
289+
auto select =
290+
builder.create<arith::SelectOp>(loc, mask, value, vectorBitCast);
291+
auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
292+
auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
293+
builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
294+
return atomicOp;
295+
}
296+
297+
// Extract a slice of a vector, and insert it into a byte vector.
298+
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
299+
Location loc, TypedValue<VectorType> vector,
300+
int64_t sliceOffset, int64_t sliceNumElements,
301+
int64_t byteOffset) {
302+
auto vectorElementType = vector.getType().getElementType();
303+
assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
304+
"vector element must be a valid sub-byte type");
305+
auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
306+
auto emptyByteVector = rewriter.create<arith::ConstantOp>(
307+
loc, VectorType::get({scale}, vectorElementType),
308+
rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
309+
auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
310+
sliceOffset, sliceNumElements);
311+
auto inserted = staticallyInsertSubvector(rewriter, loc, extracted,
312+
emptyByteVector, byteOffset);
313+
return inserted;
314+
}
315+
260316
namespace {
261317

262318
//===----------------------------------------------------------------------===//
@@ -277,7 +333,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
277333

278334
auto loc = op.getLoc();
279335
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
280-
Type oldElementType = op.getValueToStore().getType().getElementType();
336+
auto valueToStore = op.getValueToStore();
337+
Type oldElementType = valueToStore.getType().getElementType();
281338
Type newElementType = convertedType.getElementType();
282339
int srcBits = oldElementType.getIntOrFloatBitWidth();
283340
int dstBits = newElementType.getIntOrFloatBitWidth();
@@ -301,30 +358,124 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
301358
// vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
302359
// vector<4xi8>
303360

304-
auto origElements = op.getValueToStore().getType().getNumElements();
305-
if (origElements % scale != 0)
306-
return failure();
361+
auto origElements = valueToStore.getType().getNumElements();
362+
bool isUnalignedEmulation = origElements % scale != 0;
307363

308364
auto stridedMetadata =
309365
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
310366

311367
OpFoldResult linearizedIndices;
312-
std::tie(std::ignore, linearizedIndices) =
368+
memref::LinearizedMemRefInfo linearizedInfo;
369+
std::tie(linearizedInfo, linearizedIndices) =
313370
memref::getLinearizedMemRefOffsetAndSize(
314371
rewriter, loc, srcBits, dstBits,
315372
stridedMetadata.getConstifiedMixedOffset(),
316373
stridedMetadata.getConstifiedMixedSizes(),
317374
stridedMetadata.getConstifiedMixedStrides(),
318375
getAsOpFoldResult(adaptor.getIndices()));
319376

320-
auto numElements = origElements / scale;
321-
auto bitCast = rewriter.create<vector::BitCastOp>(
322-
loc, VectorType::get(numElements, newElementType),
323-
op.getValueToStore());
377+
auto foldedIntraVectorOffset =
378+
isUnalignedEmulation
379+
? getConstantIntValue(linearizedInfo.intraDataOffset)
380+
: 0;
381+
382+
if (!foldedIntraVectorOffset) {
383+
// unimplemented case for dynamic front padding size
384+
return failure();
385+
}
386+
387+
// conditions when atomic stores and all that are not needed:
388+
// 1. The source vector size is multiple of byte size
389+
// 2. The address of the store is byte aligned
390+
if (!isUnalignedEmulation && *foldedIntraVectorOffset == 0) {
391+
auto numElements = origElements / scale;
392+
auto bitCast = rewriter.create<vector::BitCastOp>(
393+
loc, VectorType::get(numElements, newElementType),
394+
op.getValueToStore());
395+
rewriter.replaceOpWithNewOp<vector::StoreOp>(
396+
op, bitCast.getResult(), adaptor.getBase(),
397+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
398+
return llvm::success();
399+
}
400+
401+
Value emulatedMemref = adaptor.getBase();
402+
// the index into the target memref we are storing to
403+
Value currentDestIndex =
404+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
405+
auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
406+
auto atomicMaskType = VectorType::get({scale}, rewriter.getI1Type());
407+
// the index into the source vector we are currently processing
408+
auto currentSourceIndex = 0;
409+
410+
// 1. atomic store for the first byte
411+
auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
412+
if (frontAtomicStoreElem != 0) {
413+
auto frontMaskValues = llvm::SmallVector<bool>(scale, false);
414+
if (*foldedIntraVectorOffset + origElements < scale) {
415+
std::fill_n(frontMaskValues.begin() + *foldedIntraVectorOffset,
416+
origElements, true);
417+
frontAtomicStoreElem = origElements;
418+
} else {
419+
std::fill_n(frontMaskValues.end() - frontAtomicStoreElem,
420+
*foldedIntraVectorOffset, true);
421+
}
422+
auto frontMask = rewriter.create<arith::ConstantOp>(
423+
loc, DenseElementsAttr::get(atomicMaskType, frontMaskValues));
424+
425+
currentSourceIndex = scale - (*foldedIntraVectorOffset);
426+
auto value = extractSliceIntoByte(
427+
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0,
428+
frontAtomicStoreElem, *foldedIntraVectorOffset);
429+
430+
atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
431+
cast<TypedValue<VectorType>>(value), frontMask.getResult(),
432+
scale);
433+
434+
currentDestIndex = rewriter.create<arith::AddIOp>(
435+
loc, rewriter.getIndexType(), currentDestIndex, constantOne);
436+
}
437+
438+
if (currentSourceIndex >= origElements) {
439+
rewriter.eraseOp(op);
440+
return success();
441+
}
442+
443+
// 2. non-atomic store
444+
int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
445+
int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
446+
if (nonAtomicStoreSize != 0) {
447+
auto nonAtomicStorePart = staticallyExtractSubvector(
448+
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
449+
currentSourceIndex, numNonAtomicElements);
450+
451+
nonAtomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
452+
nonAtomicStorePart);
453+
454+
currentSourceIndex += numNonAtomicElements;
455+
currentDestIndex = rewriter.create<arith::AddIOp>(
456+
loc, rewriter.getIndexType(), currentDestIndex,
457+
rewriter.create<arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
458+
}
459+
460+
// 3. atomic store for the last byte
461+
auto remainingElements = origElements - currentSourceIndex;
462+
if (remainingElements != 0) {
463+
auto atomicStorePart = extractSliceIntoByte(
464+
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
465+
currentSourceIndex, remainingElements, 0);
466+
467+
// back mask
468+
auto maskValues = llvm::SmallVector<bool>(scale, 0);
469+
std::fill_n(maskValues.begin(), remainingElements, 1);
470+
auto backMask = rewriter.create<arith::ConstantOp>(
471+
loc, DenseElementsAttr::get(atomicMaskType, maskValues));
472+
473+
atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
474+
cast<TypedValue<VectorType>>(atomicStorePart),
475+
backMask.getResult(), scale);
476+
}
324477

325-
rewriter.replaceOpWithNewOp<vector::StoreOp>(
326-
op, bitCast.getResult(), adaptor.getBase(),
327-
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
478+
rewriter.eraseOp(op);
328479
return success();
329480
}
330481
};
@@ -532,9 +683,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
532683
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
533684
linearizedInfo.intraDataOffset, origElements);
534685
} else if (isUnalignedEmulation) {
535-
result =
536-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
537-
*foldedIntraVectorOffset, origElements);
686+
result = staticallyExtractSubvector(
687+
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
538688
}
539689
rewriter.replaceOp(op, result);
540690
return success();
@@ -693,9 +843,8 @@ struct ConvertVectorMaskedLoad final
693843
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
694844
op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
695845
} else if (isUnalignedEmulation) {
696-
result =
697-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
698-
*foldedIntraVectorOffset, origElements);
846+
result = staticallyExtractSubvector(
847+
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
699848
}
700849
rewriter.replaceOp(op, result);
701850

@@ -778,9 +927,8 @@ struct ConvertVectorTransferRead final
778927
linearizedInfo.intraDataOffset,
779928
origElements);
780929
} else if (isUnalignedEmulation) {
781-
result =
782-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
783-
*foldedIntraVectorOffset, origElements);
930+
result = staticallyExtractSubvector(
931+
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
784932
}
785933
rewriter.replaceOp(op, result);
786934

0 commit comments

Comments
 (0)