Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {

let builders = [
OpBuilder<(ins "Value": $TensorDesc,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>,
OpBuilder<(ins "Value": $TensorDesc,
"ArrayRef<OpFoldResult>": $offsets,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
Expand Down Expand Up @@ -348,6 +353,12 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [

let builders = [
OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>,
OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
"ArrayRef<OpFoldResult>": $offsets,
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
Expand Down Expand Up @@ -419,7 +430,12 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
"xegpu::CachePolicyAttr": $l3_hint)>,
OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
"ArrayRef<OpFoldResult>": $offsets,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];


Expand Down
46 changes: 46 additions & 0 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,21 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
l1_hint, l2_hint, l3_hint);
}

void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);

auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);

build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
l2_hint, l3_hint);
}

LogicalResult PrefetchNdOp::verify() {
auto tdescTy = getTensorDescType();
if (tdescTy.isScattered())
Expand Down Expand Up @@ -406,6 +421,22 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
l3_hint);
}

void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
UnitAttr packed, DenseI64ArrayAttr transpose,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);

auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);

build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
packed, transpose, l1_hint, l2_hint, l3_hint);
}

LogicalResult LoadNdOp::verify() {
auto tdescTy = getTensorDescType();
auto valueTy = getType();
Expand Down Expand Up @@ -512,6 +543,21 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
}

void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);

auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);

build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
l1_hint, l2_hint, l3_hint);
}

LogicalResult StoreNdOp::verify() {
auto dstTy = getTensorDescType(); // Tile
auto valTy = getValueType(); // Vector
Expand Down
191 changes: 189 additions & 2 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,192 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
}
};

template <typename OpTy, typename AdaptorTy, typename CreateFn>
LogicalResult distributeNdOpWithOffset(OpTy op, AdaptorTy adaptor,
ConversionPatternRewriter &rewriter,
CreateFn &&createOp) {
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
if (offsetSize == 0 && (!op.getConstOffsetsAttr()))
return failure();

Location loc = op.getLoc();
Value tdesc = op.getTensorDesc();
auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
if (!tdescTy)
return failure();
auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
if (!layout)
return failure();

SmallVector<int64_t> sgLayout;
if (auto sgLayoutAttr = layout.getSgLayout())
sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
else
return rewriter.notifyMatchFailure(
op, "sgLayout attribute is required in layout");

ArrayRef<int64_t> wgShape = tdescTy.getShape();
SmallVector<int64_t> sgShape;
int count;
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);

// Get the subgroup ID
Value linearSgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);

int64_t startOfRange = -1, endOfRange = -1;
bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);

if (sgIdRangeSpecified) {
int64_t sgCount = endOfRange - startOfRange;
if (computeProduct(sgLayout) != sgCount)
return rewriter.notifyMatchFailure(
op, "sg_layout size must match the sg_id_range");
Value startOfRangeVal =
rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
linearSgId =
rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
}

auto maybeTdescOffsets =
layout.getOffsets(rewriter, loc, linearSgId, wgShape);
if (failed(maybeTdescOffsets))
return failure();

SmallVector<OpFoldResult> oldOffsets;
if (auto constOffsets = op.getConstOffsetsAttr()) {
for (auto attr : constOffsets.asArrayRef())
oldOffsets.push_back(rewriter.getIndexAttr(attr));
}
for (auto v : op.getOffsets())
oldOffsets.push_back(v);

return createOp(loc, sgShape, *maybeTdescOffsets, oldOffsets, adaptor,
rewriter, op);
}

