-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR] support dynamic indexing in VectorEmulateNarrowTypes
#114169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: lialan (lialan) ChangesFull diff: https://github.com/llvm/llvm-project/pull/114169.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 1d6f8a991d9b5b..04514725c3aeee 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -149,6 +150,61 @@ static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
dest, offsets, strides);
}
+static void dynamicallyExtractElementsToVector(
+ RewriterBase &rewriter, Location loc, TypedValue<VectorType> srcVec,
+ Value destVec, OpFoldResult srcOffsetVar, int64_t loopSize) {
+ /*
+ // Create affine maps for the lower and upper bounds
+ AffineMap lowerBoundMap = AffineMap::getConstantMap(0, rewriter.getContext());
+ AffineMap upperBoundMap =
+ AffineMap::getConstantMap(loopSize, rewriter.getContext());
+
+ auto forLoop = rewriter.create<affine::AffineForOp>(
+ loc, ValueRange{}, lowerBoundMap, ValueRange{}, upperBoundMap, 1,
+ ArrayRef<Value>(destVec));
+
+ OpBuilder builder =
+ OpBuilder::atBlockEnd(forLoop.getBody(), rewriter.getListener());
+
+ auto iv = forLoop.getInductionVar();
+
+ auto loopDestVec = forLoop.getRegionIterArgs()[0];
+ auto extractLoc = builder.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), srcOffsetVar.dyn_cast<Value>(), iv);
+ auto extractElemOp = builder.create<vector::ExtractElementOp>(
+ loc, elemType, srcVec, extractLoc);
+ auto insertElemOp = builder.create<vector::InsertElementOp>(
+ loc, extractElemOp, loopDestVec, iv);
+ builder.create<affine::AffineYieldOp>(loc,
+ ValueRange{insertElemOp->getResult(0)});
+ return forLoop->getResult(0);
+ */
+ for (int i = 0; i < loopSize; ++i) {
+ Value extractLoc;
+ if (i == 0) {
+ extractLoc = srcOffsetVar.dyn_cast<Value>();
+ } else {
+ extractLoc = rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), srcOffsetVar.dyn_cast<Value>(),
+ rewriter.create<arith::ConstantIndexOp>(loc, i));
+ }
+ auto extractOp =
+ rewriter.create<vector::ExtractOp>(loc, srcVec, extractLoc);
+ rewriter.create<vector::InsertOp>(loc, extractOp, destVec, i);
+ }
+}
+
+static TypedValue<VectorType>
+emulatedVectorLoad(ConversionPatternRewriter &rewriter, Location loc,
+ Value base, OpFoldResult linearizedIndices, int64_t numBytes,
+ int64_t scale, Type oldElememtType, Type newElementType) {
+ auto newLoad = rewriter.create<vector::LoadOp>(
+ loc, VectorType::get(numBytes, newElementType), base,
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+ return rewriter.create<vector::BitCastOp>(
+ loc, VectorType::get(numBytes * scale, oldElememtType), newLoad);
+};
+
namespace {
//===----------------------------------------------------------------------===//
@@ -380,26 +436,29 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;
- if (!foldedIntraVectorOffset) {
- // unimplemented case for dynamic intra vector offset
- return failure();
- }
-
+ // always load enough elements which can cover the original elements
+ auto maxintraVectorOffset =
+ foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
auto numElements =
- llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
- auto newLoad = rewriter.create<vector::LoadOp>(
- loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
- getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
-
- Value result = rewriter.create<vector::BitCastOp>(
- loc, VectorType::get(numElements * scale, oldElementType), newLoad);
+ llvm::divideCeil(maxintraVectorOffset + origElements, scale);
+ Value result =
+ emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
+ numElements, scale, oldElementType, newElementType);
- if (isUnalignedEmulation) {
- result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ if (foldedIntraVectorOffset) {
+ if (isUnalignedEmulation) {
+ result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+ *foldedIntraVectorOffset, origElements);
+ }
+ rewriter.replaceOp(op, result);
+ } else {
+ auto resultVector = rewriter.create<arith::ConstantOp>(
+ loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+ dynamicallyExtractElementsToVector(
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
+ linearizedInfo.intraVectorOffset, origElements);
+ rewriter.replaceOp(op, resultVector);
}
-
- rewriter.replaceOp(op, result);
return success();
}
};
@@ -604,13 +663,10 @@ struct ConvertVectorTransferRead final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;
- if (!foldedIntraVectorOffset) {
- // unimplemented case for dynamic inra-vector offset
- return failure();
- }
-
+ auto maxIntraVectorOffset =
+ foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
auto numElements =
- llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
+ llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
auto newRead = rewriter.create<vector::TransferReadOp>(
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
@@ -621,9 +677,17 @@ struct ConvertVectorTransferRead final
loc, VectorType::get(numElements * scale, oldElementType), newRead);
Value result = bitCast->getResult(0);
- if (isUnalignedEmulation) {
- result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ if (foldedIntraVectorOffset) {
+ if (isUnalignedEmulation) {
+ result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+ *foldedIntraVectorOffset, origElements);
+ }
+ } else {
+ result = rewriter.create<arith::ConstantOp>(
+ loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+ dynamicallyExtractElementsToVector(rewriter, loc, bitCast, result,
+ linearizedInfo.intraVectorOffset,
+ origElements);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
new file mode 100644
index 00000000000000..a92e62538c5332
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
@@ -0,0 +1,53 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
+
+// CHECK: #map = affine_map<()[s0, s1] -> ((s0 * 3 + s1) floordiv 4)>
+// CHECK: #map1 = affine_map<()[s0, s1] -> ((s0 * 3 + s1) mod 4)>
+func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %cst = arith.constant dense<0> : vector<3x3xi2>
+ %1 = vector.load %0[%arg1, %arg2] : memref<3x3xi2>, vector<3xi2>
+ %2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
+ return %2 : vector<3x3xi2>
+}
+
+// CHECK: func @vector_load_i2
+// CHECK: %[[ALLOC:.+]]= memref.alloc() : memref<3xi8>
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
+// CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2>
+
+//-----
+
+func.func @vector_transfer_read_i2(%arg1: index, %arg2: index) -> vector<3xi2> {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0i2 = arith.constant 0 : i2
+ %1 = vector.transfer_read %0[%arg1, %arg2], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
+ return %1 : vector<3xi2>
+}
+
+// CHECK: func @vector_transfer_read_i2
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C0:.+]] = arith.extui %c0_i2 : i2 to i8
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
|
|
|
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
Outdated
Show resolved
Hide resolved
cc7b19b to
2effa6a
Compare
MaheshRavishankar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missing where there is a test for dynamic indices?
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also added more static unaligned loading tests
Please send a separate PR.
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
328a0df to
c47d30f
Compare
c47d30f to
b777a60
Compare
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the improvements. I am sending a few more suggestion, nothing major. Feel free to ignore my nits.
| /// A wrapper function for emitting `vector.extract_strided_slice`. The vector | ||
| /// has to be of 1-D shape. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Rather than saying that the vector has to be 1-D, why not say "Extracts 1-D subvector from a 1-D vector". This way, the intent becomes clearer and the requirements are implied ;-)
Just so that you don't have to guess what I had in mind:
/// A wrapper function to extract a 1-D subvector from the 1-D source vector.
Also, could you remind my why use vector.extract_strided_slice rather than vector.extract?
Btw, none of this is a blocker for this PR. These are nice-to-have improvements, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
if I understand correctly, vector.extract can extract a single element or the whole innermost dimension, but if we want to operate on a certain part of the inner most dimension then we will have to use vector.extract_strided_slice?
Too bad the docs are not super formal so this is how I read it.
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
dcaballe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, just some minor comments.
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
Co-authored-by: Han-Chung Wang <[email protected]>
Co-authored-by: Han-Chung Wang <[email protected]>
Co-authored-by: Han-Chung Wang <[email protected]>
Co-authored-by: Han-Chung Wang <[email protected]>
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks! Just final nit.
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
| // CHECK-LABEL: func @vector_load_i2_dynamic_indexing( | ||
| // CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index) -> vector<3xi2> | ||
| // CHECK: %[[ALLOC:.+]]= memref.alloc() : memref<3xi8> | ||
| // CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%[[ARG0]], %[[ARG1]]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not use #map directly. Those could be anything, the name is not semantically useful. You need to capture what #map is defined as and then make sure it is used the right way
MaheshRavishankar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mapping still seems off. It needs to pick the maps for each tests. Unblocking, but please fix before landing.
…fc4896927 Local branch amd-gfx a12fc48 Merged main:0a68171b3c67503f7143856580f1b22a93ef566e into amd-gfx:cbff18bd3aba Remote branch main ce112a7 [MLIR] support dynamic indexing in `VectorEmulateNarrowTypes` (llvm#114169)
vector.loadandvector.transfer_readops.getCompressedMaskOp.