Skip to content

Commit b5b8059

Browse files
authored
[LLVMGPUVectorDistribute] Fix vector step distribute (#19227)
Currently, the 'thread_stride' of NestedLayoutAttr is misinterpreted as the access stride of multi-dimensional vector. However, it turns out it correspond to tid -> vtid mapping and the undistributed vector is packed as : subgroup x batch x outer x thread x element where vtid is used to index 'thread' dimension. Therefore, this commit removes the usage of 'thread_stride's and 'subgroups_stride' when calculating the base constant offset and rather obtain them from packed undistributed vector shape. Signed-off-by: Manupa Karunaratne <[email protected]>
1 parent 1aada43 commit b5b8059

File tree

4 files changed

+50
-21
lines changed

4 files changed

+50
-21
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -990,21 +990,27 @@ struct DistributeStep final : OpDistributionPattern<vector::StepOp> {
990990
return lens;
991991
}
992992

993+
// This is a helper to extract strides from a given shape
994+
// E.g. : a shape of 2x3x4 will return strides [12, 4, 1]
995+
SmallVector<int64_t> getStrides(ArrayRef<int64_t> shape) const {
996+
int64_t elementCount = ShapedType::getNumElements(shape);
997+
SmallVector<int64_t> strides;
998+
int64_t currStride = elementCount;
999+
for (int64_t len : shape) {
1000+
currStride = currStride / len;
1001+
strides.push_back(currStride);
1002+
}
1003+
return strides;
1004+
}
1005+
9931006
// Once we are in the realm of remaining dimensions,
9941007
// the strides are not packed. This is a helper to
9951008
// obtain the packed strides of the remaining dimensions.
9961009
// (See above for an example of remaining dimensions under
9971010
// getRemainingDims)
9981011
SmallVector<int64_t> getPackedStrides(ArrayRef<DimInfo> dims) const {
9991012
SmallVector<int64_t> lens = getLens(dims);
1000-
int64_t elementCount = ShapedType::getNumElements(lens);
1001-
SmallVector<int64_t> packedStrides;
1002-
int64_t currStride = elementCount;
1003-
for (int64_t len : lens) {
1004-
currStride = currStride / len;
1005-
packedStrides.push_back(currStride);
1006-
}
1007-
return packedStrides;
1013+
return getStrides(lens);
10081014
}
10091015

10101016
// This function emulates the slicing of otherwise large constant
@@ -1091,9 +1097,14 @@ struct DistributeStep final : OpDistributionPattern<vector::StepOp> {
10911097
SmallVector<Value> subgroupIndices, threadIndices;
10921098
populateWarpAndThreadIndices(rewriter, threadId, subgroupSize, resultLayout,
10931099
subgroupIndices, threadIndices);
1094-
ArrayRef<int64_t> subgroupStrides = resultLayout.getSubgroupStrides();
1100+
1101+
SmallVector<int64_t> undistributedShape =
1102+
resultLayout.getUndistributedPackedShape();
1103+
SmallVector<int64_t> undistributedStrides = getStrides(undistributedShape);
1104+
constexpr int64_t subgroupIdx = 0;
1105+
constexpr int64_t threadIdx = 3;
1106+
10951107
ArrayRef<int64_t> subgroupLengths = resultLayout.getSubgroupTile();
1096-
ArrayRef<int64_t> threadStrides = resultLayout.getThreadStrides();
10971108
ArrayRef<int64_t> threadLengths = resultLayout.getThreadTile();
10981109
// Step op by definition should be single dimensional.
10991110
SmallVector<int64_t> distributedShape =
@@ -1102,8 +1113,9 @@ struct DistributeStep final : OpDistributionPattern<vector::StepOp> {
11021113
int64_t distributedElements = ShapedType::getNumElements(distributedShape);
11031114
int64_t originalElements = result.getType().getNumElements();
11041115
SmallVector<DimInfo, 2> distributedDims{
1105-
{subgroupIndices[0], subgroupLengths[0], subgroupStrides[0]},
1106-
{threadIndices[0], threadLengths[0], threadStrides[0]}};
1116+
{subgroupIndices[0], subgroupLengths[0],
1117+
undistributedStrides[subgroupIdx]},
1118+
{threadIndices[0], threadLengths[0], undistributedStrides[threadIdx]}};
11071119
llvm::sort(distributedDims, [](const DimInfo &lhs, const DimInfo &rhs) {
11081120
return lhs.dimStride > rhs.dimStride;
11091121
});

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_step.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,10 @@ builtin.module attributes { transform.with_named_sequence } {
2626
}
2727

2828
// CHECK-LABEL: func @step_1
29-
// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
29+
// CHECK: %[[CST:.+]] = arith.constant dense<[0, 4, 8, 12]> : vector<4xindex>
3030
// CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 16) mod 4)>()[%thread_id_x]
31-
// CHECK: %[[TID_STRIDE:.+]] = arith.muli %[[TID]], %c16 : index
32-
// CHECK: %[[TID_STRIDEV:.+]] = vector.broadcast %[[TID_STRIDE]] : index to vector<4xindex>
33-
// CHECK: %[[OFFSET:.+]] = arith.addi %[[TID_STRIDEV]], %[[CST]] : vector<4xindex>
31+
// CHECK: %[[TIDB:.+]] = vector.broadcast %[[TID]] : index to vector<4xindex>
32+
// CHECK: %[[OFFSET:.+]] = arith.addi %[[TIDB]], %[[CST]] : vector<4xindex>
3433

3534
// -----
3635

@@ -94,10 +93,10 @@ builtin.module attributes { transform.with_named_sequence } {
9493
}
9594

9695
// CHECK-LABEL: func @step_3
97-
// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 24, 25]> : vector<4xindex>
96+
// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 8, 9]> : vector<4xindex>
9897
// CHECK: %[[WID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 512) mod 3)>()[%thread_id_x]
9998
// CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 2) mod 4)>()[%thread_id_x]
100-
// CHECK: %[[WID_STRIDE:.+]] = arith.muli %[[WID]], %c8 : index
99+
// CHECK: %[[WID_STRIDE:.+]] = arith.muli %[[WID]], %c16 : index
101100
// CHECK: %[[WID_STRIDEV:.+]] = vector.broadcast %[[WID_STRIDE]] : index to vector<4xindex>
102101
// CHECK: %[[OFFSET0:.+]] = arith.addi %[[WID_STRIDEV]], %[[CST]] : vector<4xindex>
103102
// CHECK: %[[TID_STRIDE:.+]] = arith.muli %[[TID]], %c2 : index
@@ -132,7 +131,8 @@ builtin.module attributes { transform.with_named_sequence } {
132131
}
133132

134133
// CHECK-LABEL: func @step_4
135-
// CHECK: %[[CST:.+]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112]> : vector<8xindex>
134+
// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
136135
// CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 16)>()[%thread_id_x]
137-
// CHECK: %[[TIDV:.+]] = vector.broadcast %[[TID]] : index to vector<8xindex>
138-
// CHECK: %[[OFFSET:.+]] = arith.addi %[[TIDV]], %[[CST]] : vector<8xindex>
136+
// CHECK: %[[TID_STRIDE:.+]] = arith.muli %[[TID]], %c8 : index
137+
// CHECK: %[[TID_STRIDEV:.+]] = vector.broadcast %[[TID_STRIDE]] : index to vector<8xindex>
138+
// CHECK: %[[OFFSET:.+]] = arith.addi %[[TID_STRIDEV]], %[[CST]] : vector<8xindex>

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,20 @@ SmallVector<int64_t> NestedLayoutAttr::getDistributedShape() const {
311311
return shape;
312312
}
313313

314+
/// Before we distribute, we would like to see this as:
315+
/// <SUBGROUP x BATCH x OUTER x THREAD x ELEMENT>
316+
SmallVector<int64_t> NestedLayoutAttr::getUndistributedPackedShape() const {
317+
SmallVector<int64_t> shape;
318+
int64_t rank = getRank();
319+
shape.reserve(rank * 5);
320+
shape.append(getSubgroupTile().begin(), getSubgroupTile().end());
321+
shape.append(getBatchTile().begin(), getBatchTile().end());
322+
shape.append(getOuterTile().begin(), getOuterTile().end());
323+
shape.append(getThreadTile().begin(), getThreadTile().end());
324+
shape.append(getElementTile().begin(), getElementTile().end());
325+
return shape;
326+
}
327+
314328
// Gets the rank of the undistributed vector for this layout.
315329
int64_t NestedLayoutAttr::getRank() const {
316330
// The layout requires that all size lists are the same length and match

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,9 @@ def NestedLayoutAttr : IREEVectorExt_Attr<"NestedLayout",
292292
// Returns the subgroup/lane ids delinearized from a single linearized
293293
// thread ID.
294294
SmallVector<Value> computeThreadIds(Value threadId, int64_t subgroupSize, RewriterBase &rewriter) const;
295+
296+
// Get the undistributed shape that is subgroup x batch x outer x thread x element
297+
SmallVector<int64_t> getUndistributedPackedShape() const;
295298
}];
296299

297300
let genVerifyDecl = 1;

0 commit comments

Comments
 (0)