Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {

let builders = [
OpBuilder<(ins "Value": $TensorDesc,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>,
OpBuilder<(ins "Value": $TensorDesc,
"ArrayRef<OpFoldResult>": $offsets,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
Expand Down Expand Up @@ -348,6 +353,12 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [

let builders = [
OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>,
OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
"ArrayRef<OpFoldResult>": $offsets,
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
Expand Down Expand Up @@ -419,7 +430,12 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
"xegpu::CachePolicyAttr": $l3_hint)>,
OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
"ArrayRef<OpFoldResult>": $offsets,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];


Expand Down
46 changes: 46 additions & 0 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,21 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
l1_hint, l2_hint, l3_hint);
}

void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);

auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);

build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
l2_hint, l3_hint);
}

LogicalResult PrefetchNdOp::verify() {
auto tdescTy = getTensorDescType();
if (tdescTy.isScattered())
Expand Down Expand Up @@ -427,6 +442,22 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
l3_hint);
}

void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
UnitAttr packed, DenseI64ArrayAttr transpose,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);

auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);

build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
packed, transpose, l1_hint, l2_hint, l3_hint);
}

LogicalResult LoadNdOp::verify() {
auto tdescTy = getTensorDescType();
auto valueTy = getType();
Expand Down Expand Up @@ -533,6 +564,21 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
}

void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);

auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);

build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
l1_hint, l2_hint, l3_hint);
}

LogicalResult StoreNdOp::verify() {
auto dstTy = getTensorDescType(); // Tile
auto valTy = getValueType(); // Vector
Expand Down
218 changes: 209 additions & 9 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,16 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
layout.dropSgLayoutAndData());

SmallVector<Value> newCreateNdOps;
SmallVector<OpFoldResult> wgOffsets = op.getMixedOffsets();
SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();

for (auto tdescOffsets : *maybeTdescOffsets) {
SmallVector<OpFoldResult> sgOffsets;
size_t rank = tdescOffsets.size();
for (size_t i = 0; i < rank; i++) {
size_t idx = wgOffsets.size() - rank + i;
size_t idx = origOffsets.size() - rank + i;
Value add = rewriter.createOrFold<index::AddOp>(
loc, tdescOffsets[i],
getValueOrCreateConstantIndexOp(rewriter, loc, wgOffsets[idx]));
getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
sgOffsets.push_back(add);
}

Expand Down Expand Up @@ -296,6 +296,205 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
}
};

// Utility function to compute global offsets for subgroup operations.
// Returns a vector of new offsets for each subgroup, given the original op's
// offsets and subgroup relative offsets.
static SmallVector<SmallVector<OpFoldResult>>
computeOffsets(Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
ArrayRef<OpFoldResult> origOffsets,
ConversionPatternRewriter &rewriter) {
SmallVector<SmallVector<OpFoldResult>> finalOffsets;
Location loc = op->getLoc();
for (const auto &sgOffsets : sgOffsetsList) {
SmallVector<OpFoldResult> newOffsets;
size_t rank = sgOffsets.size();
for (size_t i = 0; i < rank; i++) {
size_t idx = origOffsets.size() - rank + i;
Value add = rewriter.createOrFold<index::AddOp>(
loc, sgOffsets[i],
getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
newOffsets.push_back(add);
}
finalOffsets.push_back(std::move(newOffsets));
}
return finalOffsets;
}

// Utility function to get sgShape, sgOffsetList for a given
// op.
template <typename OpTy, typename AdaptorTy>
LogicalResult getSgOffsets(OpTy op, AdaptorTy adaptor,
ConversionPatternRewriter &rewriter,
SmallVector<int64_t> &sgShape,
SmallVector<SmallVector<Value>> &sgOffsetList) {
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
if (offsetSize == 0 && (!op.getConstOffsetsAttr()))
return failure();

Location loc = op.getLoc();
Value tdesc = op.getTensorDesc();
auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
if (!tdescTy)
return failure();
auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
if (!layout)
return failure();

SmallVector<int64_t> sgLayout;
auto sgLayoutAttr = layout.getSgLayout();
if (!sgLayoutAttr)
return rewriter.notifyMatchFailure(
op, "sgLayout attribute is required in layout");
sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());

ArrayRef<int64_t> wgShape = tdescTy.getShape();
int count;
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);

