-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR][XeGPU] Add transformation pattern for vector.broadcast in Wg to Sg pass #144417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| VectorType::get(sgShape, resultType.getElementType()); | ||
|
|
||
| SmallVector<Value> newBroadcastOps; | ||
| for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about use range-based for loop?
| xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); | ||
| if (!layout || !layout.getSgLayout()) | ||
| return failure(); | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks to me that the current implementation is assuming the rank of source is the same as the rank of the result, which is a subset of the supported semantics of vector.broadcast. I believe it is partially because of the limitation of LayoutAttr. It would be better to add a check.
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Nishant Patel (nbpatel) ChangesThis PR adds transformation pattern for vector.broadcast op in xegpu-wg-to-sg-distribute pass Full diff: https://github.com/llvm/llvm-project/pull/144417.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a26c6b52f0ddc..96c7032d6b812 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -328,6 +328,39 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};
+/// This pattern transforms vector.broadcast ops to work at subgroup level.
+struct WgToSgVectorBroadcastOp
+ : public OpConversionPattern<vector::BroadcastOp> {
+ using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResult().getType();
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+ if (!layout || !layout.getSgLayout())
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
+
+ SmallVector<Value> newBroadcastOps;
+ for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
+ auto newBroadcast = rewriter.create<vector::BroadcastOp>(
+ op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
+ xegpu::setLayoutAttr(newBroadcast->getResult(0),
+ layout.dropSgLayoutAndData());
+ newBroadcastOps.push_back(newBroadcast.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
+ return success();
+ }
+};
+
// Handles UnrealizedConversionCastOp generated during
// SCFStructuralTypeConversions (step 1). This op may appear as either a
// target or source materialization for Vector values, e.g.:
@@ -411,7 +444,8 @@ namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
- UnrealizedConversionCastOpPattern>(patterns.getContext());
+ WgToSgVectorBroadcastOp, UnrealizedConversionCastOpPattern>(
+ patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -518,6 +552,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});
+ target.addDynamicallyLegalOp<vector::BroadcastOp>(
+ [=](vector::BroadcastOp op) -> bool {
+ return isLegal(xegpu::getLayoutAttr(op.getResult()));
+ });
+
target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
[=](UnrealizedConversionCastOp op) {
return llvm::is_contained(existingCastOps, op.getOperation());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 35ad16d8cd9a9..60ac266b0f112 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -103,6 +103,24 @@ gpu.module @test_round_robin_assignment {
gpu.return
}
+ // CHECK-LABEL: test_broadcast
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+ gpu.func @test_broadcast(%src: memref<24x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+ -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ -> vector<24x1xf32>
+ // CHECK-COUNT-3: vector.broadcast {{.*}}
+ // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+ // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
+ // CHECK-NOT: vector.broadcast
+ %broadcast = vector.broadcast %load
+ {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
+ : vector<24x1xf32> to vector<24x8xf32>
+ gpu.return
+ }
+
gpu.func @test_scf_for(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
@@ -197,5 +215,4 @@ gpu.module @test_round_robin_assignment {
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
gpu.return
}
-
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 466842c968448..125bab349b4cb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -170,6 +170,22 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
gpu.return
}
+ // CHECK-LABEL: test_broadcast
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+ gpu.func @test_broadcast(%src: memref<24x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+ -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ -> vector<24x1xf32>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
+ %broadcast = vector.broadcast %load
+ {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
+ : vector<24x1xf32> to vector<24x8xf32>
+ gpu.return
+ }
+
gpu.func @test_scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
//CHECK: [[c0:%.+]] = arith.constant 0 : index
//CHECK: [[c128:%.+]] = arith.constant 128 : index
@@ -295,6 +311,5 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
gpu.return
}
-
-
}
+
|
adam-smnk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks in line with other distributions
| // TODO: Currently only supports cases where the source and result ranks | ||
| // are the same. | ||
| auto srcType = | ||
| dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can adaptor.getSource() be used here and later instead of using adaptor.getOperands().front() ?
| // and the other dimensions are the same as the destination type | ||
| // TODO: Generalize it | ||
| auto srcShape = srcType.getShape(); | ||
| for (size_t i = 0; i < srcShape.size(); ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems this check duplicates the check in broadcast verifier, unless there are cases where the source vector, e.g., vector<32x1x1xf32> can be distributed to a vector, e.g., <8x2x1>.
…o Sg pass (llvm#144417) This PR adds transformation pattern for vector.broadcast op in xegpu-wg-to-sg-distribute pass
This PR adds transformation pattern for vector.broadcast op in xegpu-wg-to-sg-distribute pass