-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][XeGPU] Add support for subgroup_id_range #148661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Nishant Patel (nbpatel) ChangesFull diff: https://github.com/llvm/llvm-project/pull/148661.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index be7b860dd1729..56dc132d8083d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -174,8 +174,46 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
sgDataDim[i] = rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]);
}
+ // Check if there is warp specialization.
+ auto isWarpSpecialized = [](Operation *op, int64_t &startRange,
+ int64_t &endRange) -> bool {
+ Operation *parent = op->getParentOp();
+ // Find the outermost scf::IfOp with xegpu.sg_id_range.
+ while (parent) {
+ if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
+ if (Attribute attr = ifOp->getAttr("xegpu.sg_id_range")) {
+ if (auto denseAttr = dyn_cast<DenseI32ArrayAttr>(attr)) {
+ auto values = denseAttr.asArrayRef();
+ if (values.size() == 2) {
+ startRange = values[0];
+ endRange = values[1];
+ }
+ }
+ break;
+ }
+ }
+ parent = parent->getParentOp();
+ }
+ // Return false if startRange is 0
+ return (startRange > 0 && endRange > startRange);
+ };
+
+ int64_t startRange = -1, endRange = -1;
+ bool warpSpecialized = isWarpSpecialized(op, startRange, endRange);
+
+ // If warp specialization is detected, adjust the subgroup id accordingly
+ Value adjustedSgId = linearSgId;
+ if (warpSpecialized) {
+ // Subtract startRange from the original subgroup id to get the adjusted
+ // sg id
+ Value startRangeVal =
+ rewriter.create<arith::ConstantIndexOp>(loc, startRange);
+ adjustedSgId =
+ rewriter.createOrFold<index::SubOp>(loc, linearSgId, startRangeVal);
+ }
+
auto deLinearizeSgId =
- affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim);
+ affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim);
if (failed(deLinearizeSgId))
return failure();
SmallVector<Value> sgIds = *deLinearizeSgId;
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 44b11c304cc80..71eb732ac4953 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -296,5 +296,88 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
gpu.return
}
+ // CHECK-LABEL: @warp_specialized
+ gpu.func @warp_specialized(%src: memref<256x128xf32>, %src1: memref<128x256xf32>, %src2: memref<128x64xf32>) {
+ %sg_id = gpu.subgroup_id : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c31 = arith.constant 31 : index
+ %c3 = arith.constant 3 : index
+ %cond1 = arith.cmpi sge, %sg_id, %c0 : index
+ %cond2 = arith.cmpi slt, %sg_id, %c1 : index
+ %cond = arith.andi %cond1, %cond2 : i1
+ scf.if %cond {
+ // CHECK-NOT: index.sub
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
+ } {xegpu.sg_id_range = array<i32: 0, 1>}
+ %cond3 = arith.cmpi sge, %sg_id, %c1 : index
+ %cond4 = arith.cmpi slt, %sg_id, %c2 : index
+ %cond5 = arith.andi %cond3, %cond4 : i1
+ scf.if %cond5 {
+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C1]]
+ %tdesc_a = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+ %load_a = xegpu.load_nd %tdesc_a
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
+ %tdesc_b = xegpu.create_nd_tdesc %src1[0, 0] : memref<128x256xf32>
+ -> !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>>
+ %load_b = xegpu.load_nd %tdesc_b
+ : !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>>
+ -> vector<128x256xf32>
+ %dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>} : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
+ }{xegpu.sg_id_range = array<i32: 1, 2>}
+ %cond6 = arith.cmpi sge, %sg_id, %c2 : index
+ %cond7 = arith.cmpi slt, %sg_id, %c31 : index
+ %cond8 = arith.andi %cond6, %cond7 : i1
+ scf.if %cond8 {
+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
+ %tdesc = xegpu.create_nd_tdesc %src2[0, 0] : memref<128x64xf32>
+ -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+ -> vector<128x64xf32>
+ %exp = math.exp %load {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
+ }{xegpu.sg_id_range = array<i32: 2, 32>}
+ gpu.return
+ }
+ // CHECK-LABEL: @subgroup_id_range_nested_if
+ gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) {
+ %sg_id = gpu.subgroup_id : index
+ %c1 = arith.constant 1 : i1
+ %c3 = arith.constant 3 : index
+ %c32 = arith.constant 32 : index
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
+ %cond1 = arith.cmpi sge, %sg_id, %c3 : index
+ %cond2 = arith.cmpi slt, %sg_id, %c32 : index
+ %cond = arith.andi %cond1, %cond2 : i1
+ scf.if %c1 {
+ scf.if %cond {
+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]]
+ %td = xegpu.create_nd_tdesc %src1[0, 0] : memref<128x64xf32>
+ -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+ %ld = xegpu.load_nd %td
+ : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+ -> vector<128x64xf32>
+ %exp = math.exp %ld {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
+ }
+ } {xegpu.sg_id_range = array<i32: 3, 8>}
+ gpu.return
+ }
}
|
| } | ||
|
|
||
| // Check if there is warp specialization. | ||
| auto isWarpSpecialized = [](Operation *op, int64_t &startOfRange, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider taking this out as a separate utility function for wg-to-sg distribution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok will change it
|
|
||
| // If warp specialization is detected, adjust the subgroup id accordingly | ||
| Value adjustedSgId = linearSgId; | ||
| if (warpSpecialized) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You also need to verify that the sg id ranges match with xegpu.sg_layout
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
verify that the number of the subgroups in the sg_layout are equal to than the number of subgroups specified by the sg_id_range?
| } | ||
| parent = parent->getParentOp(); | ||
| } | ||
| // Return false if startOfRange is 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why startOfRange can't be 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it can be 0 but if the starting subgroup id is 0 we don't need to adjust the id's, so the check returns false
| %load_b = xegpu.load_nd %tdesc_b | ||
| : !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>> | ||
| -> vector<128x256xf32> | ||
| %dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>} : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg_layout size should match with range size
|
Hi @adam-smnk , do you have any comments on this? |
adam-smnk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks neat - minor comments
| startOfRange = attr.getStart().getInt(); | ||
| endOfRange = attr.getEnd().getInt(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General suggestion, non-blocker here: getting int value directly would make for a nice attribute helper method
| RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, | ||
| IntegerAttr startOfRange, IntegerAttr endOfRange) { | ||
| if (startOfRange.getInt() >= endOfRange.getInt()) | ||
| return emitError() << "EndOfRange must be greater than StartOfRange"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: in the error, I'd refer to values to their attribute names end and start
it should improve error readability
| AttrBuilder<(ins "int":$start, "int":$end)> | ||
| ]; | ||
|
|
||
| let assemblyFormat = "`<` `[`$start ```,` $end `]``>`"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| let assemblyFormat = "`<` `[`$start ```,` $end `]``>`"; | |
| let assemblyFormat = "`<` `[`$start `,` $end `]` `>`"; |
nit: minor cleanup
| //===----------------------------------------------------------------------===// | ||
|
|
||
| LogicalResult | ||
| RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add one invalid test case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if its possible to add a negative test case with this pass...because it will always give legalization error for the create_nd_desc op if the pattern returns a failure in this case
| ```mlir | ||
| scf.if %cond { | ||
| // some operations | ||
| }{sg_id_range = #xegpu.range<[2, 4]>} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| }{sg_id_range = #xegpu.range<[2, 4]>} | |
| } {sg_id_range = #xegpu.range<[2, 4]>} |
This PR adds a new attribute to the xegpu dialect called xegpu.range. One use case of this attribute can be to attach subgroup_id_range to scf.if of to drive the execution.
This PR adds a new attribute to the xegpu dialect called xegpu.range. One use case of this attribute can be to attach subgroup_id_range to scf.if of to drive the execution.