From 5eebcc0daa1f1594955159e3d3ea13512dcacb41 Mon Sep 17 00:00:00 2001 From: Ubuntu <450283+lialan@users.noreply.github.com> Date: Wed, 30 Oct 2024 19:37:11 +0000 Subject: [PATCH 1/7] Implement dynamic indexing for MaskedLoads --- .../Transforms/VectorEmulateNarrowType.cpp | 101 ++++++++++++------ .../vector-emulate-narrow-type-unaligned.mlir | 52 +++++++++ 2 files changed, 120 insertions(+), 33 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index f169dab3bdd9a..3c94e992d695c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -53,6 +53,7 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int origElements, int scale, int intraDataOffset = 0) { + assert(intraDataOffset < scale && "intraDataOffset must be less than scale"); auto numElements = (intraDataOffset + origElements + scale - 1) / scale; Operation *maskOp = mask.getDefiningOp(); @@ -182,6 +183,27 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, return dest; } +/// Inserts a 1-D subvector into a 1-D `dest` vector at index `offset`. +static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, + TypedValue source, + Value dest, OpFoldResult destOffsetVar, + int64_t length) { + assert(length > 0 && "length must be greater than 0"); + for (int i = 0; i < length; ++i) { + Value insertLoc; + if (i == 0) { + insertLoc = destOffsetVar.dyn_cast(); + } else { + insertLoc = rewriter.create( + loc, rewriter.getIndexType(), destOffsetVar.dyn_cast(), + rewriter.create(loc, i)); + } + auto extractOp = rewriter.create(loc, source, i); + dest = rewriter.create(loc, extractOp, dest, insertLoc); + } + return dest; +} + /// Returns the op sequence for an emulated sub-byte data type vector load. /// specifically, use `emulatedElemType` for loading a vector of `origElemType`. /// The load location is given by `base` and `linearizedIndices`, and the @@ -199,7 +221,7 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, return rewriter.create( loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType), newLoad); -}; +} namespace { @@ -546,29 +568,30 @@ struct ConvertVectorMaskedLoad final ? getConstantIntValue(linearizedInfo.intraDataOffset) : 0; - if (!foldedIntraVectorOffset) { - // unimplemented case for dynamic intra vector offset - return failure(); - } - - FailureOr newMask = - getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale, - *foldedIntraVectorOffset); + auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1); + FailureOr newMask = getCompressedMaskOp( + rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset); if (failed(newMask)) return failure(); + Value passthru = op.getPassThru(); + auto numElements = - llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale); + llvm::divideCeil(maxIntraDataOffset + origElements, scale); auto loadType = VectorType::get(numElements, newElementType); auto newBitcastType = VectorType::get(numElements * scale, oldElementType); - Value passthru = op.getPassThru(); - if (isUnalignedEmulation) { - // create an empty vector of the new type - auto emptyVector = rewriter.create( - loc, newBitcastType, rewriter.getZeroAttr(newBitcastType)); - passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector, - *foldedIntraVectorOffset); + auto emptyVector = rewriter.create( + loc, newBitcastType, rewriter.getZeroAttr(newBitcastType)); + if (foldedIntraVectorOffset) { + if (isUnalignedEmulation) { + passthru = staticallyInsertSubvector( + rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset); + } + } else { + passthru = dynamicallyInsertSubVector( + rewriter, loc, dyn_cast>(passthru), + emptyVector, linearizedInfo.intraDataOffset, origElements); } auto newPassThru = rewriter.create(loc, loadType, passthru); @@ -585,23 +608,36 @@ struct ConvertVectorMaskedLoad final rewriter.create(loc, newBitcastType, newLoad); Value mask = op.getMask(); - if (isUnalignedEmulation) { - auto newSelectMaskType = - VectorType::get(numElements * scale, rewriter.getI1Type()); - // TODO: can fold if op's mask is constant - auto emptyVector = rewriter.create( - loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType)); - mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyVector, - *foldedIntraVectorOffset); + auto newSelectMaskType = + VectorType::get(numElements * scale, rewriter.getI1Type()); + // TODO: try to fold if op's mask is constant + auto emptyMask = rewriter.create( + loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType)); + if (foldedIntraVectorOffset) { + if (isUnalignedEmulation) { + mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask, + *foldedIntraVectorOffset); + } + } else { + mask = dynamicallyInsertSubVector( + rewriter, loc, dyn_cast>(mask), emptyMask, + linearizedInfo.intraDataOffset, origElements); } Value result = rewriter.create(loc, mask, bitCast, passthru); - - if (isUnalignedEmulation) { - result = - staticallyExtractSubvector(rewriter, loc, op.getType(), result, - *foldedIntraVectorOffset, origElements); + if (foldedIntraVectorOffset) { + if (isUnalignedEmulation) { + result = + staticallyExtractSubvector(rewriter, loc, op.getType(), result, + *foldedIntraVectorOffset, origElements); + } + } else { + auto resultVector = rewriter.create( + loc, op.getType(), rewriter.getZeroAttr(op.getType())); + result = dynamicallyExtractSubVector( + rewriter, loc, dyn_cast>(result), resultVector, + linearizedInfo.intraDataOffset, origElements); } rewriter.replaceOp(op, result); @@ -659,10 +695,9 @@ struct ConvertVectorTransferRead final ? getConstantIntValue(linearizedInfo.intraDataOffset) : 0; - auto maxIntraVectorOffset = - foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1; + auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1); auto numElements = - llvm::divideCeil(maxIntraVectorOffset + origElements, scale); + llvm::divideCeil(maxIntraDataOffset + origElements, scale); auto newRead = rewriter.create( loc, VectorType::get(numElements, newElementType), adaptor.getSource(), diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir index 0cecaddc5733e..efa31b8bf5ac7 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir @@ -183,3 +183,55 @@ func.func @vector_transfer_read_i2_dynamic_indexing_mixed(%idx1: index) -> vecto // CHECK: %[[C2:.+]] = arith.constant 2 : index // CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index // CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2> +// ----- + +func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, %idx: index) -> vector<3xi2> { + %0 = memref.alloc() : memref<3x3xi2> + %cst = arith.constant dense<0> : vector<3x3xi2> + %c2 = arith.constant 2 : index + %mask = vector.constant_mask [3] : vector<3xi1> + %1 = vector.maskedload %0[%idx, %c2], %mask, %passthru : + memref<3x3xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2> + return %1 : vector<3xi2> +} + +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 * 3 + 2) floordiv 4)> +// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3 - ((s0 * 3 + 2) floordiv 4) * 4 + 2)> +// CHECK: func @vector_maskedload_i2_dynamic_indexing_mixed( +// CHECK-SAME: %[[PTH:.+]]: vector<3xi2>, %[[IDX:.+]]: index) -> vector<3xi2> +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8> +// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1> +// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]] +// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]] +// CHECK: %[[ONE:.+]] = arith.constant dense : vector<2xi1> +// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2> +// CHECK: %[[EX1:.+]] = vector.extract %[[PTH]][0] : i2 from vector<3xi2> +// CHECK: %[[IN1:.+]] = vector.insert %[[EX1]], %[[ZERO]] [%[[LINEAR2]]] : i2 into vector<8xi2> +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[INCIDX:.+]] = arith.addi %[[LINEAR2]], %[[C1]] : index +// CHECK: %[[EX2:.+]] = vector.extract %[[PTH]][1] : i2 from vector<3xi2> +// CHECK: %[[IN2:.+]] = vector.insert %[[EX2]], %[[IN1]] [%[[INCIDX]]] : i2 into vector<8xi2> +// CHECK: %[[C2:.+]] = arith.constant 2 : index +// CHECK: %[[INCIDX2:.+]] = arith.addi %[[LINEAR2]], %[[C2]] : index +// CHECK: %[[EX3:.+]] = vector.extract %[[PTH]][2] : i2 from vector<3xi2> +// CHECK: %[[IN3:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2> +// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[IN3]] : vector<8xi2> to vector<2xi8> +// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BITCAST]] +// CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8> +// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2> +// extracts: +// CHECK: %[[CST1:.+]] = arith.constant dense : vector<8xi1> +// CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1> +// CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1> +// CHECK: %[[EX5:.+]] = vector.extract %[[MASK]][1] : i1 from vector<3xi1> +// CHECK: %[[IN5:.+]] = vector.insert %[[EX5]], %[[IN4]] [%[[INCIDX]]] : i1 into vector<8xi1> +// CHECK: %[[EX6:.+]] = vector.extract %[[MASK]][2] : i1 from vector<3xi1> +// CHECK: %[[IN6:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1> +// CHECK: %[[SELECT:.+]] = arith.select %[[IN6]], %[[BITCAST2]], %[[IN3]] : vector<8xi1>, vector<8xi2> +// CHECK: %[[CST2:.+]] = arith.constant dense<0> : vector<3xi2> +// CHECK: %[[EX7:.+]] = vector.extract %[[SELECT]][%[[LINEAR2]]] : i2 from vector<8xi2> +// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[CST2]] [0] : i2 into vector<3xi2> +// CHECK: %[[EX8:.+]] = vector.extract %[[SELECT]][%[[INCIDX]]] : i2 from vector<8xi2> +// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2> +// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2> +// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2> From f3c2d3ac5a1ba10c1a571c6bed01ed83161b8fea Mon Sep 17 00:00:00 2001 From: Ubuntu <450283+lialan@users.noreply.github.com> Date: Tue, 5 Nov 2024 21:32:23 +0000 Subject: [PATCH 2/7] Small update --- .../Vector/Transforms/VectorEmulateNarrowType.cpp | 14 ++++++-------- .../vector-emulate-narrow-type-unaligned.mlir | 1 - 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 3c94e992d695c..56273ac2899d7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -190,14 +190,12 @@ static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, int64_t length) { assert(length > 0 && "length must be greater than 0"); for (int i = 0; i < length; ++i) { - Value insertLoc; - if (i == 0) { - insertLoc = destOffsetVar.dyn_cast(); - } else { - insertLoc = rewriter.create( - loc, rewriter.getIndexType(), destOffsetVar.dyn_cast(), - rewriter.create(loc, i)); - } + Value insertLoc = + 1 == 0 + ? destOffsetVar.dyn_cast() + : rewriter.create( + loc, rewriter.getIndexType(), destOffsetVar.dyn_cast(), + rewriter.create(loc, i)); auto extractOp = rewriter.create(loc, source, i); dest = rewriter.create(loc, extractOp, dest, insertLoc); } diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir index efa31b8bf5ac7..6a10a2f9ed32f 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir @@ -219,7 +219,6 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, // CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BITCAST]] // CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8> // CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2> -// extracts: // CHECK: %[[CST1:.+]] = arith.constant dense : vector<8xi1> // CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1> // CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1> From 94ab287240cf9d347b322fffc5ac878c4a558431 Mon Sep 17 00:00:00 2001 From: Ubuntu <450283+lialan@users.noreply.github.com> Date: Tue, 5 Nov 2024 22:49:27 +0000 Subject: [PATCH 3/7] fix --- mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 56273ac2899d7..dabb137351601 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -191,7 +191,7 @@ static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, assert(length > 0 && "length must be greater than 0"); for (int i = 0; i < length; ++i) { Value insertLoc = - 1 == 0 + i == 0 ? destOffsetVar.dyn_cast() : rewriter.create( loc, rewriter.getIndexType(), destOffsetVar.dyn_cast(), From 21bd52c3aa08ae4b98a4d7059f2d2d3d0c453a58 Mon Sep 17 00:00:00 2001 From: hasekawa-takumi <167335845+hasekawa-takumi@users.noreply.github.com> Date: Thu, 7 Nov 2024 23:10:35 -0500 Subject: [PATCH 4/7] Update --- .../Transforms/VectorEmulateNarrowType.cpp | 6 ++---- .../vector-emulate-narrow-type-unaligned.mlir | 18 ++++++++++++++++-- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index dabb137351601..9c565c6881c4e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -631,11 +631,9 @@ struct ConvertVectorMaskedLoad final *foldedIntraVectorOffset, origElements); } } else { - auto resultVector = rewriter.create( - loc, op.getType(), rewriter.getZeroAttr(op.getType())); result = dynamicallyExtractSubVector( - rewriter, loc, dyn_cast>(result), resultVector, - linearizedInfo.intraDataOffset, origElements); + rewriter, loc, dyn_cast>(result), + op.getPassThru(), linearizedInfo.intraDataOffset, origElements); } rewriter.replaceOp(op, result); diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir index 6a10a2f9ed32f..6d37493d174a2 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir @@ -205,6 +205,8 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, // CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]] // CHECK: %[[ONE:.+]] = arith.constant dense : vector<2xi1> // CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2> + +// extract passthru vector, and insert into zero vector, this is for constructing a new passthru // CHECK: %[[EX1:.+]] = vector.extract %[[PTH]][0] : i2 from vector<3xi2> // CHECK: %[[IN1:.+]] = vector.insert %[[EX1]], %[[ZERO]] [%[[LINEAR2]]] : i2 into vector<8xi2> // CHECK: %[[C1:.+]] = arith.constant 1 : index @@ -215,21 +217,33 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, // CHECK: %[[INCIDX2:.+]] = arith.addi %[[LINEAR2]], %[[C2]] : index // CHECK: %[[EX3:.+]] = vector.extract %[[PTH]][2] : i2 from vector<3xi2> // CHECK: %[[IN3:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2> + +// bitcast the new passthru vector to emulated i8 vector // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[IN3]] : vector<8xi2> to vector<2xi8> + +// use the emulated i8 vector to masked load from the memory // CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BITCAST]] // CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8> + +// bitcast back to i2 vector // CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2> + // CHECK: %[[CST1:.+]] = arith.constant dense : vector<8xi1> + +// create a mask vector and select passthru part from the loaded vector. +// note that if indices are known then we can fold the part generating mask. // CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1> // CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1> // CHECK: %[[EX5:.+]] = vector.extract %[[MASK]][1] : i1 from vector<3xi1> // CHECK: %[[IN5:.+]] = vector.insert %[[EX5]], %[[IN4]] [%[[INCIDX]]] : i1 into vector<8xi1> // CHECK: %[[EX6:.+]] = vector.extract %[[MASK]][2] : i1 from vector<3xi1> // CHECK: %[[IN6:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1> + // CHECK: %[[SELECT:.+]] = arith.select %[[IN6]], %[[BITCAST2]], %[[IN3]] : vector<8xi1>, vector<8xi2> -// CHECK: %[[CST2:.+]] = arith.constant dense<0> : vector<3xi2> + +// finally, insert the selected parts into actual passthru vector. // CHECK: %[[EX7:.+]] = vector.extract %[[SELECT]][%[[LINEAR2]]] : i2 from vector<8xi2> -// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[CST2]] [0] : i2 into vector<3xi2> +// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[PTH]] [0] : i2 into vector<3xi2> // CHECK: %[[EX8:.+]] = vector.extract %[[SELECT]][%[[INCIDX]]] : i2 from vector<8xi2> // CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2> // CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2> From d6437e9e82a92b5653d50624a3add9a15a7cb68c Mon Sep 17 00:00:00 2001 From: Alan Li Date: Mon, 11 Nov 2024 09:34:42 -0500 Subject: [PATCH 5/7] update comments --- .../Transforms/VectorEmulateNarrowType.cpp | 6 +++-- .../vector-emulate-narrow-type-unaligned.mlir | 27 ++++++++++--------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index ef072638af26e..58b799b028694 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -43,7 +43,9 @@ using namespace mlir; /// /// %mask = [1, 1, 0, 0, 0, 0] /// -/// will first be padded with number of `intraDataOffset` zeros: +/// will first be padded in the front with number of `intraDataOffset` zeros, +/// and pad zeros in the back to make the number of elements a multiple of +/// `scale` (just to make it easier to compute). The new mask will be: /// %mask = [0, 1, 1, 0, 0, 0, 0, 0] /// /// then it will return the following new compressed mask: @@ -54,7 +56,7 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, int origElements, int scale, int intraDataOffset = 0) { assert(intraDataOffset < scale && "intraDataOffset must be less than scale"); - auto numElements = (intraDataOffset + origElements + scale - 1) / scale; + auto numElements = llvm::divideCeil(intraDataOffset + origElements, scale); Operation *maskOp = mask.getDefiningOp(); SmallVector extractOps; diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir index 6d37493d174a2..7ed75ff7f1579 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir @@ -206,7 +206,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, // CHECK: %[[ONE:.+]] = arith.constant dense : vector<2xi1> // CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2> -// extract passthru vector, and insert into zero vector, this is for constructing a new passthru +// Extract passthru vector, and insert into zero vector, this is for constructing a new passthru // CHECK: %[[EX1:.+]] = vector.extract %[[PTH]][0] : i2 from vector<3xi2> // CHECK: %[[IN1:.+]] = vector.insert %[[EX1]], %[[ZERO]] [%[[LINEAR2]]] : i2 into vector<8xi2> // CHECK: %[[C1:.+]] = arith.constant 1 : index @@ -216,32 +216,33 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, // CHECK: %[[C2:.+]] = arith.constant 2 : index // CHECK: %[[INCIDX2:.+]] = arith.addi %[[LINEAR2]], %[[C2]] : index // CHECK: %[[EX3:.+]] = vector.extract %[[PTH]][2] : i2 from vector<3xi2> -// CHECK: %[[IN3:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2> +// CHECK: %[[NEW_PASSTHRU:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2> -// bitcast the new passthru vector to emulated i8 vector -// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[IN3]] : vector<8xi2> to vector<2xi8> +// Bitcast the new passthru vector to emulated i8 vector +// CHECK: %[[BCAST_PASSTHRU:.+]] = vector.bitcast %[[NEW_PASSTHRU]] : vector<8xi2> to vector<2xi8> -// use the emulated i8 vector to masked load from the memory -// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BITCAST]] +// Use the emulated i8 vector for masked load from the source memory +// CHECK: %[[SOURCE:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BCAST_PASSTHRU]] // CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8> -// bitcast back to i2 vector -// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2> +// Bitcast back to i2 vector +// CHECK: %[[BCAST_MASKLOAD:.+]] = vector.bitcast %[[SOURCE]] : vector<2xi8> to vector<8xi2> // CHECK: %[[CST1:.+]] = arith.constant dense : vector<8xi1> -// create a mask vector and select passthru part from the loaded vector. -// note that if indices are known then we can fold the part generating mask. +// Create a mask vector +// Note that if indices are known then we can fold the part generating mask. // CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1> // CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1> // CHECK: %[[EX5:.+]] = vector.extract %[[MASK]][1] : i1 from vector<3xi1> // CHECK: %[[IN5:.+]] = vector.insert %[[EX5]], %[[IN4]] [%[[INCIDX]]] : i1 into vector<8xi1> // CHECK: %[[EX6:.+]] = vector.extract %[[MASK]][2] : i1 from vector<3xi1> -// CHECK: %[[IN6:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1> +// CHECK: %[[NEW_MASK:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1> -// CHECK: %[[SELECT:.+]] = arith.select %[[IN6]], %[[BITCAST2]], %[[IN3]] : vector<8xi1>, vector<8xi2> +// Select the effective part from the source and passthru vectors +// CHECK: %[[SELECT:.+]] = arith.select %[[NEW_MASK]], %[[BCAST_MASKLOAD]], %[[NEW_PASSTHRU]] : vector<8xi1>, vector<8xi2> -// finally, insert the selected parts into actual passthru vector. +// Finally, insert the selected parts into actual passthru vector. // CHECK: %[[EX7:.+]] = vector.extract %[[SELECT]][%[[LINEAR2]]] : i2 from vector<8xi2> // CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[PTH]] [0] : i2 into vector<3xi2> // CHECK: %[[EX8:.+]] = vector.extract %[[SELECT]][%[[INCIDX]]] : i2 from vector<8xi2> From 9e7aedfe8c88ad2be2426ac394456d83697149ed Mon Sep 17 00:00:00 2001 From: Alan Li Date: Mon, 11 Nov 2024 19:43:11 -0500 Subject: [PATCH 6/7] fix according to comments --- .../Transforms/VectorEmulateNarrowType.cpp | 73 ++++++++----------- 1 file changed, 32 insertions(+), 41 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 58b799b028694..604d261b4513d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -194,13 +194,14 @@ static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value dest, OpFoldResult destOffsetVar, int64_t length) { assert(length > 0 && "length must be greater than 0"); + Value destOffsetVal = + getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar); for (int i = 0; i < length; ++i) { - Value insertLoc = - i == 0 - ? destOffsetVar.dyn_cast() - : rewriter.create( - loc, rewriter.getIndexType(), destOffsetVar.dyn_cast(), - rewriter.create(loc, i)); + auto insertLoc = i == 0 + ? destOffsetVal + : rewriter.create( + loc, rewriter.getIndexType(), destOffsetVal, + rewriter.create(loc, i)); auto extractOp = rewriter.create(loc, source, i); dest = rewriter.create(loc, extractOp, dest, insertLoc); } @@ -465,18 +466,16 @@ struct ConvertVectorLoad final : OpConversionPattern { emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices, numElements, oldElementType, newElementType); - if (foldedIntraVectorOffset) { - if (isUnalignedEmulation) { - result = - staticallyExtractSubvector(rewriter, loc, op.getType(), result, - *foldedIntraVectorOffset, origElements); - } - } else { + if (!foldedIntraVectorOffset) { auto resultVector = rewriter.create( loc, op.getType(), rewriter.getZeroAttr(op.getType())); result = dynamicallyExtractSubVector( rewriter, loc, dyn_cast>(result), resultVector, linearizedInfo.intraDataOffset, origElements); + } else if (isUnalignedEmulation) { + result = + staticallyExtractSubvector(rewriter, loc, op.getType(), result, + *foldedIntraVectorOffset, origElements); } rewriter.replaceOp(op, result); return success(); @@ -571,7 +570,7 @@ struct ConvertVectorMaskedLoad final ? getConstantIntValue(linearizedInfo.intraDataOffset) : 0; - auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1); + int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1); FailureOr newMask = getCompressedMaskOp( rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset); if (failed(newMask)) @@ -586,15 +585,13 @@ struct ConvertVectorMaskedLoad final auto emptyVector = rewriter.create( loc, newBitcastType, rewriter.getZeroAttr(newBitcastType)); - if (foldedIntraVectorOffset) { - if (isUnalignedEmulation) { - passthru = staticallyInsertSubvector( - rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset); - } - } else { + if (!foldedIntraVectorOffset) { passthru = dynamicallyInsertSubVector( rewriter, loc, dyn_cast>(passthru), emptyVector, linearizedInfo.intraDataOffset, origElements); + } else if (isUnalignedEmulation) { + passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector, + *foldedIntraVectorOffset); } auto newPassThru = rewriter.create(loc, loadType, passthru); @@ -616,29 +613,25 @@ struct ConvertVectorMaskedLoad final // TODO: try to fold if op's mask is constant auto emptyMask = rewriter.create( loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType)); - if (foldedIntraVectorOffset) { - if (isUnalignedEmulation) { - mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask, - *foldedIntraVectorOffset); - } - } else { + if (!foldedIntraVectorOffset) { mask = dynamicallyInsertSubVector( rewriter, loc, dyn_cast>(mask), emptyMask, linearizedInfo.intraDataOffset, origElements); + } else if (isUnalignedEmulation) { + mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask, + *foldedIntraVectorOffset); } Value result = rewriter.create(loc, mask, bitCast, passthru); - if (foldedIntraVectorOffset) { - if (isUnalignedEmulation) { - result = - staticallyExtractSubvector(rewriter, loc, op.getType(), result, - *foldedIntraVectorOffset, origElements); - } - } else { + if (!foldedIntraVectorOffset) { result = dynamicallyExtractSubVector( rewriter, loc, dyn_cast>(result), op.getPassThru(), linearizedInfo.intraDataOffset, origElements); + } else if (isUnalignedEmulation) { + result = + staticallyExtractSubvector(rewriter, loc, op.getType(), result, + *foldedIntraVectorOffset, origElements); } rewriter.replaceOp(op, result); @@ -696,7 +689,7 @@ struct ConvertVectorTransferRead final ? getConstantIntValue(linearizedInfo.intraDataOffset) : 0; - auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1); + int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1); auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements, scale); @@ -709,18 +702,16 @@ struct ConvertVectorTransferRead final loc, VectorType::get(numElements * scale, oldElementType), newRead); Value result = bitCast->getResult(0); - if (foldedIntraVectorOffset) { - if (isUnalignedEmulation) { - result = - staticallyExtractSubvector(rewriter, loc, op.getType(), result, - *foldedIntraVectorOffset, origElements); - } - } else { + if (!foldedIntraVectorOffset) { auto zeros = rewriter.create( loc, op.getType(), rewriter.getZeroAttr(op.getType())); result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros, linearizedInfo.intraDataOffset, origElements); + } else if (isUnalignedEmulation) { + result = + staticallyExtractSubvector(rewriter, loc, op.getType(), result, + *foldedIntraVectorOffset, origElements); } rewriter.replaceOp(op, result); From f72ac5c339de0a3ae065fb7e35f69e9e56760476 Mon Sep 17 00:00:00 2001 From: Alan Li Date: Mon, 11 Nov 2024 20:14:06 -0500 Subject: [PATCH 7/7] another update according to comments. --- .../Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 604d261b4513d..c1324e4f3a8ea 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -188,15 +188,15 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, return dest; } -/// Inserts a 1-D subvector into a 1-D `dest` vector at index `offset`. +/// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`. static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, TypedValue source, Value dest, OpFoldResult destOffsetVar, - int64_t length) { + size_t length) { assert(length > 0 && "length must be greater than 0"); Value destOffsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar); - for (int i = 0; i < length; ++i) { + for (size_t i = 0; i < length; ++i) { auto insertLoc = i == 0 ? destOffsetVal : rewriter.create(