// This pattern transforms the LoadNdOp with explicit offsets to load
// subgroup data.
struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xegpu::LoadNdOp op,
typename OpConversionPattern<xegpu::LoadNdOp>::OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return distributeNdOpWithOffset(
op, adaptor, rewriter,
[](Location loc, SmallVector<int64_t> &sgShape,
ArrayRef<SmallVector<Value>> tdescOffsetsList,
SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
ConversionPatternRewriter &rewriter,
xegpu::LoadNdOp &op) -> LogicalResult {
SmallVector<Value> newLoadOps;
for (auto [tdescOffsets, tdesc] :
llvm::zip(tdescOffsetsList, adaptor.getTensorDesc())) {
SmallVector<OpFoldResult> newOffsets;
size_t rank = tdescOffsets.size();
for (size_t i = 0; i < rank; i++) {
size_t idx = oldOffsets.size() - rank + i;
Value add = rewriter.createOrFold<index::AddOp>(
loc, tdescOffsets[i],
getValueOrCreateConstantIndexOp(rewriter, loc,
oldOffsets[idx]));
newOffsets.push_back(add);
}
VectorType newResTy = VectorType::get(
sgShape, dyn_cast<xegpu::TensorDescType>(tdesc.getType())
.getElementType());
auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
loc, newResTy, tdesc, newOffsets,
/*packed=*/nullptr,
/*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr());
newLoadOps.push_back(newLoadOp);
}
rewriter.replaceOpWithMultiple(op, {newLoadOps});
return success();
});
}
};

// This pattern transforms the StoreNdOp with explicit offsets to store
// subgroup data.
struct WgToSgStoreNdOpWithOffset
: public OpConversionPattern<xegpu::StoreNdOp> {
using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xegpu::StoreNdOp op,
typename OpConversionPattern<xegpu::StoreNdOp>::OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return distributeNdOpWithOffset(
op, adaptor, rewriter,
[](Location loc, SmallVector<int64_t> &sgShape,
ArrayRef<SmallVector<Value>> tdescOffsetsList,
SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
ConversionPatternRewriter &rewriter,
xegpu::StoreNdOp &op) -> LogicalResult {
for (auto [tdescOffsets, tdesc, value] :
llvm::zip(tdescOffsetsList, adaptor.getTensorDesc(),
adaptor.getValue())) {
SmallVector<OpFoldResult> newOffsets;
size_t rank = tdescOffsets.size();
for (size_t i = 0; i < rank; i++) {
size_t idx = oldOffsets.size() - rank + i;
Value add = rewriter.createOrFold<index::AddOp>(
loc, tdescOffsets[i],
getValueOrCreateConstantIndexOp(rewriter, loc,
oldOffsets[idx]));
newOffsets.push_back(add);
}
rewriter.create<xegpu::StoreNdOp>(
loc, value, tdesc, newOffsets, op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
}
rewriter.eraseOp(op);
return success();
});
}
};

// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
// subgroup data.
struct WgToSgPrefetchNdOpWithOffset
: public OpConversionPattern<xegpu::PrefetchNdOp> {
using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xegpu::PrefetchNdOp op,
typename OpConversionPattern<xegpu::PrefetchNdOp>::OneToNOpAdaptor
adaptor,
ConversionPatternRewriter &rewriter) const override {
return distributeNdOpWithOffset(
op, adaptor, rewriter,
[](Location loc, SmallVector<int64_t> &sgShape,
ArrayRef<SmallVector<Value>> tdescOffsetsList,
SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
ConversionPatternRewriter &rewriter,
xegpu::PrefetchNdOp &op) -> LogicalResult {
for (auto [tdescOffsets, tdesc] :
llvm::zip(tdescOffsetsList, adaptor.getTensorDesc())) {
SmallVector<OpFoldResult> newOffsets;
size_t rank = tdescOffsets.size();
for (size_t i = 0; i < rank; i++) {
size_t idx = oldOffsets.size() - rank + i;
Value add = rewriter.createOrFold<index::AddOp>(
loc, tdescOffsets[i],
getValueOrCreateConstantIndexOp(rewriter, loc,
oldOffsets[idx]));
newOffsets.push_back(add);
}
rewriter.create<xegpu::PrefetchNdOp>(
loc, tdesc, newOffsets, op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr());
}
rewriter.eraseOp(op);
return success();
});
}
};

/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
/// offsets of the new subgroup src tensor descriptors.
Expand Down Expand Up @@ -654,8 +840,9 @@ namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns
.add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
patterns.getContext());
}
Expand Down
73 changes: 72 additions & 1 deletion mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,76 @@ gpu.module @test_distribution {
%tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
}

// CHECK-LABEL: load_nd_tdesc_with_offset
gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
// CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}]
// CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-SAME-COUNT-4: -> vector<16x16xf32>
// CHECK-NOT: xegpu.load_nd
%tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc[0, 0]
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-> vector<256x128xf32>
gpu.return
}

// CHECK-LABEL: store_nd_with_offset
gpu.func @store_nd_with_offset(%src: memref<256x128xf32>) {
// CHECK-COUNT-4: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}]
// CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.store_nd
%tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc[0, 0]
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-> vector<256x128xf32>
xegpu.store_nd %load, %tdesc[0, 0]
: vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}

// CHECK-LABEL: prefetch_nd_tdesc_with_offset
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
// CHECK-COUNT-4: xegpu.prefetch_nd {{%.*}}[{{%.*}}, {{%.*}}]
// CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.prefetch_nd
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
xegpu.prefetch_nd %tdesc[0, 0]
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}

// CHECK-LABEL: dpas
// CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>)
gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) {
// CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf16>
// CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.create_nd_tdesc
// CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf16>
// CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.create_nd_tdesc
// CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
// CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
// CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
// CHECK-NOT: xegpu.dpas
%tdesc_a = xegpu.create_nd_tdesc %a : memref<256x128xf16>
-> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a[0, 0]
: !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-> vector<256x128xf16>
%tdesc_b = xegpu.create_nd_tdesc %b : memref<128x256xf16>
-> !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
%load_b = xegpu.load_nd %tdesc_b[0, 0]
: !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
-> vector<128x256xf16>
%dpas = xegpu.dpas %load_a, %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
: vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32>
gpu.return
}
}
Loading