Skip to content

Commit 3630966

Browse files
committed
apply getOffsets in CreateNdDescOp
1 parent 60e20a0 commit 3630966

File tree

3 files changed

+60
-55
lines changed

3 files changed

+60
-55
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -217,14 +217,14 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
217217
FailureOr<SmallVector<Value>>
218218
LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
219219
Value linearId) {
220-
// delinearizeSubgroupId is only available for workgroup-level layout
221-
// attribute
220+
// delinearizeSubgroupId is only available for
221+
// workgroup-level layout attribute
222222
if (!isWgLayout())
223223
return failure();
224224

225225
auto dims =
226-
llvm::map_to_vector(getSgLayout().asArrayRef(), [&](int32_t d) -> Value {
227-
return arith::ConstantIndexOp::create(builder, loc, d);
226+
llvm::map_to_vector(*getEffectiveSgLayout(), [&](int64_t d) -> Value {
227+
return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
228228
});
229229

230230
return affine::delinearizeIndex(builder, loc, linearId, dims);
@@ -260,25 +260,32 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
260260
// nd local offset, localOffset[i] = sgId[i] * sgShape[i]
261261
SmallVector<Value> localOffsets = llvm::map_to_vector(
262262
llvm::zip(sgIds, sgShape), [&](const auto &t) -> Value {
263-
auto &[id, s] = t;
264-
Value d = arith::ConstantIndexOp::create(builder, loc, s);
265-
return index::MulOp::create(builder, loc, id, d);
263+
return builder.createOrFold<index::MulOp>(
264+
loc, std::get<0>(t),
265+
builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
266266
});
267267

268268
SmallVector<SmallVector<Value>> offsets;
269269
for (SmallVector<int64_t> unitOffs : StaticTileOffsetRange(shape, distUnit)) {
270270
SmallVector<Value> base =
271271
llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
272-
return arith::ConstantIndexOp::create(builder, loc, d);
272+
return builder.create<arith::ConstantIndexOp>(loc, d);
273273
});
274274

275275
SmallVector<Value> adds = llvm::map_to_vector(
276276
llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
277-
return arith::AddIOp::create(builder, loc, std::get<0>(t),
278-
std::get<1>(t));
277+
return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
278+
std::get<1>(t));
279279
});
280280

281-
offsets.push_back(adds);
281+
SmallVector<Value> mods = llvm::map_to_vector(
282+
llvm::zip_equal(adds, distUnit), [&](const auto &t) -> Value {
283+
return builder.createOrFold<index::RemUOp>(
284+
loc, std::get<0>(t),
285+
builder.create<arith::ConstantIndexOp>(loc, std::get<1>(t)));
286+
});
287+
288+
offsets.push_back(mods);
282289
}
283290

284291
return offsets;

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -212,39 +212,39 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
212212
rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
213213
}
214214

215-
auto deLinearizeSgId =
216-
layout.delinearizeSubgroupId(rewriter, loc, adjustedSgId);
217-
if (failed(deLinearizeSgId))
215+
auto tdescOffsets = layout.getOffsets(rewriter, loc, adjustedSgId, wgShape);
216+
if (failed(tdescOffsets))
218217
return failure();
219-
SmallVector<Value> sgIds = *deLinearizeSgId;
220-
221-
// Calculate distribution unit shape and local offsets for subgroup
222-
SmallVector<int64_t> distUnitShape(sgLayout.size());
223-
SmallVector<Value> localOffset(sgLayout.size());
224-
for (size_t i = 0; i < sgLayout.size(); i++) {
225-
distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], wgShape[i]);
226-
localOffset[i] =
227-
rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
228-
}
229-
230-
SmallVector<OpFoldResult> originalOffsets = op.getMixedOffsets();
231218

