Skip to content

Commit b4e3068

Browse files
committed
Add xegpu.sg_id_range attribute
1 parent 1d18b89 commit b4e3068

File tree

3 files changed

+45
-22
lines changed

3 files changed

+45
-22
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,4 +315,31 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
315315
let genVerifyDecl = 1;
316316
}
317317

318+
def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> {
319+
let summary = [{Specifies a half-open range}];
320+
let description = [{
321+
`RangeAttr` is an attribute that defines a half-open range [start, end).
322+
The range is inclusive of the start value and exclusive of the end value.
323+
One usage of this attribute can be for warp specialization.
324+
For warp specialization, this attribute can be attached to a scf.if op like
325+
```mlir
326+
scf.if %cond {
327+
// some operations
328+
}{sg_id_range = #xegpu.range<[2, 4]>}
329+
```
330+
In this case, the scf.if op will only be executed for subgroup IDs 2 and 3.
331+
}];
332+
333+
let parameters = (ins
334+
"IntegerAttr": $start,
335+
"IntegerAttr": $end
336+
);
337+
338+
let builders = [
339+
AttrBuilder<(ins "int":$start, "int":$end)>
340+
];
341+
342+
let assemblyFormat = "`<` `[`$start ```,` $end `]``>`";
343+
}
344+
318345
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -175,41 +175,37 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
175175
}
176176

177177
// Check if there is warp specialization.
178-
auto isWarpSpecialized = [](Operation *op, int64_t &startRange,
179-
int64_t &endRange) -> bool {
178+
auto isWarpSpecialized = [](Operation *op, int64_t &startOfRange,
179+
int64_t &endOfRange) -> bool {
180180
Operation *parent = op->getParentOp();
181181
// Find the outermost scf::IfOp with xegpu.sg_id_range.
182182
while (parent) {
183183
if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
184-
if (Attribute attr = ifOp->getAttr("xegpu.sg_id_range")) {
185-
if (auto denseAttr = dyn_cast<DenseI32ArrayAttr>(attr)) {
186-
auto values = denseAttr.asArrayRef();
187-
if (values.size() == 2) {
188-
startRange = values[0];
189-
endRange = values[1];
190-
}
191-
}
184+
if (auto attr = llvm::dyn_cast_or_null<xegpu::RangeAttr>(
185+
ifOp->getAttr("sg_id_range"))) {
186+
startOfRange = attr.getStart().getInt();
187+
endOfRange = attr.getEnd().getInt();
192188
break;
193189
}
194190
}
195191
parent = parent->getParentOp();
196192
}
197-
// Return false if startRange is 0
198-
return (startRange > 0 && endRange > startRange);
193+
// Return false if startOfRange is 0
194+
return (startOfRange > 0 && endOfRange > startOfRange);
199195
};
200196

201-
int64_t startRange = -1, endRange = -1;
202-
bool warpSpecialized = isWarpSpecialized(op, startRange, endRange);
197+
int64_t startOfRange = -1, endOfRange = -1;
198+
bool warpSpecialized = isWarpSpecialized(op, startOfRange, endOfRange);
203199

204200
// If warp specialization is detected, adjust the subgroup id accordingly
205201
Value adjustedSgId = linearSgId;
206202
if (warpSpecialized) {
207-
// Subtract startRange from the original subgroup id to get the adjusted
203+
// Subtract startOfRange from the original subgroup id to get the adjusted
208204
// sg id
209-
Value startRangeVal =
210-
rewriter.create<arith::ConstantIndexOp>(loc, startRange);
205+
Value startOfRangeVal =
206+
rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
211207
adjustedSgId =
212-
rewriter.createOrFold<index::SubOp>(loc, linearSgId, startRangeVal);
208+
rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
213209
}
214210

215211
auto deLinearizeSgId =

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
314314
%load = xegpu.load_nd %tdesc
315315
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
316316
-> vector<256x128xf32>
317-
} {xegpu.sg_id_range = array<i32: 0, 1>}
317+
} {sg_id_range = #xegpu.range<[0, 1]>}
318318
%cond3 = arith.cmpi sge, %sg_id, %c1 : index
319319
%cond4 = arith.cmpi slt, %sg_id, %c2 : index
320320
%cond5 = arith.andi %cond3, %cond4 : i1
@@ -333,7 +333,7 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
333333
: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>>
334334
-> vector<128x256xf32>
335335
%dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>} : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
336-
}{xegpu.sg_id_range = array<i32: 1, 2>}
336+
}{sg_id_range = #xegpu.range<[1, 2]>}
337337
%cond6 = arith.cmpi sge, %sg_id, %c2 : index
338338
%cond7 = arith.cmpi slt, %sg_id, %c31 : index
339339
%cond8 = arith.andi %cond6, %cond7 : i1
@@ -347,7 +347,7 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
347347
: !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
348348
-> vector<128x64xf32>
349349
%exp = math.exp %load {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
350-
}{xegpu.sg_id_range = array<i32: 2, 32>}
350+
}{sg_id_range = #xegpu.range<[2, 32]>}
351351
gpu.return
352352
}
353353

@@ -377,7 +377,7 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
377377
-> vector<128x64xf32>
378378
%exp = math.exp %ld {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
379379
}
380-
} {xegpu.sg_id_range = array<i32: 3, 8>}
380+
} {sg_id_range = #xegpu.range<[3, 8]>}
381381
gpu.return
382382
}
383383
}

0 commit comments

Comments
 (0)