diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 7ca88f1e0a0df..63365cb544612 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -415,18 +415,21 @@ struct ConvertVectorStore final : OpConversionPattern { "only 1-D vectors are supported ATM"); auto loc = op.getLoc(); + auto valueToStore = cast(op.getValueToStore()); - auto oldElementType = valueToStore.getType().getElementType(); - auto newElementType = + auto containerElemTy = cast(adaptor.getBase().getType()).getElementType(); - int srcBits = oldElementType.getIntOrFloatBitWidth(); - int dstBits = newElementType.getIntOrFloatBitWidth(); + Type emulatedElemTy = op.getValueToStore().getType().getElementType(); + int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth(); + int containerBits = containerElemTy.getIntOrFloatBitWidth(); - if (dstBits % srcBits != 0) { + // Check per-element alignment. + if (containerBits % emulatedBits != 0) { return rewriter.notifyMatchFailure( - op, "only dstBits % srcBits == 0 supported"); + op, "impossible to pack emulated elements into container elements " + "(bit-wise misalignment)"); } - int numSrcElemsPerDest = dstBits / srcBits; + int numSrcElemsPerDest = containerBits / emulatedBits; // Adjust the number of elements to store when emulating narrow types. // Here only the 1-D vector store is considered, and the N-D memref types @@ -451,7 +454,7 @@ struct ConvertVectorStore final : OpConversionPattern { memref::LinearizedMemRefInfo linearizedInfo; std::tie(linearizedInfo, linearizedIndices) = memref::getLinearizedMemRefOffsetAndSize( - rewriter, loc, srcBits, dstBits, + rewriter, loc, emulatedBits, containerBits, stridedMetadata.getConstifiedMixedOffset(), stridedMetadata.getConstifiedMixedSizes(), stridedMetadata.getConstifiedMixedStrides(), @@ -483,7 +486,7 @@ struct ConvertVectorStore final : OpConversionPattern { // Basic case: storing full bytes. auto numElements = origElements / numSrcElemsPerDest; auto bitCast = rewriter.create( - loc, VectorType::get(numElements, newElementType), + loc, VectorType::get(numElements, containerElemTy), op.getValueToStore()); rewriter.replaceOpWithNewOp( op, bitCast.getResult(), memrefBase, @@ -638,18 +641,20 @@ struct ConvertVectorMaskedStore final "only 1-D vectors are supported ATM"); auto loc = op.getLoc(); - auto convertedType = cast(adaptor.getBase().getType()); - Type oldElementType = op.getValueToStore().getType().getElementType(); - Type newElementType = convertedType.getElementType(); - int srcBits = oldElementType.getIntOrFloatBitWidth(); - int dstBits = newElementType.getIntOrFloatBitWidth(); + auto containerElemTy = + cast(adaptor.getBase().getType()).getElementType(); + Type emulatedElemTy = op.getValueToStore().getType().getElementType(); + int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth(); + int containerBits = containerElemTy.getIntOrFloatBitWidth(); - if (dstBits % srcBits != 0) { + // Check per-element alignment. + if (containerBits % emulatedBits != 0) { return rewriter.notifyMatchFailure( - op, "only dstBits % srcBits == 0 supported"); + op, "impossible to pack emulated elements into container elements " + "(bit-wise misalignment)"); } - int scale = dstBits / srcBits; + int scale = containerBits / emulatedBits; int origElements = op.getValueToStore().getType().getNumElements(); if (origElements % scale != 0) return failure(); @@ -660,7 +665,7 @@ struct ConvertVectorMaskedStore final memref::LinearizedMemRefInfo linearizedInfo; std::tie(linearizedInfo, linearizedIndicesOfr) = memref::getLinearizedMemRefOffsetAndSize( - rewriter, loc, srcBits, dstBits, + rewriter, loc, emulatedBits, containerBits, stridedMetadata.getConstifiedMixedOffset(), stridedMetadata.getConstifiedMixedSizes(), stridedMetadata.getConstifiedMixedStrides(), @@ -706,7 +711,7 @@ struct ConvertVectorMaskedStore final return failure(); auto numElements = (origElements + scale - 1) / scale; - auto newType = VectorType::get(numElements, newElementType); + auto newType = VectorType::get(numElements, containerElemTy); auto passThru = rewriter.create( loc, newType, rewriter.getZeroAttr(newType)); @@ -714,7 +719,7 @@ struct ConvertVectorMaskedStore final loc, newType, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0), passThru); - auto newBitCastType = VectorType::get(numElements * scale, oldElementType); + auto newBitCastType = VectorType::get(numElements * scale, emulatedElemTy); Value valueToStore = rewriter.create(loc, newBitCastType, newLoad); valueToStore = rewriter.create( @@ -746,17 +751,19 @@ struct ConvertVectorLoad final : OpConversionPattern { "only 1-D vectors are supported ATM"); auto loc = op.getLoc(); - auto convertedType = cast(adaptor.getBase().getType()); - Type oldElementType = op.getType().getElementType(); - Type newElementType = convertedType.getElementType(); - int srcBits = oldElementType.getIntOrFloatBitWidth(); - int dstBits = newElementType.getIntOrFloatBitWidth(); + auto containerElemTy = + cast(adaptor.getBase().getType()).getElementType(); + Type emulatedElemTy = op.getType().getElementType(); + int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth(); + int containerBits = containerElemTy.getIntOrFloatBitWidth(); - if (dstBits % srcBits != 0) { + // Check per-element alignment. + if (containerBits % emulatedBits != 0) { return rewriter.notifyMatchFailure( - op, "only dstBits % srcBits == 0 supported"); + op, "impossible to pack emulated elements into container elements " + "(bit-wise misalignment)"); } - int scale = dstBits / srcBits; + int scale = containerBits / emulatedBits; // Adjust the number of elements to load when emulating narrow types, // and then cast back to the original type with vector.bitcast op. @@ -797,7 +804,7 @@ struct ConvertVectorLoad final : OpConversionPattern { memref::LinearizedMemRefInfo linearizedInfo; std::tie(linearizedInfo, linearizedIndices) = memref::getLinearizedMemRefOffsetAndSize( - rewriter, loc, srcBits, dstBits, + rewriter, loc, emulatedBits, containerBits, stridedMetadata.getConstifiedMixedOffset(), stridedMetadata.getConstifiedMixedSizes(), stridedMetadata.getConstifiedMixedStrides(), @@ -814,7 +821,7 @@ struct ConvertVectorLoad final : OpConversionPattern { llvm::divideCeil(maxintraDataOffset + origElements, scale); Value result = emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices, - numElements, oldElementType, newElementType); + numElements, emulatedElemTy, containerElemTy); if (!foldedIntraVectorOffset) { auto resultVector = rewriter.create( @@ -848,17 +855,20 @@ struct ConvertVectorMaskedLoad final "only 1-D vectors are supported ATM"); auto loc = op.getLoc(); - auto convertedType = cast(adaptor.getBase().getType()); - Type oldElementType = op.getType().getElementType(); - Type newElementType = convertedType.getElementType(); - int srcBits = oldElementType.getIntOrFloatBitWidth(); - int dstBits = newElementType.getIntOrFloatBitWidth(); - if (dstBits % srcBits != 0) { + auto containerElemTy = + cast(adaptor.getBase().getType()).getElementType(); + Type emulatedElemTy = op.getType().getElementType(); + int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth(); + int containerBits = containerElemTy.getIntOrFloatBitWidth(); + + // Check per-element alignment. + if (containerBits % emulatedBits != 0) { return rewriter.notifyMatchFailure( - op, "only dstBits % srcBits == 0 supported"); + op, "impossible to pack emulated elements into container elements " + "(bit-wise misalignment)"); } - int scale = dstBits / srcBits; + int scale = containerBits / emulatedBits; // Adjust the number of elements to load when emulating narrow types, // and then cast back to the original type with vector.bitcast op. @@ -912,7 +922,7 @@ struct ConvertVectorMaskedLoad final memref::LinearizedMemRefInfo linearizedInfo; std::tie(linearizedInfo, linearizedIndices) = memref::getLinearizedMemRefOffsetAndSize( - rewriter, loc, srcBits, dstBits, + rewriter, loc, emulatedBits, containerBits, stridedMetadata.getConstifiedMixedOffset(), stridedMetadata.getConstifiedMixedSizes(), stridedMetadata.getConstifiedMixedStrides(), @@ -933,8 +943,8 @@ struct ConvertVectorMaskedLoad final auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements, scale); - auto loadType = VectorType::get(numElements, newElementType); - auto newBitcastType = VectorType::get(numElements * scale, oldElementType); + auto loadType = VectorType::get(numElements, containerElemTy); + auto newBitcastType = VectorType::get(numElements * scale, emulatedElemTy); auto emptyVector = rewriter.create( loc, newBitcastType, rewriter.getZeroAttr(newBitcastType)); @@ -1009,23 +1019,25 @@ struct ConvertVectorTransferRead final "only 1-D vectors are supported ATM"); auto loc = op.getLoc(); - auto convertedType = cast(adaptor.getSource().getType()); - Type oldElementType = op.getType().getElementType(); - Type newElementType = convertedType.getElementType(); - int srcBits = oldElementType.getIntOrFloatBitWidth(); - int dstBits = newElementType.getIntOrFloatBitWidth(); - - if (dstBits % srcBits != 0) { + auto containerElemTy = + cast(adaptor.getSource().getType()).getElementType(); + Type emulatedElemTy = op.getType().getElementType(); + int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth(); + int containerBits = containerElemTy.getIntOrFloatBitWidth(); + + // Check per-element alignment. + if (containerBits % emulatedBits != 0) { return rewriter.notifyMatchFailure( - op, "only dstBits % srcBits == 0 supported"); + op, "impossible to pack emulated elements into container elements " + "(bit-wise misalignment)"); } - int scale = dstBits / srcBits; + int scale = containerBits / emulatedBits; auto origElements = op.getVectorType().getNumElements(); bool isAlignedEmulation = origElements % scale == 0; - auto newPadding = rewriter.create(loc, newElementType, + auto newPadding = rewriter.create(loc, containerElemTy, adaptor.getPadding()); auto stridedMetadata = @@ -1035,7 +1047,7 @@ struct ConvertVectorTransferRead final memref::LinearizedMemRefInfo linearizedInfo; std::tie(linearizedInfo, linearizedIndices) = memref::getLinearizedMemRefOffsetAndSize( - rewriter, loc, srcBits, dstBits, + rewriter, loc, emulatedBits, containerBits, stridedMetadata.getConstifiedMixedOffset(), stridedMetadata.getConstifiedMixedSizes(), stridedMetadata.getConstifiedMixedStrides(), @@ -1051,12 +1063,12 @@ struct ConvertVectorTransferRead final llvm::divideCeil(maxIntraDataOffset + origElements, scale); auto newRead = rewriter.create( - loc, VectorType::get(numElements, newElementType), adaptor.getSource(), + loc, VectorType::get(numElements, containerElemTy), adaptor.getSource(), getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices), newPadding); auto bitCast = rewriter.create( - loc, VectorType::get(numElements * scale, oldElementType), newRead); + loc, VectorType::get(numElements * scale, emulatedElemTy), newRead); Value result = bitCast->getResult(0); if (!foldedIntraVectorOffset) {