Skip to content

Commit 65dec99

Browse files
authored
[MLIR][XeGPU] Add support for subgroup_id_range (#148661)
This PR adds a new attribute to the xegpu dialect called xegpu.range. One use case of this attribute can be to attach subgroup_id_range to scf.if of to drive the execution.
1 parent 45d99c2 commit 65dec99

File tree

4 files changed

+149
-2
lines changed

4 files changed

+149
-2
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,4 +336,33 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
336336
let genVerifyDecl = 1;
337337
}
338338

339+
def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> {
340+
let summary = [{Specifies a half-open range}];
341+
let description = [{
342+
`RangeAttr` is an attribute that defines a half-open range [start, end).
343+
The range is inclusive of the start value and exclusive of the end value.
344+
One usage of this attribute can be to specify the subgroup id range.
345+
The subgroup id range can be specified using this attribute,
346+
and it can be attached to a scf.if op like
347+
```mlir
348+
scf.if %cond {
349+
// some operations
350+
} {sg_id_range = #xegpu.range<[2, 4]>}
351+
```
352+
In this case, the scf.if op will only be executed for subgroup IDs 2 and 3.
353+
}];
354+
355+
let parameters = (ins
356+
"IntegerAttr": $start,
357+
"IntegerAttr": $end
358+
);
359+
360+
let builders = [
361+
AttrBuilder<(ins "int":$start, "int":$end)>
362+
];
363+
364+
let assemblyFormat = "`<` `[`$start `,` $end `]` `>`";
365+
let genVerifyDecl = 1;
366+
}
367+
339368
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,21 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
211211
return success();
212212
}
213213

214+
//===----------------------------------------------------------------------===//
215+
// XeGPU_RangeAttr
216+
//===----------------------------------------------------------------------===//
217+
218+
LogicalResult
219+
RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
220+
IntegerAttr startOfRange, IntegerAttr endOfRange) {
221+
if (startOfRange.getInt() >= endOfRange.getInt())
222+
return emitError() << "'end' : " << endOfRange.getInt()
223+
<< " must be greater than 'start' : "
224+
<< startOfRange.getInt();
225+
226+
return success();
227+
}
228+
214229
//===----------------------------------------------------------------------===//
215230
// XeGPU_TensorDescType
216231
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,26 @@ using namespace mlir;
3434

3535
namespace {
3636

37+
// Check if there is sg id range attached to the scf.if op.
38+
static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange,
39+
int64_t &endOfRange) {
40+
Operation *parent = op->getParentOp();
41+
// Find the outermost scf::IfOp with xegpu.sg_id_range.
42+
while (parent) {
43+
if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
44+
if (auto attr = llvm::dyn_cast_or_null<xegpu::RangeAttr>(
45+
ifOp->getAttr("sg_id_range"))) {
46+
startOfRange = attr.getStart().getInt();
47+
endOfRange = attr.getEnd().getInt();
48+
break;
49+
}
50+
}
51+
parent = parent->getParentOp();
52+
}
53+
// Return false if startOfRange is 0
54+
return (startOfRange > 0 && endOfRange > startOfRange);
55+
}
56+
3757
static std::pair<SmallVector<int64_t>, int>
3858
getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
3959
int count = 1;
@@ -174,8 +194,26 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
174194
sgDataDim[i] = arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
175195
}
176196

197+
int64_t startOfRange = -1, endOfRange = -1;
198+
bool sgIdRangeSpecified =
199+
isSgIdRangeSpecified(op, startOfRange, endOfRange);
200+
201+
Value adjustedSgId = linearSgId;
202+
if (sgIdRangeSpecified) {
203+
int64_t sgCount = endOfRange - startOfRange;
204+
if (computeProduct(sgLayout) != sgCount)
205+
return rewriter.notifyMatchFailure(
206+
op, "sg_layout size must match the sg_id_range");
207+
// Subtract startOfRange from the original subgroup id to get the adjusted
208+
// sg id
209+
Value startOfRangeVal =
210+
rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
211+
adjustedSgId =
212+
rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
213+
}
214+
177215
auto deLinearizeSgId =
178-
affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim);
216+
affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim);
179217
if (failed(deLinearizeSgId))
180218
return failure();
181219
SmallVector<Value> sgIds = *deLinearizeSgId;

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,5 +327,70 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
327327
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
328328
gpu.return
329329
}
330-
}
331330

331+
// CHECK-LABEL: @subgroup_id_range
332+
gpu.func @subgroup_id_range(%src: memref<256x128xf32>, %src1: memref<128x256xf32>, %src2: memref<128x64xf32>) {
333+
%sg_id = gpu.subgroup_id : index
334+
%c0 = arith.constant 0 : index
335+
%c1 = arith.constant 1 : index
336+
%c2 = arith.constant 2 : index
337+
%c31 = arith.constant 31 : index
338+
%c3 = arith.constant 3 : index
339+
%cond1 = arith.cmpi sge, %sg_id, %c0 : index
340+
%cond2 = arith.cmpi slt, %sg_id, %c1 : index
341+
%cond = arith.andi %cond1, %cond2 : i1
342+
scf.if %cond {
343+
// CHECK-NOT: index.sub
344+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
345+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
346+
%load = xegpu.load_nd %tdesc
347+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
348+
-> vector<256x128xf32>
349+
} {sg_id_range = #xegpu.range<[0, 32]>}
350+
%cond3 = arith.cmpi sge, %sg_id, %c2 : index
351+
%cond4 = arith.cmpi slt, %sg_id, %c31 : index
352+
%cond5 = arith.andi %cond3, %cond4 : i1
353+
scf.if %cond5 {
354+
// CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
355+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
356+
// CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
357+
%tdesc = xegpu.create_nd_tdesc %src2[0, 0] : memref<128x64xf32>
358+
-> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
359+
%load = xegpu.load_nd %tdesc
360+
: !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
361+
-> vector<128x64xf32>
362+
%exp = math.exp %load {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
363+
}{sg_id_range = #xegpu.range<[2, 18]>}
364+
gpu.return
365+
}
366+
367+
// CHECK-LABEL: @subgroup_id_range_nested_if
368+
gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) {
369+
%sg_id = gpu.subgroup_id : index
370+
%c1 = arith.constant 1 : i1
371+
%c3 = arith.constant 3 : index
372+
%c32 = arith.constant 32 : index
373+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
374+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
375+
%load = xegpu.load_nd %tdesc
376+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
377+
-> vector<256x128xf32>
378+
%cond1 = arith.cmpi sge, %sg_id, %c3 : index
379+
%cond2 = arith.cmpi slt, %sg_id, %c32 : index
380+
%cond = arith.andi %cond1, %cond2 : i1
381+
scf.if %c1 {
382+
scf.if %cond {
383+
// CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
384+
// CHECK: %[[C3:.*]] = arith.constant 3 : index
385+
// CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]]
386+
%td = xegpu.create_nd_tdesc %src1[0, 0] : memref<128x64xf32>
387+
-> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
388+
%ld = xegpu.load_nd %td
389+
: !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
390+
-> vector<128x64xf32>
391+
%exp = math.exp %ld {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
392+
}
393+
} {sg_id_range = #xegpu.range<[3, 19]>}
394+
gpu.return
395+
}
396+
}

0 commit comments

Comments
 (0)