Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {

let builders = [
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>,
OpBuilder<(ins "Type": $value, "Value": $source,
"ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
"IntegerAttr": $chunk_size,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
Expand Down Expand Up @@ -936,6 +942,12 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {

let builders = [
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>,
OpBuilder<(ins "Value": $value, "Value": $dest,
"ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
"IntegerAttr": $chunk_size,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
Expand Down
34 changes: 34 additions & 0 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,22 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
l1_hint, l2_hint, l3_hint);
}

void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
Type valueType, Value source,
ArrayRef<OpFoldResult> offsets, Value mask,
IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
auto loc = source.getLoc();
int64_t size = static_cast<int64_t>(offsets.size());
auto type = VectorType::get(size, builder.getIndexType());
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
auto offset = vector::FromElementsOp::create(builder, loc, type, values);

build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
l2_hint, l3_hint);
}

//===----------------------------------------------------------------------===//
// XeGPU_StoreScatterOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -844,6 +860,24 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
l2_hint, l3_hint);
}

void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
Value value, Value dest,
ArrayRef<OpFoldResult> offsets, Value mask,
IntegerAttr chunk_size,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
auto loc = dest.getLoc();
int64_t size = static_cast<int64_t>(offsets.size());
auto type = VectorType::get(size, builder.getIndexType());
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
auto offset = vector::FromElementsOp::create(builder, loc, type, values);

// Call the correct builder overload that does not expect result types.
build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
l3_hint);
}

//===----------------------------------------------------------------------===//
// XeGPU_UpdateOffsetOp
//===----------------------------------------------------------------------===//
Expand Down
102 changes: 100 additions & 2 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,88 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
}
};

// This pattern transforms the LoadGatherOp with explicit offsets to load
// subgroup data, similar to WgToSgLoadNdOpWithOffset.
struct WgToSgLoadGatherOpWithOffset
: public OpConversionPattern<xegpu::LoadGatherOp> {
using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (!op.getOffsets())
return failure();

Location loc = op.getLoc();
VectorType resultType = op.getResult().getType();
ArrayRef<int64_t> wgShape = resultType.getShape();

xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
if (!layout || !layout.getSgLayout())
return failure();

SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;

SmallVector<Value> newLoadOps;
auto chunkSizeAttr =
rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
for (auto [offsets, mask] :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Here the code assums the offset has been distributed by its defining op. It is not always true currently,
e.g., the offsets is from an function parameter, or non-splat arith constant. Thus, it would be better to check
whether the offsets has been distributed yet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, could you add a test for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the current design, its not possible to add negative tests

llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
auto newLoadOp = rewriter.create<xegpu::LoadGatherOp>(
loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
xegpu::setLayoutAttr(newLoadOp->getResult(0),
layout.dropSgLayoutAndData());
newLoadOps.push_back(newLoadOp);
}
rewriter.replaceOpWithMultiple(op, {newLoadOps});
return success();
}
};

// This pattern transforms the StoreScatterOp with explicit offsets to store
// subgroup data, similar to WgToSgStoreNdOpWithOffset.
struct WgToSgStoreScatterOpWithOffset
: public OpConversionPattern<xegpu::StoreScatterOp> {
using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (!op.getOffsets())
return failure();

Location loc = op.getLoc();
VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
if (!valueType)
return failure();

ArrayRef<int64_t> wgShape = valueType.getShape();
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getValue());
if (!layout || !layout.getSgLayout())
return failure();

auto chunkSizeOpt = op.getChunkSize();
int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
for (auto [val, offs, mask] : llvm::zip(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same check for offsets as above.

adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
rewriter.create<xegpu::StoreScatterOp>(
loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
// Update the layout_result_0 attribute to drop sg_layout and sg_data.
if (auto layoutAttr =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to use the getLayoutAttr and setLayoutAttr interface here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

store doesn't have a result, and the layout is attached to the op, so not sure if getters/setters will work

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then, what is the purpose of setting "layout_result_0" here?

Copy link
Contributor Author

@nbpatel nbpatel Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is to retain lane layout & lane data if its present...maybe we need to name it differently?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since StoreScatterOp doesn't have result. There should be no layout_result_0 for this op from the input code. but it could have layout_operand_0 instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed it to "layout", similar to load/store matrix ops

op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0")) {
if (auto newLayout = layoutAttr.dropSgLayoutAndData())
op->setAttr("layout_result_0", newLayout);
}
}
rewriter.eraseOp(op);
return success();
}
};

struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
LogicalResult
Expand Down Expand Up @@ -824,8 +906,9 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>(
patterns.getContext());
WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
WgToSgStoreMatrixOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
Expand Down Expand Up @@ -950,6 +1033,21 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(xegpu::getLayoutAttr(op.getResult()));
});

target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
[=](xegpu::LoadGatherOp op) -> bool {
auto layout = xegpu::getLayoutAttr(op.getResult());
return isLegal(layout);
});

target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
[=](xegpu::StoreScatterOp op) -> bool {
// Check if the layout attribute is present on the result.
auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0");
if (!layout)
return true;
return isLegal(layout);
});

target.addDynamicallyLegalOp<vector::BroadcastOp>(
[=](vector::BroadcastOp op) -> bool {
return isLegal(xegpu::getLayoutAttr(op.getResult()));
Expand Down
41 changes: 41 additions & 0 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,47 @@ gpu.module @test_distribution {
gpu.return
}

// CHECK-LABEL: @load_gather
// CHECK-SAME: %[[ARG0:.*]]: memref<?xf16>
gpu.func @load_gather(%src : memref<?xf16>) {
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<32x4xindex>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<32x4xi1>
// CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<32x4xindex>, vector<32x4xi1> -> vector<32x4xf16>
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<0> : vector<256x16xindex>
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<1> : vector<256x16xi1>
%load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>, l1_hint = #xegpu.cache_hint<cached>}
: memref<?xf16>, vector<256x16xindex>, vector<256x16xi1> -> vector<256x16xf16>
gpu.return
}

// CHECK-LABEL: @store_scatter
// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>
gpu.func @store_scatter(%dest : memref<256xf16>) {
// CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16>
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
// CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
%val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<25.5> : vector<256xf16>
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>, l1_hint = #xegpu.cache_hint<cached>}
: vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
gpu.return
}

// CHECK-LABEL: @load_with_chunk_size
// CHECK-SAME: %[[ARG0:.*]]: memref<?xf16>
gpu.func @load_with_chunk_size(%src : memref<?xf16>) {
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
// CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 4 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<8xindex>, vector<8xi1> -> vector<8x4xf16>
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
%load = xegpu.load %src[%offset], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 4]>, l1_hint = #xegpu.cache_hint<cached>}
: memref<?xf16>, vector<256xindex>, vector<256xi1> -> vector<256x4xf16>
gpu.return
}

// CHECK-LABEL: distribute_load_matrix
// CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3>
gpu.func @distribute_load_matrix(%arg0: memref<32768xi8, 3>) {
Expand Down