Skip to content

Commit fdfc751

Browse files
authored
[MLIR][XeGPU] Distribute load_gather/store_scatter op from Wg To Sg (#154420)
This PR adds distribution patterns for scatter ops (LoadGather and StoreScatter) with offsets.
1 parent a7224dc commit fdfc751

File tree

4 files changed

+212
-2
lines changed

4 files changed

+212
-2
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
887887

888888
let builders = [
889889
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
890+
"xegpu::CachePolicyAttr": $l1_hint,
891+
"xegpu::CachePolicyAttr": $l2_hint,
892+
"xegpu::CachePolicyAttr": $l3_hint)>,
893+
OpBuilder<(ins "Type": $value, "Value": $source,
894+
"ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
895+
"IntegerAttr": $chunk_size,
890896
"xegpu::CachePolicyAttr": $l1_hint,
891897
"xegpu::CachePolicyAttr": $l2_hint,
892898
"xegpu::CachePolicyAttr": $l3_hint)>
@@ -1016,6 +1022,12 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
10161022

10171023
let builders = [
10181024
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
1025+
"xegpu::CachePolicyAttr": $l1_hint,
1026+
"xegpu::CachePolicyAttr": $l2_hint,
1027+
"xegpu::CachePolicyAttr": $l3_hint)>,
1028+
OpBuilder<(ins "Value": $value, "Value": $dest,
1029+
"ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
1030+
"IntegerAttr": $chunk_size,
10191031
"xegpu::CachePolicyAttr": $l1_hint,
10201032
"xegpu::CachePolicyAttr": $l2_hint,
10211033
"xegpu::CachePolicyAttr": $l3_hint)>

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,22 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
819819
l1_hint, l2_hint, l3_hint);
820820
}
821821

822+
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
823+
Type valueType, Value source,
824+
ArrayRef<OpFoldResult> offsets, Value mask,
825+
IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
826+
xegpu::CachePolicyAttr l2_hint,
827+
xegpu::CachePolicyAttr l3_hint) {
828+
auto loc = source.getLoc();
829+
int64_t size = static_cast<int64_t>(offsets.size());
830+
auto type = VectorType::get(size, builder.getIndexType());
831+
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
832+
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
833+
834+
build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
835+
l2_hint, l3_hint);
836+
}
837+
822838
//===----------------------------------------------------------------------===//
823839
// XeGPU_StoreScatterOp
824840
//===----------------------------------------------------------------------===//
@@ -870,6 +886,24 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
870886
l2_hint, l3_hint);
871887
}
872888

889+
void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
890+
Value value, Value dest,
891+
ArrayRef<OpFoldResult> offsets, Value mask,
892+
IntegerAttr chunk_size,
893+
xegpu::CachePolicyAttr l1_hint,
894+
xegpu::CachePolicyAttr l2_hint,
895+
xegpu::CachePolicyAttr l3_hint) {
896+
auto loc = dest.getLoc();
897+
int64_t size = static_cast<int64_t>(offsets.size());
898+
auto type = VectorType::get(size, builder.getIndexType());
899+
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
900+
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
901+
902+
// Call the correct builder overload that does not expect result types.
903+
build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
904+
l3_hint);
905+
}
906+
873907
//===----------------------------------------------------------------------===//
874908
// XeGPU_UpdateOffsetOp
875909
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 122 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,110 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
765765
}
766766
};
767767