// Get the subgroup ID
Value linearSgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);

int64_t startOfRange = -1, endOfRange = -1;
bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);

if (sgIdRangeSpecified) {
int64_t sgCount = endOfRange - startOfRange;
if (computeProduct(sgLayout) != sgCount)
return rewriter.notifyMatchFailure(
op, "sg_layout size must match the sg_id_range");
Value startOfRangeVal =
rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
linearSgId =
rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
}

auto sgOffsets = layout.getOffsets(rewriter, loc, linearSgId, wgShape);
if (failed(sgOffsets))
return failure();

sgOffsetList = *sgOffsets;
return success();
}

template <typename OpTy>
SmallVector<OpFoldResult> getOffsets(OpTy op,
ConversionPatternRewriter &rewriter) {
SmallVector<OpFoldResult> origOffsets;
if (auto constOffsets = op.getConstOffsetsAttr()) {
for (auto attr : constOffsets.asArrayRef())
origOffsets.push_back(rewriter.getIndexAttr(attr));
}
for (auto v : op.getOffsets())
origOffsets.push_back(v);
return origOffsets;
}

// This pattern transforms the LoadNdOp with explicit offsets to load
// subgroup data.
struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

SmallVector<int64_t> sgShape;
SmallVector<SmallVector<Value>> sgOffsetList;

// Do the distribution from workgroup to subgroup and get subgroup offsets
if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList)))
return failure();

// Get the original workgroup offsets
SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);

// Calculate the final offsets for each subgroup
auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);

SmallVector<Value> newLoadOps;
for (auto [offsets, tdesc] :
llvm::zip(finalOffsets, adaptor.getTensorDesc())) {
VectorType newResTy = VectorType::get(
sgShape,
dyn_cast<xegpu::TensorDescType>(tdesc.getType()).getElementType());
auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
op.getLoc(), newResTy, tdesc, offsets,
/*packed=*/nullptr,
/*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr());
newLoadOps.push_back(newLoadOp);
}
rewriter.replaceOpWithMultiple(op, {newLoadOps});
return success();
}
};

// This pattern transforms the StoreNdOp with explicit offsets to store
// subgroup data.
struct WgToSgStoreNdOpWithOffset
: public OpConversionPattern<xegpu::StoreNdOp> {
using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

SmallVector<int64_t> sgShape;
SmallVector<SmallVector<Value>> sgOffsetList;

// Do the distribution from workgroup to subgroup and get subgroup offsets
if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList)))
return failure();

// Get the original workgroup offsets
SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);

// Calculate the final offsets for each subgroup
auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);

for (auto [offsets, tdesc, value] :
llvm::zip(finalOffsets, adaptor.getTensorDesc(), adaptor.getValue())) {
rewriter.create<xegpu::StoreNdOp>(op.getLoc(), value, tdesc, offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr());
}
rewriter.eraseOp(op);
return success();
}
};

// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
// subgroup data.
struct WgToSgPrefetchNdOpWithOffset
: public OpConversionPattern<xegpu::PrefetchNdOp> {
using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

SmallVector<int64_t> sgShape;
SmallVector<SmallVector<Value>> sgOffsetList;

// Do the distribution from workgroup to subgroup and get subgroup offsets
if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList)))
return failure();

// Get the original workgroup offsets
SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);

// Calculate the final offsets for each subgroup
auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);

for (auto [offsets, tdesc] :
llvm::zip(finalOffsets, adaptor.getTensorDesc())) {
rewriter.create<xegpu::PrefetchNdOp>(
op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr());
}
rewriter.eraseOp(op);
return success();
}
};

/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
/// offsets of the new subgroup src tensor descriptors.
Expand Down Expand Up @@ -690,12 +889,13 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
namespace mlir {
namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
WgToSgElementwiseOp, WgToSgVectorBroadcastOp,
WgToSgConvertLayoutOp, WgToSgArithConstantOp>(
patterns.getContext());
patterns
.add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
WgToSgArithConstantOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
Expand Down
Loading