-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][vector] Document ConvertVectorStore + unify var names (nfc)
#126422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -432,7 +432,86 @@ namespace { | |
| // ConvertVectorStore | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| // TODO: Document-me | ||
| // Emulate vector.store using a multi-byte container type | ||
| // | ||
| // The container type is obtained through Op adaptor and would normally be | ||
| // generated via `NarrowTypeEmulationConverter`. | ||
| // | ||
| // EXAMPLE 1 | ||
| // (aligned store of i4, emulated using i8) | ||
| // | ||
| // vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4> | ||
| // | ||
| // is rewritten as: | ||
| // | ||
| // %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8> | ||
| // vector.store %src_bitcast, %dest_bitcast[%idx] | ||
| // : memref<16xi8>, vector<4xi8> | ||
| // | ||
| // EXAMPLE 2 | ||
| // (unaligned store of i2, emulated using i8, non-atomic) | ||
| // | ||
| // vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2> | ||
| // | ||
| // The i2 store is emulated through 2 x RMW sequences. The destination i2 memref | ||
| // is modelled using 3 bytes: | ||
| // | ||
| // Byte 0 Byte 1 Byte 2 | ||
| // +----------+----------+----------+ | ||
| // | oooooooo | ooooNNNN | NNoooooo | | ||
| // +----------+----------+----------+ | ||
| // | ||
| // N - (N)ew entries (i.e. to be overwritten by vector.store) | ||
| // o - (o)ld entries (to be preserved) | ||
| // | ||
| // The following 2 RMW sequences will be generated: | ||
| // | ||
| // %init = arith.constant dense<0> : vector<4xi2> | ||
| // | ||
| // (RMW sequence for Byte 1) | ||
| // (Mask for 4 x i2 elements, i.e. a byte) | ||
| // %mask_1 = arith.constant dense<[false, false, true, true]> | ||
| // %src_slice_1 = vector.extract_strided_slice %src | ||
| // {offsets = [0], sizes = [2], strides = [1]} | ||
| // : vector<3xi2> to vector<2xi2> | ||
| // %init_with_slice_1 = vector.insert_strided_slice %src_slice_1, %init | ||
| // {offsets = [2], strides = [1]} | ||
| // : vector<2xi2> into vector<4xi2> | ||
| // %dest_byte_1 = vector.load %dest[%c1] | ||
| // %dest_byte_1_as_i2 = vector.bitcast %dest_byte_1 | ||
| // : vector<1xi8> to vector<4xi2> | ||
| // %res_byte_1 = arith.select %mask_1, %init_with_slice_1, %dest_byte_1_as_i2 | ||
| // %res_byte_1_as_i8 = vector.bitcast %res_byte_1 | ||
| // vector.store %res_byte_1_as_i8, %dest[1] | ||
|
|
||
| // (RMW sequence for Byte 22) | ||
| // (Mask for 4 x i2 elements, i.e. a byte) | ||
| // %mask_2 = arith.constant dense<[true, false, false, false]> | ||
| // %src_slice_2 = vector.extract_strided_slice %src | ||
| // : {offsets = [2], sizes = [1], strides = [1]} | ||
| // : vector<3xi2> to vector<1xi2> | ||
| // %initi_with_slice_2 = vector.insert_strided_slice %src_slice_2, %init | ||
| // : {offsets = [0], strides = [1]} | ||
| // : vector<1xi2> into vector<4xi2> | ||
| // %dest_byte_2 = vector.load %dest[%c2] | ||
|
||
| // %dest_byte_2_as_i2 = vector.bitcast %dest_byte_2 | ||
| // : vector<1xi8> to vector<4xi2> | ||
| // vector<4xi2> %res_byte_2 = arith.select %ask_2, %init_with_slice_2, | ||
| // %dest_byte_2_as_i2 %res_byte_1_as_i8 = vector.bitcast %rest_byte_2 | ||
| // vector.store %res_byte_1_as_i8, %dest[2] | ||
| // | ||
| // NOTE: Unlike EXAMPLE 1, this case requires index re-calculation. | ||
| // NOTE: This example assumes that `disableAtomicRMW` was set. | ||
| // | ||
| // EXAMPLE 3 | ||
| // (unaligned store of i2, emulated using i8, atomic) | ||
| // | ||
| // Similar to EXAMPLE 2, with the addition of | ||
| // * `memref.generic_atomic_rmw`, | ||
| // to guarantee atomicity. The actual output is skipped for brevity. | ||
| // | ||
| // NOTE: by default, all RMW sequences are atomic. Set `disableAtomicRMW` to | ||
| // `false` to generate non-atomic RMW sequences. | ||
| struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { | ||
| using OpConversionPattern::OpConversionPattern; | ||
|
|
||
|
|
@@ -464,7 +543,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { | |
| op, "impossible to pack emulated elements into container elements " | ||
| "(bit-wise misalignment)"); | ||
| } | ||
| int numSrcElemsPerDest = containerBits / emulatedBits; | ||
| int emulatedPerContainerElem = containerBits / emulatedBits; | ||
|
|
||
| // Adjust the number of elements to store when emulating narrow types. | ||
| // Here only the 1-D vector store is considered, and the N-D memref types | ||
|
|
@@ -480,7 +559,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { | |
| // vector<4xi8> | ||
|
|
||
| auto origElements = valueToStore.getType().getNumElements(); | ||
| bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0; | ||
| // Note, per-element-alignment was already verified above. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so far we cannot support non-aligned cases for per element? i.e. we cannot support 7bit emulation?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAIK, all patterns require per-element alignment, yes. We should extract that condition somewhere to avoid repeating it. On a related note, given how quickly this logic is growing, I feel that we should try to avoid emulating: |
||
| bool isFullyAligned = origElements % emulatedPerContainerElem == 0; | ||
|
|
||
| auto stridedMetadata = | ||
| rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase()); | ||
|
|
@@ -496,7 +576,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { | |
| getAsOpFoldResult(adaptor.getIndices())); | ||
|
|
||
| std::optional<int64_t> foldedNumFrontPadElems = | ||
| isAlignedEmulation | ||
| isFullyAligned | ||
| ? 0 | ||
| : getConstantIntValue(linearizedInfo.intraDataOffset); | ||
|
|
||
|
|
@@ -516,10 +596,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { | |
| // need unaligned emulation because the store address is aligned and the | ||
| // source is a whole byte. | ||
| bool emulationRequiresPartialStores = | ||
| !isAlignedEmulation || *foldedNumFrontPadElems != 0; | ||
| !isFullyAligned || *foldedNumFrontPadElems != 0; | ||
| if (!emulationRequiresPartialStores) { | ||
| // Basic case: storing full bytes. | ||
| auto numElements = origElements / numSrcElemsPerDest; | ||
| auto numElements = origElements / emulatedPerContainerElem; | ||
| auto bitCast = rewriter.create<vector::BitCastOp>( | ||
| loc, VectorType::get(numElements, containerElemTy), | ||
| op.getValueToStore()); | ||
|
|
@@ -567,7 +647,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { | |
|
|
||
| // Build a mask used for rmw. | ||
| auto subWidthStoreMaskType = | ||
| VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type()); | ||
| VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type()); | ||
|
|
||
| auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW; | ||
|
|
||
|
|
@@ -576,10 +656,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { | |
| // with the unaligned part so that the rest elements are aligned to width | ||
| // boundary. | ||
| auto frontSubWidthStoreElem = | ||
| (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest; | ||
| (emulatedPerContainerElem - *foldedNumFrontPadElems) % emulatedPerContainerElem; | ||
| if (frontSubWidthStoreElem > 0) { | ||
| SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false); | ||
| if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) { | ||
| SmallVector<bool> frontMaskValues(emulatedPerContainerElem, false); | ||
| if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) { | ||
| std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems, | ||
| origElements, true); | ||
| frontSubWidthStoreElem = origElements; | ||
|
|
@@ -590,7 +670,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { | |
| auto frontMask = rewriter.create<arith::ConstantOp>( | ||
| loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues)); | ||
|
|
||
| currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems); | ||
| currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems); | ||
| auto value = | ||
| extractSliceIntoByte(rewriter, loc, valueToStore, 0, | ||
| frontSubWidthStoreElem, *foldedNumFrontPadElems); | ||
|
|
@@ -614,8 +694,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { | |
| // After the previous step, the store address is aligned to the emulated | ||
| // width boundary. | ||
| int64_t fullWidthStoreSize = | ||
| (origElements - currentSourceIndex) / numSrcElemsPerDest; | ||
| int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest; | ||
| (origElements - currentSourceIndex) / emulatedPerContainerElem; | ||
| int64_t numNonFullWidthElements = fullWidthStoreSize * emulatedPerContainerElem; | ||
| if (fullWidthStoreSize > 0) { | ||
| auto fullWidthStorePart = staticallyExtractSubvector( | ||
| rewriter, loc, valueToStore, currentSourceIndex, | ||
|
|
@@ -624,7 +704,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { | |
| auto originType = cast<VectorType>(fullWidthStorePart.getType()); | ||
| auto memrefElemType = getElementTypeOrSelf(memrefBase.getType()); | ||
| auto storeType = VectorType::get( | ||
| {originType.getNumElements() / numSrcElemsPerDest}, memrefElemType); | ||
| {originType.getNumElements() / emulatedPerContainerElem}, memrefElemType); | ||
| auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType, | ||
| fullWidthStorePart); | ||
| rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase, | ||
|
|
@@ -646,7 +726,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { | |
| currentSourceIndex, remainingElements, 0); | ||
|
|
||
| // Generate back mask. | ||
| auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0); | ||
| auto maskValues = SmallVector<bool>(emulatedPerContainerElem, 0); | ||
| std::fill_n(maskValues.begin(), remainingElements, 1); | ||
| auto backMask = rewriter.create<arith::ConstantOp>( | ||
| loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues)); | ||
|
|
@@ -960,7 +1040,8 @@ struct ConvertVectorMaskedLoad final | |
| // subvector at the proper offset after bit-casting. | ||
| auto origType = op.getVectorType(); | ||
| auto origElements = origType.getNumElements(); | ||
| bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0; | ||
| // Note, per-element-alignment was already verified above. | ||
| bool isFullyAligned = origElements % emulatedPerContainerElem == 0; | ||
|
|
||
| auto stridedMetadata = | ||
| rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase()); | ||
|
|
@@ -975,7 +1056,7 @@ struct ConvertVectorMaskedLoad final | |
| getAsOpFoldResult(adaptor.getIndices())); | ||
|
|
||
| std::optional<int64_t> foldedIntraVectorOffset = | ||
| isAlignedEmulation | ||
| isFullyAligned | ||
| ? 0 | ||
| : getConstantIntValue(linearizedInfo.intraDataOffset); | ||
|
|
||
|
|
@@ -1001,7 +1082,7 @@ struct ConvertVectorMaskedLoad final | |
| passthru = dynamicallyInsertSubVector( | ||
| rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset, | ||
| origElements); | ||
| } else if (!isAlignedEmulation) { | ||
| } else if (!isFullyAligned) { | ||
| passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector, | ||
| *foldedIntraVectorOffset); | ||
| } | ||
|
|
@@ -1029,7 +1110,7 @@ struct ConvertVectorMaskedLoad final | |
| mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask, | ||
| linearizedInfo.intraDataOffset, | ||
| origElements); | ||
| } else if (!isAlignedEmulation) { | ||
| } else if (!isFullyAligned) { | ||
| mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask, | ||
| *foldedIntraVectorOffset); | ||
| } | ||
|
|
@@ -1040,7 +1121,7 @@ struct ConvertVectorMaskedLoad final | |
| result = dynamicallyExtractSubVector( | ||
| rewriter, loc, result, op.getPassThru(), | ||
| linearizedInfo.intraDataOffset, origElements); | ||
| } else if (!isAlignedEmulation) { | ||
| } else if (!isFullyAligned) { | ||
| result = staticallyExtractSubvector( | ||
| rewriter, loc, result, *foldedIntraVectorOffset, origElements); | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.