-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[MLIR][XeGPU] Matrix load/store subgroup distribution #165008
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
887f978
f80ee32
b4f5a4d
3c4a5aa
5965b54
246761e
c99294a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,47 +38,52 @@ void XeGPUDialect::initialize() { | |
| >(); | ||
| } | ||
|
|
||
| /// Generates instructions to compute offsets for a subgroup identified by | ||
| /// its multidimensional indices (sgId), using the specified subgroup layout | ||
| /// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data | ||
| /// dimensions (sizePerWg). | ||
| // A `srcShape` consists of N distribution units, each being `subShapesLayout` x | ||
akroviakov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // `subShape`. A `delinearizedId` is used to identify a particular `subShape` | ||
| // within each distribution unit. | ||
| // Example: | ||
| // WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a | ||
| // distribution unit of shape 64x64, we have 2x4 such distribution units. | ||
| // `delinearizedId` is used to identify a 16x32 of a subgroup in each | ||
| // distribution unit. | ||
| static SmallVector<SmallVector<Value>> | ||
| genOffsetsComputingInsts(OpBuilder &builder, Location loc, | ||
| SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout, | ||
| ArrayRef<int64_t> sizePerSg, | ||
| ArrayRef<int64_t> sizePerWg) { | ||
|
|
||
| genOffsets(OpBuilder &builder, Location loc, SmallVector<Value> delinearizedId, | ||
|
||
| ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape, | ||
| ArrayRef<int64_t> srcShape) { | ||
| SmallVector<SmallVector<Value>> offsets; | ||
|
|
||
| // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i] | ||
| SmallVector<Value> localOffsets = llvm::map_to_vector( | ||
| llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value { | ||
| // A distribution unit must be less than or equal to `srcShape` | ||
| SmallVector<int64_t> distUnitShape = llvm::map_to_vector( | ||
| llvm::zip_equal(srcShape, | ||
| computeElementwiseMul(subShapesLayout, subShape)), | ||
| [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); }); | ||
|
|
||
| // Get the offset of `subShape` within a distribution unit. | ||
| SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector( | ||
| llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value { | ||
| return builder.createOrFold<index::MulOp>( | ||
| loc, std::get<0>(t), | ||
| builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t))); | ||
| }); | ||
|
|
||
| // distUnit[i] is the minimum value between sizePerWg[i] and | ||
| // sgLayout[i] * sizePerSg[i] | ||
| SmallVector<int64_t> distUnit = llvm::map_to_vector( | ||
| llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)), | ||
| [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); }); | ||
|
|
||
| // For each dist unit | ||
| for (SmallVector<int64_t> unitOffs : | ||
| StaticTileOffsetRange(sizePerWg, distUnit)) { | ||
| StaticTileOffsetRange(srcShape, distUnitShape)) { | ||
| // Get dist unit offset within `srcShape`. | ||
| SmallVector<Value> base = | ||
| llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value { | ||
| return arith::ConstantIndexOp::create(builder, loc, d); | ||
| }); | ||
|
|
||
| SmallVector<Value> adds = llvm::map_to_vector( | ||
| llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value { | ||
| return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t), | ||
| std::get<1>(t)); | ||
| }); | ||
|
|
||
| // Calculate `subShape` offset within `srcShape`. | ||
| SmallVector<Value> adds = | ||
| llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset), | ||
| [&](const auto &t) -> Value { | ||
| return builder.createOrFold<arith::AddIOp>( | ||
| loc, std::get<0>(t), std::get<1>(t)); | ||
| }); | ||
| // Do not go beyond `srcShape` bounds. | ||
| SmallVector<Value> mods = llvm::map_to_vector( | ||
| llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value { | ||
| llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value { | ||
| return builder.createOrFold<index::RemUOp>( | ||
| loc, std::get<0>(t), | ||
| arith::ConstantIndexOp::create(builder, loc, std::get<1>(t))); | ||
|
|
@@ -268,12 +273,7 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, | |
| } | ||
|
|
||
| FailureOr<SmallVector<Value>> | ||
| LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, | ||
| Value linearId) { | ||
| // delinearizeSubgroupId is only available for | ||
| // workgroup-level layout attribute | ||
| if (!isForWorkgroup()) | ||
| return failure(); | ||
| LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) { | ||
|
|
||
| // TODO: handle order attribute | ||
| auto hasDefaultOrder = [&]() { | ||
|
|
@@ -283,41 +283,52 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, | |
| }; | ||
| if (!hasDefaultOrder()) | ||
| return mlir::emitError(loc, "order attribute is currently not supported."); | ||
|
|
||
| auto dims = | ||
| llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value { | ||
| return builder.createOrFold<arith::ConstantIndexOp>(loc, d); | ||
| }); | ||
| SmallVector<int64_t> layout; | ||
| if (isForWorkgroup()) { | ||
| layout = getEffectiveSgLayoutAsInt(); | ||
| } else if (isForSubgroup()) { | ||
| layout = getEffectiveLaneLayoutAsInt(); | ||
| } else { | ||
| return failure(); | ||
| } | ||
| auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value { | ||
| return builder.createOrFold<arith::ConstantIndexOp>(loc, d); | ||
| }); | ||
|
|
||
| return affine::delinearizeIndex(builder, loc, linearId, dims); | ||
| } | ||
|
|
||
| /// Implements DistributeLayoutAttr::getOffsets to generate | ||
| /// Implements DistributeLayoutAttr::computeDistributedOffsets to generate | ||
| /// instructions for computing multi-dimensional offsets when distributed by | ||
| /// LayoutAttr. | ||
| FailureOr<SmallVector<SmallVector<Value>>> | ||
| LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, | ||
| ArrayRef<int64_t> shape) { | ||
| if (!isForWorkgroup()) | ||
| LayoutAttr::computeDistributedOffsets(OpBuilder &builder, Location loc, | ||
| Value linearId, ArrayRef<int64_t> shape) { | ||
| SmallVector<int64_t> layout; | ||
| SmallVector<int64_t> subShape; | ||
| if (isForWorkgroup()) { | ||
| layout = getEffectiveSgLayoutAsInt(); | ||
| subShape = getEffectiveSgDataAsInt(); | ||
| } else if (isForSubgroup()) { | ||
| layout = getEffectiveLaneLayoutAsInt(); | ||
| subShape = getEffectiveLaneDataAsInt(); | ||
| } else { | ||
| return failure(); | ||
|
|
||
| SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt(); | ||
| SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt(); | ||
| if (sgShape.empty()) { | ||
| if (auto derivedShape = computeShapeRatio(shape, sgLayout)) | ||
| sgShape = derivedShape.value(); | ||
| } | ||
| if (subShape.empty()) { | ||
| if (auto derivedShape = computeShapeRatio(shape, layout)) | ||
| subShape = derivedShape.value(); | ||
| else | ||
| return failure(); | ||
| } | ||
|
|
||
| // delinearize Ids | ||
| auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); | ||
| auto maybeIds = delinearizeId(builder, loc, linearId); | ||
| if (failed(maybeIds)) | ||
| return failure(); | ||
| SmallVector<Value> sgIds = *maybeIds; | ||
| SmallVector<Value> ids = *maybeIds; | ||
|
|
||
| return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, | ||
| shape); | ||
| return genOffsets(builder, loc, ids, layout, subShape, shape); | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
|
|
@@ -371,34 +382,43 @@ SliceAttr SliceAttr::flatten() const { | |
| } | ||
|
|
||
| FailureOr<SmallVector<Value>> | ||
| SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, | ||
| Value linearId) { | ||
| SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) { | ||
| SliceAttr attr = flatten(); | ||
| auto parent = dyn_cast<LayoutAttr>(attr.getParent()); | ||
| return parent.delinearizeSubgroupId(builder, loc, linearId); | ||
| return parent.delinearizeId(builder, loc, linearId); | ||
| } | ||
|
|
||
| /// Implements DistributeLayoutAttr::getOffsets to generate | ||
| /// instructions for computing multi-dimensional offsets when distributed by | ||
| /// SliceAttr. | ||
| // Implements DistributeLayoutAttr::computeDistributedOffsets to generate | ||
| // instructions for computing multi-dimensional offsets when distributed by | ||
| // LayoutAttr. | ||
| FailureOr<SmallVector<SmallVector<Value>>> | ||
| SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, | ||
| ArrayRef<int64_t> shape) { | ||
| SliceAttr::computeDistributedOffsets(OpBuilder &builder, Location loc, | ||
| Value linearId, ArrayRef<int64_t> shape) { | ||
| assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape."); | ||
| if (!isForWorkgroup()) | ||
| return failure(); | ||
|
|
||
| SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt(); | ||
| SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt(); | ||
| if (sgShape.empty()) { | ||
| if (auto derivedShape = computeShapeRatio(shape, sgLayout)) | ||
| sgShape = derivedShape.value(); | ||
| SmallVector<int64_t> layout; | ||
| SmallVector<int64_t> subShape; | ||
| if (isForWorkgroup()) { | ||
| layout = getEffectiveSgLayoutAsInt(); | ||
| subShape = getEffectiveSgDataAsInt(); | ||
| } else if (isForSubgroup()) { | ||
| layout = getEffectiveLaneLayoutAsInt(); | ||
| subShape = getEffectiveLaneDataAsInt(); | ||
| } else { | ||
| return failure(); | ||
| } | ||
|
|
||
| if (subShape.empty()) { | ||
| if (auto derivedShape = computeShapeRatio(shape, layout)) | ||
| subShape = derivedShape.value(); | ||
| else | ||
| return failure(); | ||
| } | ||
|
|
||
| // delinearize Ids | ||
| auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); | ||
| auto maybeIds = delinearizeId(builder, loc, linearId); | ||
| if (failed(maybeIds)) | ||
| return failure(); | ||
|
|
||
|
|
@@ -408,8 +428,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, | |
| SmallVector<Value> sgIds = | ||
| XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims); | ||
|
|
||
| return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, | ||
| shape); | ||
| return genOffsets(builder, loc, sgIds, layout, subShape, shape); | ||
| } | ||
|
|
||
| bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a verification here for load/store_matrix with @subgroup_block_io attribute: The payload must be contiguous in the memory.
Both of these two IRs in the tests added in this PR are actually not correct. Since the payload data are not contiguous between lanes. They are correct if you change the vector<2x16xf32> to <16x2xf32> (lane_layout/lane_data need to change accordingly but that is out of IR verifier's scope).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added verification. However,
I understand the logical reasoning for this in the matrix ops case, but the current distribution does not allow it, considering the "correct" lane layout the block load requires.
We have
Meaning that given
lane_layout = [1, 16], lane_data = [1, 1]and a16x2data shape, we getWe can change the layout to be
[16, 1], which would allow the pattern to complete and the distributed code to still be correct, since the lane layout is not used in further coordinate calculations. But[16, 1]may be harder for users to reason about by simply looking at the xevm block load description and the sg-levelsubgroup_block_iomatrix op.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If user uses stride=[1, 32] in the memory layout, then user should able to reason sg_layout = [16, 1].
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if user use lane_layout = [1, 16], it should not use strided memory layout, the example above should just use block layout. The maxtrix op with subgroup_block_io is a subgroup operation, and all lanes collectively access a contiguous memory buffer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is now a test with a
16x2xf32result using the proper stride.Short snippet:
It distributes to
1x2xf32