232219
xegpu::TensorDescType newTdescTy =
233220
xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
234221
layout.dropSgLayoutAndData());
222+
235223
SmallVector<Value> newCreateNdOps;
236-
for (SmallVector<int64_t> distUnitBaseAddr :
237-
StaticTileOffsetRange(wgShape, distUnitShape)) {
238-
SmallVector<OpFoldResult> globalOffsets =
239-
calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset,
240-
distUnitBaseAddr, distUnitShape);
224+
SmallVector<OpFoldResult> offset = op.getMixedOffsets();
225+
226+
for (auto tdescOffset : *tdescOffsets) {
227+
SmallVector<OpFoldResult> newOffsets = llvm::map_to_vector(
228+
llvm::zip_longest(tdescOffset, offset),
229+
[&](const auto &t) -> OpFoldResult {
230+
std::optional<Value> off = std::get<0>(t);
231+
std::optional<OpFoldResult> old = std::get<1>(t);
232+
if (!off.has_value())
233+
return *old;
234+
235+
if (!old.has_value() || isZeroInteger(*old))
236+
return *off;
237+
238+
return rewriter.createOrFold<index::AddOp>(
239+
loc, *off,
240+
getValueOrCreateConstantIndexOp(rewriter, loc, *old));
241+
});
241242

242243
auto newCreateNdOp = xegpu::CreateNdDescOp::create(
243-
rewriter, loc, newTdescTy, op.getSource(), globalOffsets,
244+
rewriter, loc, newTdescTy, op.getSource(), newOffsets,
244245
op.getMixedSizes(), op.getMixedStrides());
245246
newCreateNdOps.push_back(newCreateNdOp);
246247
}
247-
248248
rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
249249
return success();
250250
}

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,25 @@
44
//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
55
gpu.module @test_1_1_assignment {
66
// CHECK-LABEL: create_nd_tdesc
7-
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
7+
// CHECK-SAME: [[ARG_0:%.*]]: memref<24x32xf32>
88
gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) {
9-
// CHECK: %[[SGID:.*]] = gpu.subgroup_id
10-
// CHECK: %[[C12:.*]] = arith.constant 12 : index
11-
// CHECK: %[[C4:.*]] = arith.constant 4 : index
12-
// CHECK: %[[C8:.*]] = arith.constant 8 : index
13-
// CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]]
14-
// CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
15-
// CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C12]]
16-
// CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C8]]
17-
// CHECK: %[[C24:.*]] = arith.constant 24 : index
18-
// CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C24]]
19-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
20-
// CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0]]
21-
// CHECK: %[[C32:.*]] = arith.constant 32 : index
22-
// CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C32]]
23-
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
24-
// CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_1]]
25-
// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32>
26-
// CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
27-
// CHECK: gpu.return
9+
//CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
10+
//CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
11+
//CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
12+
//CHECK: [[C12:%.+]] = arith.constant 12 : index
13+
//CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C12]]
14+
//CHECK: [[C8:%.+]] = arith.constant 8 : index
15+
//CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C8]]
16+
//CHECK: [[C0:%.+]] = arith.constant 0 : index
17+
//CHECK: [[C0_1:%.+]] = arith.constant 0 : index
18+
//CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
19+
//CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
20+
//CHECK: [[C24:%.+]] = arith.constant 24 : index
21+
//CHECK: [[Y:%.+]] = index.remu [[UY]], [[C24]]
22+
//CHECK: [[C32:%.+]] = arith.constant 32 : index
23+
//CHECK: [[X:%.+]] = index.remu [[UX]], [[C32]]
24+
//CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][[[Y]], [[X]]] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
25+
2826
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
2927
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
3028
gpu.return
@@ -180,7 +178,7 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
180178
-> vector<24x1xf32>
181179
// CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
182180
// CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
183-
%broadcast = vector.broadcast %load
181+
%broadcast = vector.broadcast %load
184182
{layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
185183
: vector<24x1xf32> to vector<24x8xf32>
186184
gpu.return
@@ -367,7 +365,7 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
367365
// CHECK-LABEL: @subgroup_id_range_nested_if
368366
gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) {
369367
%sg_id = gpu.subgroup_id : index
370-
%c1 = arith.constant 1 : i1
368+
%c1 = arith.constant 1 : i1
371369
%c3 = arith.constant 3 : index
372370
%c32 = arith.constant 32 : index
373371
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>

0 commit comments

Comments
 (0)