-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR][XeGPU] Scattered ops sg-to-wi distribution #154949
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][XeGPU] Scattered ops sg-to-wi distribution #154949
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Artem Kroviakov (akroviakov) ChangesThis PR adds distribution patterns for scattered load and store ops, chunk size included. XeGPU moves toward offsets being part of the load/store ops, so the pass only supports this case. Manipulating a vector of offsets indirectly through create_tdesc is complex and soon to become obsolete anyway. Full diff: https://github.com/llvm/llvm-project/pull/154949.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 2088c3c7fc5ec..a1e5855aed264 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -811,6 +811,136 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
}
};
+struct StoreDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ auto yield = cast<gpu::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ Operation *lastNode = yield->getPrevNode();
+ auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
+ if (!storeScatterOp)
+ return failure();
+ else if (!storeScatterOp.getOffsets())
+ return rewriter.notifyMatchFailure(storeScatterOp,
+ "Store op must have offsets argument");
+ else if (cast<VectorType>(storeScatterOp.getOffsets().getType())
+ .getRank() != 1)
+ return rewriter.notifyMatchFailure(storeScatterOp,
+ "Expected 1D offsets vector");
+
+ VectorType storeVecTy =
+ cast<VectorType>(storeScatterOp.getValue().getType());
+ assert(storeVecTy.getRank() <= 2 &&
+ "Expected at most 2D result at SG level");
+ VectorType distStoreVecTy;
+ if (storeVecTy.getRank() == 2)
+ distStoreVecTy = VectorType::Builder(storeVecTy).dropDim(0);
+ else // rank 1
+ distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
+
+ SmallVector<size_t> newRetIndices;
+ SmallVector<Value> operands =
+ llvm::to_vector_of<Value>(storeScatterOp->getOperands());
+ SmallVector<Type> operandTypes =
+ llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes());
+ operandTypes[0] = distStoreVecTy;
+
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
+ SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ Value offsetsVec = newStoreScatterOpOperands[2];
+ Value maskVec = newStoreScatterOpOperands[3];
+
+ auto loc = newWarpOp.getLoc();
+ Value laneId = warpOp.getLaneid();
+ rewriter.setInsertionPointAfter(newWarpOp);
+ Value laneOffset =
+ vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
+ laneOffset = vector::BroadcastOp::create(
+ rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
+ Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
+ laneMask = vector::BroadcastOp::create(
+ rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
+ newStoreScatterOpOperands[2] = laneOffset;
+ newStoreScatterOpOperands[3] = laneMask;
+
+ xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
+ rewriter, loc, TypeRange{}, newStoreScatterOpOperands,
+ storeScatterOp->getAttrs());
+ xegpu::removeLayoutAttrs(newOp);
+ rewriter.eraseOp(storeScatterOp);
+ return success();
+ }
+};
+
+struct LoadDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand = getWarpResult(warpOp, [&](Operation *op) {
+ if (!isa<xegpu::LoadGatherOp>(op))
+ return false;
+ auto yield = cast<gpu::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ return yield->getPrevNode() == op;
+ });
+ if (!yieldOperand)
+ return rewriter.notifyMatchFailure(
+ warpOp, "warp result is not a xegpu::LoadGatherOp op");
+
+ auto loadGatherOp =
+ yieldOperand->get().getDefiningOp<xegpu::LoadGatherOp>();
+ if (!loadGatherOp.getOffsets())
+ return rewriter.notifyMatchFailure(loadGatherOp,
+ "Load op must have offsets argument");
+ else if (cast<VectorType>(loadGatherOp.getOffsets().getType()).getRank() !=
+ 1)
+ return rewriter.notifyMatchFailure(loadGatherOp,
+ "Expected 1D offsets vector");
+
+ SmallVector<size_t> newRetIndices;
+ SmallVector<Value> operands =
+ llvm::to_vector_of<Value>(loadGatherOp->getOperands());
+ SmallVector<Type> operandTypes =
+ llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes());
+
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
+
+ SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ const unsigned operandIdx = yieldOperand->getOperandNumber();
+ VectorType loadVecTy =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ assert(loadVecTy.getRank() == 1 && "Expected a distributed vector");
+
+ Value offsetsVec = newLoadGatherOperands[1];
+ Value maskVec = newLoadGatherOperands[2];
+ auto loc = newWarpOp.getLoc();
+ Value laneId = warpOp.getLaneid();
+ rewriter.setInsertionPointAfter(newWarpOp);
+ Value laneOffset =
+ vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
+ laneOffset = vector::BroadcastOp::create(
+ rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
+ Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
+ laneMask = vector::BroadcastOp::create(
+ rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
+ newLoadGatherOperands[1] = laneOffset;
+ newLoadGatherOperands[2] = laneMask;
+
+ xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
+ loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs());
+ Value distributedVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
+ return success();
+ }
+};
+
} // namespace
namespace {
@@ -823,10 +953,11 @@ struct XeGPUSubgroupDistributePass final
void xegpu::populateXeGPUSubgroupDistributePatterns(
RewritePatternSet &patterns) {
- patterns.add<CreateNdDescDistribution, StoreNdDistribution,
- LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
- UpdateNdOffsetDistribution, GpuBarrierDistribution>(
- patterns.getContext());
+ patterns
+ .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
+ DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
+ GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
+ patterns.getContext());
}
void XeGPUSubgroupDistributePass::runOnOperation() {
@@ -841,6 +972,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
if (!isa<VectorType>(operand.get().getType()))
continue;
+ if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
+ continue;
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
if (!layout) {
op->emitError("Could not find layout attribute for operand ")
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 54ef56e013abb..a4757dd132024 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -302,20 +302,43 @@ gpu.module @test {
}
// -----
-// CHECK-LABEL: gpu.func @gpu_barrier({{.*}}) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
-// CHECK-NEXT: %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<16xf16> -> vector<1xf16>
-// CHECK-NEXT: gpu.barrier
-// CHECK-NEXT: %[[T2:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
-// CHECK-NEXT: xegpu.store_nd %[[T1]], %[[T2]] : vector<1xf16>, !xegpu.tensor_desc<16xf16>
+// CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
+// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
gpu.module @test {
- gpu.func @gpu_barrier(%arg0: memref<256xf16>, %arg1: memref<256xf16>) {
- %c0 = arith.constant 0 : index
- %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- %1 = xegpu.load_nd %0 {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf16>
- gpu.barrier
- %2 = xegpu.create_nd_tdesc %arg1[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- xegpu.store_nd %1, %2 : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+ gpu.func @scatter_ops_chunksize(%src: memref<256xf16>, %offset1: vector<16xindex>) {
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %3 = xegpu.load %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+ xegpu.store %3, %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
+// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.module @test {
+ gpu.func @scatter_ops(%src: memref<256xf16>, %offset1: vector<16xindex>) {
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %3 = xegpu.load %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+ xegpu.store %3, %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
gpu.return
}
}
|
Pinging @charithaintc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a first pass.
generally I don't agree with the code sequence extract[laneid] -> broadcast.
Instead I think offsets and masks must be distributed. We should discuss this.
general comments:
- No need of casts for typed values.
- please pay attention to variable names.
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, | ||
PatternRewriter &rewriter) const override { | ||
auto yield = cast<gpu::YieldOp>( | ||
warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getTerminator
helper was added recently to WarpOp, please use it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using getTerminator
now.
auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode); | ||
if (!storeScatterOp) | ||
return failure(); | ||
else if (!storeScatterOp.getOffsets()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: no need of else
here. can simply use a new if.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
"Expected 1D offsets vector"); | ||
|
||
VectorType storeVecTy = | ||
cast<VectorType>(storeScatterOp.getValue().getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cast<VectorType>(storeScatterOp.getValue().getType()); | |
storeScatterOp.getValueType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
else if (!storeScatterOp.getOffsets()) | ||
return rewriter.notifyMatchFailure(storeScatterOp, | ||
"Store op must have offsets argument"); | ||
else if (cast<VectorType>(storeScatterOp.getOffsets().getType()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like getOffsets
return a typed value. So no need of casts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Offsets can now also be scalar, a cast is needed, added a requirement for vector input.
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, | ||
PatternRewriter &rewriter) const override { | ||
OpOperand *yieldOperand = getWarpResult(warpOp, [&](Operation *op) { | ||
if (!isa<xegpu::LoadGatherOp>(op)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add explanation why this check is needed (memory ordering preservation)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
if (!loadGatherOp.getOffsets()) | ||
return rewriter.notifyMatchFailure(loadGatherOp, | ||
"Load op must have offsets argument"); | ||
else if (cast<VectorType>(loadGatherOp.getOffsets().getType()).getRank() != |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: drop else
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need of casts for typed values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Offsets can now also be scalar, a cast is needed, added a requirement for vector input.
SmallVector<size_t> newRetIndices; | ||
SmallVector<Value> operands = | ||
llvm::to_vector_of<Value>(loadGatherOp->getOperands()); | ||
SmallVector<Type> operandTypes = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use a better name. check other patterns. Also same comment as above. I think offsets and masks must be distributed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed
Value laneOffset = | ||
vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId); | ||
laneOffset = vector::BroadcastOp::create( | ||
rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset); | ||
Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId); | ||
laneMask = vector::BroadcastOp::create( | ||
rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask); | ||
newLoadGatherOperands[1] = laneOffset; | ||
newLoadGatherOperands[2] = laneMask; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure about this code sequence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed
This is the main difference between scattered ops and nd ops.
The offsets are the layout, and they are not necessarily linear (w.r.t. lane id) or compile time defined. The documentation does not prevent me from supplying a completely unstructured vector of offsets (e.g.,
Therefore, we cannot "distribute" such vector based on The same applies to the mask, one can supply any random vector of The offsets/mask vectors are not SG-uniform, they are allowed to be unstructured, and they can be completely runtime defined. What is distribution supposed to do with them at compile time, in your opinion? |
566cb7b
to
a081c8d
Compare
When I say "offsets are distributed" it does not mean we have to describe them as some affine function of laneID. I meant is the vector<16xindex> will become vector<1xindex>. And then each lane can extract the scalar value from this <1xindex> vector. Let me give an example. Before.
After SIMT distribution.
Can you please explain why such strategy would not work? If instead if we broadcast the offsets, we are wasting a lot of registers plus broadcasting need cross-lane comm. Also, upstream already have patterns to distribute the constants (i.e. elementwise ops). So you don't have to do anything there. |
That is the point, the offsets are not
They can be arbitrary, how does the proposed distribution work with a vector of UPD: I see, let me think about it |
I feel such random offsets are more likely to come as a memory buffer and not likely to be static at SG level. |
But should we care about the vector producer? Arith distribution is there, but there are more ways to create a vector, we could even receive it as an argument from the runtime. Do I miss the op or distribution logic constraint that the offsets producer must be retrievable? |
"even receive it as an argument from the runtime." what does this mean? a func argument? AFAIK, from the load gather distribution perspective, only thing we need to care about is what are the distributed types for the base, offsets and masks. Everything else should ideally be handled by the framework. I suggest, testing it with distributed type as <1xindex> and see what the framework does. |
Cases where sg distribution has no access to I agree that extracting at idx 0 would work if the input is already distributed, but we go bottom up, and we cannot assume that it will be distributed. |
I did some quick testing. my conclusion is that we don't have to care about how the offset is defined. It will be taken care by the framework (unless it is produced by some op that is not supported, in which case we need to add support). Example 1: (Trivially distributable)
To
Example 2 (complicated case).
To
I agree, for the complex case broadcasting is needed indeed. But I guess this is outside the scope of gather/scatter distribution. It should not care about it. |
I get your point. But I think gather scatter logic should not care about this. It should simply assume this is always distributable. Maybe we should wait for @Jianhui-Li's input also :-) |
This solves the major issue. In your examples, how does the distribution pattern decide whether to use |
This is done by WarpOpElementwise, If the compile can prove the value is uniform it will uniformly distribute the vector. If not the values can not be distributed. So |
Final result (after lowering remaining warpop) :
As you can see if the offsets are "weird" it will go though SLM. My point is gather/scatter distribution should not care about it. It is separation of concerns. |
Thanks @charithaintc for the examples and clarifications. I update the patterns to distribute the offsets and mask. The |
Will address the rest of the feedback in a separate commit |
20c04fc
to
935b896
Compare
935b896
to
4971e49
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
cast<VectorType>(storeScatterOp.getValue().getType()); | ||
assert(storeVecTy.getRank() <= 2 && | ||
"Expected at most 2D result at SG level"); | ||
VectorType distStoreVecTy; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that the layout is not useful here. But it is better to keep this logic in a single place. This also ensures that the layout assigned to offsets (byt propagation logic) is indeed correct.
4971e49
to
daa143f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generally looks good. better if you can address the comments.
f16e614
to
bcc9d85
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
I think this may have broken a buildbot (https://lab.llvm.org/buildbot/#/builders/55/builds/16630/steps/11/logs/stdio). Could you please take a look?
|
Reverts #154949 due to suspected buildbot breakage (https://lab.llvm.org/buildbot/#/builders/55/builds/16630/steps/11/logs/stdio). Previously commented on the original pull request: #154949 (comment) ``` ******************** TEST 'MLIR :: Dialect/XeGPU/subgroup-distribute.mlir' FAILED ******************** ... # | PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace. # | Stack dump: # | 0. Program arguments: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/mlir-opt -xegpu-subgroup-distribute -allow-unregistered-dialect -canonicalize -cse -split-input-file /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir # | #0 0x0000c0af4b066df0 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/lib/Support/Unix/Signals.inc:834:13 # | #1 0x0000c0af4b060e20 llvm::sys::RunSignalHandlers() /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/lib/Support/Signals.cpp:105:18 # | #2 0x0000c0af4b0691b4 SignalHandler(int, siginfo_t*, void*) /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/lib/Support/Unix/Signals.inc:426:38 # | #3 0x0000ee25a3dcb8f8 (linux-vdso.so.1+0x8f8) # | #4 0x0000ee25a36c7608 (/lib/aarch64-linux-gnu/libc.so.6+0x87608) # | #5 0x0000ee25a367cb3c raise (/lib/aarch64-linux-gnu/libc.so.6+0x3cb3c) # | #6 0x0000ee25a3667e00 abort (/lib/aarch64-linux-gnu/libc.so.6+0x27e00) # | #7 0x0000c0af4ae7e4b0 __sanitizer::Atexit(void (*)()) /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/sanitizer_common/sanitizer_posix_libcdep.cpp:168:10 # | #8 0x0000c0af4ae7c354 __sanitizer::Die() /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/sanitizer_common/sanitizer_termination.cpp:52:5 # | #9 0x0000c0af4ae66a30 Unlock /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/hwasan/../sanitizer_common/sanitizer_mutex.h:250:16 # | #10 0x0000c0af4ae66a30 ~GenericScopedLock /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/hwasan/../sanitizer_common/sanitizer_mutex.h:386:51 # | #11 0x0000c0af4ae66a30 __hwasan::ScopedReport::~ScopedReport() /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/hwasan/hwasan_report.cpp:54:5 # | #12 0x0000c0af4ae661b8 __hwasan::(anonymous namespace)::BaseReport::~BaseReport() /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/hwasan/hwasan_report.cpp:477:7 # | #13 0x0000c0af4ae63f5c __hwasan::ReportTagMismatch(__sanitizer::StackTrace*, unsigned long, unsigned long, bool, bool, unsigned long*) /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/hwasan/hwasan_report.cpp:1094:1 # | #14 0x0000c0af4ae4f8e0 Destroy /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/hwasan/../sanitizer_common/sanitizer_common.h:532:31 # | #15 0x0000c0af4ae4f8e0 ~InternalMmapVector /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/hwasan/../sanitizer_common/sanitizer_common.h:642:56 # | #16 0x0000c0af4ae4f8e0 __hwasan::HandleTagMismatch(__hwasan::AccessInfo, unsigned long, unsigned long, void*, unsigned long*) /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/hwasan/hwasan.cpp:245:1 # | #17 0x0000c0af4ae51e8c __hwasan_tag_mismatch4 /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/hwasan/hwasan.cpp:764:1 # | #18 0x0000c0af4ae67b30 __interception::InterceptFunction(char const*, unsigned long*, unsigned long, unsigned long) /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/interception/interception_linux.cpp:60:0 # | #19 0x0000c0af5641cd24 getNumResults /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/mlir/include/mlir/IR/Operation.h:404:37 # | #20 0x0000c0af5641cd24 getOpResultImpl /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/mlir/include/mlir/IR/Operation.h:1010:5 # | #21 0x0000c0af5641cd24 getResult /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/mlir/include/mlir/IR/Operation.h:407:54 # | #22 0x0000c0af5641cd24 mlir::OpTrait::detail::MultiResultTraitBase<mlir::gpu::WarpExecuteOnLane0Op, mlir::OpTrait::VariadicResults>::getResult(unsigned int) /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/mlir/include/mlir/IR/OpDefinition.h:638:62 # | #23 0x0000c0af56426b60 getType /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/mlir/include/mlir/IR/Value.h:63:33 # | #24 0x0000c0af56426b60 getType /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/mlir/include/mlir/IR/Value.h:105:39 # | #25 0x0000c0af56426b60 (anonymous namespace)::LoadDistribution::matchAndRewrite(mlir::gpu::WarpExecuteOnLane0Op, mlir::PatternRewriter&) const /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp:991:55 ... ```
… (#156761) Reverts llvm/llvm-project#154949 due to suspected buildbot breakage (https://lab.llvm.org/buildbot/#/builders/55/builds/16630/steps/11/logs/stdio). Previously commented on the original pull request: llvm/llvm-project#154949 (comment) ``` ******************** TEST 'MLIR :: Dialect/XeGPU/subgroup-distribute.mlir' FAILED ******************** ... # | PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace. # | Stack dump: # | 0. Program arguments: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/mlir-opt -xegpu-subgroup-distribute -allow-unregistered-dialect -canonicalize -cse -split-input-file /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir # | #0 0x0000c0af4b066df0 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/lib/Support/Unix/Signals.inc:834:13 # | #1 0x0000c0af4b060e20 llvm::sys::RunSignalHandlers() /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/lib/Support/Signals.cpp:105:18 # | #2 0x0000c0af4b0691b4 SignalHandler(int, siginfo_t*, void*) /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/lib/Support/Unix/Signals.inc:426:38 # | #3 0x0000ee25a3dcb8f8 (linux-vdso.so.1+0x8f8) # | #4 0x0000ee25a36c7608 (/lib/aarch64-linux-gnu/libc.so.6+0x87608) # | #5 0x0000ee25a367cb3c raise (/lib/aarch64-linux-gnu/libc.so.6+0x3cb3c) # | #6 0x0000ee25a3667e00 abort (/lib/aarch64-linux-gnu/libc.so.6+0x27e00) # | #7 0x0000c0af4ae7e4b0 __sanitizer::Atexit(void (*)()) /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/sanitizer_common/sanitizer_posix_libcdep.cpp:168:10 # | #8 0x0000c0af4ae7c354 __sanitizer::Die() /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/sanitizer_common/sanitizer_termination.cpp:52:5 # | #9 0x0000c0af4ae66a30 Unlock /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/hwasan/../sanitizer_common/sanitizer_mutex.h:250:16 # | #10 0x0000c0af4ae66a30 ~GenericScopedLock /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/hwasan/../sanitizer_common/sanitizer_mutex.h:386:51 # | #11 0x0000c0af4ae66a30 __hwasan::ScopedReport::~ScopedReport() /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/hwasan/hwasan_report.cpp:54:5 # | #12 0x0000c0af4ae661b8 __hwasan::(anonymous namespace)::BaseReport::~BaseReport() /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/hwasan/hwasan_report.cpp:477:7 # | #13 0x0000c0af4ae63f5c __hwasan::ReportTagMismatch(__sanitizer::StackTrace*, unsigned long, unsigned long, bool, bool, unsigned long*) /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/hwasan/hwasan_report.cpp:1094:1 # | #14 0x0000c0af4ae4f8e0 Destroy /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/hwasan/../sanitizer_common/sanitizer_common.h:532:31 # | #15 0x0000c0af4ae4f8e0 ~InternalMmapVector /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/hwasan/../sanitizer_common/sanitizer_common.h:642:56 # | #16 0x0000c0af4ae4f8e0 __hwasan::HandleTagMismatch(__hwasan::AccessInfo, unsigned long, unsigned long, void*, unsigned long*) /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/hwasan/hwasan.cpp:245:1 # | #17 0x0000c0af4ae51e8c __hwasan_tag_mismatch4 /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/hwasan/hwasan.cpp:764:1 # | #18 0x0000c0af4ae67b30 __interception::InterceptFunction(char const*, unsigned long*, unsigned long, unsigned long) /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/compiler-rt/lib/interception/interception_linux.cpp:60:0 # | #19 0x0000c0af5641cd24 getNumResults /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/mlir/include/mlir/IR/Operation.h:404:37 # | #20 0x0000c0af5641cd24 getOpResultImpl /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/mlir/include/mlir/IR/Operation.h:1010:5 # | #21 0x0000c0af5641cd24 getResult /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/mlir/include/mlir/IR/Operation.h:407:54 # | #22 0x0000c0af5641cd24 mlir::OpTrait::detail::MultiResultTraitBase<mlir::gpu::WarpExecuteOnLane0Op, mlir::OpTrait::VariadicResults>::getResult(unsigned int) /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/mlir/include/mlir/IR/OpDefinition.h:638:62 # | #23 0x0000c0af56426b60 getType /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/mlir/include/mlir/IR/Value.h:63:33 # | #24 0x0000c0af56426b60 getType /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/mlir/include/mlir/IR/Value.h:105:39 # | #25 0x0000c0af56426b60 (anonymous namespace)::LoadDistribution::matchAndRewrite(mlir::gpu::WarpExecuteOnLane0Op, mlir::PatternRewriter&) const /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp:991:55 ... ```
…#154949" (#156924) This PR is a reapply of #154949, which failed one of sanitizer checks. The issue was querying the `warpOp` results in `LoadDistribution` after calling `moveRegionToNewWarpOpAndAppendReturns()`, which resulted in use after free. This PR solves the issue by moving the op query before the call and is otherwise identical to the one linked above. --------- Co-authored-by: Charitha Saumya <[email protected]>
…distribution #154949" (#156924) This PR is a reapply of llvm/llvm-project#154949, which failed one of sanitizer checks. The issue was querying the `warpOp` results in `LoadDistribution` after calling `moveRegionToNewWarpOpAndAppendReturns()`, which resulted in use after free. This PR solves the issue by moving the op query before the call and is otherwise identical to the one linked above. --------- Co-authored-by: Charitha Saumya <[email protected]>
This PR adds distribution patterns for scattered load and store ops, chunk size included.
XeGPU moves toward offsets being part of the load/store ops, so the pass only supports this case. Manipulating a vector of offsets indirectly through create_tdesc is complex and soon to become obsolete anyway.
This PR assumes the SIMT-adapted scatter ops verification introduced in #154653. The distribution itself can be reviewed in the meantime.