Skip to content

Commit 1d18b89

Browse files
committed
Add support for subgroup_id_range
1 parent f1acd69 commit 1d18b89

File tree

2 files changed

+122
-1
lines changed

2 files changed

+122
-1
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,46 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
174174
sgDataDim[i] = rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]);
175175
}
176176

177+
// Check if there is warp specialization.
178+
auto isWarpSpecialized = [](Operation *op, int64_t &startRange,
179+
int64_t &endRange) -> bool {
180+
Operation *parent = op->getParentOp();
181+
// Find the outermost scf::IfOp with xegpu.sg_id_range.
182+
while (parent) {
183+
if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
184+
if (Attribute attr = ifOp->getAttr("xegpu.sg_id_range")) {
185+
if (auto denseAttr = dyn_cast<DenseI32ArrayAttr>(attr)) {
186+
auto values = denseAttr.asArrayRef();
187+
if (values.size() == 2) {
188+
startRange = values[0];
189+
endRange = values[1];
190+
}
191+
}
192+
break;
193+
}
194+
}
195+
parent = parent->getParentOp();
196+
}
197+
// Return false if startRange is 0
198+
return (startRange > 0 && endRange > startRange);
199+
};
200+
201+
int64_t startRange = -1, endRange = -1;
202+
bool warpSpecialized = isWarpSpecialized(op, startRange, endRange);
203+
204+
// If warp specialization is detected, adjust the subgroup id accordingly
205+
Value adjustedSgId = linearSgId;
206+
if (warpSpecialized) {
207+
// Subtract startRange from the original subgroup id to get the adjusted
208+
// sg id
209+
Value startRangeVal =
210+
rewriter.create<arith::ConstantIndexOp>(loc, startRange);
211+
adjustedSgId =
212+
rewriter.createOrFold<index::SubOp>(loc, linearSgId, startRangeVal);
213+
}
214+
177215
auto deLinearizeSgId =
178-
affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim);
216+
affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim);
179217
if (failed(deLinearizeSgId))
180218
return failure();
181219
SmallVector<Value> sgIds = *deLinearizeSgId;

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,5 +296,88 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
296296
gpu.return
297297
}
298298

299+
// CHECK-LABEL: @warp_specialized
300+
gpu.func @warp_specialized(%src: memref<256x128xf32>, %src1: memref<128x256xf32>, %src2: memref<128x64xf32>) {
301+
%sg_id = gpu.subgroup_id : index
302+
%c0 = arith.constant 0 : index
303+
%c1 = arith.constant 1 : index
304+
%c2 = arith.constant 2 : index
305+
%c31 = arith.constant 31 : index
306+
%c3 = arith.constant 3 : index
307+
%cond1 = arith.cmpi sge, %sg_id, %c0 : index
308+
%cond2 = arith.cmpi slt, %sg_id, %c1 : index
309+
%cond = arith.andi %cond1, %cond2 : i1
310+
scf.if %cond {
311+
// CHECK-NOT: index.sub
312+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
313+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
314+
%load = xegpu.load_nd %tdesc
315+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
316+
-> vector<256x128xf32>
317+
} {xegpu.sg_id_range = array<i32: 0, 1>}
318+
%cond3 = arith.cmpi sge, %sg_id, %c1 : index
319+
%cond4 = arith.cmpi slt, %sg_id, %c2 : index
320+
%cond5 = arith.andi %cond3, %cond4 : i1
321+
scf.if %cond5 {
322+
// CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
323+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
324+
// CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C1]]
325+
%tdesc_a = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
326+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
327+
%load_a = xegpu.load_nd %tdesc_a
328+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
329+
-> vector<256x128xf32>
330+
%tdesc_b = xegpu.create_nd_tdesc %src1[0, 0] : memref<128x256xf32>
331+
-> !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>>
332+
%load_b = xegpu.load_nd %tdesc_b
333+
: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>>
334+
-> vector<128x256xf32>
335+
%dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>} : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
336+
}{xegpu.sg_id_range = array<i32: 1, 2>}
337+
%cond6 = arith.cmpi sge, %sg_id, %c2 : index
338+
%cond7 = arith.cmpi slt, %sg_id, %c31 : index
339+
%cond8 = arith.andi %cond6, %cond7 : i1
340+
scf.if %cond8 {
341+
// CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
342+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
343+
// CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
344+
%tdesc = xegpu.create_nd_tdesc %src2[0, 0] : memref<128x64xf32>
345+
-> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
346+
%load = xegpu.load_nd %tdesc
347+
: !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
348+
-> vector<128x64xf32>
349+
%exp = math.exp %load {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
350+
}{xegpu.sg_id_range = array<i32: 2, 32>}
351+
gpu.return
352+
}
299353

354+
// CHECK-LABEL: @subgroup_id_range_nested_if
355+
gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) {
356+
%sg_id = gpu.subgroup_id : index
357+
%c1 = arith.constant 1 : i1
358+
%c3 = arith.constant 3 : index
359+
%c32 = arith.constant 32 : index
360+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
361+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
362+
%load = xegpu.load_nd %tdesc
363+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
364+
-> vector<256x128xf32>
365+
%cond1 = arith.cmpi sge, %sg_id, %c3 : index
366+
%cond2 = arith.cmpi slt, %sg_id, %c32 : index
367+
%cond = arith.andi %cond1, %cond2 : i1
368+
scf.if %c1 {
369+
scf.if %cond {
370+
// CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
371+
// CHECK: %[[C3:.*]] = arith.constant 3 : index
372+
// CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]]
373+
%td = xegpu.create_nd_tdesc %src1[0, 0] : memref<128x64xf32>
374+
-> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
375+
%ld = xegpu.load_nd %td
376+
: !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
377+
-> vector<128x64xf32>
378+
%exp = math.exp %ld {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
379+
}
380+
} {xegpu.sg_id_range = array<i32: 3, 8>}
381+
gpu.return
382+
}
300383
}

0 commit comments

Comments
 (0)