768+
// This pattern transforms the LoadGatherOp with explicit offsets to load
769+
// subgroup data
770+
struct WgToSgLoadGatherOpWithOffset
771+
: public OpConversionPattern<xegpu::LoadGatherOp> {
772+
using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
773+
LogicalResult
774+
matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
775+
ConversionPatternRewriter &rewriter) const override {
776+
777+
if (!op.getOffsets())
778+
return failure();
779+
780+
Location loc = op.getLoc();
781+
VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
782+
if (!resultType)
783+
return failure();
784+
ArrayRef<int64_t> wgShape = resultType.getShape();
785+
786+
xegpu::DistributeLayoutAttr layout =
787+
xegpu::getDistributeLayoutAttr(op.getResult());
788+
if (!layout || !layout.isForWorkgroup())
789+
return failure();
790+
791+
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
792+
793+
// The offsets need to be distributed
794+
auto offsetsVecType =
795+
dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
796+
auto maskVecType =
797+
dyn_cast<VectorType>(adaptor.getMask().front().getType());
798+
if (!offsetsVecType || !maskVecType ||
799+
offsetsVecType.getShape() != maskVecType.getShape()) {
800+
return rewriter.notifyMatchFailure(op,
801+
"offsets have not been distributed");
802+
}
803+
804+
SmallVector<Value> newLoadOps;
805+
auto chunkSizeAttr =
806+
rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
807+
VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
808+
for (auto [offsets, mask] :
809+
llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
810+
auto newLoadOp = rewriter.create<xegpu::LoadGatherOp>(
811+
loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
812+
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
813+
xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0),
814+
layout.dropSgLayoutAndData());
815+
newLoadOps.push_back(newLoadOp);
816+
}
817+
rewriter.replaceOpWithMultiple(op, {newLoadOps});
818+
return success();
819+
}
820+
};
821+
822+
// This pattern transforms the StoreScatterOp with explicit offsets to store
823+
// subgroup data
824+
struct WgToSgStoreScatterOpWithOffset
825+
: public OpConversionPattern<xegpu::StoreScatterOp> {
826+
using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
827+
LogicalResult
828+
matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
829+
ConversionPatternRewriter &rewriter) const override {
830+
831+
if (!op.getOffsets())
832+
return failure();
833+
834+
Location loc = op.getLoc();
835+
VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
836+
if (!valueType)
837+
return failure();
838+
839+
xegpu::DistributeLayoutAttr layout =
840+
xegpu::getDistributeLayoutAttr(op.getValue());
841+
if (!layout || !layout.isForWorkgroup())
842+
return failure();
843+
844+
// The offsets need to be distributed
845+
auto offsetsVecType =
846+
dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
847+
auto maskVecType =
848+
dyn_cast<VectorType>(adaptor.getMask().front().getType());
849+
if (!offsetsVecType || !maskVecType ||
850+
offsetsVecType.getShape() != maskVecType.getShape()) {
851+
return rewriter.notifyMatchFailure(op,
852+
"offsets have not been distributed");
853+
}
854+
855+
auto chunkSizeOpt = op.getChunkSize();
856+
int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
857+
auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
858+
for (auto [val, offs, mask] : llvm::zip(
859+
adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
860+
rewriter.create<xegpu::StoreScatterOp>(
861+
loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(),
862+
op.getL2HintAttr(), op.getL3HintAttr());
863+
// Update the layout attribute to drop sg_layout and sg_data.
864+
if (auto newLayout = layout.dropSgLayoutAndData())
865+
op->setAttr("layout", newLayout);
866+
}
867+
rewriter.eraseOp(op);
868+
return success();
869+
}
870+
};
871+
768872
struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
769873
using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
770874
LogicalResult
@@ -826,8 +930,9 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
826930
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
827931
WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
828932
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
829-
WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>(
830-
patterns.getContext());
933+
WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
934+
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
935+
WgToSgStoreMatrixOp>(patterns.getContext());
831936
}
832937
} // namespace xegpu
833938
} // namespace mlir
@@ -952,6 +1057,21 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
9521057
return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
9531058
});
9541059

1060+
target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1061+
[=](xegpu::LoadGatherOp op) -> bool {
1062+
auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
1063+
return isLegal(layout);
1064+
});
1065+
1066+
target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1067+
[=](xegpu::StoreScatterOp op) -> bool {
1068+
// Check if the layout attribute is present on the result.
1069+
auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout");
1070+
if (!layout)
1071+
return true;
1072+
return isLegal(layout);
1073+
});
1074+
9551075
target.addDynamicallyLegalOp<vector::BroadcastOp>(
9561076
[=](vector::BroadcastOp op) -> bool {
9571077
return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,50 @@ gpu.module @test_distribution {
264264
gpu.return
265265
}
266266

267+
// CHECK-LABEL: @load_gather
268+
// CHECK-SAME: %[[ARG0:.*]]: memref<?xf16>
269+
gpu.func @load_gather(%src : memref<?xf16>) {
270+
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<32x4xindex>
271+
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<32x4xi1>
272+
// CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}>
273+
// CHECK-SAME: : memref<?xf16>, vector<32x4xindex>, vector<32x4xi1> -> vector<32x4xf16>
274+
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<0> : vector<256x16xindex>
275+
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<1> : vector<256x16xi1>
276+
%load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>, l1_hint = #xegpu.cache_hint<cached>}
277+
: memref<?xf16>, vector<256x16xindex>, vector<256x16xi1> -> vector<256x16xf16>
278+
gpu.return
279+
}
280+
281+
// CHECK-LABEL: @store_scatter
282+
// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>
283+
gpu.func @store_scatter(%dest : memref<256xf16>) {
284+
// CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16>
285+
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
286+
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
287+
// CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}>
288+
// CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
289+
%val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<25.5> : vector<256xf16>
290+
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
291+
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
292+
xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [32], sg_data = [8]>, l1_hint = #xegpu.cache_hint<cached>}
293+
: vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
294+
gpu.return
295+
}
296+
297+
// CHECK-LABEL: @load_with_non_unit_chunk_size
298+
// CHECK-SAME: %[[ARG0:.*]]: memref<?xf16>
299+
gpu.func @load_with_non_unit_chunk_size(%src : memref<?xf16>) {
300+
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
301+
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
302+
// CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 4 : i64, l1_hint = #xegpu.cache_hint<cached>}>
303+
// CHECK-SAME: : memref<?xf16>, vector<8xindex>, vector<8xi1> -> vector<8x4xf16>
304+
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
305+
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
306+
%load = xegpu.load %src[%offset], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 4]>, l1_hint = #xegpu.cache_hint<cached>}
307+
: memref<?xf16>, vector<256xindex>, vector<256xi1> -> vector<256x4xf16>
308+
gpu.return
309+
}
310+
267311
// CHECK-LABEL: distribute_load_matrix
268312
// CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3>
269313
gpu.func @distribute_load_matrix(%arg0: memref<32768xi8, 3>) {

0 commit comments

Comments
 (0)