Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {

let builders = [
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>,
OpBuilder<(ins "Type": $value, "Value": $source,
"ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
"IntegerAttr": $chunk_size,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
Expand Down Expand Up @@ -936,6 +942,12 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {

let builders = [
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>,
OpBuilder<(ins "Value": $value, "Value": $dest,
"ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
"IntegerAttr": $chunk_size,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
Expand Down
34 changes: 34 additions & 0 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,22 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
l1_hint, l2_hint, l3_hint);
}

void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
Type valueType, Value source,
ArrayRef<OpFoldResult> offsets, Value mask,
IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
auto loc = source.getLoc();
int64_t size = static_cast<int64_t>(offsets.size());
auto type = VectorType::get(size, builder.getIndexType());
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
auto offset = vector::FromElementsOp::create(builder, loc, type, values);

build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
l2_hint, l3_hint);
}

//===----------------------------------------------------------------------===//
// XeGPU_StoreScatterOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -844,6 +860,24 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
l2_hint, l3_hint);
}

void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
Value value, Value dest,
ArrayRef<OpFoldResult> offsets, Value mask,
IntegerAttr chunk_size,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
auto loc = dest.getLoc();
int64_t size = static_cast<int64_t>(offsets.size());
auto type = VectorType::get(size, builder.getIndexType());
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
auto offset = vector::FromElementsOp::create(builder, loc, type, values);

// Call the correct builder overload that does not expect result types.
build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
l3_hint);
}

//===----------------------------------------------------------------------===//
// XeGPU_UpdateOffsetOp
//===----------------------------------------------------------------------===//
Expand Down
114 changes: 112 additions & 2 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,100 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
}
};

// This pattern transforms the LoadGatherOp with explicit offsets to load
// subgroup data
struct WgToSgLoadGatherOpWithOffset
: public OpConversionPattern<xegpu::LoadGatherOp> {
using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (!op.getOffsets())
return failure();

Location loc = op.getLoc();
VectorType resultType = op.getResult().getType();
ArrayRef<int64_t> wgShape = resultType.getShape();

xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
if (!layout || !layout.isForWorkgroup())
return failure();

SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;

// The offsets need to be distributed
if (dyn_cast<VectorType>(adaptor.getOffsets().front().getType())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it would read better if we do dyn_cast before the if. and same for the following one.

.getShape() !=
dyn_cast<VectorType>(adaptor.getMask().front().getType()).getShape()) {
return rewriter.notifyMatchFailure(op,
"offsets have not been distributed");
}

SmallVector<Value> newLoadOps;
auto chunkSizeAttr =
rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
for (auto [offsets, mask] :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Here the code assums the offset has been distributed by its defining op. It is not always true currently,
e.g., the offsets is from an function parameter, or non-splat arith constant. Thus, it would be better to check
whether the offsets has been distributed yet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, could you add a test for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the current design, its not possible to add negative tests

llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
auto newLoadOp = rewriter.create<xegpu::LoadGatherOp>(
loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
xegpu::setLayoutAttr(newLoadOp->getResult(0),
layout.dropSgLayoutAndData());
newLoadOps.push_back(newLoadOp);
}
rewriter.replaceOpWithMultiple(op, {newLoadOps});
return success();
}
};

// This pattern transforms the StoreScatterOp with explicit offsets to store
// subgroup data
struct WgToSgStoreScatterOpWithOffset
: public OpConversionPattern<xegpu::StoreScatterOp> {
using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (!op.getOffsets())
return failure();

Location loc = op.getLoc();
VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
if (!valueType)
return failure();

xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getValue());
if (!layout || !layout.isForWorkgroup())
return failure();

// The offsets need to be distributed
if (dyn_cast<VectorType>(adaptor.getOffsets().front().getType())
.getShape() !=
dyn_cast<VectorType>(adaptor.getMask().front().getType()).getShape()) {
return rewriter.notifyMatchFailure(op,
"offsets have not been distributed");
}

auto chunkSizeOpt = op.getChunkSize();
int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
for (auto [val, offs, mask] : llvm::zip(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same check for offsets as above.

adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
rewriter.create<xegpu::StoreScatterOp>(
loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
// Update the layout attribute to drop sg_layout and sg_data.
if (auto newLayout = layout.dropSgLayoutAndData())
op->setAttr("layout", newLayout);
}
rewriter.eraseOp(op);
return success();
}
};

struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
LogicalResult
Expand Down Expand Up @@ -824,8 +918,9 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>(
patterns.getContext());
WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
WgToSgStoreMatrixOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
Expand Down Expand Up @@ -950,6 +1045,21 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(xegpu::getLayoutAttr(op.getResult()));
});

target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
[=](xegpu::LoadGatherOp op) -> bool {
auto layout = xegpu::getLayoutAttr(op.getResult());
return isLegal(layout);
});

target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
[=](xegpu::StoreScatterOp op) -> bool {
// Check if the layout attribute is present on the result.
auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout");
if (!layout)
return true;
return isLegal(layout);
});

target.addDynamicallyLegalOp<vector::BroadcastOp>(
[=](vector::BroadcastOp op) -> bool {
return isLegal(xegpu::getLayoutAttr(op.getResult()));
Expand Down
44 changes: 44 additions & 0 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,50 @@ gpu.module @test_distribution {
gpu.return
}

// CHECK-LABEL: @load_gather
// CHECK-SAME: %[[ARG0:.*]]: memref<?xf16>
gpu.func @load_gather(%src : memref<?xf16>) {
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<32x4xindex>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<32x4xi1>
// CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}>
// CHECK-SAME: : memref<?xf16>, vector<32x4xindex>, vector<32x4xi1> -> vector<32x4xf16>
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<0> : vector<256x16xindex>
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<1> : vector<256x16xi1>
%load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>, l1_hint = #xegpu.cache_hint<cached>}
: memref<?xf16>, vector<256x16xindex>, vector<256x16xi1> -> vector<256x16xf16>
gpu.return
}

// CHECK-LABEL: @store_scatter
// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>
gpu.func @store_scatter(%dest : memref<256xf16>) {
// CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16>
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
// CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}>
// CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
%val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<25.5> : vector<256xf16>
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [32], sg_data = [8]>, l1_hint = #xegpu.cache_hint<cached>}
: vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
gpu.return
}

// CHECK-LABEL: @load_with_non_unit_chunk_size
// CHECK-SAME: %[[ARG0:.*]]: memref<?xf16>
gpu.func @load_with_non_unit_chunk_size(%src : memref<?xf16>) {
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
// CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 4 : i64, l1_hint = #xegpu.cache_hint<cached>}>
// CHECK-SAME: : memref<?xf16>, vector<8xindex>, vector<8xi1> -> vector<8x4xf16>
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
%load = xegpu.load %src[%offset], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 4]>, l1_hint = #xegpu.cache_hint<cached>}
: memref<?xf16>, vector<256xindex>, vector<256xi1> -> vector<256x4xf16>
gpu.return
}

// CHECK-LABEL: distribute_load_matrix
// CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3>
gpu.func @distribute_load_matrix(%arg0: memref<32768xi8, 3>) {
Expand Down
Loading