diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index bf1ecd7d4559c..5d8a525ac87f1 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -432,7 +432,45 @@ namespace { // ConvertVectorStore //===----------------------------------------------------------------------===// -// TODO: Document-me +// Emulate `vector.store` using a multi-byte container type. +// +// The container type is obtained through Op adaptor and would normally be +// generated via `NarrowTypeEmulationConverter`. +// +// EXAMPLE 1 +// (aligned store of i4, emulated using i8 as the container type) +// +// vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4> +// +// is rewritten as: +// +// %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8> +// vector.store %src_bitcast, %dest_bitcast[%idx] +// : memref<16xi8>, vector<4xi8> +// +// EXAMPLE 2 +// (unaligned store of i2, emulated using i8 as the container type) +// +// vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2> +// +// The i2 store is emulated through 2 x RMW sequences. The destination i2 memref +// is modelled using 3 bytes: +// +// Byte 0 Byte 1 Byte 2 +// +----------+----------+----------+ +// | oooooooo | ooooNNNN | NNoooooo | +// +----------+----------+----------+ +// +// N - (N)ew entries (i.e. to be overwritten by vector.store) +// o - (o)ld entries (to be preserved) +// +// For the generated output in the non-atomic case, see: +// * @vector_store_i2_const_index_two_partial_stores` +// in: +// * "vector-emulate-narrow-type-unaligned-non-atomic.mlir". +// +// NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to +// `false` to generate non-atomic RMW sequences. struct ConvertVectorStore final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -464,7 +502,7 @@ struct ConvertVectorStore final : OpConversionPattern { op, "impossible to pack emulated elements into container elements " "(bit-wise misalignment)"); } - int numSrcElemsPerDest = containerBits / emulatedBits; + int emulatedPerContainerElem = containerBits / emulatedBits; // Adjust the number of elements to store when emulating narrow types. // Here only the 1-D vector store is considered, and the N-D memref types @@ -480,7 +518,8 @@ struct ConvertVectorStore final : OpConversionPattern { // vector<4xi8> auto origElements = valueToStore.getType().getNumElements(); - bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0; + // Note, per-element-alignment was already verified above. + bool isFullyAligned = origElements % emulatedPerContainerElem == 0; auto stridedMetadata = rewriter.create(loc, op.getBase()); @@ -496,9 +535,8 @@ struct ConvertVectorStore final : OpConversionPattern { getAsOpFoldResult(adaptor.getIndices())); std::optional foldedNumFrontPadElems = - isAlignedEmulation - ? 0 - : getConstantIntValue(linearizedInfo.intraDataOffset); + isFullyAligned ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); if (!foldedNumFrontPadElems) { return rewriter.notifyMatchFailure( @@ -516,10 +554,10 @@ struct ConvertVectorStore final : OpConversionPattern { // need unaligned emulation because the store address is aligned and the // source is a whole byte. bool emulationRequiresPartialStores = - !isAlignedEmulation || *foldedNumFrontPadElems != 0; + !isFullyAligned || *foldedNumFrontPadElems != 0; if (!emulationRequiresPartialStores) { // Basic case: storing full bytes. - auto numElements = origElements / numSrcElemsPerDest; + auto numElements = origElements / emulatedPerContainerElem; auto bitCast = rewriter.create( loc, VectorType::get(numElements, containerElemTy), op.getValueToStore()); @@ -567,7 +605,7 @@ struct ConvertVectorStore final : OpConversionPattern { // Build a mask used for rmw. auto subWidthStoreMaskType = - VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type()); + VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type()); auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW; @@ -576,10 +614,11 @@ struct ConvertVectorStore final : OpConversionPattern { // with the unaligned part so that the rest elements are aligned to width // boundary. auto frontSubWidthStoreElem = - (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest; + (emulatedPerContainerElem - *foldedNumFrontPadElems) % + emulatedPerContainerElem; if (frontSubWidthStoreElem > 0) { - SmallVector frontMaskValues(numSrcElemsPerDest, false); - if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) { + SmallVector frontMaskValues(emulatedPerContainerElem, false); + if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) { std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems, origElements, true); frontSubWidthStoreElem = origElements; @@ -590,7 +629,7 @@ struct ConvertVectorStore final : OpConversionPattern { auto frontMask = rewriter.create( loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues)); - currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems); + currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems); auto value = extractSliceIntoByte(rewriter, loc, valueToStore, 0, frontSubWidthStoreElem, *foldedNumFrontPadElems); @@ -614,8 +653,9 @@ struct ConvertVectorStore final : OpConversionPattern { // After the previous step, the store address is aligned to the emulated // width boundary. int64_t fullWidthStoreSize = - (origElements - currentSourceIndex) / numSrcElemsPerDest; - int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest; + (origElements - currentSourceIndex) / emulatedPerContainerElem; + int64_t numNonFullWidthElements = + fullWidthStoreSize * emulatedPerContainerElem; if (fullWidthStoreSize > 0) { auto fullWidthStorePart = staticallyExtractSubvector( rewriter, loc, valueToStore, currentSourceIndex, @@ -624,7 +664,8 @@ struct ConvertVectorStore final : OpConversionPattern { auto originType = cast(fullWidthStorePart.getType()); auto memrefElemType = getElementTypeOrSelf(memrefBase.getType()); auto storeType = VectorType::get( - {originType.getNumElements() / numSrcElemsPerDest}, memrefElemType); + {originType.getNumElements() / emulatedPerContainerElem}, + memrefElemType); auto bitCast = rewriter.create(loc, storeType, fullWidthStorePart); rewriter.create(loc, bitCast.getResult(), memrefBase, @@ -646,7 +687,7 @@ struct ConvertVectorStore final : OpConversionPattern { currentSourceIndex, remainingElements, 0); // Generate back mask. - auto maskValues = SmallVector(numSrcElemsPerDest, 0); + auto maskValues = SmallVector(emulatedPerContainerElem, 0); std::fill_n(maskValues.begin(), remainingElements, 1); auto backMask = rewriter.create( loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues)); @@ -960,7 +1001,8 @@ struct ConvertVectorMaskedLoad final // subvector at the proper offset after bit-casting. auto origType = op.getVectorType(); auto origElements = origType.getNumElements(); - bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0; + // Note, per-element-alignment was already verified above. + bool isFullyAligned = origElements % emulatedPerContainerElem == 0; auto stridedMetadata = rewriter.create(loc, op.getBase()); @@ -975,9 +1017,8 @@ struct ConvertVectorMaskedLoad final getAsOpFoldResult(adaptor.getIndices())); std::optional foldedIntraVectorOffset = - isAlignedEmulation - ? 0 - : getConstantIntValue(linearizedInfo.intraDataOffset); + isFullyAligned ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1); @@ -1001,7 +1042,7 @@ struct ConvertVectorMaskedLoad final passthru = dynamicallyInsertSubVector( rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset, origElements); - } else if (!isAlignedEmulation) { + } else if (!isFullyAligned) { passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset); } @@ -1029,7 +1070,7 @@ struct ConvertVectorMaskedLoad final mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask, linearizedInfo.intraDataOffset, origElements); - } else if (!isAlignedEmulation) { + } else if (!isFullyAligned) { mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask, *foldedIntraVectorOffset); } @@ -1040,7 +1081,7 @@ struct ConvertVectorMaskedLoad final result = dynamicallyExtractSubVector( rewriter, loc, result, op.getPassThru(), linearizedInfo.intraDataOffset, origElements); - } else if (!isAlignedEmulation) { + } else if (!isFullyAligned) { result = staticallyExtractSubvector( rewriter, loc, result, *foldedIntraVectorOffset, origElements); } diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir index 1d6263535ae80..d27e99a54529c 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir @@ -1,5 +1,7 @@ // RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8 disable-atomic-rmw=true" --cse --split-input-file %s | FileCheck %s +// NOTE: In this file all RMW stores are non-atomic. + // TODO: remove memref.alloc() in the tests to eliminate noises. // memref.alloc exists here because sub-byte vector data types such as i2 // are currently not supported as input arguments. @@ -8,121 +10,144 @@ /// vector.store ///---------------------------------------------------------------------------------------- -func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) { - %0 = memref.alloc() : memref<3x3xi2> +func.func @vector_store_i2_const_index_two_partial_stores(%src: vector<3xi2>) { + %dest = memref.alloc() : memref<3x3xi2> %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index - vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2> + vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2> return } -// Emit two non-atomic RMW partial stores. Store 6 bits from the input vector (bits [12:18)), -// into bytes [1:2] from a 3-byte output memref. Due to partial storing, -// both bytes are accessed partially through masking. - -// CHECK: func @vector_store_i2_const_index_two_partial_stores( -// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>) -// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8> -// CHECK: %[[C1:.+]] = arith.constant 1 : index - -// Part 1 RMW sequence -// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> -// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2> -// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]] -// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2> -// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]] -// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2> -// CHECK: %[[LOAD:.+]] = vector.load -// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2> -// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[DOWNCAST]] -// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[SELECT]] -// CHECK: vector.store %[[UPCAST]], %[[ALLOC]][%[[C1]]] - -// Part 2 RMW sequence -// CHECK: %[[OFFSET:.+]] = arith.addi %[[C1]], %[[C1]] : index -// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]] -// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2> -// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]] -// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2> -// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1> -// CHECK: %[[LOAD2:.+]] = vector.load -// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]] : vector<1xi8> to vector<4xi2> -// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]] -// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]] -// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[OFFSET]]] - +// Store 6 bits from the input vector into bytes [1:2] of a 3-byte destination +// memref, i.e. into bits [12:18) of a 24-bit destintion container +// (`memref<3x3xi2>` is emulated via `memref<3xi8>`). This requires two +// non-atomic RMW partial stores. Due to partial storing, both bytes are +// accessed partially through masking. + +// CHECK: func @vector_store_i2_const_index_two_partial_stores( +// CHECK-SAME: %[[SRC:.+]]: vector<3xi2>) + +// CHECK: %[[DEST:.+]] = memref.alloc() : memref<3xi8> +// CHECK: %[[C1:.+]] = arith.constant 1 : index + +// RMW sequence for Byte 1 +// CHECK: %[[MASK_1:.+]] = arith.constant dense<[false, false, true, true]> +// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<4xi2> +// CHECK: %[[SRC_SLICE_1:.+]] = vector.extract_strided_slice %[[SRC]] +// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2> +// CHECK: %[[INIT_WITH_SLICE_1:.+]] = vector.insert_strided_slice %[[SRC_SLICE_1]], %[[INIT]] +// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2> +// CHECK: %[[DEST_BYTE_1:.+]] = vector.load %[[DEST]][%[[C1]]] : memref<3xi8>, vector<1xi8> +// CHECK: %[[DEST_BYTE_1_AS_I2:.+]] = vector.bitcast %[[DEST_BYTE_1]] +// CHECK-SAME: vector<1xi8> to vector<4xi2> +// CHECK: %[[RES_BYTE_1:.+]] = arith.select %[[MASK_1]], %[[INIT_WITH_SLICE_1]], %[[DEST_BYTE_1_AS_I2]] +// CHECK: %[[RES_BYTE_1_AS_I8:.+]] = vector.bitcast %[[RES_BYTE_1]] +// CHECK-SAME: vector<4xi2> to vector<1xi8> +// CHECK: vector.store %[[RES_BYTE_1_AS_I8]], %[[DEST]][%[[C1]]] + +// RMW sequence for Byte 2 +// CHECK: %[[OFFSET:.+]] = arith.addi %[[C1]], %[[C1]] : index +// CHECK: %[[SRC_SLICE_2:.+]] = vector.extract_strided_slice %[[SRC]] +// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2> +// CHECK: %[[INIT_WITH_SLICE_2:.+]] = vector.insert_strided_slice %[[SRC_SLICE_2]], %[[INIT]] +// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2> +// CHECK: %[[MASK_2:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1> +// CHECK: %[[DEST_BYTE_2:.+]] = vector.load %[[DEST]][%[[OFFSET]]] : memref<3xi8>, vector<1xi8> +// CHECK: %[[DEST_BYTE_2_AS_I2:.+]] = vector.bitcast %[[DEST_BYTE_2]] +// CHECK-SAME: vector<1xi8> to vector<4xi2> +// CHECK: %[[RES_BYTE_2:.+]] = arith.select %[[MASK_2]], %[[INIT_WITH_SLICE_2]], %[[DEST_BYTE_2_AS_I2]] +// CHECK: %[[RES_BYTE_2_AS_I8:.+]] = vector.bitcast %[[RES_BYTE_2]] +// CHECK-SAME: vector<4xi2> to vector<1xi8> +// CHECK: vector.store %[[RES_BYTE_2_AS_I8]], %[[DEST]][%[[OFFSET]]] // ----- -func.func @vector_store_i2_two_partial_one_full_stores(%arg0: vector<7xi2>) { - %0 = memref.alloc() : memref<3x7xi2> +func.func @vector_store_i2_two_partial_one_full_stores(%src: vector<7xi2>) { + %dest = memref.alloc() : memref<3x7xi2> %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2> + vector.store %src, %dest[%c1, %c0] :memref<3x7xi2>, vector<7xi2> return } -// In this example, emit two RMW stores and one full-width store. - -// CHECK: func @vector_store_i2_two_partial_one_full_stores( -// CHECK-SAME: %[[ARG0:.+]]: -// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8> -// CHECK: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> -// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2> -// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]] -// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} -// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]] -// CHECK-SAME: {offsets = [3], strides = [1]} -// First sub-width RMW: -// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C1]]] -// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2> -// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]] -// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]] -// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]] +// Store 14 bits from the input vector into bytes [1:3] of a 6-byte destination +// memref, i.e. bits [15:29) of a 48-bit destination container memref +// (`memref<3x7xi2>` is emulated via `memref<6xi8>`). This requires two +// non-atomic RMW stores (for the "boundary" bytes) and one full byte store +// (for the "middle" byte). Note that partial stores require masking. + +// CHECK: func @vector_store_i2_two_partial_one_full_stores( +// CHECK-SAME: %[[SRC:.+]]: + +// CHECK: %[[DEST:.+]] = memref.alloc() : memref<6xi8> +// CHECK: %[[C1:.+]] = arith.constant 1 : index + +// First partial/RMW store: +// CHECK: %[[MASK_1:.+]] = arith.constant dense<[false, false, false, true]> +// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<4xi2> +// CHECK: %[[SRC_SLICE_0:.+]] = vector.extract_strided_slice %[[SRC]] +// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} +// CHECK: %[[INIT_WITH_SLICE_1:.+]] = vector.insert_strided_slice %[[SRC_SLICE_0]], %[[INIT]] +// CHECK-SAME: {offsets = [3], strides = [1]} +// CHECK: %[[DEST_BYTE_1:.+]] = vector.load %[[DEST]][%[[C1]]] +// CHECK: %[[DEST_BYTE_AS_I2:.+]] = vector.bitcast %[[DEST_BYTE_1]] +// CHECK-SAME: : vector<1xi8> to vector<4xi2> +// CHECK: %[[RES_BYTE_1:.+]] = arith.select %[[MASK_1]], %[[INIT_WITH_SLICE_1]], %[[DEST_BYTE_AS_I2]] +// CHECK: %[[RES_BYTE_1_AS_I8:.+]] = vector.bitcast %[[RES_BYTE_1]] +// CHECK-SAME: : vector<4xi2> to vector<1xi8> +// CHECK: vector.store %[[RES_BYTE_1_AS_I8]], %[[DEST]][%[[C1]]] // Full-width store: -// CHECK: %[[INDEX:.+]] = arith.addi %[[C1]], %[[C1]] -// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[ARG0]] -// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} -// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EXTRACT1]] -// CHECK: vector.store %[[BITCAST]], %[[ALLOC]][%[[INDEX]]] - -// Second sub-width RMW: -// CHECK: %[[INDEX2:.+]] = arith.addi %[[INDEX]], %[[C1]] -// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]] -// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} -// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]] -// CHECK-SAME: {offsets = [0], strides = [1]} -// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]> -// CHECK: %[[LOAD2:.+]] = vector.load %[[ALLOC]][%[[INDEX2]]] -// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]] -// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]] -// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]] -// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[INDEX2]]] +// CHECK: %[[C2:.+]] = arith.addi %[[C1]], %[[C1]] +// CHECK: %[[SRC_SLICE_1:.+]] = vector.extract_strided_slice %[[SRC]] +// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} +// CHECK: %[[SRC_SLICE_1_AS_I8:.+]] = vector.bitcast %[[SRC_SLICE_1]] +// CHECK-SAME: : vector<4xi2> to vector<1xi8> +// CHECK: vector.store %[[SRC_SLICE_1_AS_I8]], %[[DEST]][%[[C2]]] + +// Second partial/RMW store: +// CHECK: %[[C3:.+]] = arith.addi %[[C2]], %[[C1]] +// CHECK: %[[SRC_SLICE_2:.+]] = vector.extract_strided_slice %[[SRC]] +// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} +// CHECK: %[[INIT_WITH_SLICE2:.+]] = vector.insert_strided_slice %[[SRC_SLICE_2]] +// CHECK-SAME: {offsets = [0], strides = [1]} +// CHECK: %[[MASK_2:.+]] = arith.constant dense<[true, true, false, false]> +// CHECK: %[[DEST_BYTE_2:.+]] = vector.load %[[DEST]][%[[C3]]] +// CHECK: %[[DEST_BYTE_2_AS_I2:.+]] = vector.bitcast %[[DEST_BYTE_2]] +// CHECK: %[[RES_BYTE_2:.+]] = arith.select %[[MASK_2]], %[[INIT_WITH_SLICE2]], %[[DEST_BYTE_2_AS_I2]] +// CHECK: %[[RES_BYTE_2_AS_I8:.+]] = vector.bitcast %[[RES_BYTE_2]] +// CHECK-SAME: : vector<4xi2> to vector<1xi8> +// CHECK: vector.store %[[RES_BYTE_2_AS_I8]], %[[DEST]][%[[C3]]] // ----- -func.func @vector_store_i2_const_index_one_partial_store(%arg0: vector<1xi2>) { - %0 = memref.alloc() : memref<4x1xi2> +func.func @vector_store_i2_const_index_one_partial_store(%src: vector<1xi2>) { + %dest = memref.alloc() : memref<4x1xi2> %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2> + vector.store %src, %dest[%c1, %c0] :memref<4x1xi2>, vector<1xi2> return } -// in this test, only emit partial RMW store as the store is within one byte. - -// CHECK: func @vector_store_i2_const_index_one_partial_store( -// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>) -// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8> -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]> -// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2> -// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]] -// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2> -// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C0]]] : memref<1xi8>, vector<1xi8> -// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2> -// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]] -// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]] -// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C0]]] +// Store 2 bits from the input vector into byte 0 of a 1-byte destination +// memref, i.e. bits [3:5) of a 8-bit destination container memref +// (`` is emulated via `memref<1xi8>`). This requires one +// non-atomic RMW. + +// CHECK: func @vector_store_i2_const_index_one_partial_store( +// CHECK-SAME: %[[SRC:.+]]: vector<1xi2>) + +// CHECK: %[[DEST:.+]] = memref.alloc() : memref<1xi8> +// CHECK: %[[C0:.+]] = arith.constant 0 : index + +// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, false, false]> +// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<4xi2> +// CHECK: %[[INIT_WITH_SLICE:.+]] = vector.insert_strided_slice %[[SRC]], %[[INIT]] +// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2> +// CHECK: %[[DEST_BYTE:.+]] = vector.load %[[DEST]][%[[C0]]] : memref<1xi8>, vector<1xi8> +// CHECK: %[[DEST_BYTE_AS_I2:.+]] = vector.bitcast %[[DEST_BYTE]] +// CHECK-SAME: : vector<1xi8> to vector<4xi2> +// CHECK: %[[RES_BYTE:.+]] = arith.select %[[MASK]], %[[INIT_WITH_SLICE]], %[[DEST_BYTE_AS_I2]] +// CHECK: %[[RES_BYTE_AS_I8:.+]] = vector.bitcast %[[RES_BYTE]] +// CHECK-SAME: : vector<4xi2> to vector<1xi8> +// CHECK: vector.store %[[RES_BYTE_AS_I8]], %[[DEST]][%[[C0]]]