Skip to content

Commit 45e56ff

Browse files
committed
Address feedback
1 parent a1b35a4 commit 45e56ff

File tree

2 files changed

+99
-144
lines changed

2 files changed

+99
-144
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 99 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,38 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
296296
}
297297
};
298298

299-
template <typename OpTy, typename AdaptorTy, typename CreateFn>
300-
LogicalResult distributeNdOpWithOffset(OpTy op, AdaptorTy adaptor,
301-
ConversionPatternRewriter &rewriter,
302-
CreateFn &&createOp) {
299+
// Utility function to compute distributed offsets for subgroup operations.
300+
// Returns a vector of new offsets for each subgroup, given the original op's
301+
// offsets and subgroup relative offsets.
302+
static SmallVector<SmallVector<OpFoldResult>> computeDistributedOffsets(
303+
Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
304+
ArrayRef<OpFoldResult> wgOffsets, ConversionPatternRewriter &rewriter) {
305+
SmallVector<SmallVector<OpFoldResult>> distributedOffsets;
306+
Location loc = op->getLoc();
307+
for (const auto &sgOffsets : sgOffsetsList) {
308+
SmallVector<OpFoldResult> newOffsets;
309+
size_t rank = sgOffsets.size();
310+
for (size_t i = 0; i < rank; i++) {
311+
size_t idx = wgOffsets.size() - rank + i;
312+
Value add = rewriter.createOrFold<index::AddOp>(
313+
loc, sgOffsets[i],
314+
getValueOrCreateConstantIndexOp(rewriter, loc, wgOffsets[idx]));
315+
newOffsets.push_back(add);
316+
}
317+
distributedOffsets.push_back(std::move(newOffsets));
318+
}
319+
return distributedOffsets;
320+
}
321+
322+
// Utility function to get sgShape, sgOffsetList, and wgOffsets for a given
323+
// op.
324+
template <typename OpTy, typename AdaptorTy>
325+
LogicalResult
326+
prepareOpDistribution(OpTy op, AdaptorTy adaptor,
327+
ConversionPatternRewriter &rewriter,
328+
SmallVector<int64_t> &sgShape,
329+
SmallVector<SmallVector<Value>> &sgOffsetList,
330+
SmallVector<OpFoldResult> &wgOffsets) {
303331
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
304332
if (offsetSize == 0 && (!op.getConstOffsetsAttr()))
305333
return failure();
@@ -321,7 +349,6 @@ LogicalResult distributeNdOpWithOffset(OpTy op, AdaptorTy adaptor,
321349
op, "sgLayout attribute is required in layout");
322350

323351
ArrayRef<int64_t> wgShape = tdescTy.getShape();
324-
SmallVector<int64_t> sgShape;
325352
int count;
326353
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
327354

@@ -343,21 +370,19 @@ LogicalResult distributeNdOpWithOffset(OpTy op, AdaptorTy adaptor,
343370
rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
344371
}
345372

346-
auto maybeTdescOffsets =
347-
layout.getOffsets(rewriter, loc, linearSgId, wgShape);
348-
if (failed(maybeTdescOffsets))
373+
auto sgOffsets = layout.getOffsets(rewriter, loc, linearSgId, wgShape);
374+
if (failed(sgOffsets))
349375
return failure();
350376

351-
SmallVector<OpFoldResult> oldOffsets;
352377
if (auto constOffsets = op.getConstOffsetsAttr()) {
353378
for (auto attr : constOffsets.asArrayRef())
354-
oldOffsets.push_back(rewriter.getIndexAttr(attr));
379+
wgOffsets.push_back(rewriter.getIndexAttr(attr));
355380
}
356381
for (auto v : op.getOffsets())
357-
oldOffsets.push_back(v);
382+
wgOffsets.push_back(v);
358383

359-
return createOp(loc, sgShape, *maybeTdescOffsets, oldOffsets, adaptor,
360-
rewriter, op);
384+
sgOffsetList = *sgOffsets;
385+
return success();
361386
}
362387

363388
// This pattern transforms the LoadNdOp with explicit offsets to load
@@ -368,39 +393,31 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
368393
xegpu::LoadNdOp op,
369394
typename OpConversionPattern<xegpu::LoadNdOp>::OneToNOpAdaptor adaptor,
370395
ConversionPatternRewriter &rewriter) const override {
371-
return distributeNdOpWithOffset(
372-
op, adaptor, rewriter,
373-
[](Location loc, SmallVector<int64_t> &sgShape,
374-
ArrayRef<SmallVector<Value>> tdescOffsetsList,
375-
SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
376-
ConversionPatternRewriter &rewriter,
377-
xegpu::LoadNdOp &op) -> LogicalResult {
378-
SmallVector<Value> newLoadOps;
379-
for (auto [tdescOffsets, tdesc] :
380-
llvm::zip(tdescOffsetsList, adaptor.getTensorDesc())) {
381-
SmallVector<OpFoldResult> newOffsets;
382-
size_t rank = tdescOffsets.size();
383-
for (size_t i = 0; i < rank; i++) {
384-
size_t idx = oldOffsets.size() - rank + i;
385-
Value add = rewriter.createOrFold<index::AddOp>(
386-
loc, tdescOffsets[i],
387-
getValueOrCreateConstantIndexOp(rewriter, loc,
388-
oldOffsets[idx]));
389-
newOffsets.push_back(add);
390-
}
391-
VectorType newResTy = VectorType::get(
392-
sgShape, dyn_cast<xegpu::TensorDescType>(tdesc.getType())
393-
.getElementType());
394-
auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
395-
loc, newResTy, tdesc, newOffsets,
396-
/*packed=*/nullptr,
397-
/*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
398-
op.getL3HintAttr());
399-
newLoadOps.push_back(newLoadOp);
400-
}
401-
rewriter.replaceOpWithMultiple(op, {newLoadOps});
402-
return success();
403-
});
396+
SmallVector<int64_t> sgShape;
397+
SmallVector<SmallVector<Value>> sgOffsetList;
398+
SmallVector<OpFoldResult> wgOffsets;
399+
if (failed(prepareOpDistribution(op, adaptor, rewriter, sgShape,
400+
sgOffsetList, wgOffsets)))
401+
return failure();
402+
403+
auto distributedOffsets =
404+
computeDistributedOffsets(op, sgOffsetList, wgOffsets, rewriter);
405+
406+
SmallVector<Value> newLoadOps;
407+
for (auto [newOffsets, tdesc] :
408+
llvm::zip(distributedOffsets, adaptor.getTensorDesc())) {
409+
VectorType newResTy = VectorType::get(
410+
sgShape,
411+
dyn_cast<xegpu::TensorDescType>(tdesc.getType()).getElementType());
412+
auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
413+
op.getLoc(), newResTy, tdesc, newOffsets,
414+
/*packed=*/nullptr,
415+
/*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
416+
op.getL3HintAttr());
417+
newLoadOps.push_back(newLoadOp);
418+
}
419+
rewriter.replaceOpWithMultiple(op, {newLoadOps});
420+
return success();
404421
}
405422
};
406423

@@ -413,33 +430,24 @@ struct WgToSgStoreNdOpWithOffset
413430
xegpu::StoreNdOp op,
414431
typename OpConversionPattern<xegpu::StoreNdOp>::OneToNOpAdaptor adaptor,
415432
ConversionPatternRewriter &rewriter) const override {
416-
return distributeNdOpWithOffset(
417-
op, adaptor, rewriter,
418-
[](Location loc, SmallVector<int64_t> &sgShape,
419-
ArrayRef<SmallVector<Value>> tdescOffsetsList,
420-
SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
421-
ConversionPatternRewriter &rewriter,
422-
xegpu::StoreNdOp &op) -> LogicalResult {
423-
for (auto [tdescOffsets, tdesc, value] :
424-
llvm::zip(tdescOffsetsList, adaptor.getTensorDesc(),
425-
adaptor.getValue())) {
426-
SmallVector<OpFoldResult> newOffsets;
427-
size_t rank = tdescOffsets.size();
428-
for (size_t i = 0; i < rank; i++) {
429-
size_t idx = oldOffsets.size() - rank + i;
430-
Value add = rewriter.createOrFold<index::AddOp>(
431-
loc, tdescOffsets[i],
432-
getValueOrCreateConstantIndexOp(rewriter, loc,
433-
oldOffsets[idx]));
434-
newOffsets.push_back(add);
435-
}
436-
rewriter.create<xegpu::StoreNdOp>(
437-
loc, value, tdesc, newOffsets, op.getL1HintAttr(),
438-
op.getL2HintAttr(), op.getL3HintAttr());
439-
}
440-
rewriter.eraseOp(op);
441-
return success();
442-
});
433+
SmallVector<int64_t> sgShape;
434+
SmallVector<SmallVector<Value>> sgOffsetList;
435+
SmallVector<OpFoldResult> wgOffsets;
436+
if (failed(prepareOpDistribution(op, adaptor, rewriter, sgShape,
437+
sgOffsetList, wgOffsets)))
438+
return failure();
439+
440+
auto distributedOffsets =
441+
computeDistributedOffsets(op, sgOffsetList, wgOffsets, rewriter);
442+
443+
for (auto [newOffsets, tdesc, value] : llvm::zip(
444+
distributedOffsets, adaptor.getTensorDesc(), adaptor.getValue())) {
445+
rewriter.create<xegpu::StoreNdOp>(op.getLoc(), value, tdesc, newOffsets,
446+
op.getL1HintAttr(), op.getL2HintAttr(),
447+
op.getL3HintAttr());
448+
}
449+
rewriter.eraseOp(op);
450+
return success();
443451
}
444452
};
445453

@@ -453,32 +461,24 @@ struct WgToSgPrefetchNdOpWithOffset
453461
typename OpConversionPattern<xegpu::PrefetchNdOp>::OneToNOpAdaptor
454462
adaptor,
455463
ConversionPatternRewriter &rewriter) const override {
456-
return distributeNdOpWithOffset(
457-
op, adaptor, rewriter,
458-
[](Location loc, SmallVector<int64_t> &sgShape,
459-
ArrayRef<SmallVector<Value>> tdescOffsetsList,
460-
SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
461-
ConversionPatternRewriter &rewriter,
462-
xegpu::PrefetchNdOp &op) -> LogicalResult {
463-
for (auto [tdescOffsets, tdesc] :
464-
llvm::zip(tdescOffsetsList, adaptor.getTensorDesc())) {
465-
SmallVector<OpFoldResult> newOffsets;
466-
size_t rank = tdescOffsets.size();
467-
for (size_t i = 0; i < rank; i++) {
468-
size_t idx = oldOffsets.size() - rank + i;
469-
Value add = rewriter.createOrFold<index::AddOp>(
470-
loc, tdescOffsets[i],
471-
getValueOrCreateConstantIndexOp(rewriter, loc,
472-
oldOffsets[idx]));
473-
newOffsets.push_back(add);
474-
}
475-
rewriter.create<xegpu::PrefetchNdOp>(
476-
loc, tdesc, newOffsets, op.getL1HintAttr(), op.getL2HintAttr(),
477-
op.getL3HintAttr());
478-
}
479-
rewriter.eraseOp(op);
480-
return success();
481-
});
464+
SmallVector<int64_t> sgShape;
465+
SmallVector<SmallVector<Value>> sgOffsetList;
466+
SmallVector<OpFoldResult> wgOffsets;
467+
if (failed(prepareOpDistribution(op, adaptor, rewriter, sgShape,
468+
sgOffsetList, wgOffsets)))
469+
return failure();
470+
471+
auto distributedOffsets =
472+
computeDistributedOffsets(op, sgOffsetList, wgOffsets, rewriter);
473+
474+
for (auto [newOffsets, tdesc] :
475+
llvm::zip(distributedOffsets, adaptor.getTensorDesc())) {
476+
rewriter.create<xegpu::PrefetchNdOp>(
477+
op.getLoc(), tdesc, newOffsets, op.getL1HintAttr(),
478+
op.getL2HintAttr(), op.getL3HintAttr());
479+
}
480+
rewriter.eraseOp(op);
481+
return success();
482482
}
483483
};
484484

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,8 @@ gpu.module @test_distribution {
2828
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
2929
gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
3030
//CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
31-
//CHECK: [[C8:%.+]] = arith.constant 8 : index
32-
//CHECK: [[C4:%.+]] = arith.constant 4 : index
33-
//CHECK: [[C4_1:%.+]] = arith.constant 4 : index
3431
//CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
3532
//CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
36-
//CHECK: [[C32:%.+]] = arith.constant 32 : index
37-
//CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
38-
//CHECK: [[C32_1:%.+]] = arith.constant 32 : index
39-
//CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32_1]]
40-
//CHECK: [[C0:%.+]] = arith.constant 0 : index
41-
//CHECK: [[C0_1:%.+]] = arith.constant 0 : index
42-
//CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
43-
//CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
44-
//CHECK: [[C256:%.+]] = arith.constant 256 : index
45-
//CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
46-
//CHECK: [[C128:%.+]] = arith.constant 128 : index
47-
//CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
4833
//CHECK: %[[LOAD:.*]] = xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
4934
%tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
5035
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -58,23 +43,8 @@ gpu.module @test_distribution {
5843
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
5944
gpu.func @store_nd_with_offsets(%src: memref<256x128xf32>) {
6045
//CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
61-
//CHECK: [[C8:%.+]] = arith.constant 8 : index
62-
//CHECK: [[C4:%.+]] = arith.constant 4 : index
63-
//CHECK: [[C4_1:%.+]] = arith.constant 4 : index
6446
//CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
6547
//CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
66-
//CHECK: [[C32:%.+]] = arith.constant 32 : index
67-
//CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
68-
//CHECK: [[C32_1:%.+]] = arith.constant 32 : index
69-
//CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32_1]]
70-
//CHECK: [[C0:%.+]] = arith.constant 0 : index
71-
//CHECK: [[C0_1:%.+]] = arith.constant 0 : index
72-
//CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
73-
//CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
74-
//CHECK: [[C256:%.+]] = arith.constant 256 : index
75-
//CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
76-
//CHECK: [[C128:%.+]] = arith.constant 128 : index
77-
//CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
7848
//CHECK: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}] : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
7949
%tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
8050
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -90,23 +60,8 @@ gpu.module @test_distribution {
9060
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
9161
gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
9262
//CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
93-
//CHECK: [[C8:%.+]] = arith.constant 8 : index
94-
//CHECK: [[C4:%.+]] = arith.constant 4 : index
95-
//CHECK: [[C4_1:%.+]] = arith.constant 4 : index
9663
//CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
9764
//CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
98-
//CHECK: [[C32:%.+]] = arith.constant 32 : index
99-
//CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
100-
//CHECK: [[C32_1:%.+]] = arith.constant 32 : index
101-
//CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32_1]]
102-
//CHECK: [[C0:%.+]] = arith.constant 0 : index
103-
//CHECK: [[C0_1:%.+]] = arith.constant 0 : index
104-
//CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
105-
//CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
106-
//CHECK: [[C256:%.+]] = arith.constant 256 : index
107-
//CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
108-
//CHECK: [[C128:%.+]] = arith.constant 128 : index
109-
//CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
11065
//CHECK: xegpu.prefetch_nd %{{.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
11166
%cst0 = arith.constant 0 : index
11267
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>

0 commit comments

Comments
 (0)